Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions examples/models/llama/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,18 @@ fbcode_target(_kind = runtime.python_test,
],
)

fbcode_target(_kind = runtime.python_test,
name = "attention_sink_ring_buffer_test",
srcs = [
"source_transformation/test_attention_sink_ring_buffer.py",
],
supports_static_listing = False,
deps = [
"//caffe2:torch",
":export_library",
],
)

fbcode_target(_kind = runtime.python_test,
name = "quantized_sdpa_source_transform_test",
srcs = [
Expand Down
31 changes: 31 additions & 0 deletions examples/models/llama/config/llama_attention_sink.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
base:
metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'

model:
use_sdpa_with_kv_cache: True # Now supported! We set use_attention_mask=True on SDPACustom
use_kv_cache: True
dtype_override: fp32
enable_dynamic_shape: True
# Attention Sink: "sink_size,window_size,eviction_batch_size"
# sink_size=4: Keep first 4 tokens (e.g., BOS + system prompt)
# window_size=124: 滑动窗口大小
# eviction_batch_size=1: 每次驱逐 1 个 token
Comment on lines +11 to +12
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment contains Chinese characters ("滑动窗口大小" meaning "sliding window size" and "每次驱逐 1 个 token" meaning "evict 1 token each time"). Comments should be in English for consistency with the rest of the codebase and to ensure all team members can understand the documentation.

Suggested change
# window_size=124: 滑动窗口大小
# eviction_batch_size=1: 每次驱逐 1 token
# window_size=124: sliding window size
# eviction_batch_size=1: evict 1 token at a time

Copilot uses AI. Check for mistakes.
Comment on lines +11 to +12
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chinese characters are used in comments. Comments should be in English for consistency with the rest of the codebase. Please translate:

  • Line 11: "# window_size=124: 滑动窗口大小" should be "# window_size=124: Sliding window size"
  • Line 12: "# eviction_batch_size=1: 每次驱逐 1 个 token" should be "# eviction_batch_size=1: Evict 1 token at a time"
Suggested change
# window_size=124: 滑动窗口大小
# eviction_batch_size=1: 每次驱逐 1 token
# window_size=124: Sliding window size
# eviction_batch_size=1: Evict 1 token at a time

Copilot uses AI. Check for mistakes.
# KV cache size = sink_size + window_size * 2 = 4 + 124*2 = 252
use_attention_sink: "4,124,1"

export:
# max_context_length controls the RoPE frequency table size.
# It must be >= sink_size + window_size (128), but larger values are
# recommended to support generation beyond the sliding window.
# The model default (e.g., 8192 or 131072) is typically used if not specified.
# For testing, we use the model's default by not setting this explicitly.

quantization:
qmode: 8da4w
group_size: 128
embedding_quantize: 4,32

backend:
xnnpack:
enabled: True
extended_ops: True
10 changes: 9 additions & 1 deletion examples/models/llama/eval_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,8 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse
Evaluate the model's perplexity when AttentionSink is enabled.

This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py

Updated for the ring-buffer based attention sink implementation.
"""
# Convert args to LlmConfig
from executorch.extension.llm.export.config.llm_config import LlmConfig
Expand All @@ -351,7 +353,13 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse
sink_size = int(attention_sink_params[0])
window_size = int(attention_sink_params[1])

assert llm_config.export.max_seq_length == sink_size + window_size
# For the ring buffer implementation, the cache size is sink_size + window_size * 2
# max_context_length should be >= sink_size + window_size (for RoPE frequencies)
# but can be larger to support extended generation
assert llm_config.export.max_context_length >= sink_size + window_size, (
f"max_context_length ({llm_config.export.max_context_length}) must be >= "
f"sink_size + window_size ({sink_size + window_size})"
)

device = "cuda" if torch.cuda.is_available() else "cpu"
manager: LLMEdgeManager = _prepare_for_llama_export(llm_config)
Expand Down
17 changes: 16 additions & 1 deletion examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,22 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
window_size = int(attention_sink_params[1])
eviction_batch_size = int(attention_sink_params[2])

assert self.llm_config.export.max_context_length == sink_size + window_size
# max_context_length must be >= sink_size + window_size to have enough RoPE frequencies
# A larger max_context_length is allowed (and recommended) to support generation beyond
# the sliding window size.
assert self.llm_config.export.max_context_length >= sink_size + window_size, (
f"max_context_length ({self.llm_config.export.max_context_length}) must be >= "
f"sink_size + window_size ({sink_size + window_size})"
)

# IMPORTANT: For attention sink, we need RoPE frequencies for all possible generation
# positions, not just the cache size. Override the model's max_context_len to use
# a larger value that supports extended generation.
# We use model_args.max_context_len which was set from export.max_context_length
# but for RoPE we need the full generation length capability.
# Use 131072 (128k) as default for Llama 3.2 models or the original model max if larger.
default_rope_length = max(131072, model_args.max_context_len)
Comment on lines +234 to +235
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hardcoded value of 131072 (128k) seems arbitrary and may not be appropriate for all model configurations. This could cause excessive memory usage for RoPE frequency tables in models that don't need such large context. Consider making this configurable through the llm_config or deriving it from model-specific parameters rather than using a hardcoded default.

Suggested change
# Use 131072 (128k) as default for Llama 3.2 models or the original model max if larger.
default_rope_length = max(131072, model_args.max_context_len)
# Derive the RoPE max context length from configuration rather than using a fixed
# default. Use the larger of export.max_context_length and the original
# model_args.max_context_len.
default_rope_length = max(
self.llm_config.export.max_context_length,
model_args.max_context_len,
)

Copilot uses AI. Check for mistakes.
model_args.max_context_len = default_rope_length
Comment on lines +230 to +236
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoding a default RoPE length of 131072 (128k) at line 235 could be problematic for models that don't support such a large context or when memory is constrained. This value appears to be arbitrary and could cause out-of-memory issues during RoPE frequency table initialization. Consider either: (1) using a more reasonable default that aligns with common model capabilities, (2) making this configurable, or (3) deriving it from the model's actual capabilities. The comment mentions "Llama 3.2 models" specifically, but this code affects all models using attention sink.

Suggested change
# positions, not just the cache size. Override the model's max_context_len to use
# a larger value that supports extended generation.
# We use model_args.max_context_len which was set from export.max_context_length
# but for RoPE we need the full generation length capability.
# Use 131072 (128k) as default for Llama 3.2 models or the original model max if larger.
default_rope_length = max(131072, model_args.max_context_len)
model_args.max_context_len = default_rope_length
# positions, not just the cache size. We rely on model_args.max_context_len, which is
# set from export.max_context_length and already validated above to be large enough.
# Ensure it is at least sink_size + window_size without forcing an arbitrary larger default
# that could cause excessive memory usage on models with smaller contexts.
min_required_context = sink_size + window_size
if model_args.max_context_len < min_required_context:
model_args.max_context_len = min_required_context

Copilot uses AI. Check for mistakes.

self.model_ = enable_attention_sink(
module=self.model_,
Expand Down
Loading
Loading