← zerotvm.com

WebGPU Fusion Architecture

Phi-3 Mini 4K (3.8B, Q4) running 100% in-browser via WebGPU. Reverse-engineered from TVM.

279
Our Dispatches (per token)
63
TVM Dispatches (KV+Attn)
10
Custom WGSL Shaders
85
TVM Shaders (captured)

Three Inference Approaches

HYBRID

Engine v3 (Capture-Replay)

Msg 1: TVM generates + captures full dispatch graph. Msg 2+: replays captured pipelines with 6 per-token value updates. ~100 dispatches vs TVM's 342.

engine-v3.ts chat-v3.ts v3.html
COMPILER

Compiler Engine (228/342 Own)

Replaces 228 of 342 dispatches with our 11 hand-written WGSL shaders (post-M4 QKV+RoPE+KV-append fusion). Only attention + LM-head reduction fall back to TVM-style paths. Full prefill + decode.

compiler/ compiler-chat.html
FULL

Phi3Engine (Named Pipelines)

Maps each step to named TVM pipelines (qkv, rope, attention...) with per-layer bind groups. Fused FFN enabled. Supports own prefill + decode via chat().

phi3.ts fast-chat.ts

Data Flow: Model Load to Chat

WebLLM
@mlc-ai/web-llm
TVM Runtime
loads Q4 weights
capture.ts
intercepts WebGPU
CaptureResult
shaders, dispatches, writes, buffers
Our Engine
builds pipelines + bind groups
WebGPU
zero TVM in hot path

capture.ts — WebGPU Interceptor

Monkey-patches GPUDevice to record everything TVM does:

State machine: loadingprefill_donedone

Separates prefill vs decode dispatches at the mapAsync boundary.

CaptureResult

FieldCountDescription
shaders85WGSL source code + module
pipelines85entry point + pipeline object
dispatches342Decode: pipeline + bind group entries + workgroups
prefillDispatches343Prefill: same structure
writes357Decode buffer writes (buffer + offset + data)
copy1Token readback: src buffer + offset
weights~200Large buffer writes (>1KB) during load

Phi-3 Mini Architecture (Per Token)

Decode Pipeline: 342 Dispatches (32 Layers x 10 + 2 Preamble + 20 Tail)

Our shader (279)
TVM kernel (63)
Fused (saves 1 dispatch/layer)
Skipped (replaced)

Preamble (2 dispatches)

embed Embedding Lookup RMS Norm

Per Layer x 32 (10 dispatches each = 320 total)

attention QKV Matmul RoPE KV Append Paged Attention O Proj Matmul Add + RMSNorm
FFN Fused Gate+Up+SiLU SiLU (skipped) FFN Down Matmul Add + RMSNorm

Tail (20 dispatches)

sampling LM Head Matmul (f32) Argmax 18x TVM sampling (replaced by argmax)

Compiler engine: 279 our dispatches + 63 TVM (2 per layer for KV append + attention, skipped 19 sampling). Fused FFN saves 32 dispatches.


Per-Layer Dispatch Map (Layer N, base = 2 + N*10)

StepIndexKernelOwnerWorkgroupsDescription
0base+0int4_matmulOurs9216QKV projection: D=3072 → 3*D=9216 (Q,K,V concatenated)
1base+1rope_kernelOurs36Rotary position embedding on Q,K
2base+2kv_cache_transpose_appendTVMvariesAppend K,V to paged cache (transpose layout)
3base+3batch_decode_paged_kvTVMvariesPaged KV attention with sliding window
4base+4int4_matmulOurs3072Output projection: head_dim*heads → D
5base+5add_normOurs1Residual add + RMSNorm (pre-FFN)
6base+6fused_ffn_kernelFused8192Gate+Up int4 matmul + SiLU in ONE dispatch
7base+7split_silu_multiplySkip-Replaced by fused FFN above
8base+8int4_matmulOurs3072FFN down projection: 8192 → 3072
9base+9add_normOurs1Residual add + RMSNorm (post-FFN)

Our 10 Custom WGSL Shaders

int4_matmul.wgsl

Dequantize int4 weights + matrix multiply. 64 threads, 6 chunks, tree reduction. Zero-point 7, 32-group scales. f16 accumulation.

D_PACKED = D/8 packed u32s per row

int4_matmul_f32.wgsl

Same as int4_matmul but accumulates in f32. Used for LM head (32064 outputs) where f16 precision causes sampling errors.

