/ Documentation

Zero TVM Documentation

A complete LLM inference engine in the browser — no WebLLM, no TVM, no ONNX, no WASM runtime. 10 hand-written WGSL kernel roles (27 files, 3,078 LOC), a BPE tokenizer, and raw WebGPU. ~40 tok/s on M2 Pro, 22% behind WebLLM's TVM-autotuned 85 kernels on identical weights.

WebGPU 10 kernel roles · 27 WGSL files Q4F16_1 Phi-3-mini 3.8B Paged KV Cache 0 dependencies

This project implements the full Phi-3-mini-4k-instruct transformer forward pass using only WebGPU compute shaders written by hand in WGSL. Weights are loaded directly from HuggingFace in MLC Q4F16_1 format (or from the browser Cache API if WebLLM was used previously). The tokenizer is implemented in pure TypeScript — no SentencePiece WASM.

Why?

WebLLM and transformers.js both work, but they're black boxes — TVM compiles kernels you never see. This project exists to show that you can understand every single step of a modern transformer inference pipeline, written at the GPU level, in a browser tab.

Quick start

No install. Just open the chat page. Weights load from your browser cache (if you've used WebLLM before) or download fresh from HuggingFace (~2.1 GB).

1
Open the chat

Navigate to zero-tvm.html. WebGPU initializes automatically. Requires Chrome 113+, Edge 113+, or Safari 18+.

2
Wait for weights

First load downloads ~2.1 GB. Subsequent loads are instant from browser Cache API. Progress shown per shard in the log panel.

3
Chat

Phi-3-mini runs at ~40 tok/s on M2 Pro (vs ~51 tok/s WebLLM on identical weights). Your conversation never leaves the browser. No API calls during inference.

4
(Optional) Cache weights locally

Run node scripts/download-weights.mjs once to save all 83 shards to public/weights/. Subsequent loads are served from localhost at full disk speed.

GPU requirement

Requires a GPU with shader-f16 WebGPU feature (f16 arithmetic). Most M-series Macs and recent NVIDIA/AMD GPUs support this. Intel integrated graphics may not.

How it works

The engine has three phases per generated token:

  1. Prefill — process each prompt token sequentially, building up the KV cache
  2. First decode — the last prefill step produces the first generated token
  3. Decode loop — each step takes the previous token as input, runs the full forward pass with the KV cache providing attention context, produces the next token

Each forward pass runs 10 kernel roles through 32 transformer layers — 228 dispatches per token in the default f16-KV path (260 with the ?int8kv=1 flag) — then reads one i32 token ID back from the GPU. For comparison, WebLLM's TVM-generated decode path fires 342 dispatches per token using 85 distinct shaders.

Architecture overview

Four source files make up the engine:

FileRole
src/zero-tvm/chat.tsMain decode engine + UI. Allocates buffers, builds bind groups, runs the decode loop.
src/zero-tvm/weight-loader.tsFetches ndarray-cache.json, downloads shards, uploads to GPU buffers.
src/zero-tvm/tokenizer.tsBPE tokenizer: encode text → token IDs, decode IDs → text.
src/compiler/compiler.tsCompiles all 27 WGSL files (10 kernel roles + tiled/subgroup variants) into GPUComputePipeline objects.

Weight loader

Weights are stored in MLC's ndarray-cache.json format — an index file listing every parameter, which shard binary it lives in, its byte offset, and byte size. The loader reads this index then fetches each referenced shard.

Fetch priority

  1. Browser Cache API — if a prior WebLLM session cached the shards, they load instantly
  2. HuggingFace — direct HTTPS fetch from huggingface.co/mlc-ai/Phi-3-mini-4k-instruct-q4f16_1-MLC
Version mismatch pitfall

Do not mix a locally-downloaded ndarray-cache.json with shards from the browser cache. The byte offsets in the index must match the shards exactly. If you download the index fresh but use cached shards from an older model version, all weight slices will be wrong → zero logits → <unk> output.

Parameter naming (MLC format)

MLC uses non-standard parameter names. The actual names in the cache:

MLC nameRole
transformer.embd.q_weightEmbedding weights (uint32 packed int4)
transformer.embd.q_scaleEmbedding scales (f16)
transformer.norm.weightFinal RMSNorm gamma (after all layers)
lm_head.q_weightLM head weights
transformer.h.N.ln.weightLayer N input_layernorm (normGamma1)
transformer.h.N.post_attention_layernorm.weightLayer N post-attention norm (normGamma2)
transformer.h.N.mixer.qkv_proj.q_weightLayer N QKV projection weights
transformer.h.N.mixer.out_proj.q_weightLayer N output projection weights
transformer.h.N.mlp.gate_up_proj.q_weightLayer N FFN gate+up weights
transformer.h.N.mlp.down_proj.q_weightLayer N FFN down weights

Tokenizer

A hand-written BPE tokenizer in TypeScript. No SentencePiece WASM, no HuggingFace tokenizers bundle. Reads tokenizer.json directly.

Key steps

Phi-3 chat template

📄 prompt format
<|system|>
You are a helpful assistant.<|end|>
<|user|>
What is the capital of Australia?<|end|>
<|assistant|>

Stop tokens: 2 (EOS), 32000 (<|end|>), 32007 (<|endoftext|>).

KV cache

Uses a paged KV cache — the same approach as vLLM and PagedAttention. Memory is divided into fixed-size pages (16 slots each). A page table maps logical positions to physical pages.

ParameterValueNotes
PAGE_SIZE16slots per page
MAX_PAGES257≈ 4096 context tokens
Bytes per page196,60832 heads × 16 slots × 96 dims × 2 (K+V) × 2 bytes
Total KV buffer~50 MB per layer32 layers = ~1.6 GB

Each layer has its own GPUBuffer for KV pages. The page table is a simple identity mapping (page i → physical page i) for single-sequence inference.

Decode loop

Each call to decodeToken(tokenId, position) submits one command encoder with the full forward pass:

src/zero-tvm/chat.ts · decodeToken()
// Per-token GPU state written via writeBuffer
B.inputIds   ← [tokenId]          // i32
B.posMap     ← [position]         // i32
B.pageIndptr ← [0, nnzPages]      // page range
B.lengthInfo ← [position+1, 0, 0] // seq length

// Forward pass (one command encoder)
embedding(B.residual)             // token → hidden state
rmsNorm(B.hidden1, B.residual)    // initial norm

for L in 0..32:
  // QKV matmul + RoPE + KV-append in ONE dispatch (M4 fusion)
  qkvFused(B.qOut, kvPages[L], B.hidden1)
  attention(B.attnOut, B.qOut, kvPages[L])
  int4Matmul(B.hidden2, B.attnOut)               // O projection
  addNorm(B.hidden2, resIn → B.hidden1, resOut)  // residual + RMSNorm, ping-pong
  // Gate + Up + SiLU + mul + Down in ONE dispatch
  fusedFfn(B.hidden2, B.hidden1)
  addNorm(B.hidden2, resIn → B.hidden1, resOut)  // residual + RMSNorm, ping-pong

int4Matmul(B.logits, B.hidden1)   // LM head
argmax(B.tokenOut, B.logits)      // → next token ID

Ping-pong residual buffers

WebGPU's validation rules forbid binding the same buffer as both read and read_write in the same dispatch. The add_norm shader needs to read the old residual and write the new one.

Solution: two residual buffers that alternate each dispatch.

ping-pong pattern
let resIn  = B.residual   // ping (starts with embedding)
let resOut = B.residual2  // pong (uninitialized)

// Each add_norm:
dispatch(addNorm, [delta, resIn, gamma, hidden1, resOut])
[resIn, resOut] = [resOut, resIn]  // swap — O(1), no GPU copy

The swap is just two JavaScript variable reassignments — no GPU buffer copy. Both buffers always exist on the GPU; we just change which one we tell the bind group to read vs write.

WGSL Kernel Roles

All shaders live in src/compiler/shaders/ — 27 .wgsl files implementing 10 distinct kernel roles. The rest are tiled and subgroup variants of the same role, selectable at runtime via URL flags. The compiler compiles them all at startup into GPUComputePipeline objects.

Binding convention: @group(0) always. Binding indices are zero-based and match the order you pass buffers to bg(device, pipeline, [...bufs]).

Decode vs prefill paths diverge

The decode loop uses the fused qkv_fused kernel (QKV matmul + RoPE + KV-append all in one dispatch). Prefill still uses separate int4_matmul + rope + kv_append dispatches because prefill processes many tokens at once, and the fusion win only lands for ntoken=1.


🔤 1 · Embedding embedding.wgsl

Token ID lookup with Q4F16 dequantization. Each output element is dequantized from a packed int4 value: (nibble - 7) × scale.

BindingTypeRole
@0read_write f16[]output hidden state
@1read i32[]input token IDs
@2read f16[]scales (group_size=32)
@3read u32[]packed weights (8 int4 per u32)
@4uniform{ seq_len, packGridDimX }

Dispatch: 12 workgroups × 256 threads = 3072 output elements (D=3072)

📐 2 · RMSNorm rms_norm.wgsl

Root mean square layer normalization. Computes x / sqrt(mean(x²) + ε) × gamma. Uses 256-thread tree reduction in workgroup shared memory.

BindingTypeRole
@0read_write f16[]normalized output
@1read f16[]input
@2read f16[]gamma weights
@3uniform{ packGridDimX }

Dispatch: 1 workgroup (one token, D=3072)

⚡ 3 · QKV + RoPE + KV-append (fused, decode) qkv_fused.wgsl

The big M4 fusion. One dispatch replaces three on the decode path: the int4 QKV matmul, the RoPE rotation of Q and K, and the write of K/V into the paged KV cache. Each workgroup computes two output rows that form a RoPE pair (dim and dim+48 within the same head), rotates the pair in registers, and writes K/V straight into kv_pages — the intermediate qkv / k_out / v_out buffers from the pre-fusion path are skipped entirely.

BindingTypeRole
@0read_write f16[]q_out [3072]
@1read_write f16[]kv_pages (paged KV cache)
@2read f16[]hidden [3072]
@3read f16[]scales [9216 × 96]
@4read u32[]packed weights [9216 × 384]
@5read i32[]position map
@6uniform{ position_map_elem_offset, pages_elem_offset, packGridDimX }

Dispatch: 4,608 workgroups (down from 9,216 matmul + 36 RoPE + 12 KV-append = 9,264 in the pre-fusion path). Decode-only; prefill still uses the 3-dispatch path (see shaders 8 and 9).

👁️ 4 · Paged Attention attention.wgsl

Multi-head attention over the paged KV cache. Reads K and V from pages, computes scaled dot-product attention with an online-softmax reduction in shared memory. Each workgroup handles one attention head.

BindingTypeRole
@0read f16[]Q [3072]
@1read i32[]page indptr
@2read i32[]page values (page table)
@3read f16[]KV pages
@4read i32[]length info
@5read_write f16[]attn output [3072]
@6uniformattention config (scale, pages)

Dispatch: 1 × HEADS workgroups (1 × 32). An attention_int8.wgsl variant reads an int8-quantized KV cache; enable via ?int8kv=1.

✖️ 5 · int4 Matmul (output projection + LM head) int4_matmul.wgsl

General-purpose dequantize-on-the-fly int4 × f16 matmul. Used for the attention output projection (3072 → 3072) and the LM head (3072 → 32064). Weights are Q4F16_1: N output rows × 384 u32 columns (each u32 = 8 int4 values = 32 elements, group_size=32). The int4_matmul_tiled.wgsl / _tiled8 / _sg variants in the same directory add shared-memory tiling and subgroup reductions; the runtime picks one via the ?matmul= URL flag.

Uniform fieldValue
K_groups384 (= input_dim / 8)
scale_stride96 (= input_dim / group_size)
N3072 (o-proj) or 32064 (lm_head)

Dispatch: N workgroups — 3,072 for o-proj, 32,064 for lm_head.

🔀 6 · Fused FFN (Gate · Up · SiLU · Mul · Down) fused_ffn.wgsl

Full SwiGLU FFN block in a single dispatch: the gate and up projections (both int4 matmuls sharing the 16,384-row gate_up_proj weight matrix), the SiLU activation, the elementwise multiply, and the final int4 down projection back to 3,072 dims. Intermediate 8,192-dim hidden state is kept in workgroup-shared memory across the two matmul stages.

BindingTypeRole
@0read_write f16[]output [3072]
@1read f16[]input [3072]
@2read f16[]gate_up scales
@3read u32[]gate_up packed weights (16,384 × 384)
@4read f16[]down_proj scales
@5read u32[]down_proj packed weights (3,072 × 1,024)
@6uniformFFN config

Dispatch: 3,072 workgroups — one per output row of the down projection. A fused_ffn_tiled_sg.wgsl variant uses subgroup reductions; selectable via URL flag.

➕ 7 · Fused Add + RMSNorm add_norm.wgsl

Residual add + RMSNorm in one pass. Computes residual_out = A + B, then output = RMSNorm(residual_out) × gamma. Used twice per layer (post-attention and post-FFN). Mirrors TVM's fuse_add_norm_decode.

BindingTypeRole
@0read f16[]A — the new contribution (O-proj or FFN-down output)
@1read f16[]B — the running residual (resIn)
@2read f16[]gamma — normalization weights
@3read_write f16[]normalized output (B.hidden1)
@4read_write f16[]new residual (resOut — ping-pong)
@5uniform{ packGridDimX }

Dispatch: 1 workgroup · 256 threads · 12 elements each = 3,072.

💾 8 · KV Append (prefill path) kv_append.wgsl

Writes K and V vectors into the paged KV cache at the correct slot for each position. On the decode path this work is folded into qkv_fused; on prefill it runs as a separate dispatch because prefill processes many tokens at once and the per-token fusion no longer pays off.

BindingTypeRole
@0read f16[]k_out [3072]
@1read f16[]v_out [3072]
@2read_write f16[]KV pages buffer
@3read i32[]position map
@4uniformpage config

Dispatch: 12 workgroups per token (HEADS=32, HEAD_DIM=96).

🌀 9 · RoPE (prefill path) rope.wgsl

Rotary position embeddings applied to Q and K. Prefill-only — the decode path folds RoPE into qkv_fused. Splits the concatenated 9,216-dim QKV buffer into Q / K / V, rotates Q and K in place based on position, and copies V unchanged.

BindingTypeRole
@0read_write f16[]q_out [3072]
@1read_write f16[]k_out [3072]
@2read_write f16[]v_out [3072]
@3read f16[]qkv input [9216]
@4read i32[]position map
@5uniformRoPE config
Critical binding order

The binding order must be [q_out, k_out, v_out, qkv, posMap, uniform]. Swapping these caused a garbage-output bug during development.

Dispatch: 36 workgroups × 256 threads = 9,216 = 3 × 3,072.

🎯 10 · Argmax Sampler argmax.wgsl

Parallel-reduction argmax over the 32,064-entry logit buffer produced by the LM-head int4_matmul. Replaces TVM's ~20-dispatch sampling chain (penalty → softmax → cumsum → argsort → gather → …) with a single dispatch. Greedy decoding only; top-k / top-p not wired up yet.

BindingTypeRole
@0read f16[]logits [32064]
@1read_write i32[]output token id [1]

Dispatch: 1 workgroup (tree reduction over 32,064 logits). An argmax_sg.wgsl subgroup variant is available.

Phi-3 model constants

src/compiler/compiler.ts
export const PHI3 = {
  D:        3072,   // hidden dimension
  HEADS:    32,     // attention heads
  HEAD_DIM: 96,     // D / HEADS
  LAYERS:   32,     // transformer layers
  FFN:      8192,   // FFN intermediate dimension
  VOCAB:    32064,  // vocabulary size
  PAGE_SIZE:16,     // KV cache slots per page
  MAX_PAGES:257,    // max pages (≈ 4096 context)
}

Q4F16 quantization format

MLC's Q4F16_1 format packs 8 int4 values into each uint32. Scales are stored as float16 with group_size=32 — one scale per 32 weights.

Dequantization formula (from embedding.wgsl)
// Extract nibble for element i within a u32
let nibble = (packed_u32 >> (i * 4)) & 0xF;

// Dequantize: center around 0, multiply by scale
let value = f16(i32(nibble) - 7) * scale;

Weight shapes in Q4F16 (for Phi-3-mini):

Parameterq_weight shape (u32)q_scale shape (f16)
Embedding[32064, 384][32064, 96]
QKV proj (per layer)[9216, 384][9216, 96]
O proj (per layer)[3072, 384][3072, 96]
Gate+Up FFN (per layer)[16384, 384][16384, 96]
Down FFN (per layer)[3072, 1024][3072, 256]
LM head[32064, 384][32064, 96]

Port to Phi-4-mini or Qwen3

Both are available as MLC Q4F16 packages. The steps to port:

  1. Update PHI3 constants in compiler.ts — D, HEADS, HEAD_DIM, LAYERS, FFN, VOCAB
  2. Check parameter names — fetch ndarray-cache.json and log all keys. Update weight-loader.ts candidates to match
  3. Check for GQA — if KV heads ≠ Q heads (grouped-query attention), the attention shader needs a small change to repeat KV heads
  4. Update chat template — each model has its own special tokens and prompt format
  5. Update HuggingFace base URL — change PHI3_MODEL_BASE in weight-loader.ts
Phi-4-mini is the easiest port

Same family as Phi-3. MLC package is already available at mlc-ai/Phi-4-mini-instruct-q4f16_1-MLC. Parameter naming is likely identical or very similar.

Local weight serving

Run the download script once to save all shards locally. Subsequent page loads are instant (served from localhost, no network).

terminal
node scripts/download-weights.mjs

# Downloads to: public/weights/Phi-3-mini-4k-instruct-q4f16_1-MLC/
# Served at:    /weights/Phi-3-mini-4k-instruct-q4f16_1-MLC/
# Size:         ~2.1 GB
Keep index + shards in sync

Always download everything together. Never mix a freshly-downloaded ndarray-cache.json with old cached shards — the byte offsets will not match and all weights will be corrupted.

Debugging tips

All output is <unk>

Garbage / repetitive output

WebGPU validation error about aliasing

Model not loading (Weight not found)

Bugs we fixed (and how)

BugSymptomFix
Wrong MLC param names Weight not found error on load Logged all 325 param names from console, updated candidates to transformer.h.* prefix
Buffer aliasing in add_norm <unk> × 500 at 314 tok/s + WebGPU validation error Added B.residual2 (pong buffer), ping-pong with JS variable swap
Wrong rope binding order Garbage: -,unlintzegesenma\dOCĆalloqueIAL repeated Read rope.wgsl — bindings are @0=q_out @1=k_out @2=v_out @3=qkv @4=posMap
Mixed ndarray-cache.json version <unk> after downloading index locally but using old cached shards Always fetch index and shards from the same source atomically

vs WebLLM

Head-to-head on Phi-3-mini-4k-instruct Q4F16_1, same weights, same browser (Chrome with WebGPU), Apple M2 Pro:

Zero-TVMWebLLM
Decode speed~40 tok/s~51 tok/s
GapZero-TVM is 22% behind WebLLM
Dispatches / token228 (260 with ?int8kv=1)342
Distinct shaders10 kernel roles · 27 .wgsl files · 3,078 LOC85 TVM-generated shaders · 12,962 LOC
Shipped JS bundle157 kB / 33 kB gz (zero-tvm.html)5.9 MB / 2.1 MB gz (compiler-chat.html)
Bandwidth utilization~36% of 111 tok/s ceiling~46% of same ceiling
Paged attention✓ Hand-written✓ TVM compiled
Readable kernels?✓ Yes — all 27 .wgsl files in repo✗ No — emitted by the TVM compiler

Where the 22% gap comes from

At ~40 tok/s Zero-TVM already wins on dispatch overhead (228 vs 342 submissions per token). The remaining gap is inside the matmul and attention kernels themselves:

What is not in the gap: dispatch submission overhead, command-encoder build time, or framework marshalling. Those all favor Zero-TVM and are already measured out.

The memory-bandwidth ceiling

Phi-3-mini Q4F16_1 touches ~1.8 GB of weights per decode token. On M2 Pro's 200 GB/s memory bus that's 9 ms/token, or ~111 tok/s theoretical max. Zero-TVM at ~40 tok/s hits 36% of that ceiling; WebLLM at ~51 tok/s hits 46%. Neither engine can exceed the ceiling without changing weight layout or quantization (int8 KV shaves a bit, which is why it's exposed behind a flag).