Skip to content

Commit 442c2d7

Browse files
unamedkrclaude
andcommitted
refactor(batched): vector-accumulator matmul + deeper drift analysis
bm_q4_worker now uses NEON vector accumulators (float32x4_t sumv[N] with vmlaq_n_f32 + vaddvq_f32 reduce) to match matmul_q4_rows' FP rounding. This brings it architecturally in line with baseline's per-token quantized matmul. However, integration-level drift persists. Even TQ_BATCHED_SERIAL=1 (which forces bit-for-bit identical per-token matmul via the same tq_matmul_q4_preq call baseline uses) still produces wrong output. The bug is therefore NOT in the matmul accumulator but in surrounding tq_forward_batch orchestration. Divergence is highly specific: Layer 3 tok1 (pos=1) diverges at indices 1, 5, 6, 7 but matches at 0, 2, 3, 4 — a pattern-based drift, not random noise. Updated handoff doc with concrete next-session experiments: - Dump Layer 3 tok0 wo-matmul output byte-for-byte - Dump Layer 3 tok1 attention scores att[0], att[1] - If scores differ: trace back to K-cache at layer 3 pos=0 - If K-cache differs: trace back to WK matmul output for tok0 11/11 STRICT tests still pass (batched still TQ_BATCH_PREFILL-gated). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent bd063e0 commit 442c2d7

File tree

2 files changed

+73
-16
lines changed

2 files changed

+73
-16
lines changed

