← Back to Writeups

Building Tensorbit Core: A C++/CUDA Transformer Pruning Library

Large language models keep getting bigger. Every new release pushes the parameter count higher, and with it the compute and memory requirements for inference. The standard response is to buy more GPUs. The better response is to prune.

Tensorbit Core is a pruning library I built as part of Tensorbit Labs that applies structured N:M sparsity to transformer architectures using Hessian-aware pruning and a custom binary container format. This writeup covers the full pipeline — from reading HuggingFace .safetensors weights to deploying a pruned model on an A100 GPU — and the engineering lessons learned along the way.

Why Prune?

A 7-billion-parameter model like Mistral 7B consumes 28 GB of RAM in FP32 (14 GB in BF16). The compute cost is dominated by the linear layers: each of the 32 layers has 4 attention projections (Q, K, V, O) and 3 MLP projections (Gate, Up, Down), each operating on 4096- or 14336-dimensional vectors. That's 224 matrix-vector multiplies per token, each doing tens of millions of multiply-accumulate operations.

N:M structured sparsity forces exactly N non‑zero elements out of every M-element group. The 2:4 pattern is especially attractive because Ampere GPUs have mma.sp tensor core instructions that accelerate 2:4‑sparse matrix multiplication natively. In theory, this gives a 2× compute speedup and a 2× memory reduction with less than 1% accuracy loss on most benchmarks.

Existing open-source tools (SparseGPT, Wanda) are research‑focused — they work in PyTorch on small models and don't produce deployable artifacts. Tensorbit Core was built from the ground up as a production pruning pipeline that takes real model weights and outputs a self‑contained container file ready for inference.

The P-D-Q Pipeline: Architecture Overview

The Tensorbit ecosystem follows a four-stage pipeline:

.safetensors ──→ core (prune) ──→ .tbm ──→ distill ──→ .tbm ──→ quant ──→ .tbm ──→ run
  1. tensorbit-core — Loads .safetensors model weights, runs Hessian‑aware pruning (EHAP) + structured N:M masking (CORING), and serializes to .tbm.
  2. tensorbit-distill — (future) Reads .tbm, writes .tbm with student model weights.
  3. tensorbit-quant — Reads .tbm, quantizes FP32→INT4/INT8, writes .tbm with scale metadata. ~3.6 GB output vs 29 GB input.
  4. tensorbit-run — Reads .tbm, runs autoregressive inference via CPU or CUDA backends with sparse N:M GEMM.

Each stage communicates via a single file: the .tbm container. This is the core design insight — no database, no intermediate formats, no configuration catalogs. One file, one stage.

The .tbm Container Format

I designed the .tbm format to be simple enough to parse with 100 lines of C while supporting zero‑copy weight loading on GPU. The on‑disk layout is:

┌─────────────────────────────┐ │ Tensor 0 (.tb blob) │ │ ┌───────────┬───────┬────┐ │ │ │ TBHeader │ FP32 │Mask│ │ │ │ (4096 B) │ wt's │ │ │ │ └───────────┴───────┴────┘ │ ├─────────────────────────────┤ │ Tensor 1 │ │ ... │ ├─────────────────────────────┤ │ Tensor K │ ├─────────────────────────────┤ │ JSON Index (UTF-8) │ │ {"name":"...","offset":0, │ │ "shape":[4096,4096], │ │ "nm_n":2,"nm_m":4, │ │ "dtype":"fp32",...} │ ├─────────────────────────────┤ │ 4-byte LE uint32 │ ← JSON byte length └─────────────────────────────┘

The TBHeader is a packed 4096-byte struct with fields for magic (0x31304254 = "TB01"), version, N:M parameters, weight/mask byte counts, offsets, and a precision byte (0=FP32, 4=INT8, 5=INT4). The JSON index is read from the tail (seek backwards 4 bytes → read JSON length → seek back that many bytes → parse JSON). This means inference engines can mmap the entire file and access any tensor's data by offset with zero parsing overhead — the JSON index is only read once at load time.

EHAP: Efficient Hessian-Aware Pruning

The core pruning algorithm is Efficient Hessian-Aware Pruning. Instead of just looking at weight magnitudes (which throws away 90% of the available information), EHAP uses second‑order information to decide which weights to prune.

The Fisher information matrix measures how sensitive the loss is to each weight. For a diagonal approximation, the Fisher information for weight wi is:

Fii = E[(∂L/∂wi)²]

This diagonal captures the curvature of the loss landscape. A weight with high Fisher information sits in a steep valley — zeroing it causes a large loss increase. A weight with low Fisher information sits on flat terrain — removing it has minimal impact.

EHAP accumulates the Fisher diagonal during a calibration pass and supports three importance modes:

Mode Formula Behavior
OBD w²·(F+λ) Optimal Brain Damage — product with Fisher
OBS w²/(F+λ) Optimal Brain Surgeon — ratio, compensates other weights
Normalized w²(F+λ)/(1+w²) Bounded version of OBD, stable for outliers

The Fisher diagonal uses EMA decay (exponential moving average with configurable alpha) to smooth the estimate across calibration steps. A fisher_beta_decay_kernel on GPU applies the decay in‑place without a host round‑trip.

BlockOBS: Greedy Pruning with Woodbury Updates

For the highest accuracy, EHAP includes BlockOBS — a block‑wise OBS implementation that constructs a 128×128 Hessian submatrix for each block of weights and performs greedy removal with Sherman‑Morrison rank‑1 updates. The Woodbury identity keeps the per‑block Hessian inversion tractable:

H⁻¹ = D⁻¹ − D⁻¹ U (I + Uᵀ D⁻¹ U)⁻¹ Uᵀ D⁻¹

Where D is the diagonal of the Fisher matrix and U is a low‑rank factor built from gradient history. For each output row, BlockOBS uses a greedy loop: find the weight with the smallest saliency w² / H⁻¹ii, compute the optimal update to the remaining weights, deflate the inverse via Sherman‑Morrison, and repeat until the sparsity target is met.

Trade‑off: BlockOBS is slow. Each 128×128 block does a full greedy loop with matrix rank‑1 updates, all on CPU via Eigen (the prune() method requires host‑resident weights). For Mistral 7B, BlockOBS takes 8–12 hours on an A100. The Iterative strategy (multi‑round OBD with cubic sparsity schedule and GPU‑accelerated Fisher recompute) runs in 30–45 minutes.

CORING: N:M Structured Sparsity

Once EHAP computes importance scores, CORING generates N:M masks and redistributes weight magnitude. CORING stands for Coarse‑to‑fine N:M mask generation — it operates on groups of M weights, ensuring exactly N survive in each group.

Importance scores 2:4 Mask [8.1, 0.3, 4.5, 1.2] → [1, 0, 1, 0] [2.0, 6.3, 0.7, 5.1] → [0, 1, 0, 1]

Three mask strategies are implemented:

After mask generation, CORING redistributes the pruned weight magnitude to the survivors. Two modes: proportional (adds the pruned magnitude scaled by the kept weight's share) and uniform (evenly distributes). This compensation step reduces the post‑pruning accuracy drop by 1‑3 percentage points.

CUDA Kernels

Both EHAP and CORING have GPU‑accelerated paths via CUDA kernels. Each kernel uses 256 threads per block with grid‑stride loops for arbitrary tensor sizes, compiled for SM80/SM90 with a if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) guard so users can override for consumer GPUs (e.g., -DCMAKE_CUDA_ARCHITECTURES="86" for RTX 3090):

fisher_accumulate — accumulates square of gradient into Fisher diagonal fisher_beta_decay — EMA decay of Fisher values (eliminates host round‑trip) ehap_importance — computes OBD/OBS/Normalized scores nm_mask_2_4 — 2:4 mask, register‑only, 0 shared memory nm_mask_generic — general N:M mask with cooperative shared‑memory ranking apply_mask — applies binary mask to weight tensor fisher_beta_decay_kernel — GPU in‑place beta decay

The 2:4 mask kernel deserves special mention. Ampere's mma.sp tensor core instruction requires mask bytes in a specific layout: one byte per group of 4 output columns, with bits corresponding to kept positions within that group. The mask kernel produces exactly this format — groups are contiguous along the K‑dimension of the GEMM, matching the layout cuSPARSELt expects.

Multi‑Shard Model Support

Models like Mistral 7B are distributed as multiple .safetensors shards on HuggingFace (e.g., model‑00001‑of‑00002.safetensors and model‑00002‑of‑00002.safetensors). Tensorbit Core handles this by pruning each shard independently, producing per‑shard .tb directories, then running merge_tbm.py to concatenate the per‑tensor blobs into a single .tbm with a unified JSON index.

The merge_tbm.py script reads the per‑shard model.tbm JSON indexes to extract exact tensor shapes (since the per‑shard files have correct metadata from tb‑prune). If the per‑shard metadata isn't available, it falls back to inferring shapes from naming conventions: q_proj is detected as a square [hidden, hidden] matrix by checking that sqrt(num_weights) is an integer ≥ 512, and all other projection shapes are derived from that hidden_size.

Inference Engine: tensorbit-run

The inference side is a separate C engine (tensorbit-run) with a C++20 wrapper that provides RAII‑managed tensors, an autoregressive Transformer runner, and a backend registry for CPU/CUDA dispatch. The engine is organized as:

INT4 Weight Dequantization

