Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions examples/models/llama/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,18 @@ fbcode_target(_kind = runtime.python_test,
],
)

fbcode_target(_kind = runtime.python_test,
name = "attention_sink_ring_buffer_test",
srcs = [
"source_transformation/test_attention_sink_ring_buffer.py",
],
supports_static_listing = False,
deps = [
"//caffe2:torch",
":export_library",
],
)

fbcode_target(_kind = runtime.python_test,
name = "quantized_sdpa_source_transform_test",
srcs = [
Expand Down
23 changes: 16 additions & 7 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def __init__(
[0, 1, 2, 3, 4, NA, NA, NA] After cache update we would have
[8, 1, 2, 3, 4, 5, 6, 7]. We kicked out token at pos = 0. However, the
current step still has access to [pos - sliding_window_size, pos] tokens.

To make sure we dont over attend, i.e. we dont have pos = 5
to attend to pos = 1, mask calculaton has to account for the sliding window
size.
Expand Down Expand Up @@ -486,21 +486,30 @@ def forward(

if self.use_kv_cache:
assert input_pos is not None
if self.enable_dynamic_shape:
is_ring = getattr(self.kv_cache, "is_ring_buffer", False)
if is_ring:
# Ring buffer models: positions can exceed max_context_len.
# The ring buffer handles wrapping via modular arithmetic.
# The causal mask is computed dynamically from cache_positions,
# so we don't use the pre-computed self.mask here.
k, v = self.kv_cache.update(input_pos, k, v)
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(
start_pos, seqlen
)
elif self.enable_dynamic_shape:
start_pos = input_pos[-1].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_context_len)
seq_length = q.size(2)
# pyre-ignore: Incompatible parameter type [6]
attn_mask = self.mask.narrow(0, start_pos, seq_length)
k, v = self.kv_cache.update(input_pos, k, v)
else:
# mask is always 2D
attn_mask = self.mask[input_pos]
k, v = self.kv_cache.update(input_pos, k, v)
if getattr(self.kv_cache, "is_ring_buffer", False):
attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(
input_pos[0].item(), seqlen
)
k, v = self.kv_cache.update(input_pos, k, v)
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask)
return self.wo(output), None

Expand Down
31 changes: 31 additions & 0 deletions examples/models/llama/config/llama_attention_sink.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
base:
metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'

model:
use_sdpa_with_kv_cache: True # Now supported! We set use_attention_mask=True on SDPACustom
use_kv_cache: True
dtype_override: fp32
enable_dynamic_shape: True
# Attention Sink: "sink_size,window_size,eviction_batch_size"
# sink_size=4: Keep first 4 tokens (e.g., BOS + system prompt)
# window_size=124: 滑动窗口大小
# eviction_batch_size=1: 每次驱逐 1 个 token
# KV cache size = sink_size + window_size * 2 = 4 + 124*2 = 252
use_attention_sink: "4,124,1"

export:
# max_context_length controls the RoPE frequency table size.
# It must be >= sink_size + window_size (128), but larger values are
# recommended to support generation beyond the sliding window.
# The model default (e.g., 8192 or 131072) is typically used if not specified.
# For testing, we use the model's default by not setting this explicitly.

quantization:
qmode: 8da4w
group_size: 128
embedding_quantize: 4,32

backend:
xnnpack:
enabled: True
extended_ops: True
29 changes: 29 additions & 0 deletions examples/models/llama/config/llama_attention_sink_noxnn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
base:
metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'

model:
use_sdpa_with_kv_cache: True # Now supported! We set use_attention_mask=True on SDPACustom
use_kv_cache: True
dtype_override: fp32
enable_dynamic_shape: True
# Attention Sink: "sink_size,window_size,eviction_batch_size"
# sink_size=4: Keep first 4 tokens (e.g., BOS + system prompt)
# window_size=124: 滑动窗口大小
# eviction_batch_size=1: 每次驱逐 1 个 token
# KV cache size = sink_size + window_size * 2 = 4 + 124*2 = 252
use_attention_sink: "4,124,1"

export:
max_seq_length: 252
max_context_length: 512

# Quantization enabled for this test
quantization:
qmode: 8da4w
group_size: 128
embedding_quantize: 4,32

# No XNNPACK for this test
backend:
xnnpack:
enabled: False
25 changes: 25 additions & 0 deletions examples/models/llama/config/llama_attention_sink_xnnpack.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
base:
metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'

model:
use_sdpa_with_kv_cache: True
use_kv_cache: True
dtype_override: fp32
enable_dynamic_shape: True
use_attention_sink: "4,124,1"

export:
max_seq_length: 252
max_context_length: 512

# No quantization
# quantization:
# qmode: 8da4w
# group_size: 128
# embedding_quantize: 4,32

# XNNPACK enabled
backend:
xnnpack:
enabled: True
extended_ops: True
41 changes: 41 additions & 0 deletions examples/models/llama/config/llama_runner_attention_sink.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
##
## Runner-side Attention Sink configuration
##
## This uses KVCacheWithAttentionSink (model-side) together with
## the runner's AttentionSinkIOManager for position bookkeeping.
##
## Key behavior:
## - Model has KVCacheWithAttentionSink which preserves sink tokens and
## uses a ring buffer for the sliding window (is_ring_buffer=True)
## - Runner's AttentionSinkIOManager tracks logical position and allows
## generation to continue past max_context_len
## - KV cache size = sink_size + window_size * 2 = 4 + 124*2 = 252
##

base:
metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'

model:
use_sdpa_with_kv_cache: True
use_kv_cache: True
dtype_override: fp32
enable_dynamic_shape: True
# Attention Sink: "sink_size,window_size,eviction_batch_size"
# sink_size=4, window_size=124, eviction_batch_size=1
# Max Context (Buffer) = 4 + 1 * 124 = 128
use_attention_sink: "4,124,1"

export:
# max_seq_length for single prefill chunk
max_context_length: 128
max_seq_length: 128

quantization:
qmode: 8da4w
group_size: 128
embedding_quantize: 4,32

backend:
xnnpack:
enabled: True
extended_ops: True
10 changes: 9 additions & 1 deletion examples/models/llama/eval_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,8 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse
Evaluate the model's perplexity when AttentionSink is enabled.

This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py

Updated for the ring-buffer based attention sink implementation.
"""
# Convert args to LlmConfig
from executorch.extension.llm.export.config.llm_config import LlmConfig
Expand All @@ -351,7 +353,13 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse
sink_size = int(attention_sink_params[0])
window_size = int(attention_sink_params[1])

assert llm_config.export.max_seq_length == sink_size + window_size
# For the ring buffer implementation, the cache size is sink_size + window_size * 2
# max_context_length should be >= sink_size + window_size (for RoPE frequencies)
# but can be larger to support extended generation
assert llm_config.export.max_context_length >= sink_size + window_size, (
f"max_context_length ({llm_config.export.max_context_length}) must be >= "
f"sink_size + window_size ({sink_size + window_size})"
)

device = "cuda" if torch.cuda.is_available() else "cpu"
manager: LLMEdgeManager = _prepare_for_llama_export(llm_config)
Expand Down
Loading
Loading