docs/dev/batched_prefill_handoff.md

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,49 @@ DYLD_LIBRARY_PATH=build \
2727
# prints: " I'm so excited" ← baseline (correct)
2828
```
2929

30+
## Session 2 findings (2026-04-16 early AM)
31+
32+
After extensive layer-by-layer diff between batched and per-token paths:
33+
34+
**What's bit-identical**:
35+
- tok0 (pos=0) through Layer 15 — every sub-op, every layer
36+
- tok1 (pos=1) through Layer 2 final Xres
37+
- Layer 3 attention-residual value at indices 0, 2, 3, 4 (partial match)
38+
39+
**What diverges**:
40+
- Layer 3 tok1 attention-residual at indices 1, 5, 6, 7 — exactly 1 ULP off
41+
- This 1-ULP drift compounds ~1%/layer → wrong token by Layer 15
42+
43+
**Surprising**: Setting `TQ_BATCHED_SERIAL=1` (which replaces my bm_q4_worker
44+
with literal per-token `tq_matmul_q4_preq` calls — the SAME function baseline
45+
uses) STILL produces the divergence. So the bug is not in the batched matmul
46+
accumulator order; it's somewhere in the broader orchestration of
47+
tq_forward_batch when processing multi-token.
48+
49+
**Fixed along the way** (each moved Layer 0 closer to bit-identical):
50+
- Q8 quantization: `roundf``tq_quantize_row_q8` (NEON RNE)
51+
- FP16 conversion (write): inline → `f32_to_fp16_vec`
52+
- FP16 conversion (read): inline → NEON `vcvt_f32_f16`
53+
- Attention score accumulation: scalar → `vfmaq_f32` NEON
54+
- bm_q4_worker: scalar accumulator → NEON `float32x4_t sumv[]` + `vaddvq_f32`
55+
56+
**Remaining hypothesis** (to test next session):
57+
The drift is at specific indices, not random. Index 1 of Layer 3 tok1 diverges
58+
but indices 0, 2, 3 don't. This is consistent with a SPECIFIC memory location
59+
being read slightly off. Possibilities:
60+
- Aliasing: my X buffer might be accidentally read before fully written in
61+
some multi-token iteration (out-of-order thread effect)
62+
- FP16 round-trip on a specific value whose LSB happens to sit on a boundary
63+
- The `tq_forward` final call (after batched) reads K-cache positions [0..pos-1]
64+
written by batched; if ANY of those are 1 ULP off for any layer, final
65+
attention sees slightly wrong history. Could be compounding effect.
66+
67+
**Concrete next-session experiment**:
68+
1. Dump Layer 3 tok0 wo matmul output (OB→X) byte-for-byte vs baseline
69+
2. Dump Layer 3 tok1 attention scores (att[0], att[1]) vs baseline
70+
3. If scores differ, dump K-cache at layer 3 pos=0 vs baseline
71+
4. If K-cache differs, dump the WK matmul output for tok0 at layer 3
72+
3073
## Latest session findings (2026-04-15 evening)
3174

3275
-**SANITY mode confirms orchestration is correct**. Setting

src/engine/tq_ops.c

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,26 +1081,31 @@ static void* bm_q4_worker(void* arg) {
10811081
#ifdef __ARM_NEON
10821082
const uint8x16_t mask_0f = vdupq_n_u8(0x0F);
10831083
const uint8x16_t v8 = vdupq_n_u8(8);
1084+
#endif
1085+
/* Per-row, per-token NEON vector accumulators. To match matmul_q4_rows'
1086+
* FP rounding bit-for-bit, we must use vmlaq_n_f32 into a 4-lane
1087+
* float32x4_t accumulator and reduce with vaddvq_f32 at the end. A
1088+
* scalar `acc[n] += ...` path produces 1-ULP drift per block which
1089+
* compounds 1% per transformer layer and flips output tokens. */
1090+
#ifdef __ARM_NEON
1091+
if (N > 64) { /* safety limit for stack-alloc'd sumv array */
1092+
/* Fallback: serial per-token via tq_matmul_q4_preq would be needed
1093+
* here. For now reject large batches to keep the hot path simple. */
1094+
return NULL;
1095+
}
1096+
float32x4_t sumv[64];
10841097
#endif
10851098
for (int i = t->start_row; i < t->end_row; i++) {
10861099
const uint8_t* wi = t->w_qs + (size_t)i * n_blocks * 16;
10871100
const float* si = t->w_scales + (size_t)i * n_blocks;
10881101

1089-
/* Per-row N-element accumulator (FP32, on stack — N usually small). */
1090-
/* For very large N callers will need a different design (chunk N). */
1091-
float acc[256];
1092-
if (N > 256) { /* shouldn't happen at sane batch sizes */ continue; }
1093-
memset(acc, 0, sizeof(float) * N);
1102+
#ifdef __ARM_NEON
1103+
for (int n = 0; n < N; n++) sumv[n] = vdupq_n_f32(0.0f);
10941104

10951105
for (int b = 0; b < n_blocks; b++) {
1096-
#ifdef __ARM_NEON
1097-
/* Unpack 16 packed bytes → 32 signed int8 nibbles, range [-8, 7]. */
10981106
uint8x16_t pk = vld1q_u8(wi + b * 16);
10991107
int8x16_t lo = vreinterpretq_s8_u8(vsubq_u8(vandq_u8(pk, mask_0f), v8));
11001108
int8x16_t hi = vreinterpretq_s8_u8(vsubq_u8(vshrq_n_u8(pk, 4), v8));
1101-
/* The packed layout interleaves (lo,hi) pairs. Use vld2q_s8 on
1102-
* x_q to deinterleave to the same scheme: x_q[0,2,4,...] vs
1103-
* x_q[1,3,5,...]. matmul_q4_rows uses this; we match it. */
11041109

11051110
const float wd = si[b];
11061111
for (int n = 0; n < N; n++) {
@@ -1115,12 +1120,23 @@ static void* bm_q4_worker(void* arg) {
11151120
a0 = vaddq_s32(a0, vpaddlq_s16(vmull_s8(vget_low_s8(hi), vget_low_s8(xd.val[1]))));
11161121
a0 = vaddq_s32(a0, vpaddlq_s16(vmull_s8(vget_high_s8(hi), vget_high_s8(xd.val[1]))));
11171122
#endif
1118-
int32_t s = vaddvq_s32(a0);
11191123
float xd_n = t->X_d[(size_t)n * n_blocks + b];
1120-
acc[n] += wd * xd_n * (float)s;
1124+
/* Match matmul_q4_rows exactly: vmlaq_n_f32 with combined scale.
1125+
* vcvtq_f32_s32(a0) on the int32 accumulator, scalar scale =
1126+
* wd * xd_n, accumulate into sumv[n]. */
1127+
sumv[n] = vmlaq_n_f32(sumv[n], vcvtq_f32_s32(a0), wd * xd_n);
11211128
}
1129+
}
1130+
1131+
for (int n = 0; n < N; n++) {
1132+
t->out[(size_t)n * n_rows + i] = vaddvq_f32(sumv[n]);
1133+
}
11221134
#else
1123-
/* Scalar fallback */
1135+
/* Scalar fallback (x86 / no NEON). */
1136+
float acc[256];
1137+
if (N > 256) continue;
1138+
memset(acc, 0, sizeof(float) * N);
1139+
for (int b = 0; b < n_blocks; b++) {
11241140
const float wd = si[b];
11251141
int8_t lo[32], hi[32];
11261142
for (int j = 0; j < 16; j++) {
@@ -1134,13 +1150,11 @@ static void* bm_q4_worker(void* arg) {
11341150
float xd_n = t->X_d[(size_t)n * n_blocks + b];
11351151
acc[n] += wd * xd_n * (float)s;
11361152
}
1137-
#endif
11381153
}
1139-
1140-
/* Scatter accumulator into output row */
11411154
for (int n = 0; n < N; n++) {
11421155
t->out[(size_t)n * n_rows + i] = acc[n];
11431156
}
1157+
#endif
11441158
}
11451159
return NULL;
11461160
}

0 commit comments

Comments
 (0)