When the .tbm contains INT4 weights (produced by tensorbit-quant), the runner's weight_fp32() helper dequantizes each tensor on‑the‑fly before passing it to the linear kernels. The dequantization unpacks the nibble‑packed INT4 values, multiplies by per‑group FP32 scales, and produces a temporary FP32 buffer that is consumed by the existing FP32 kernels. This temporary buffer is freed by RAII at the end of each operation, keeping peak memory at ~4.1 GB for a 3.6 GB .tbm (one tensor's worth of FP32 dequant buffer + the mmap'd file).

The Lambda Cloud Test

The real test was pruning Mistral 7B v0.1 on a Lambda A100 40GB GPU. The full workflow:

  1. Launch a Lambda GPU instance (A100 40GB, $1.99/hr)
  2. Setup: sudo ./scripts/setup_cloud.sh installs GCC 13, CUDA 12.6, Eigen3, Python, and all build tools
  3. Download: python scripts/download_model.py --repo mistralai/Mistral-7B-v0.1
  4. Build: cmake .. -DCMAKE_BUILD_TYPE=Release -GNinja && ninja
  5. Prune Shard 1: ~40 minutes with Iterative strategy (BlockOBS would take 8+ hours)
  6. Prune Shard 2: ~40 minutes
  7. Merge: python merge_tbm.py --input ./pruned/1/ ./pruned/2/ --output model.tbm
  8. Download: scp -r ubuntu@<ip>:.../pruned/ . (30 GB, ~30 minutes)

The resulting model.tbm is 30.78 GB containing 291 tensor entries, each with its FP32 weights, N:M masks, and metadata. A JSON index at the tail maps tensor names to byte offsets with full shape information.

Engineering: Hard‑Earned Bugs

Over the course of building this pipeline, I ran into several bugs that cost days of debugging. Here are the ones that left a mark:

"Specialization after instantiation"

GCC 13 eagerly instantiates templates with inline method bodies. When I used extern template class EHAPPruner<float>, the compiler instantiated the template at the declaration site and then rejected the explicit instantiation in ehap.cpp. The fix was to remove all extern template declarations and let the linker handle duplicates. This is why every header in tensorbit-core has its implementation fully inline — .cpp files contain only empty explicit instantiations.

std::make_format_args requires lvalue references

The project uses a custom thread‑safe Logger that formats messages with std::vformat. The Logger's log() method takes std::string_view (not std::format_string, which is C++23). Formatting happens at the macro call site via std::make_format_args, which only accepts lvalue references. Every log argument must be a named variable:

// WRONG — literals, ternaries, and arithmetic all fail:
TENSORBIT_LOG_INFO("x={}", 42);
TENSORBIT_LOG_INFO("pct={:.1f}", 100.0 * a / b);

// RIGHT — named variable:
auto x = 42;
TENSORBIT_LOG_INFO("x={}", x);

BF16 conversion is just a bit shift

HuggingFace stores Mistral 7B weights in BF16 format — the same 16‑bit float format used internally by A100 tensor cores. BF16→FP32 conversion is a single uint16_t << 16 operation. The exponent and mantissa bits are already in the right positions; only the width changes. F16 (IEEE 754 half‑precision) requires proper exponent/mantissa extraction. Getting both correct was essential for the safetensors loader.

WSL2 vs. DrvFs: mmap and large files

On WSL2 (Windows Subsystem for Linux), the DrvFs filesystem (which mounts /mnt/d/) uses the 9p protocol to bridge the Linux and Windows file systems. The 9p driver cannot mmap files larger than ~1 GB — the kernel tries to pre‑allocate a buffer equal to the mapping size, causing ENOMEM for a 30 GB model file. The fix was a fallback in tb_file_open_read(): when mmap fails, malloc the full file size, then read() in a loop. This loads the entire 30 GB into virtual memory, consuming 30 GB of address space (but only the actively accessed pages fault in physical RAM). On WSL1, mmap works fine through the Windows syscall translation layer, but the per‑process virtual memory cap of 7.7 GB prevents loading 30 GB models entirely.

Quantization: From 30 GB to 4 GB

After pruning, the model is still 30 GB in FP32. To run inference on consumer hardware, I built tensorbit-quant — an INT4/INT8 quantization tool that reads an FP32 .tbm and writes a quantized .tbm with per‑group scale factors. The quantization scheme is:

The quantized .tbm stores scales between the packed INT4 weights and the N:M masks, with the JSON index recording scale_count, group_size, and zp_count per tensor. This allows the inference engine to locate scales without any out‑of‑band data — everything needed to load and dequantize is in the file itself.

Current Status & Future Work

As of May 2026, the full pipeline is operational. The core pruning (EHAP + CORING + BlockOBS) is complete with 30/30 tests passing. The inference engine has 88/88 tests passing across 6 test suites. The quantizer produces valid INT4/INT8 output with verified dequantization error bounds.

The main remaining gaps are:

Key Lessons

The full source code is available on GitHub: github.com/Tensorbit-Labs.

↑ Back to top