@@ -3013,3 +3013,308 @@ float* tq_forward(tq_model_t* model, tq_state_t* s, int token, int pos) {
30133013 }
30143014 return s -> logits ;
30153015}
3016+
3017+ /* ============================================================
3018+ * Batched prefill — process N consecutive tokens in one call.
3019+ *
3020+ * Strategy: walk the layers exactly like tq_forward does, but with the
3021+ * matmul calls replaced by tq_batched_matmul_q4 over an [N, D] activation
3022+ * matrix. Per-token operations (RoPE, attention against history) are
3023+ * still sequential — this is fine because attention is small for short
3024+ * context; the matmul gains dominate. KV cache writes are per-token but
3025+ * to consecutive positions [pos_start..pos_start+N).
3026+ *
3027+ * Currently supported: standard Llama family with load-time Q4 conversion
3028+ * (use_q4_weights=1, separate Q/K/V/O + gate/up/down). Returns -1 for
3029+ * unsupported architectures so the caller falls back to a per-token loop.
3030+ *
3031+ * Output: KV cache populated for [pos_start..pos_start+N). Logits NOT
3032+ * computed (would require an extra batched lm_head matmul; the caller
3033+ * can still tq_forward(token[N-1], pos_start+N-1) for the final token).
3034+ * ============================================================ */
3035+ int tq_forward_batch (tq_model_t * model , tq_state_t * s ,
3036+ const int * tokens , int N , int pos_start ) {
3037+ if (N <= 0 ) return pos_start ;
3038+ tq_model_config_t * c = & model -> config ;
3039+
3040+ /* Architectural gating: only standard Llama for now. */
3041+ int dbg = (getenv ("TQ_DEBUG_PREFILL" ) != NULL );
3042+ if (!model -> use_q4_weights ) { if (dbg ) fprintf (stderr , "[batch] bail: !use_q4_weights\n" ); return -1 ; }
3043+ if (c -> is_moe || c -> is_gemma4 ) { if (dbg ) fprintf (stderr , "[batch] bail: moe/gemma4\n" ); return -1 ; }
3044+ if (c -> has_fused_qkv || c -> has_fused_up_gate ) { if (dbg ) fprintf (stderr , "[batch] bail: fused qkv/up\n" ); return -1 ; }
3045+ if (c -> n_kv_shared_layers > 0 ) { if (dbg ) fprintf (stderr , "[batch] bail: kv_shared\n" ); return -1 ; }
3046+ /* DeltaNet check */
3047+ for (int l = 0 ; l < c -> n_layers ; l ++ ) {
3048+ if (model -> layers [l ].delta_a_log ) { if (dbg ) fprintf (stderr , "[batch] bail: deltanet l=%d\n" , l ); return -1 ; }
3049+ }
3050+
3051+ int dim = c -> hidden_dim ;
3052+ int q_dim = c -> n_heads * c -> head_dim ;
3053+ int kv_dim = c -> n_kv_heads * c -> head_dim ;
3054+ int inter = c -> intermediate_dim ;
3055+
3056+ /* Allocate batch scratch (N×{dim, q_dim, kv_dim, inter} FP32).
3057+ * For Phi-3.5 N=32: 7 buffers × 32 × 8192 × 4 = 7 MB. Trivial. */
3058+ size_t bytes_x = (size_t )N * dim * sizeof (float );
3059+ size_t bytes_q = (size_t )N * q_dim * sizeof (float );
3060+ size_t bytes_kv = (size_t )N * kv_dim * sizeof (float );
3061+ size_t bytes_h = (size_t )N * inter * sizeof (float );
3062+
3063+ float * X = (float * )malloc (bytes_x );
3064+ float * Xres = (float * )malloc (bytes_x ); /* residual stream */
3065+ float * XBN = (float * )malloc (bytes_x ); /* normed for matmul */
3066+ float * QB = (float * )malloc (bytes_q );
3067+ float * KB = (float * )malloc (bytes_kv );
3068+ float * VB = (float * )malloc (bytes_kv );
3069+ float * OB = (float * )malloc (bytes_x );
3070+ float * GB = (float * )malloc (bytes_h );
3071+ float * UB = (float * )malloc (bytes_h );
3072+ if (!X || !Xres || !XBN || !QB || !KB || !VB || !OB || !GB || !UB ) {
3073+ free (X ); free (Xres ); free (XBN ); free (QB ); free (KB ); free (VB );
3074+ free (OB ); free (GB ); free (UB );
3075+ return -1 ;
3076+ }
3077+
3078+ /* Step 1: token embeddings (per-token). Mirror tq_forward's lookup.
3079+ * Three sources: BF16 mmap, GGUF on-demand dequant, or FP32 table. */
3080+ for (int n = 0 ; n < N ; n ++ ) {
3081+ int tok = tokens [n ];
3082+ float * dst = Xres + (size_t )n * dim ;
3083+ if (model -> embed_bf16 ) {
3084+ const uint16_t * src = model -> embed_bf16 + (size_t )tok * dim ;
3085+ for (int i = 0 ; i < dim ; i ++ ) {
3086+ uint32_t bits = ((uint32_t )src [i ]) << 16 ;
3087+ memcpy (& dst [i ], & bits , 4 );
3088+ }
3089+ } else if (model -> embed_gguf && !model -> token_embedding ) {
3090+ int block_elems = tq_ggml_type_blck (model -> embed_gguf_type );
3091+ int block_bytes = (int )tq_ggml_type_size (model -> embed_gguf_type );
3092+ int n_blocks = dim / block_elems ;
3093+ size_t row_bytes = (size_t )n_blocks * block_bytes ;
3094+ const uint8_t * row_ptr = (const uint8_t * )model -> embed_gguf + (size_t )tok * row_bytes ;
3095+ tq_dequant_row_gguf (model -> embed_gguf_type , row_ptr , dst , dim );
3096+ } else if (model -> token_embedding ) {
3097+ memcpy (dst , model -> token_embedding + (size_t )tok * dim ,
3098+ (size_t )dim * sizeof (float ));
3099+ } else {
3100+ if (dbg ) fprintf (stderr , "[batch] bail: no embed source (n=%d tok=%d)\n" , n , tok );
3101+ free (X ); free (Xres ); free (XBN ); free (QB ); free (KB ); free (VB );
3102+ free (OB ); free (GB ); free (UB );
3103+ return -1 ;
3104+ }
3105+ }
3106+
3107+ /* Per-layer KV cache stride (FP32 K and V). */
3108+ size_t kv_layer_stride = (size_t )c -> max_seq_len * (size_t )kv_dim ;
3109+
3110+ for (int l = 0 ; l < c -> n_layers ; l ++ ) {
3111+ tq_layer_weights_t * layer = & model -> layers [l ];
3112+
3113+ /* Required Q4 weights for this fast path. */
3114+ if (!layer -> wq_q4 || !layer -> wk_q4 || !layer -> wv_q4 || !layer -> wo_q4 ||
3115+ !layer -> w_gate_q4 || !layer -> w_up_q4 || !layer -> w_down_q4 ) {
3116+ if (dbg ) fprintf (stderr , "[batch] bail: layer %d missing q4 weights (wq=%p wk=%p wv=%p wo=%p g=%p u=%p d=%p)\n" ,
3117+ l , (void * )layer -> wq_q4 , (void * )layer -> wk_q4 , (void * )layer -> wv_q4 ,
3118+ (void * )layer -> wo_q4 , (void * )layer -> w_gate_q4 , (void * )layer -> w_up_q4 , (void * )layer -> w_down_q4 );
3119+ free (X ); free (Xres ); free (XBN ); free (QB ); free (KB ); free (VB );
3120+ free (OB ); free (GB ); free (UB );
3121+ return -1 ;
3122+ }
3123+
3124+ /* 1. attn RMSNorm (per-row) */
3125+ for (int n = 0 ; n < N ; n ++ ) {
3126+ tq_rmsnorm (XBN + (size_t )n * dim , Xres + (size_t )n * dim ,
3127+ layer -> attn_norm , dim , c -> rms_norm_eps );
3128+ }
3129+
3130+ /* 2. Q, K, V batched matmul */
3131+ tq_batched_matmul_q4 (QB , layer -> wq_q4 , layer -> wq_q4s , XBN , q_dim , dim , N , NULL );
3132+ tq_batched_matmul_q4 (KB , layer -> wk_q4 , layer -> wk_q4s , XBN , kv_dim , dim , N , NULL );
3133+ tq_batched_matmul_q4 (VB , layer -> wv_q4 , layer -> wv_q4s , XBN , kv_dim , dim , N , NULL );
3134+
3135+ /* 3. RoPE + KV cache write (per-token).
3136+ * Mirror tq_forward's RoPE selection: if model->rope_freqs is set
3137+ * (Llama 3.x learned RoPE scaling, 64 freq factors), apply per-pair
3138+ * factor; otherwise plain interleaved RoPE. */
3139+ for (int n = 0 ; n < N ; n ++ ) {
3140+ float * qn = QB + (size_t )n * q_dim ;
3141+ float * kn = KB + (size_t )n * kv_dim ;
3142+ int pos = pos_start + n ;
3143+ if (model -> rope_freqs && model -> rope_freqs_len > 0 ) {
3144+ int rope_pairs = c -> head_dim / 2 ;
3145+ if (rope_pairs > model -> rope_freqs_len ) rope_pairs = model -> rope_freqs_len ;
3146+ /* Llama 3 uses interleaved layout (a=2i, b=2i+1) */
3147+ for (int h = 0 ; h < c -> n_heads ; h ++ ) {
3148+ float * qh = qn + h * c -> head_dim ;
3149+ for (int i = 0 ; i < rope_pairs ; i ++ ) {
3150+ float base = 1.0f / powf (c -> rope_freq_base , 2.0f * i / (float )c -> head_dim );
3151+ float freq = base / model -> rope_freqs [i ];
3152+ float theta = pos * freq ;
3153+ float ct = cosf (theta ), st = sinf (theta );
3154+ float q0 = qh [2 * i ], q1 = qh [2 * i + 1 ];
3155+ qh [2 * i ] = q0 * ct - q1 * st ;
3156+ qh [2 * i + 1 ] = q0 * st + q1 * ct ;
3157+ }
3158+ }
3159+ for (int h = 0 ; h < c -> n_kv_heads ; h ++ ) {
3160+ float * kh = kn + h * c -> head_dim ;
3161+ for (int i = 0 ; i < rope_pairs ; i ++ ) {
3162+ float base = 1.0f / powf (c -> rope_freq_base , 2.0f * i / (float )c -> head_dim );
3163+ float freq = base / model -> rope_freqs [i ];
3164+ float theta = pos * freq ;
3165+ float ct = cosf (theta ), st = sinf (theta );
3166+ float k0 = kh [2 * i ], k1 = kh [2 * i + 1 ];
3167+ kh [2 * i ] = k0 * ct - k1 * st ;
3168+ kh [2 * i + 1 ] = k0 * st + k1 * ct ;
3169+ }
3170+ }
3171+ } else {
3172+ tq_rope (qn , kn , pos , c -> head_dim , c -> n_heads , c -> n_kv_heads ,
3173+ c -> rope_freq_base );
3174+ }
3175+ /* Write to cache */
3176+ memcpy (s -> key_cache + (size_t )l * kv_layer_stride + (size_t )pos * kv_dim ,
3177+ kn , (size_t )kv_dim * sizeof (float ));
3178+ if (s -> value_cache ) {
3179+ memcpy (s -> value_cache + (size_t )l * kv_layer_stride + (size_t )pos * kv_dim ,
3180+ VB + (size_t )n * kv_dim , (size_t )kv_dim * sizeof (float ));
3181+ } else if (s -> value_cache_fp16 ) {
3182+ /* FP32 → FP16 conversion for storage. */
3183+ uint16_t * dst = s -> value_cache_fp16
3184+ + (size_t )l * kv_layer_stride + (size_t )pos * kv_dim ;
3185+ const float * src = VB + (size_t )n * kv_dim ;
3186+ for (int i = 0 ; i < kv_dim ; i ++ ) {
3187+ /* Use round-to-nearest IEEE 754 binary16 conversion via union */
3188+ union { float f ; uint32_t u ; } v = { .f = src [i ] };
3189+ uint32_t b = v .u ;
3190+ uint16_t sign = (b >> 16 ) & 0x8000 ;
3191+ int32_t e = (int32_t )((b >> 23 ) & 0xff ) - 127 + 15 ;
3192+ uint32_t m = b & 0x7fffff ;
3193+ uint16_t out ;
3194+ if (e <= 0 ) {
3195+ if (e < -10 ) out = sign ;
3196+ else {
3197+ m = (m | 0x800000 ) >> (1 - e );
3198+ if (m & 0x1000 ) m += 0x2000 ;
3199+ out = sign | (uint16_t )(m >> 13 );
3200+ }
3201+ } else if (e >= 31 ) {
3202+ out = sign | 0x7c00 | (m ? (uint16_t )(m >> 13 ) : 0 );
3203+ } else {
3204+ if (m & 0x1000 ) {
3205+ m += 0x2000 ;
3206+ if (m & 0x800000 ) { m = 0 ; e ++ ; }
3207+ }
3208+ out = sign | ((uint16_t )e << 10 ) | (uint16_t )(m >> 13 );
3209+ }
3210+ dst [i ] = out ;
3211+ }
3212+ } else {
3213+ if (dbg ) fprintf (stderr , "[batch] bail: no FP32/FP16 V cache\n" );
3214+ free (X ); free (Xres ); free (XBN ); free (QB ); free (KB ); free (VB );
3215+ free (OB ); free (GB ); free (UB );
3216+ return -1 ;
3217+ }
3218+ }
3219+
3220+ /* 4. Attention (per-token, sequential — needs all preceding KV). */
3221+ int n_kv_heads = c -> n_kv_heads ;
3222+ int head_dim = c -> head_dim ;
3223+ int n_heads = c -> n_heads ;
3224+ int kv_mul = n_heads / n_kv_heads ;
3225+ float * K_layer = s -> key_cache + (size_t )l * kv_layer_stride ;
3226+ float * V_layer = s -> value_cache ? (s -> value_cache + (size_t )l * kv_layer_stride ) : NULL ;
3227+ uint16_t * V_layer_fp16 = s -> value_cache_fp16
3228+ ? (s -> value_cache_fp16 + (size_t )l * kv_layer_stride ) : NULL ;
3229+
3230+ for (int n = 0 ; n < N ; n ++ ) {
3231+ int pos = pos_start + n ;
3232+ float * qn = QB + (size_t )n * q_dim ;
3233+ float * on = OB + (size_t )n * dim ;
3234+ for (int h = 0 ; h < n_heads ; h ++ ) {
3235+ int kvh = h / kv_mul ;
3236+ float * qh = qn + h * head_dim ;
3237+ float * att = s -> att + (size_t )h * c -> max_seq_len ;
3238+ float scale = 1.0f / sqrtf ((float )head_dim );
3239+ for (int t = 0 ; t <= pos ; t ++ ) {
3240+ float * kh = K_layer + (size_t )t * kv_dim + kvh * head_dim ;
3241+ float score = 0.0f ;
3242+ for (int i = 0 ; i < head_dim ; i ++ ) score += qh [i ] * kh [i ];
3243+ att [t ] = score * scale ;
3244+ }
3245+ tq_softmax (att , pos + 1 );
3246+ float * oh = on + h * head_dim ;
3247+ memset (oh , 0 , (size_t )head_dim * sizeof (float ));
3248+ if (V_layer ) {
3249+ for (int t = 0 ; t <= pos ; t ++ ) {
3250+ float * vh = V_layer + (size_t )t * kv_dim + kvh * head_dim ;
3251+ float w = att [t ];
3252+ for (int i = 0 ; i < head_dim ; i ++ ) oh [i ] += w * vh [i ];
3253+ }
3254+ } else {
3255+ /* FP16 V cache: dequant per element via shift. */
3256+ for (int t = 0 ; t <= pos ; t ++ ) {
3257+ uint16_t * vh = V_layer_fp16 + (size_t )t * kv_dim + kvh * head_dim ;
3258+ float w = att [t ];
3259+ for (int i = 0 ; i < head_dim ; i ++ ) {
3260+ uint16_t h16 = vh [i ];
3261+ uint32_t sign = (uint32_t )(h16 >> 15 ) << 31 ;
3262+ uint32_t exp = (h16 >> 10 ) & 0x1f ;
3263+ uint32_t mant = h16 & 0x3ff ;
3264+ uint32_t bits ;
3265+ if (exp == 0 ) {
3266+ if (mant == 0 ) bits = sign ;
3267+ else {
3268+ /* subnormal */
3269+ while (!(mant & 0x400 )) { mant <<= 1 ; exp -- ; }
3270+ mant &= 0x3ff ;
3271+ bits = sign | ((exp + 127 - 15 + 1 ) << 23 ) | (mant << 13 );
3272+ }
3273+ } else if (exp == 31 ) {
3274+ bits = sign | 0x7f800000u | (mant << 13 );
3275+ } else {
3276+ bits = sign | ((exp + 127 - 15 ) << 23 ) | (mant << 13 );
3277+ }
3278+ float vf ;
3279+ memcpy (& vf , & bits , 4 );
3280+ oh [i ] += w * vf ;
3281+ }
3282+ }
3283+ }
3284+ }
3285+ }
3286+
3287+ /* 5. O matmul batched */
3288+ tq_batched_matmul_q4 (X , layer -> wo_q4 , layer -> wo_q4s , OB , dim , q_dim , N , NULL );
3289+
3290+ /* 6. Residual: Xres += X */
3291+ for (size_t i = 0 ; i < (size_t )N * dim ; i ++ ) Xres [i ] += X [i ];
3292+
3293+ /* 7. ffn_norm */
3294+ for (int n = 0 ; n < N ; n ++ ) {
3295+ tq_rmsnorm (XBN + (size_t )n * dim , Xres + (size_t )n * dim ,
3296+ layer -> ffn_norm , dim , c -> rms_norm_eps );
3297+ }
3298+
3299+ /* 8. gate, up batched matmul */
3300+ tq_batched_matmul_q4 (GB , layer -> w_gate_q4 , layer -> w_gate_q4s , XBN , inter , dim , N , NULL );
3301+ tq_batched_matmul_q4 (UB , layer -> w_up_q4 , layer -> w_up_q4s , XBN , inter , dim , N , NULL );
3302+
3303+ /* 9. SiLU(gate) * up (per-element) */
3304+ for (size_t i = 0 ; i < (size_t )N * inter ; i ++ ) {
3305+ float g = GB [i ];
3306+ float silu = g / (1.0f + expf (- g ));
3307+ GB [i ] = silu * UB [i ];
3308+ }
3309+
3310+ /* 10. down matmul batched (output back into X) */
3311+ tq_batched_matmul_q4 (X , layer -> w_down_q4 , layer -> w_down_q4s , GB , dim , inter , N , NULL );
3312+
3313+ /* 11. Residual: Xres += X */
3314+ for (size_t i = 0 ; i < (size_t )N * dim ; i ++ ) Xres [i ] += X [i ];
3315+ }
3316+
3317+ free (X ); free (XBN ); free (QB ); free (KB ); free (VB ); free (OB ); free (GB ); free (UB );
3318+ free (Xres );
3319+ return pos_start + N ;
3320+ }
0 commit comments