diff --git a/JAXBench/benchmark/11p_Megablox_GMM/baseline.py b/JAXBench/benchmark/11p_Megablox_GMM/baseline.py index 676c0ff..f0649b3 100644 --- a/JAXBench/benchmark/11p_Megablox_GMM/baseline.py +++ b/JAXBench/benchmark/11p_Megablox_GMM/baseline.py @@ -4,7 +4,6 @@ and multiply with that expert's weight matrix. Core primitive for MoE layers. From JAX experimental pallas ops (reference_gmm). -Not jit-compatible: uses data-dependent slicing on group_sizes. """ import jax @@ -21,8 +20,6 @@ 'seq_len': 4096, } -_skip_jit = True - def create_inputs(dtype=jnp.bfloat16): key = jax.random.key(42) @@ -38,28 +35,64 @@ def create_inputs(dtype=jnp.bfloat16): lhs = lhs.astype(jnp.bfloat16).astype(dtype) rhs = jax.random.uniform(k2, (G, K, N), dtype=dtype, minval=-limit, maxval=limit) rhs = rhs.astype(jnp.bfloat16).astype(dtype) - tokens_per_expert = M // G - group_sizes = jnp.full((G,), tokens_per_expert, dtype=jnp.int32) - return lhs, rhs, group_sizes + max_expert_size = M // G + group_sizes = jnp.full((G,), max_expert_size, dtype=jnp.int32) + return lhs, rhs, group_sizes, max_expert_size -def workload(lhs, rhs, group_sizes): - """Reference grouped matmul from upstream JAX tests. +def workload(lhs, rhs, group_sizes, max_expert_size): + """Jittable grouped matmul using static shapes and masking. - For each group i, slices lhs[start:start+size] and computes dot with rhs[i]. - Uses data-dependent slicing so must be run eagerly (not under jax.jit). + Computes dot product for each group with static slice sizes to allow JIT. """ - start = 0 - out = [] - for i, size in enumerate(group_sizes): - result = jax.lax.dot( - lhs[start:start + size, :], - rhs[i, :, :], - preferred_element_type=jnp.float32, + G = rhs.shape[0] + M, K = lhs.shape + N = rhs.shape[2] + + # Compute expert offsets + group_ends = jnp.cumsum(group_sizes) + group_starts = jnp.concatenate( + [jnp.zeros(1, dtype=jnp.int32), group_ends[:-1]] + ) + + # Initialize flat result array with padding + res_flat = jnp.zeros((M + max_expert_size, N), dtype=lhs.dtype) + + def body_fun(carry_res_flat, i): + start = group_starts[i] + count = group_sizes[i] + + # Slice with a STATIC size + expert_lhs = jax.lax.dynamic_slice( + lhs, (start, 0), (max_expert_size, K) + ) + expert_rhs = rhs[i, :, :] + + # Compute GEMM + res = jax.lax.dot( + expert_lhs, expert_rhs, preferred_element_type=jnp.float32 ) - out.append(result) - start += group_sizes[i] - return jnp.concatenate(out, axis=0) + + # Mask out invalid rows + mask = ( + jax.lax.broadcasted_iota(jnp.int32, (max_expert_size, N), 0) < count + ) + res_masked = jnp.where(mask, res, 0.0) + + # Read-Modify-Write to accumulate results + current_slice = jax.lax.dynamic_slice( + carry_res_flat, (start, 0), (max_expert_size, N) + ) + updated_slice = current_slice + res_masked.astype(carry_res_flat.dtype) + carry_res_flat = jax.lax.dynamic_update_slice( + carry_res_flat, updated_slice, (start, 0) + ) + + return carry_res_flat, None + + res_flat, _ = jax.lax.scan(body_fun, res_flat, jnp.arange(G)) + + return res_flat[:M, :] def get_flops(): @@ -73,17 +106,20 @@ def get_flops(): def benchmark(num_warmup=2, num_iters=10): - """Benchmark eagerly (no JIT — data-dependent control flow).""" + """Benchmark with JIT.""" import time inputs = create_inputs() + + fn = jax.jit(workload, static_argnums=(3,)) + # Warmup for _ in range(num_warmup): - out = workload(*inputs) + out = fn(*inputs) out.block_until_ready() times = [] for _ in range(num_iters): t0 = time.perf_counter() - out = workload(*inputs) + out = fn(*inputs) out.block_until_ready() times.append(time.perf_counter() - t0) import numpy as np diff --git a/JAXBench/benchmark/7p_Ragged_Paged_Attention/baseline.py b/JAXBench/benchmark/7p_Ragged_Paged_Attention/baseline.py index 659e136..7566644 100644 --- a/JAXBench/benchmark/7p_Ragged_Paged_Attention/baseline.py +++ b/JAXBench/benchmark/7p_Ragged_Paged_Attention/baseline.py @@ -4,7 +4,6 @@ Processes each sequence independently with variable-length queries and paged KV cache. From JAX experimental pallas ops (ref_ragged_paged_attention). -Not jit-compatible: uses data-dependent slicing. """ import math @@ -27,8 +26,6 @@ 'pages_per_seq': 256, } -_skip_jit = True - def create_inputs(dtype=jnp.bfloat16): key = jax.random.key(42) @@ -58,10 +55,9 @@ def create_inputs(dtype=jnp.bfloat16): def workload(queries, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs): - """Reference ragged paged attention from upstream JAX. + """Ragged paged attention using static shapes and masking for JIT compatibility. - Processes each sequence independently with data-dependent slicing. - Must be run eagerly (not under jax.jit). + Processes each sequence independently, avoiding data-dependent slicing. """ sm_scale = 1.0 / math.sqrt(CONFIG['head_dim']) mask_value = DEFAULT_MASK_VALUE @@ -70,17 +66,21 @@ def workload(queries, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs): num_q_heads = queries.shape[1] num_query_per_kv = num_q_heads // num_kv_heads + max_seqs = CONFIG['max_num_seqs'] + tokens_per_seq = CONFIG['max_num_batched_tokens'] // max_seqs + outputs = [] - for i in range(num_seqs[0]): + for i in range(max_seqs): q_start = cu_q_lens[i] - q_end = cu_q_lens[i + 1] - q_len = q_end - q_start kv_len = kv_lens[i] indices = page_indices[i] - q = queries[q_start:q_end] - k = kv_pages[indices, :, 0::2, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] - v = kv_pages[indices, :, 1::2, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] + q = jax.lax.dynamic_slice( + queries, (q_start, 0, 0), (tokens_per_seq, num_q_heads, head_dim) + ) + + k = kv_pages[indices, :, 0::2, :].reshape(-1, num_kv_heads, head_dim) + v = kv_pages[indices, :, 1::2, :].reshape(-1, num_kv_heads, head_dim) k = jnp.repeat(k, num_query_per_kv, axis=1) v = jnp.repeat(v, num_query_per_kv, axis=1) @@ -90,15 +90,20 @@ def workload(queries, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs): ) attn *= sm_scale - q_span = (kv_len - q_len) + jax.lax.broadcasted_iota( + q_span = (kv_len - tokens_per_seq) + jax.lax.broadcasted_iota( jnp.int32, attn.shape, 1 ) kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2) - mask = q_span < kv_span - attn += jnp.where(mask, mask_value, 0.0) + + mask = (q_span < kv_span) | (kv_span >= kv_len) + attn = jnp.where(mask, mask_value, attn) attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype) out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype) + + is_valid = i < num_seqs[0] + out = jnp.where(is_valid, out, 0.0) + outputs.append(out) return jnp.concatenate(outputs, axis=0) @@ -115,17 +120,18 @@ def get_flops(): def benchmark(num_warmup=2, num_iters=10): - """Benchmark eagerly (no JIT — data-dependent control flow).""" + """Benchmark with JIT.""" import time inputs = create_inputs() + fn = jax.jit(workload) # Warmup for _ in range(num_warmup): - out = workload(*inputs) + out = fn(*inputs) out.block_until_ready() times = [] for _ in range(num_iters): t0 = time.perf_counter() - out = workload(*inputs) + out = fn(*inputs) out.block_until_ready() times.append(time.perf_counter() - t0) import numpy as np