fused_ffn.wgsl

Fuses gate+up matmul + SiLU into ONE dispatch. Dual dot product (gate row i, up row i+8192), shared memory input cache (3072 f16 = 6KB), f32 sigmoid.

Replaces 2 TVM dispatches → saves 32 dispatches/token

rms_norm.wgsl

RMSNorm: x * rsqrt(mean(x^2) + eps) * weight. Single workgroup, D=3072 elements. Used for initial norm.

add_norm.wgsl

Fused residual add + RMSNorm. y = norm(residual + x) * weight. Saves a dispatch vs separate add+norm.

rope.wgsl

Rotary Position Embedding. Applies cos/sin rotation to Q,K heads based on position. 36 workgroups for 32+4 heads.

embedding.wgsl

Token embedding lookup from dequantized int4 embedding table. Maps token ID → D=3072 hidden state.

argmax.wgsl

Simple argmax over 32064 logits. Replaces TVM's 19-dispatch sampling pipeline (top-p, temperature, etc.) with greedy decoding.

attention.wgsl

Multi-head attention with paged KV cache. (Compiler engine only — main engines use TVM's optimized kernel.)

kv_append.wgsl

Append K,V to paged cache. (Compiler engine only — main engines use TVM's transpose-append kernel.)


Buffer Landscape (Decode Token)

6 Per-Token Writes

WriteBufferValue
w[0]token_idCurrent token ID (u32)
w[4]position_mapPosition for RoPE (u32)
w[8]seq_counterSequence index (u32)
w[11]q_rope_positionRoPE position (u32)
w[12]length_infoSequence length (u32)
w[349]seedSampling RNG seed (f32)

Plus: nnz_pages at offset 16 in 32 attention uniform buffers

Activation Ping-Pong

3 scratch buffers cycle through layers:

BufferSizeRole
BUF#7306KBHidden state (3072 f16)
BUF#73132KBFFN intermediate (16384 f16)
BUF#73218KBQKV output (9216 f16)

Same buffers reused across all 32 layers. Weight buffers are per-layer (~50MB total for Q4).


File Structure

Core Engine

src/
capture.ts — WebGPU interceptor (state machine)
engine.ts — WebLLM wrapper (model loading, tokenizer)
phi3.ts — Phi3Engine (named pipelines, fused FFN)
engine-v3.ts — Minimal replay engine (~170 lines)
ui.ts — Chat UI component
fast-chat.ts — Phi3Engine chat page
chat-v3.ts — Engine v3 chat page
main.ts — Default TVM chat page
dump-tvm.ts — Architecture analysis tool
dump-shaders.ts — Shader extraction tool
shaders/ — Fused shaders for phi3 engine
fused-ffn.wgsl
argmax.wgsl
fused-norm-matmul.wgsl

Compiler Engine

src/compiler/
compiler.ts — Pipeline + bind group builder
runtime.ts — Full transformer forward pass
chat-v2.ts — Chat with 279/342 own shaders
chat.ts — Earlier chat version
model.ts — Model config + weight loading
test-harness.ts — Shader correctness tests
test-chain.ts — Full chain verification
shaders/ — 10 custom WGSL kernels
int4_matmul.wgsl
fused_ffn.wgsl
rms_norm.wgsl
add_norm.wgsl
rope.wgsl
attention.wgsl
kv_append.wgsl
embedding.wgsl
int4_matmul_f32.wgsl
argmax.wgsl

HTML Entry Points

FileEngineDescription
index.htmlLandingProject overview, stats, shader catalog, compare table
zero-tvm.htmlZero-TVMChat demo on 10 hand-written WGSL kernels (no TVM)
validate.htmlZero-TVMMulti-prompt smoke test driving engine-core.ts
webllm-bench.htmlWebLLMHead-to-head harness against identical weights
compiler-chat.htmlCompilerTVM capture → replay via own shaders (279/342)
demo.htmlDemoInteractive dispatch-timeline visualization
dump.htmlAnalysisFull TVM architecture dump
shaders.htmlAnalysisBrowse all 85 captured TVM shader sources
docs.htmlDocsShader catalog, URL flags, benchmarks

Model Constants

3072
Hidden dim (D)
32
Attention heads
96
Head dim (D/heads)
32
Transformer layers
8192
FFN intermediate dim
32064
Vocabulary size
16
KV cache page size
4096
Max sequence length
webgpu-fusion-webllm — Reverse-engineered WebGPU LLM inference