-
Notifications
You must be signed in to change notification settings - Fork 829
[TEST] Attention sink #17252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[TEST] Attention sink #17252
Changes from all commits
c954f27
34a937f
83f437a
97f9910
075d521
fe66e74
04fadb5
97ba715
e4ccab4
bbca983
fec5eb9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,31 @@ | ||||||||||
| base: | ||||||||||
| metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' | ||||||||||
|
|
||||||||||
| model: | ||||||||||
| use_sdpa_with_kv_cache: True # Now supported! We set use_attention_mask=True on SDPACustom | ||||||||||
| use_kv_cache: True | ||||||||||
| dtype_override: fp32 | ||||||||||
| enable_dynamic_shape: True | ||||||||||
| # Attention Sink: "sink_size,window_size,eviction_batch_size" | ||||||||||
| # sink_size=4: Keep first 4 tokens (e.g., BOS + system prompt) | ||||||||||
| # window_size=124: 滑动窗口大小 | ||||||||||
| # eviction_batch_size=1: 每次驱逐 1 个 token | ||||||||||
|
Comment on lines
+11
to
+12
|
||||||||||
| # window_size=124: 滑动窗口大小 | |
| # eviction_batch_size=1: 每次驱逐 1 个 token | |
| # window_size=124: Sliding window size | |
| # eviction_batch_size=1: Evict 1 token at a time |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -218,7 +218,22 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): | |||||||||||||||||||||||||||||
| window_size = int(attention_sink_params[1]) | ||||||||||||||||||||||||||||||
| eviction_batch_size = int(attention_sink_params[2]) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| assert self.llm_config.export.max_context_length == sink_size + window_size | ||||||||||||||||||||||||||||||
| # max_context_length must be >= sink_size + window_size to have enough RoPE frequencies | ||||||||||||||||||||||||||||||
| # A larger max_context_length is allowed (and recommended) to support generation beyond | ||||||||||||||||||||||||||||||
| # the sliding window size. | ||||||||||||||||||||||||||||||
| assert self.llm_config.export.max_context_length >= sink_size + window_size, ( | ||||||||||||||||||||||||||||||
| f"max_context_length ({self.llm_config.export.max_context_length}) must be >= " | ||||||||||||||||||||||||||||||
| f"sink_size + window_size ({sink_size + window_size})" | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # IMPORTANT: For attention sink, we need RoPE frequencies for all possible generation | ||||||||||||||||||||||||||||||
| # positions, not just the cache size. Override the model's max_context_len to use | ||||||||||||||||||||||||||||||
| # a larger value that supports extended generation. | ||||||||||||||||||||||||||||||
| # We use model_args.max_context_len which was set from export.max_context_length | ||||||||||||||||||||||||||||||
| # but for RoPE we need the full generation length capability. | ||||||||||||||||||||||||||||||
| # Use 131072 (128k) as default for Llama 3.2 models or the original model max if larger. | ||||||||||||||||||||||||||||||
| default_rope_length = max(131072, model_args.max_context_len) | ||||||||||||||||||||||||||||||
|
Comment on lines
+234
to
+235
|
||||||||||||||||||||||||||||||
| # Use 131072 (128k) as default for Llama 3.2 models or the original model max if larger. | |
| default_rope_length = max(131072, model_args.max_context_len) | |
| # Derive the RoPE max context length from configuration rather than using a fixed | |
| # default. Use the larger of export.max_context_length and the original | |
| # model_args.max_context_len. | |
| default_rope_length = max( | |
| self.llm_config.export.max_context_length, | |
| model_args.max_context_len, | |
| ) |
Copilot
AI
Feb 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardcoding a default RoPE length of 131072 (128k) at line 235 could be problematic for models that don't support such a large context or when memory is constrained. This value appears to be arbitrary and could cause out-of-memory issues during RoPE frequency table initialization. Consider either: (1) using a more reasonable default that aligns with common model capabilities, (2) making this configurable, or (3) deriving it from the model's actual capabilities. The comment mentions "Llama 3.2 models" specifically, but this code affects all models using attention sink.
| # positions, not just the cache size. Override the model's max_context_len to use | |
| # a larger value that supports extended generation. | |
| # We use model_args.max_context_len which was set from export.max_context_length | |
| # but for RoPE we need the full generation length capability. | |
| # Use 131072 (128k) as default for Llama 3.2 models or the original model max if larger. | |
| default_rope_length = max(131072, model_args.max_context_len) | |
| model_args.max_context_len = default_rope_length | |
| # positions, not just the cache size. We rely on model_args.max_context_len, which is | |
| # set from export.max_context_length and already validated above to be large enough. | |
| # Ensure it is at least sink_size + window_size without forcing an arbitrary larger default | |
| # that could cause excessive memory usage on models with smaller contexts. | |
| min_required_context = sink_size + window_size | |
| if model_args.max_context_len < min_required_context: | |
| model_args.max_context_len = min_required_context |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment contains Chinese characters ("滑动窗口大小" meaning "sliding window size" and "每次驱逐 1 个 token" meaning "evict 1 token each time"). Comments should be in English for consistency with the rest of the codebase and to ensure all team members can understand the documentation.