Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 59 additions & 23 deletions JAXBench/benchmark/11p_Megablox_GMM/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
and multiply with that expert's weight matrix. Core primitive for MoE layers.
From JAX experimental pallas ops (reference_gmm).

Not jit-compatible: uses data-dependent slicing on group_sizes.
"""

import jax
Expand All @@ -21,8 +20,6 @@
'seq_len': 4096,
}

_skip_jit = True


def create_inputs(dtype=jnp.bfloat16):
key = jax.random.key(42)
Expand All @@ -38,28 +35,64 @@ def create_inputs(dtype=jnp.bfloat16):
lhs = lhs.astype(jnp.bfloat16).astype(dtype)
rhs = jax.random.uniform(k2, (G, K, N), dtype=dtype, minval=-limit, maxval=limit)
rhs = rhs.astype(jnp.bfloat16).astype(dtype)
tokens_per_expert = M // G
group_sizes = jnp.full((G,), tokens_per_expert, dtype=jnp.int32)
return lhs, rhs, group_sizes
max_expert_size = M // G
group_sizes = jnp.full((G,), max_expert_size, dtype=jnp.int32)
return lhs, rhs, group_sizes, max_expert_size


def workload(lhs, rhs, group_sizes):
"""Reference grouped matmul from upstream JAX tests.
def workload(lhs, rhs, group_sizes, max_expert_size):
"""Jittable grouped matmul using static shapes and masking.

For each group i, slices lhs[start:start+size] and computes dot with rhs[i].
Uses data-dependent slicing so must be run eagerly (not under jax.jit).
Computes dot product for each group with static slice sizes to allow JIT.
"""
start = 0
out = []
for i, size in enumerate(group_sizes):
result = jax.lax.dot(
lhs[start:start + size, :],
rhs[i, :, :],
preferred_element_type=jnp.float32,
G = rhs.shape[0]
M, K = lhs.shape
N = rhs.shape[2]

# Compute expert offsets
group_ends = jnp.cumsum(group_sizes)
group_starts = jnp.concatenate(
[jnp.zeros(1, dtype=jnp.int32), group_ends[:-1]]
)

# Initialize flat result array with padding
res_flat = jnp.zeros((M + max_expert_size, N), dtype=lhs.dtype)

def body_fun(carry_res_flat, i):
start = group_starts[i]
count = group_sizes[i]

# Slice with a STATIC size
expert_lhs = jax.lax.dynamic_slice(
lhs, (start, 0), (max_expert_size, K)
)
expert_rhs = rhs[i, :, :]

# Compute GEMM
res = jax.lax.dot(
expert_lhs, expert_rhs, preferred_element_type=jnp.float32
)
out.append(result)
start += group_sizes[i]
return jnp.concatenate(out, axis=0)

# Mask out invalid rows
mask = (
jax.lax.broadcasted_iota(jnp.int32, (max_expert_size, N), 0) < count
)
res_masked = jnp.where(mask, res, 0.0)

# Read-Modify-Write to accumulate results
current_slice = jax.lax.dynamic_slice(
carry_res_flat, (start, 0), (max_expert_size, N)
)
updated_slice = current_slice + res_masked.astype(carry_res_flat.dtype)
carry_res_flat = jax.lax.dynamic_update_slice(
carry_res_flat, updated_slice, (start, 0)
)

return carry_res_flat, None

res_flat, _ = jax.lax.scan(body_fun, res_flat, jnp.arange(G))

return res_flat[:M, :]


def get_flops():
Expand All @@ -73,17 +106,20 @@ def get_flops():


def benchmark(num_warmup=2, num_iters=10):
"""Benchmark eagerly (no JIT — data-dependent control flow)."""
"""Benchmark with JIT."""
import time
inputs = create_inputs()

fn = jax.jit(workload, static_argnums=(3,))

# Warmup
for _ in range(num_warmup):
out = workload(*inputs)
out = fn(*inputs)
out.block_until_ready()
times = []
for _ in range(num_iters):
t0 = time.perf_counter()
out = workload(*inputs)
out = fn(*inputs)
out.block_until_ready()
times.append(time.perf_counter() - t0)
import numpy as np
Expand Down
42 changes: 24 additions & 18 deletions JAXBench/benchmark/7p_Ragged_Paged_Attention/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
Processes each sequence independently with variable-length queries and paged KV cache.
From JAX experimental pallas ops (ref_ragged_paged_attention).

Not jit-compatible: uses data-dependent slicing.
"""

import math
Expand All @@ -27,8 +26,6 @@
'pages_per_seq': 256,
}

_skip_jit = True


def create_inputs(dtype=jnp.bfloat16):
key = jax.random.key(42)
Expand Down Expand Up @@ -58,10 +55,9 @@ def create_inputs(dtype=jnp.bfloat16):


def workload(queries, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs):
"""Reference ragged paged attention from upstream JAX.
"""Ragged paged attention using static shapes and masking for JIT compatibility.

Processes each sequence independently with data-dependent slicing.
Must be run eagerly (not under jax.jit).
Processes each sequence independently, avoiding data-dependent slicing.
"""
sm_scale = 1.0 / math.sqrt(CONFIG['head_dim'])
mask_value = DEFAULT_MASK_VALUE
Expand All @@ -70,17 +66,21 @@ def workload(queries, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs):
num_q_heads = queries.shape[1]
num_query_per_kv = num_q_heads // num_kv_heads

max_seqs = CONFIG['max_num_seqs']
tokens_per_seq = CONFIG['max_num_batched_tokens'] // max_seqs

outputs = []
for i in range(num_seqs[0]):
for i in range(max_seqs):
q_start = cu_q_lens[i]
q_end = cu_q_lens[i + 1]
q_len = q_end - q_start
kv_len = kv_lens[i]
indices = page_indices[i]

q = queries[q_start:q_end]
k = kv_pages[indices, :, 0::2, :].reshape(-1, num_kv_heads, head_dim)[:kv_len]
v = kv_pages[indices, :, 1::2, :].reshape(-1, num_kv_heads, head_dim)[:kv_len]
q = jax.lax.dynamic_slice(
queries, (q_start, 0, 0), (tokens_per_seq, num_q_heads, head_dim)
)

k = kv_pages[indices, :, 0::2, :].reshape(-1, num_kv_heads, head_dim)
v = kv_pages[indices, :, 1::2, :].reshape(-1, num_kv_heads, head_dim)

k = jnp.repeat(k, num_query_per_kv, axis=1)
v = jnp.repeat(v, num_query_per_kv, axis=1)
Expand All @@ -90,15 +90,20 @@ def workload(queries, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs):
)
attn *= sm_scale

q_span = (kv_len - q_len) + jax.lax.broadcasted_iota(
q_span = (kv_len - tokens_per_seq) + jax.lax.broadcasted_iota(
jnp.int32, attn.shape, 1
)
kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2)
mask = q_span < kv_span
attn += jnp.where(mask, mask_value, 0.0)

mask = (q_span < kv_span) | (kv_span >= kv_len)
attn = jnp.where(mask, mask_value, attn)

attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype)
out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype)

is_valid = i < num_seqs[0]
out = jnp.where(is_valid, out, 0.0)

outputs.append(out)

return jnp.concatenate(outputs, axis=0)
Expand All @@ -115,17 +120,18 @@ def get_flops():


def benchmark(num_warmup=2, num_iters=10):
"""Benchmark eagerly (no JIT — data-dependent control flow)."""
"""Benchmark with JIT."""
import time
inputs = create_inputs()
fn = jax.jit(workload)
# Warmup
for _ in range(num_warmup):
out = workload(*inputs)
out = fn(*inputs)
out.block_until_ready()
times = []
for _ in range(num_iters):
t0 = time.perf_counter()
out = workload(*inputs)
out = fn(*inputs)
out.block_until_ready()
times.append(time.perf_counter() - t0)
import numpy as np
Expand Down