Conversation
Co-authored-by: Claude <noreply@anthropic.com>
This reverts commit 075d521.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17271
Note: Links to docs will display an error until the docs builds have been completed. ❌ 40 New Failures, 1 Pending, 3 Unrelated FailuresAs of commit 42e5e7a with merge base 1eb3f9d ( NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@larryliu0820 I think it's not totally adhd after 1000+ tokens |
There was a problem hiding this comment.
Pull request overview
This PR adds a ring-buffer–based Attention Sink implementation for Llama exports and attempts to wire up runner-side support so generation can continue beyond the KV-cache capacity while maintaining correct masking/positioning.
Changes:
- Implement ring-buffer Attention Sink KV-cache + masking in Llama source transformation and update attention execution to use dynamic masks for ring-buffer caches.
- Add a new runner IOManager (
AttentionSinkIOManager) and plumb it into runner build targets, plus update prompt-length handling in the text runner. - Add/export updates and new configs/tests for Attention Sink (including mutable buffer initialization for
cache_positions).
Reviewed changes
Copilot reviewed 22 out of 22 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
test_attention_sink.md |
Adds local build/export/run notes for attention sink testing. |
extension/llm/runner/text_llm_runner.cpp |
Changes prompt validation to use max_seq_len and adjusts max-new-token resolution behavior. |
extension/llm/runner/targets.bzl |
Links runner to the new attention-sink IOManager target. |
extension/llm/runner/llm_runner_helper.cpp |
Selects AttentionSinkIOManager based on module metadata keys. |
extension/llm/runner/io_manager/targets.bzl |
Adds Bazel target for attention_sink_io_manager. |
extension/llm/runner/io_manager/attention_sink_io_manager.h |
Introduces IOManager subclass for attention-sink / “infinite context” bookkeeping. |
extension/llm/runner/io_manager/attention_sink_io_manager.cpp |
Implements pass-through prefill/decode input preparation while tracking logical position. |
extension/llm/runner/constants.h |
Adds metadata keys for attention sink configuration. |
examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py |
Adds comprehensive ring-buffer attention sink unit tests. |
examples/models/llama/source_transformation/test_attention_sink.py |
Removes the previous attention sink test suite. |
examples/models/llama/source_transformation/sdpa.py |
Forces custom_sdpa to use start_pos=0 when an attention mask is provided. |
examples/models/llama/source_transformation/custom_kv_cache.py |
Prevents custom KV-cache replacement from clobbering KVCacheWithAttentionSink. |
examples/models/llama/source_transformation/attention_sink.py |
Reworks attention sink to a torch.export-compatible ring-buffer approach (masking, cache position tracking, cache update). |
examples/models/llama/model.py |
Loosens constraints and increases RoPE table length for attention sink generation. |
examples/models/llama/export_llama_lib.py |
Enables attention masks for custom SDPA when attention sink is enabled; initializes cache_positions mutable buffer. |
examples/models/llama/eval_llama_lib.py |
Updates attention-sink eval assumptions for ring-buffer/cache-size vs RoPE-length. |
examples/models/llama/config/*.yaml |
Adds new attention sink configs (runner-side + xnnpack/noxnn variants). |
examples/models/llama/attention.py |
Uses dynamic ring-buffer masking when KV cache exposes is_ring_buffer. |
examples/models/llama/BUCK |
Registers the new attention sink ring-buffer test. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ) | ||
| or bool(llm_config.model.use_attention_sink), | ||
| use_sdpa_with_kv_cache=llm_config.model.use_sdpa_with_kv_cache, | ||
| quantize_kv_cache=llm_config.model.quantize_kv_cache, |
There was a problem hiding this comment.
This block enables attention-mask mode for custom SDPA when attention sink is on, but attention-sink configs also enable KV-cache quantization in some cases. In the quantized path, SDPACustom is replaced by QuantizedSDPA, which (per sdpa.py) still doesn’t support attention masks and continues to use start_pos for causal masking. That means attention sink + quantized KV cache may still hit the same start_pos >= cache_size validation failure. Consider propagating the use_attention_mask flag into QuantizedSDPA and making its mask path ignore start_pos similarly (or preventing this combination).
| quantize_kv_cache=llm_config.model.quantize_kv_cache, | |
| # Quantized KV cache currently does not support attention-mask-based | |
| # custom SDPA (QuantizedSDPA still relies on start_pos for masking). | |
| # To avoid incompatible behavior, disable KV-cache quantization when | |
| # attention sink is enabled. | |
| quantize_kv_cache=( | |
| False | |
| if getattr(llm_config.model, "use_attention_sink", False) | |
| else llm_config.model.quantize_kv_cache | |
| ), |
| # Export model | ||
| Take a look at examples/models/llama/README.md | ||
|
|
||
| Check point is in ~/executorch/ |
There was a problem hiding this comment.
Spelling: “Check point” should be “Checkpoint”.
| Check point is in ~/executorch/ | |
| Checkpoint is in ~/executorch/ |
| // Get max_seq_len for single prefill chunk limit | ||
| int64_t max_seq_len = metadata_.at(kMaxSeqLen); | ||
| int64_t max_context_len = metadata_.at(kMaxContextLen); | ||
|
|
||
| ET_CHECK_OR_RETURN_ERROR( | ||
| num_prompt_tokens >= 1, | ||
| InvalidArgument, | ||
| "Expected at least 1 prompt token"); | ||
|
|
||
| ET_CHECK_OR_RETURN_ERROR( | ||
| num_prompt_tokens < max_context_len, | ||
| num_prompt_tokens <= max_seq_len, | ||
| InvalidArgument, | ||
| "num_prompt_tokens %d >= max_context_len %" PRId64 | ||
| ", Max seq length exceeded - please increase max seq len value in your export script", | ||
| "num_prompt_tokens %d > max_seq_len %" PRId64 | ||
| ", Single prefill chunk too large", | ||
| num_prompt_tokens, | ||
| max_context_len); | ||
| max_seq_len); |
There was a problem hiding this comment.
The new prompt-length validation only checks num_prompt_tokens <= max_seq_len and no longer accounts for pos_ / remaining context. If generate() is called multiple times without reset(), pos_ can be non-zero and text_prefiller_->prefill(prompt_tokens, pos_) may exceed the model’s usable context for non-ring-buffer models. Consider keeping the remaining-context check for non-ring-buffer models (e.g., num_prompt_tokens < (max_context_len - pos_)) while retaining the per-prefill-chunk max_seq_len check for ring/attention-sink exports.
| int max_new_tokens = | ||
| config.resolve_max_new_tokens(max_context_len, num_prompt_tokens); |
There was a problem hiding this comment.
resolve_max_new_tokens(max_context_len, num_prompt_tokens) currently ignores the existing pos_, so for non-ring-buffer models it can allow generating past the remaining context budget when pos_ > 0. Previously this was handled by reducing max_context_len by pos_. Consider passing max_context_len - pos_ for non-ring-buffer models (or extending resolve_max_new_tokens to take start_pos).
| int max_new_tokens = | |
| config.resolve_max_new_tokens(max_context_len, num_prompt_tokens); | |
| // For non-ring-buffer models starting at a non-zero position, only the | |
| // remaining context window (max_context_len - pos_) is available. | |
| int64_t remaining_context_len = max_context_len - pos_; | |
| if (remaining_context_len < 0) { | |
| remaining_context_len = 0; | |
| } | |
| int max_new_tokens = | |
| config.resolve_max_new_tokens(remaining_context_len, num_prompt_tokens); |
| // Get method names to check for attention sink metadata | ||
| auto method_names_result = module->method_names(); | ||
| if (method_names_result.error() != Error::Ok) { | ||
| ET_LOG(Error, "Failed reading method names for IOManager selection"); | ||
| return nullptr; | ||
| } | ||
| const auto& method_names = method_names_result.get(); | ||
|
|
||
| // Check if attention sink is enabled via metadata | ||
| bool use_attention_sink = false; | ||
| int64_t sink_size = 4; // Default values | ||
| int64_t window_size = 124; | ||
|
|
||
| if (method_names.count(kUseAttentionSink)) { | ||
| auto get_result = module->get(kUseAttentionSink); | ||
| use_attention_sink = get_result.get().toScalar().to<bool>(); | ||
| } |
There was a problem hiding this comment.
AttentionSinkIOManager selection relies on module methods named use_attention_sink, attention_sink_size, and attention_sink_window_size, but the Llama export metadata loader currently only emits get_max_seq_len, get_max_context_len, etc. (no attention-sink keys). As a result, this branch will never trigger and AttentionSinkIOManager won’t be used. Suggest exporting these metadata methods when llm_config.model.use_attention_sink is set (or using an existing reliable signal) so runner-side selection works.
| """ | ||
| Transform the model to be able to run inference with Attention Sink. | ||
| There mainly three steps: | ||
| - Replace Rope with RopeWithAttentionSink | ||
| - Replace Attention's KVCache with KVCacheWithAttentionSink, forward with attention_sink_forward | ||
| - Replace Attention's KVCache with KVCacheWithAttentionSink | ||
| - Replace Attention's forward with attention_sink_forward |
There was a problem hiding this comment.
The docstring for enable_attention_sink says it replaces Attention’s forward with attention_sink_forward, but _replace_attention explicitly does not replace forward anymore. Update the docstring (and/or remove attention_sink_forward if it’s intentionally unused) to match actual behavior.
| k_decode = torch.randn(1, self.n_heads, 1, self.head_dim) | ||
| v_decode = torch.randn(1, self.n_heads, 1, self.head_dim) | ||
| input_pos = torch.tensor([20 + i], dtype=torch.long) | ||
| k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode) |
There was a problem hiding this comment.
Variable k_out is not used.
| k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode) | |
| self.kv_cache.update(input_pos, k_decode, v_decode) |
| k_decode = torch.randn(1, self.n_heads, 1, self.head_dim) | ||
| v_decode = torch.randn(1, self.n_heads, 1, self.head_dim) | ||
| input_pos = torch.tensor([20 + i], dtype=torch.long) | ||
| k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode) |
There was a problem hiding this comment.
Variable v_out is not used.
| k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode) | |
| _, _ = self.kv_cache.update(input_pos, k_decode, v_decode) |
|
|
||
| def test_cache_positions_consistency(self): | ||
| """Test that cache positions remain consistent during generation.""" | ||
| cache_size = 32 |
There was a problem hiding this comment.
Variable cache_size is not used.
| CachePositionsManager, | ||
| KVCache, | ||
| RingKVCache, |
There was a problem hiding this comment.
Import of 'CachePositionsManager' is not used.
Import of 'RingKVCache' is not used.
| CachePositionsManager, | |
| KVCache, | |
| RingKVCache, | |
| KVCache, |
| logical_pos_, | ||
| is_cache_full() ? "true" : "false"); | ||
|
|
||
| // Pass through to model as-is. The model's KVCacheWithAttentionSink |
| k, | ||
| v, | ||
| input_pos[0].item(), | ||
| 0, # start_pos: not used when mask is provided |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 22 out of 22 changed files in this pull request and generated 6 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # IOManager expects these methods to exist returning int. | ||
| # By adding them to metadata, export_to_edge will generate constant methods. | ||
| metadata["get_sink_size"] = sink_params[0] | ||
| metadata["get_window_size"] = sink_params[1] |
There was a problem hiding this comment.
The metadata added for attention sink uses keys get_sink_size/get_window_size, but the runtime runner expects attention sink metadata under use_attention_sink/attention_sink_size/attention_sink_window_size (see extension/llm/runner/constants.h and llm_runner_helper.cpp). This mismatch prevents the runner from detecting attention sink models. Update the metadata keys (and add an explicit enable flag) to match what the runner reads, or adjust the runner to match these exported keys.
| # IOManager expects these methods to exist returning int. | |
| # By adding them to metadata, export_to_edge will generate constant methods. | |
| metadata["get_sink_size"] = sink_params[0] | |
| metadata["get_window_size"] = sink_params[1] | |
| # Runtime runner expects these metadata keys: | |
| # - "use_attention_sink": bool flag to enable attention sink | |
| # - "attention_sink_size": sink size (int) | |
| # - "attention_sink_window_size": window size (int) | |
| metadata["use_attention_sink"] = True | |
| metadata["attention_sink_size"] = sink_params[0] | |
| metadata["attention_sink_window_size"] = sink_params[1] |
| # Max Context (Buffer) = 4 + 1 * 124 = 128 | ||
| use_attention_sink: "4,124,1" | ||
|
|
||
| export: | ||
| # max_seq_length for single prefill chunk | ||
| max_context_len: 128 |
There was a problem hiding this comment.
Export config uses max_context_len, but LlmConfig’s ExportConfig field is max_context_length (see extension/llm/export/config/llm_config.py). As written, this config likely won’t set the intended context length. Also the comment on line 25 (“Max Context (Buffer) = 4 + 1 * 124 = 128”) doesn’t match the earlier cache-size note (sink_size + 2*window_size = 252) and is confusing—please correct it to reflect the actual meaning (RoPE table vs KV cache size).
| # Max Context (Buffer) = 4 + 1 * 124 = 128 | |
| use_attention_sink: "4,124,1" | |
| export: | |
| # max_seq_length for single prefill chunk | |
| max_context_len: 128 | |
| # RoPE/logical max context per step = sink_size + 1 * window_size = 4 + 1 * 124 = 128 tokens | |
| use_attention_sink: "4,124,1" | |
| export: | |
| # max_seq_length for single prefill chunk | |
| max_context_length: 128 |
| if indices is None: | ||
| # Calculate write indices | ||
| indices = self.cache_positions_manager.calculate_positions_and_update_indices( | ||
| input_pos, seq_len | ||
| ) |
There was a problem hiding this comment.
KVCacheWithAttentionSink.update() only updates cache_positions when it computes indices internally. When indices are provided (e.g., runner-supplied cache_indices), cache_positions is never updated, so attention masks derived from cache_positions can be stale/incorrect (and may remain at the -1 sentinel). Update cache_positions_manager.cache_positions for the provided indices as well (e.g., index_copy with orig positions derived from input_pos and seq_len).
| max_context_len = sink_size + window_size * 2 | ||
|
|
||
| # We update params.max_context_len to reflect the actual buffer size | ||
| # This ensures export captures the correct cache size in metadata | ||
| params.max_context_len = max_context_len |
There was a problem hiding this comment.
enable_attention_sink() overwrites params.max_context_len with the KV cache buffer size (sink_size + 2*window_size). Rope.get_freqs() enforces input_pos < params.max_context_len under dynamic shape, so this change will cap generation to the cache size and break the intended “logical position can exceed cache size” behavior (and also contradict the earlier logic in model.py that tries to enlarge max_context_len for RoPE). Consider keeping params.max_context_len as the RoPE table length and passing the cache size via a separate field/argument (or only adjusting metadata without shrinking RoPE capacity).
| max_context_len = sink_size + window_size * 2 | |
| # We update params.max_context_len to reflect the actual buffer size | |
| # This ensures export captures the correct cache size in metadata | |
| params.max_context_len = max_context_len | |
| # Default KV cache buffer size: sink tokens + sliding window on both sides | |
| max_context_len = sink_size + window_size * 2 |
| if (method_names.count(kUseAttentionSink)) { | ||
| auto get_result = module->get(kUseAttentionSink); | ||
| use_attention_sink = get_result.get().toScalar().to<bool>(); | ||
| } | ||
|
|
||
| if (use_attention_sink) { | ||
| // Get attention sink configuration from metadata | ||
| if (method_names.count(kAttentionSinkSize)) { | ||
| auto get_result = module->get(kAttentionSinkSize); | ||
| sink_size = get_result.get().toScalar().to<int64_t>(); | ||
| } | ||
| if (method_names.count(kAttentionSinkWindowSize)) { | ||
| auto get_result = module->get(kAttentionSinkWindowSize); | ||
| window_size = get_result.get().toScalar().to<int64_t>(); | ||
| } |
There was a problem hiding this comment.
IOManager selection is keyed off module methods named use_attention_sink / attention_sink_size / attention_sink_window_size, but the exporter code in this PR adds attention sink metadata as get_sink_size/get_window_size (and does not add use_attention_sink). As a result, use_attention_sink will remain false here and AttentionSinkIOManager will never be constructed. Align the exporter’s metadata keys with the runner constants (or update the runner to look for the exported keys), and include a clear boolean enable flag plus the numeric parameters.
|
|
||
| def test_cache_positions_consistency(self): | ||
| """Test that cache positions remain consistent during generation.""" | ||
| cache_size = 32 |
There was a problem hiding this comment.
Variable cache_size is not used.
| cache_size = 32 |
Example text: