Zero TVM Documentation
A complete LLM inference engine in the browser — no WebLLM, no TVM, no ONNX, no WASM runtime. 10 hand-written WGSL kernel roles (27 files, 3,078 LOC), a BPE tokenizer, and raw WebGPU. ~40 tok/s on M2 Pro, 22% behind WebLLM's TVM-autotuned 85 kernels on identical weights.
This project implements the full Phi-3-mini-4k-instruct transformer forward pass using only WebGPU compute shaders written by hand in WGSL. Weights are loaded directly from HuggingFace in MLC Q4F16_1 format (or from the browser Cache API if WebLLM was used previously). The tokenizer is implemented in pure TypeScript — no SentencePiece WASM.
WebLLM and transformers.js both work, but they're black boxes — TVM compiles kernels you never see. This project exists to show that you can understand every single step of a modern transformer inference pipeline, written at the GPU level, in a browser tab.
Quick start
No install. Just open the chat page. Weights load from your browser cache (if you've used WebLLM before) or download fresh from HuggingFace (~2.1 GB).
Open the chat
Navigate to zero-tvm.html. WebGPU initializes automatically. Requires Chrome 113+, Edge 113+, or Safari 18+.
Wait for weights
First load downloads ~2.1 GB. Subsequent loads are instant from browser Cache API. Progress shown per shard in the log panel.
Chat
Phi-3-mini runs at ~40 tok/s on M2 Pro (vs ~51 tok/s WebLLM on identical weights). Your conversation never leaves the browser. No API calls during inference.
(Optional) Cache weights locally
Run node scripts/download-weights.mjs once to save all 83 shards to public/weights/. Subsequent loads are served from localhost at full disk speed.
Requires a GPU with shader-f16 WebGPU feature (f16 arithmetic). Most M-series Macs and recent NVIDIA/AMD GPUs support this. Intel integrated graphics may not.
How it works
The engine has three phases per generated token:
- Prefill — process each prompt token sequentially, building up the KV cache
- First decode — the last prefill step produces the first generated token
- Decode loop — each step takes the previous token as input, runs the full forward pass with the KV cache providing attention context, produces the next token
Each forward pass runs 10 kernel roles through 32 transformer layers — 228 dispatches per token in the default f16-KV path (260 with the ?int8kv=1 flag) — then reads one i32 token ID back from the GPU. For comparison, WebLLM's TVM-generated decode path fires 342 dispatches per token using 85 distinct shaders.
Architecture overview
Four source files make up the engine:
| File | Role |
|---|---|
src/zero-tvm/chat.ts | Main decode engine + UI. Allocates buffers, builds bind groups, runs the decode loop. |
src/zero-tvm/weight-loader.ts | Fetches ndarray-cache.json, downloads shards, uploads to GPU buffers. |
src/zero-tvm/tokenizer.ts | BPE tokenizer: encode text → token IDs, decode IDs → text. |
src/compiler/compiler.ts | Compiles all 27 WGSL files (10 kernel roles + tiled/subgroup variants) into GPUComputePipeline objects. |
Weight loader
Weights are stored in MLC's ndarray-cache.json format — an index file listing every parameter, which shard binary it lives in, its byte offset, and byte size. The loader reads this index then fetches each referenced shard.
Fetch priority
- Browser Cache API — if a prior WebLLM session cached the shards, they load instantly
- HuggingFace — direct HTTPS fetch from
huggingface.co/mlc-ai/Phi-3-mini-4k-instruct-q4f16_1-MLC
Do not mix a locally-downloaded ndarray-cache.json with shards from the browser cache.
The byte offsets in the index must match the shards exactly. If you download the index fresh but use
cached shards from an older model version, all weight slices will be wrong → zero logits → <unk> output.
Parameter naming (MLC format)
MLC uses non-standard parameter names. The actual names in the cache:
| MLC name | Role |
|---|---|
transformer.embd.q_weight | Embedding weights (uint32 packed int4) |
transformer.embd.q_scale | Embedding scales (f16) |
transformer.norm.weight | Final RMSNorm gamma (after all layers) |
lm_head.q_weight | LM head weights |
transformer.h.N.ln.weight | Layer N input_layernorm (normGamma1) |
transformer.h.N.post_attention_layernorm.weight | Layer N post-attention norm (normGamma2) |
transformer.h.N.mixer.qkv_proj.q_weight | Layer N QKV projection weights |
transformer.h.N.mixer.out_proj.q_weight | Layer N output projection weights |
transformer.h.N.mlp.gate_up_proj.q_weight | Layer N FFN gate+up weights |
transformer.h.N.mlp.down_proj.q_weight | Layer N FFN down weights |
Tokenizer
A hand-written BPE tokenizer in TypeScript. No SentencePiece WASM, no HuggingFace tokenizers bundle.
Reads tokenizer.json directly.
Key steps
- Pre-tokenization — Metaspace: spaces become
▁, words are split on whitespace - BPE encoding — merge pairs by rank from the merge table in tokenizer.json
- Special tokens —
<|system|>,<|user|>,<|assistant|>,<|end|> - Chat template — Phi-3 format applied by
buildChatPrompt()
Phi-3 chat template
<|system|>
You are a helpful assistant.<|end|>
<|user|>
What is the capital of Australia?<|end|>
<|assistant|>
Stop tokens: 2 (EOS), 32000 (<|end|>), 32007 (<|endoftext|>).
KV cache
Uses a paged KV cache — the same approach as vLLM and PagedAttention. Memory is divided into fixed-size pages (16 slots each). A page table maps logical positions to physical pages.
| Parameter | Value | Notes |
|---|---|---|
PAGE_SIZE | 16 | slots per page |
MAX_PAGES | 257 | ≈ 4096 context tokens |
| Bytes per page | 196,608 | 32 heads × 16 slots × 96 dims × 2 (K+V) × 2 bytes |
| Total KV buffer | ~50 MB per layer | 32 layers = ~1.6 GB |
Each layer has its own GPUBuffer for KV pages. The page table is a simple identity mapping (page i → physical page i) for single-sequence inference.
Decode loop
Each call to decodeToken(tokenId, position) submits one command encoder with the full forward pass:
// Per-token GPU state written via writeBuffer
B.inputIds ← [tokenId] // i32
B.posMap ← [position] // i32
B.pageIndptr ← [0, nnzPages] // page range
B.lengthInfo ← [position+1, 0, 0] // seq length
// Forward pass (one command encoder)
embedding(B.residual) // token → hidden state
rmsNorm(B.hidden1, B.residual) // initial norm
for L in 0..32:
// QKV matmul + RoPE + KV-append in ONE dispatch (M4 fusion)
qkvFused(B.qOut, kvPages[L], B.hidden1)
attention(B.attnOut, B.qOut, kvPages[L])
int4Matmul(B.hidden2, B.attnOut) // O projection
addNorm(B.hidden2, resIn → B.hidden1, resOut) // residual + RMSNorm, ping-pong
// Gate + Up + SiLU + mul + Down in ONE dispatch
fusedFfn(B.hidden2, B.hidden1)
addNorm(B.hidden2, resIn → B.hidden1, resOut) // residual + RMSNorm, ping-pong
int4Matmul(B.logits, B.hidden1) // LM head
argmax(B.tokenOut, B.logits) // → next token ID
Ping-pong residual buffers
WebGPU's validation rules forbid binding the same buffer as both read and read_write
in the same dispatch. The add_norm shader needs to read the old residual and write the new one.
Solution: two residual buffers that alternate each dispatch.
let resIn = B.residual // ping (starts with embedding)
let resOut = B.residual2 // pong (uninitialized)
// Each add_norm:
dispatch(addNorm, [delta, resIn, gamma, hidden1, resOut])
[resIn, resOut] = [resOut, resIn] // swap — O(1), no GPU copy
The swap is just two JavaScript variable reassignments — no GPU buffer copy. Both buffers always exist on the GPU; we just change which one we tell the bind group to read vs write.
WGSL Kernel Roles
All shaders live in src/compiler/shaders/ — 27 .wgsl files implementing 10 distinct kernel roles. The rest are tiled and subgroup variants of the same role, selectable at runtime via URL flags. The compiler compiles them all at startup into GPUComputePipeline objects.
Binding convention: @group(0) always. Binding indices are zero-based and match the order you pass buffers to bg(device, pipeline, [...bufs]).
The decode loop uses the fused qkv_fused kernel (QKV matmul + RoPE + KV-append all in one dispatch). Prefill still uses separate int4_matmul + rope + kv_append dispatches because prefill processes many tokens at once, and the fusion win only lands for ntoken=1.
🔤 1 · Embedding embedding.wgsl
Token ID lookup with Q4F16 dequantization. Each output element is dequantized from a packed int4 value: (nibble - 7) × scale.
| Binding | Type | Role |
|---|---|---|
@0 | read_write f16[] | output hidden state |
@1 | read i32[] | input token IDs |
@2 | read f16[] | scales (group_size=32) |
@3 | read u32[] | packed weights (8 int4 per u32) |
@4 | uniform | { seq_len, packGridDimX } |
Dispatch: 12 workgroups × 256 threads = 3072 output elements (D=3072)
📐 2 · RMSNorm rms_norm.wgsl
Root mean square layer normalization. Computes x / sqrt(mean(x²) + ε) × gamma. Uses 256-thread tree reduction in workgroup shared memory.
| Binding | Type | Role |
|---|---|---|
@0 | read_write f16[] | normalized output |
@1 | read f16[] | input |
@2 | read f16[] | gamma weights |
@3 | uniform | { packGridDimX } |
Dispatch: 1 workgroup (one token, D=3072)
⚡ 3 · QKV + RoPE + KV-append (fused, decode) qkv_fused.wgsl
The big M4 fusion. One dispatch replaces three on the decode path: the int4 QKV matmul, the RoPE rotation of Q and K, and the write of K/V into the paged KV cache. Each workgroup computes two output rows that form a RoPE pair (dim and dim+48 within the same head), rotates the pair in registers, and writes K/V straight into kv_pages — the intermediate qkv / k_out / v_out buffers from the pre-fusion path are skipped entirely.
| Binding | Type | Role |
|---|---|---|
@0 | read_write f16[] | q_out [3072] |
@1 | read_write f16[] | kv_pages (paged KV cache) |
@2 | read f16[] | hidden [3072] |
@3 | read f16[] | scales [9216 × 96] |
@4 | read u32[] | packed weights [9216 × 384] |
@5 | read i32[] | position map |
@6 | uniform | { position_map_elem_offset, pages_elem_offset, packGridDimX } |
Dispatch: 4,608 workgroups (down from 9,216 matmul + 36 RoPE + 12 KV-append = 9,264 in the pre-fusion path). Decode-only; prefill still uses the 3-dispatch path (see shaders 8 and 9).
👁️ 4 · Paged Attention attention.wgsl
Multi-head attention over the paged KV cache. Reads K and V from pages, computes scaled dot-product attention with an online-softmax reduction in shared memory. Each workgroup handles one attention head.
| Binding | Type | Role |
|---|---|---|
@0 | read f16[] | Q [3072] |
@1 | read i32[] | page indptr |
@2 | read i32[] | page values (page table) |
@3 | read f16[] | KV pages |
@4 | read i32[] | length info |
@5 | read_write f16[] | attn output [3072] |
@6 | uniform | attention config (scale, pages) |
Dispatch: 1 × HEADS workgroups (1 × 32). An attention_int8.wgsl variant reads an int8-quantized KV cache; enable via ?int8kv=1.
✖️ 5 · int4 Matmul (output projection + LM head) int4_matmul.wgsl
General-purpose dequantize-on-the-fly int4 × f16 matmul. Used for the attention output projection (3072 → 3072) and the LM head (3072 → 32064). Weights are Q4F16_1: N output rows × 384 u32 columns (each u32 = 8 int4 values = 32 elements, group_size=32). The int4_matmul_tiled.wgsl / _tiled8 / _sg variants in the same directory add shared-memory tiling and subgroup reductions; the runtime picks one via the ?matmul= URL flag.
| Uniform field | Value |
|---|---|
K_groups | 384 (= input_dim / 8) |
scale_stride | 96 (= input_dim / group_size) |
N | 3072 (o-proj) or 32064 (lm_head) |
Dispatch: N workgroups — 3,072 for o-proj, 32,064 for lm_head.
🔀 6 · Fused FFN (Gate · Up · SiLU · Mul · Down) fused_ffn.wgsl
Full SwiGLU FFN block in a single dispatch: the gate and up projections (both int4 matmuls sharing the 16,384-row gate_up_proj weight matrix), the SiLU activation, the elementwise multiply, and the final int4 down projection back to 3,072 dims. Intermediate 8,192-dim hidden state is kept in workgroup-shared memory across the two matmul stages.
| Binding | Type | Role |
|---|---|---|
@0 | read_write f16[] | output [3072] |
@1 | read f16[] | input [3072] |
@2 | read f16[] | gate_up scales |
@3 | read u32[] | gate_up packed weights (16,384 × 384) |
@4 | read f16[] | down_proj scales |
@5 | read u32[] | down_proj packed weights (3,072 × 1,024) |
@6 | uniform | FFN config |
Dispatch: 3,072 workgroups — one per output row of the down projection. A fused_ffn_tiled_sg.wgsl variant uses subgroup reductions; selectable via URL flag.
➕ 7 · Fused Add + RMSNorm add_norm.wgsl
Residual add + RMSNorm in one pass. Computes residual_out = A + B, then output = RMSNorm(residual_out) × gamma. Used twice per layer (post-attention and post-FFN). Mirrors TVM's fuse_add_norm_decode.
| Binding | Type | Role |
|---|---|---|
@0 | read f16[] | A — the new contribution (O-proj or FFN-down output) |
@1 | read f16[] | B — the running residual (resIn) |
@2 | read f16[] | gamma — normalization weights |
@3 | read_write f16[] | normalized output (B.hidden1) |
@4 | read_write f16[] | new residual (resOut — ping-pong) |
@5 | uniform | { packGridDimX } |
Dispatch: 1 workgroup · 256 threads · 12 elements each = 3,072.
💾 8 · KV Append (prefill path) kv_append.wgsl
Writes K and V vectors into the paged KV cache at the correct slot for each position. On the decode path this work is folded into qkv_fused; on prefill it runs as a separate dispatch because prefill processes many tokens at once and the per-token fusion no longer pays off.
| Binding | Type | Role |
|---|---|---|
@0 | read f16[] | k_out [3072] |
@1 | read f16[] | v_out [3072] |
@2 | read_write f16[] | KV pages buffer |
@3 | read i32[] | position map |
@4 | uniform | page config |
Dispatch: 12 workgroups per token (HEADS=32, HEAD_DIM=96).
🌀 9 · RoPE (prefill path) rope.wgsl
Rotary position embeddings applied to Q and K. Prefill-only — the decode path folds RoPE into qkv_fused. Splits the concatenated 9,216-dim QKV buffer into Q / K / V, rotates Q and K in place based on position, and copies V unchanged.
| Binding | Type | Role |
|---|---|---|
@0 | read_write f16[] | q_out [3072] |
@1 | read_write f16[] | k_out [3072] |
@2 | read_write f16[] | v_out [3072] |
@3 | read f16[] | qkv input [9216] |
@4 | read i32[] | position map |
@5 | uniform | RoPE config |
The binding order must be [q_out, k_out, v_out, qkv, posMap, uniform]. Swapping these caused a garbage-output bug during development.
Dispatch: 36 workgroups × 256 threads = 9,216 = 3 × 3,072.
🎯 10 · Argmax Sampler argmax.wgsl
Parallel-reduction argmax over the 32,064-entry logit buffer produced by the LM-head int4_matmul. Replaces TVM's ~20-dispatch sampling chain (penalty → softmax → cumsum → argsort → gather → …) with a single dispatch. Greedy decoding only; top-k / top-p not wired up yet.
| Binding | Type | Role |
|---|---|---|
@0 | read f16[] | logits [32064] |
@1 | read_write i32[] | output token id [1] |
Dispatch: 1 workgroup (tree reduction over 32,064 logits). An argmax_sg.wgsl subgroup variant is available.
Phi-3 model constants
export const PHI3 = {
D: 3072, // hidden dimension
HEADS: 32, // attention heads
HEAD_DIM: 96, // D / HEADS
LAYERS: 32, // transformer layers
FFN: 8192, // FFN intermediate dimension
VOCAB: 32064, // vocabulary size
PAGE_SIZE:16, // KV cache slots per page
MAX_PAGES:257, // max pages (≈ 4096 context)
}
Q4F16 quantization format
MLC's Q4F16_1 format packs 8 int4 values into each uint32.
Scales are stored as float16 with group_size=32 — one scale per 32 weights.
// Extract nibble for element i within a u32
let nibble = (packed_u32 >> (i * 4)) & 0xF;
// Dequantize: center around 0, multiply by scale
let value = f16(i32(nibble) - 7) * scale;
Weight shapes in Q4F16 (for Phi-3-mini):
| Parameter | q_weight shape (u32) | q_scale shape (f16) |
|---|---|---|
| Embedding | [32064, 384] | [32064, 96] |
| QKV proj (per layer) | [9216, 384] | [9216, 96] |
| O proj (per layer) | [3072, 384] | [3072, 96] |
| Gate+Up FFN (per layer) | [16384, 384] | [16384, 96] |
| Down FFN (per layer) | [3072, 1024] | [3072, 256] |
| LM head | [32064, 384] | [32064, 96] |
Port to Phi-4-mini or Qwen3
Both are available as MLC Q4F16 packages. The steps to port:
- Update
PHI3constants incompiler.ts— D, HEADS, HEAD_DIM, LAYERS, FFN, VOCAB - Check parameter names — fetch
ndarray-cache.jsonand log all keys. Updateweight-loader.tscandidates to match - Check for GQA — if KV heads ≠ Q heads (grouped-query attention), the attention shader needs a small change to repeat KV heads
- Update chat template — each model has its own special tokens and prompt format
- Update HuggingFace base URL — change
PHI3_MODEL_BASEinweight-loader.ts
Same family as Phi-3. MLC package is already available at mlc-ai/Phi-4-mini-instruct-q4f16_1-MLC. Parameter naming is likely identical or very similar.
Local weight serving
Run the download script once to save all shards locally. Subsequent page loads are instant (served from localhost, no network).
node scripts/download-weights.mjs
# Downloads to: public/weights/Phi-3-mini-4k-instruct-q4f16_1-MLC/
# Served at: /weights/Phi-3-mini-4k-instruct-q4f16_1-MLC/
# Size: ~2.1 GB
Always download everything together. Never mix a freshly-downloaded ndarray-cache.json with old cached shards — the byte offsets will not match and all weights will be corrupted.
Debugging tips
All output is <unk>
- Weight version mismatch —
ndarray-cache.jsonoffsets don't match shard content - Buffer aliasing —
add_normdispatched with same buffer as both@1and@4 - Wrong rope binding order — check
@0=q_out, @1=k_out, @2=v_out, @3=qkv, @4=posMap
Garbage / repetitive output
- Rope bindings are in the wrong order (this was our bug — garbage like
-,unlintzegesenma) - Wrong uniform values for a shader (K_groups, N, etc.)
WebGPU validation error about aliasing
- Same buffer bound as
read_writeandreadin one dispatch - Fix: use ping-pong buffers. Never bind
B.residualas both@1and@4toadd_norm
Model not loading (Weight not found)
- Log all available keys: the weight loader prints them to console on load
- MLC names differ from HuggingFace standard names (
transformer.h.N.mixer.*notmodel.layers.N.self_attn.*)
Bugs we fixed (and how)
| Bug | Symptom | Fix |
|---|---|---|
| Wrong MLC param names | Weight not found error on load |
Logged all 325 param names from console, updated candidates to transformer.h.* prefix |
| Buffer aliasing in add_norm | <unk> × 500 at 314 tok/s + WebGPU validation error |
Added B.residual2 (pong buffer), ping-pong with JS variable swap |
| Wrong rope binding order | Garbage: -,unlintzegesenma\dOCĆalloqueIAL repeated |
Read rope.wgsl — bindings are @0=q_out @1=k_out @2=v_out @3=qkv @4=posMap |
| Mixed ndarray-cache.json version | <unk> after downloading index locally but using old cached shards |
Always fetch index and shards from the same source atomically |
vs WebLLM
Head-to-head on Phi-3-mini-4k-instruct Q4F16_1, same weights, same browser (Chrome with WebGPU), Apple M2 Pro:
| Zero-TVM | WebLLM | |
|---|---|---|
| Decode speed | ~40 tok/s | ~51 tok/s |
| Gap | Zero-TVM is 22% behind WebLLM | |
| Dispatches / token | 228 (260 with ?int8kv=1) | 342 |
| Distinct shaders | 10 kernel roles · 27 .wgsl files · 3,078 LOC | 85 TVM-generated shaders · 12,962 LOC |
| Shipped JS bundle | 157 kB / 33 kB gz (zero-tvm.html) | 5.9 MB / 2.1 MB gz (compiler-chat.html) |
| Bandwidth utilization | ~36% of 111 tok/s ceiling | ~46% of same ceiling |
| Paged attention | ✓ Hand-written | ✓ TVM compiled |
| Readable kernels? | ✓ Yes — all 27 .wgsl files in repo | ✗ No — emitted by the TVM compiler |
Where the 22% gap comes from
At ~40 tok/s Zero-TVM already wins on dispatch overhead (228 vs 342 submissions per token). The remaining gap is inside the matmul and attention kernels themselves:
- Matmul tiling. WebLLM's
fused_dequantize_NT_matmulkernels are TVM-autotuned with aggressive shared-memory tiling and 2D thread blocks. Ourint4_matmul_tiled/_tiled8/_sgvariants close some of this, but the defaultint4_matmulloses ~15% here. - Attention reduction. TVM's
batch_decode_paged_kv_kernelsplits the softmax reduction across more workgroups with a follow-upattention_merge_statepass. Our single-workgroup-per-head design is simpler but leaves M2 Pro's 20 SMs partly idle. - Subgroup ops. Recent WebGPU supports subgroup reductions; we have
_sgvariants but they aren't default. WebLLM compiles with subgroup ops turned on.
What is not in the gap: dispatch submission overhead, command-encoder build time, or framework marshalling. Those all favor Zero-TVM and are already measured out.
The memory-bandwidth ceiling
Phi-3-mini Q4F16_1 touches ~1.8 GB of weights per decode token. On M2 Pro's 200 GB/s memory bus that's 9 ms/token, or ~111 tok/s theoretical max. Zero-TVM at ~40 tok/s hits 36% of that ceiling; WebLLM at ~51 tok/s hits 46%. Neither engine can exceed the ceiling without changing weight layout or quantization (int8 KV shaves a bit, which is why it's exposed behind a flag).