Skip to content

[TEST] Attention sink#17252

Open
kirklandsign wants to merge 11 commits intomainfrom
attention-sink
Open

[TEST] Attention sink#17252
kirklandsign wants to merge 11 commits intomainfrom
attention-sink

Conversation

@kirklandsign
Copy link
Contributor

Summary

Test plan

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 5, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17252

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures, 27 Cancelled Jobs, 1 Pending, 4 Unrelated Failures

As of commit bbca983 with merge base 1eb3f9d (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOBS - The following jobs were cancelled. Please retry:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 5, 2026
@github-actions
Copy link

github-actions bot commented Feb 5, 2026

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

kirklandsign and others added 3 commits February 5, 2026 12:12
Co-authored-by: Claude <noreply@anthropic.com>
@kirklandsign kirklandsign marked this pull request as ready for review February 5, 2026 21:51
Copilot AI review requested due to automatic review settings February 5, 2026 21:51
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request implements a ring-buffer-based attention sink mechanism for LLM inference, enabling models to generate beyond their fixed context window by maintaining a sliding window of recent tokens plus a fixed set of "sink" tokens (typically the initial prompt tokens).

Changes:

  • Rewrites attention_sink.py to use a torch.export-compatible ring buffer approach instead of explicit token eviction
  • Updates C++ text runner to support sliding window validation and extended generation
  • Adds comprehensive unit tests for the ring buffer implementation
  • Updates evaluation code to support the new ring buffer approach
  • Adds configuration for attention sink models

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
examples/models/llama/source_transformation/attention_sink.py Complete rewrite to ring-buffer implementation with torch.export compatibility
examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py New comprehensive test suite for ring buffer attention sink
extension/llm/runner/text_llm_runner.cpp Updates validation logic for sliding window and adds TEMPORARY unlimited generation workaround
examples/models/llama/eval_llama_lib.py Updates assertions for ring buffer compatibility
examples/models/llama/source_transformation/custom_kv_cache.py Adds skip logic for KVCacheWithAttentionSink to prevent incorrect replacement
examples/models/llama/model.py Updates max_context_len validation and sets default RoPE length to 131072
examples/models/llama/config/llama_attention_sink.yaml New configuration file for attention sink models
examples/models/llama/BUCK Adds test target for attention sink ring buffer tests

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

}

} // namespace executorch::extension::llm
/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated */// A simple llama2 runner that includes preprocessing and post processing logic.// The module takes in a string as input and emits a string as output.#include <executorch/extension/llm/runner/io_manager/io_manager.h>#include <executorch/extension/llm/runner/text_llm_runner.h>#include <executorch/extension/llm/runner/util.h>#include <executorch/runtime/platform/runtime.h>#include <pytorch/tokenizers/hf_tokenizer.h>#include <pytorch/tokenizers/llama2c_tokenizer.h>#include <pytorch/tokenizers/sentencepiece.h>#include <pytorch/tokenizers/tiktoken.h>namespace executorch::extension::llm {using ::executorch::extension::Module;using ::executorch::runtime::Error;using ::executorch::runtime::Result;TextLLMRunner::TextLLMRunner( std::unordered_map<std::string, int64_t> metadata, std::unique_ptr<::tokenizers::Tokenizer> tokenizer, std::unique_ptr<::executorch::extension::Module> module, std::unique_ptr<TextDecoderRunner> text_decoder_runner, std::unique_ptr<TextPrefiller> text_prefiller, std::unique_ptr<IOManager> io_manager, std::unique_ptr<TextTokenGenerator> text_token_generator, std::unique_ptr<Stats> stats, float temperature) : tokenizer_(std::move(tokenizer)), metadata_(std::move(metadata)), module_(std::move(module)), text_decoder_runner_(std::move(text_decoder_runner)), text_prefiller_(std::move(text_prefiller)), io_manager_(std::move(io_manager)), text_token_generator_(std::move(text_token_generator)), stats_(std::move(stats)), temperature_(temperature), pos_(0) { // Note: This constructor assumes that text_prefiller and text_token_generator // already have references to the Module and TextDecoderRunner they need}bool TextLLMRunner::is_loaded() const { return text_prefiller_->is_loaded() && text_token_generator_->is_loaded();}Error TextLLMRunner::load() { if (is_loaded()) { return Error::Ok; } ET_CHECK_OK_OR_RETURN_ERROR(text_prefiller_->load()); ET_CHECK_OK_OR_RETURN_ERROR(io_manager_->load()); ET_CHECK_OK_OR_RETURN_ERROR(text_token_generator_->load()); return Error::Ok;}// Don't print with the same priority during warmup#define RUNNER_ET_LOG(warmup, format, ...) \ if (warmup) { \ ET_LOG(Debug, format, __VA_ARGS__); \ } else { \ ET_LOG(Info, format, __VA_ARGS__); \ }Error TextLLMRunner::generate( const std::string& prompt, const GenerationConfig& config, std::function<void(const std::string&)> token_callback, std::function<void(const Stats&)> stats_callback) { // Prepare the inputs. // Use ones-initialized inputs. ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); if (!is_loaded()) { stats_->model_load_start_ms = time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(load()); stats_->model_load_end_ms = time_in_ms(); } if (config.warming) { ET_LOG(Info, "Doing a warmup run..."); } RUNNER_ET_LOG( config.warming, "RSS after loading model: %f MiB (0 if unsupported)", get_rss_bytes() / 1024.0 / 1024.0); // Wrap the token_callback with print function std::function<void(const std::string&)> wrapped_callback = [token_callback, config](const std::string& piece) { if (!config.warming) { llm::safe_printf(piece.c_str()); fflush(stdout); } if (token_callback) { token_callback(piece); } }; // First token time only measures the time it takes to encode the prompt and // return a response token. stats_->inference_start_ms = time_in_ms(); shouldStop_ = false; ::tokenizers::Result<std::vector<uint64_t>> encode_res = tokenizer_->encode( prompt, /*bos=*/config.num_bos, /*eos=*/config.num_eos); if (!encode_res.ok()) { ET_LOG( Error, "Failed to encode prompt %s. Tokenizers error code %d", prompt.c_str(), static_cast<uint32_t>(encode_res.error())); return Error::InvalidArgument; } // encode the (string) prompt into tokens sequence std::vector<uint64_t> prompt_tokens = encode_res.get(); int num_prompt_tokens = prompt_tokens.size(); // Get max_seq_len for single prefill chunk limit int64_t max_seq_len = metadata_.at(kMaxSeqLen); int64_t max_context_len = metadata_.at(kMaxContextLen); ET_CHECK_OR_RETURN_ERROR( num_prompt_tokens >= 1, InvalidArgument, "Expected at least 1 prompt token"); // For models with sliding window (Ring Buffer / Attention Sink), // we allow pos_ to exceed max_context_len. The model handles this // internally via ring buffer indexing or token eviction. // We only check that a single prefill chunk doesn't exceed max_seq_len. ET_CHECK_OR_RETURN_ERROR( num_prompt_tokens <= max_seq_len, InvalidArgument, "num_prompt_tokens %d > max_seq_len %" PRId64 ", Single prefill chunk too large - please reduce prompt size or increase max_seq_len", num_prompt_tokens, max_seq_len); // Determine max_new_tokens using the GenerationConfig's resolve method. // For sliding window models, we use max_context_len directly (not reduced by pos_) // because the model handles position wrapping internally via ring buffer. int max_new_tokens = config.resolve_max_new_tokens(max_context_len, num_prompt_tokens); // TEMPORARY: For sliding window / infinite context testing, // override max_new_tokens to allow unlimited generation if (max_new_tokens <= 0) { max_new_tokens = 1000000; // Effectively unlimited } // If user specified seq_len, use that instead if (config.seq_len > 0 && config.seq_len > max_new_tokens) { max_new_tokens = config.seq_len; } ET_LOG( Info, "Max new tokens resolved: %d, given pos_ %" PRId64 ", num_prompt_tokens %zu, max_context_len %" PRId64, max_new_tokens, pos_, prompt_tokens.size(), max_context_len); ET_CHECK_OR_RETURN_ERROR( max_new_tokens > 0, InvalidArgument, "Max new tokens %d is less than or equal to 0", max_new_tokens); // Prefill first // Here feed all tokens to the model and get the next predicted token // after the prompt. After that we will enter generate loop. // print prompts if (config.echo) { wrapped_callback(prompt); } auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); uint64_t cur_token = prefill_res.get(); stats_->first_token_ms = time_in_ms(); stats_->prompt_eval_end_ms = time_in_ms(); // print the first token from prefill. No prev_token so use cur_token for it. auto decode_result = tokenizer_->decode(cur_token, cur_token); if (!decode_result.ok()) { ET_LOG( Error, "Tokenizers error code %d", static_cast<uint32_t>(decode_result.error())); return ::executorch::runtime::Error::InvalidArgument; } wrapped_callback(std::move(*decode_result)); RUNNER_ET_LOG( config.warming, "RSS after prompt prefill: %f MiB (0 if unsupported)", get_rss_bytes() / 1024.0 / 1024.0); // start the main loop prompt_tokens.push_back(cur_token); // Set ignore_eos based on config text_token_generator_->set_ignore_eos(config.ignore_eos); // Generate max_new_tokens - 1 because prefill already generated 1 token. auto generate_result = text_token_generator_->generate( prompt_tokens, pos_, max_new_tokens - 1, temperature_ == -1.0f ? config.temperature : temperature_, wrapped_callback); if (!generate_result.ok()) { return generate_result.error(); } int64_t num_generated_tokens = generate_result.get(); pos_ += num_generated_tokens; stats_->inference_end_ms = time_in_ms(); if (!config.warming) { printf("\n"); } RUNNER_ET_LOG( config.warming, "RSS after finishing text generation: %f MiB (0 if unsupported)", get_rss_bytes() / 1024.0 / 1024.0); if (num_generated_tokens == max_new_tokens) { RUNNER_ET_LOG(config.warming, "Max new tokens %i reached!", max_new_tokens); } stats_->num_prompt_tokens = num_prompt_tokens; stats_->num_generated_tokens = num_generated_tokens; if (config.warming) { ET_LOG(Info, "Warmup run finished!"); } else { // Do not print report during warmup print_report(*stats_); } if (stats_callback) { stats_callback(*stats_); } return Error::Ok;}Error TextLLMRunner::prefill( const std::string& prompt, const GenerationConfig& config) { if (!is_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(load()); } ::tokenizers::Result<std::vector<uint64_t>> encode_res = tokenizer_->encode( prompt, /*bos=*/config.num_bos, /*eos=*/config.num_eos); ET_CHECK_TK_OK_OR_RETURN_ERROR( encode_res.error(), "Failed to encode prompt %s", prompt.c_str()); // encode the (string) prompt into tokens sequence std::vector<uint64_t> prompt_tokens = encode_res.get(); auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); return Error::Ok;}Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) { // Create a GenerationConfig for warmup GenerationConfig config; config.echo = false; config.max_new_tokens = max_new_tokens; config.warming = true; // Call generate with the warmup config Error err = generate(prompt, config); // Reset stats after warmup, not resetting the std::unique_ptr! reset(); return err;}void TextLLMRunner::stop() { if (is_loaded()) { text_token_generator_->stop(); } else { ET_LOG(Error, "Token generator is not loaded, cannot stop"); }}void TextLLMRunner::reset() { stats_->reset(); pos_ = 0;}} // namespace executorch::extension::llm No newline at end of file
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The entire file content has been compressed onto a single line, removing all newlines. This makes the code completely unreadable and unmaintainable. The file should be properly formatted with appropriate line breaks, indentation, and spacing according to C++ coding standards.

Copilot uses AI. Check for mistakes.
Comment on lines +11 to +12
# window_size=124: 滑动窗口大小
# eviction_batch_size=1: 每次驱逐 1 个 token
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment contains Chinese characters ("滑动窗口大小" meaning "sliding window size" and "每次驱逐 1 个 token" meaning "evict 1 token each time"). Comments should be in English for consistency with the rest of the codebase and to ensure all team members can understand the documentation.

Suggested change
# window_size=124: 滑动窗口大小
# eviction_batch_size=1: 每次驱逐 1 token
# window_size=124: sliding window size
# eviction_batch_size=1: evict 1 token at a time

Copilot uses AI. Check for mistakes.
}

} // namespace executorch::extension::llm
/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated */// A simple llama2 runner that includes preprocessing and post processing logic.// The module takes in a string as input and emits a string as output.#include <executorch/extension/llm/runner/io_manager/io_manager.h>#include <executorch/extension/llm/runner/text_llm_runner.h>#include <executorch/extension/llm/runner/util.h>#include <executorch/runtime/platform/runtime.h>#include <pytorch/tokenizers/hf_tokenizer.h>#include <pytorch/tokenizers/llama2c_tokenizer.h>#include <pytorch/tokenizers/sentencepiece.h>#include <pytorch/tokenizers/tiktoken.h>namespace executorch::extension::llm {using ::executorch::extension::Module;using ::executorch::runtime::Error;using ::executorch::runtime::Result;TextLLMRunner::TextLLMRunner( std::unordered_map<std::string, int64_t> metadata, std::unique_ptr<::tokenizers::Tokenizer> tokenizer, std::unique_ptr<::executorch::extension::Module> module, std::unique_ptr<TextDecoderRunner> text_decoder_runner, std::unique_ptr<TextPrefiller> text_prefiller, std::unique_ptr<IOManager> io_manager, std::unique_ptr<TextTokenGenerator> text_token_generator, std::unique_ptr<Stats> stats, float temperature) : tokenizer_(std::move(tokenizer)), metadata_(std::move(metadata)), module_(std::move(module)), text_decoder_runner_(std::move(text_decoder_runner)), text_prefiller_(std::move(text_prefiller)), io_manager_(std::move(io_manager)), text_token_generator_(std::move(text_token_generator)), stats_(std::move(stats)), temperature_(temperature), pos_(0) { // Note: This constructor assumes that text_prefiller and text_token_generator // already have references to the Module and TextDecoderRunner they need}bool TextLLMRunner::is_loaded() const { return text_prefiller_->is_loaded() && text_token_generator_->is_loaded();}Error TextLLMRunner::load() { if (is_loaded()) { return Error::Ok; } ET_CHECK_OK_OR_RETURN_ERROR(text_prefiller_->load()); ET_CHECK_OK_OR_RETURN_ERROR(io_manager_->load()); ET_CHECK_OK_OR_RETURN_ERROR(text_token_generator_->load()); return Error::Ok;}// Don't print with the same priority during warmup#define RUNNER_ET_LOG(warmup, format, ...) \ if (warmup) { \ ET_LOG(Debug, format, __VA_ARGS__); \ } else { \ ET_LOG(Info, format, __VA_ARGS__); \ }Error TextLLMRunner::generate( const std::string& prompt, const GenerationConfig& config, std::function<void(const std::string&)> token_callback, std::function<void(const Stats&)> stats_callback) { // Prepare the inputs. // Use ones-initialized inputs. ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); if (!is_loaded()) { stats_->model_load_start_ms = time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(load()); stats_->model_load_end_ms = time_in_ms(); } if (config.warming) { ET_LOG(Info, "Doing a warmup run..."); } RUNNER_ET_LOG( config.warming, "RSS after loading model: %f MiB (0 if unsupported)", get_rss_bytes() / 1024.0 / 1024.0); // Wrap the token_callback with print function std::function<void(const std::string&)> wrapped_callback = [token_callback, config](const std::string& piece) { if (!config.warming) { llm::safe_printf(piece.c_str()); fflush(stdout); } if (token_callback) { token_callback(piece); } }; // First token time only measures the time it takes to encode the prompt and // return a response token. stats_->inference_start_ms = time_in_ms(); shouldStop_ = false; ::tokenizers::Result<std::vector<uint64_t>> encode_res = tokenizer_->encode( prompt, /*bos=*/config.num_bos, /*eos=*/config.num_eos); if (!encode_res.ok()) { ET_LOG( Error, "Failed to encode prompt %s. Tokenizers error code %d", prompt.c_str(), static_cast<uint32_t>(encode_res.error())); return Error::InvalidArgument; } // encode the (string) prompt into tokens sequence std::vector<uint64_t> prompt_tokens = encode_res.get(); int num_prompt_tokens = prompt_tokens.size(); // Get max_seq_len for single prefill chunk limit int64_t max_seq_len = metadata_.at(kMaxSeqLen); int64_t max_context_len = metadata_.at(kMaxContextLen); ET_CHECK_OR_RETURN_ERROR( num_prompt_tokens >= 1, InvalidArgument, "Expected at least 1 prompt token"); // For models with sliding window (Ring Buffer / Attention Sink), // we allow pos_ to exceed max_context_len. The model handles this // internally via ring buffer indexing or token eviction. // We only check that a single prefill chunk doesn't exceed max_seq_len. ET_CHECK_OR_RETURN_ERROR( num_prompt_tokens <= max_seq_len, InvalidArgument, "num_prompt_tokens %d > max_seq_len %" PRId64 ", Single prefill chunk too large - please reduce prompt size or increase max_seq_len", num_prompt_tokens, max_seq_len); // Determine max_new_tokens using the GenerationConfig's resolve method. // For sliding window models, we use max_context_len directly (not reduced by pos_) // because the model handles position wrapping internally via ring buffer. int max_new_tokens = config.resolve_max_new_tokens(max_context_len, num_prompt_tokens); // TEMPORARY: For sliding window / infinite context testing, // override max_new_tokens to allow unlimited generation if (max_new_tokens <= 0) { max_new_tokens = 1000000; // Effectively unlimited } // If user specified seq_len, use that instead if (config.seq_len > 0 && config.seq_len > max_new_tokens) { max_new_tokens = config.seq_len; } ET_LOG( Info, "Max new tokens resolved: %d, given pos_ %" PRId64 ", num_prompt_tokens %zu, max_context_len %" PRId64, max_new_tokens, pos_, prompt_tokens.size(), max_context_len); ET_CHECK_OR_RETURN_ERROR( max_new_tokens > 0, InvalidArgument, "Max new tokens %d is less than or equal to 0", max_new_tokens); // Prefill first // Here feed all tokens to the model and get the next predicted token // after the prompt. After that we will enter generate loop. // print prompts if (config.echo) { wrapped_callback(prompt); } auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); uint64_t cur_token = prefill_res.get(); stats_->first_token_ms = time_in_ms(); stats_->prompt_eval_end_ms = time_in_ms(); // print the first token from prefill. No prev_token so use cur_token for it. auto decode_result = tokenizer_->decode(cur_token, cur_token); if (!decode_result.ok()) { ET_LOG( Error, "Tokenizers error code %d", static_cast<uint32_t>(decode_result.error())); return ::executorch::runtime::Error::InvalidArgument; } wrapped_callback(std::move(*decode_result)); RUNNER_ET_LOG( config.warming, "RSS after prompt prefill: %f MiB (0 if unsupported)", get_rss_bytes() / 1024.0 / 1024.0); // start the main loop prompt_tokens.push_back(cur_token); // Set ignore_eos based on config text_token_generator_->set_ignore_eos(config.ignore_eos); // Generate max_new_tokens - 1 because prefill already generated 1 token. auto generate_result = text_token_generator_->generate( prompt_tokens, pos_, max_new_tokens - 1, temperature_ == -1.0f ? config.temperature : temperature_, wrapped_callback); if (!generate_result.ok()) { return generate_result.error(); } int64_t num_generated_tokens = generate_result.get(); pos_ += num_generated_tokens; stats_->inference_end_ms = time_in_ms(); if (!config.warming) { printf("\n"); } RUNNER_ET_LOG( config.warming, "RSS after finishing text generation: %f MiB (0 if unsupported)", get_rss_bytes() / 1024.0 / 1024.0); if (num_generated_tokens == max_new_tokens) { RUNNER_ET_LOG(config.warming, "Max new tokens %i reached!", max_new_tokens); } stats_->num_prompt_tokens = num_prompt_tokens; stats_->num_generated_tokens = num_generated_tokens; if (config.warming) { ET_LOG(Info, "Warmup run finished!"); } else { // Do not print report during warmup print_report(*stats_); } if (stats_callback) { stats_callback(*stats_); } return Error::Ok;}Error TextLLMRunner::prefill( const std::string& prompt, const GenerationConfig& config) { if (!is_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(load()); } ::tokenizers::Result<std::vector<uint64_t>> encode_res = tokenizer_->encode( prompt, /*bos=*/config.num_bos, /*eos=*/config.num_eos); ET_CHECK_TK_OK_OR_RETURN_ERROR( encode_res.error(), "Failed to encode prompt %s", prompt.c_str()); // encode the (string) prompt into tokens sequence std::vector<uint64_t> prompt_tokens = encode_res.get(); auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); return Error::Ok;}Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) { // Create a GenerationConfig for warmup GenerationConfig config; config.echo = false; config.max_new_tokens = max_new_tokens; config.warming = true; // Call generate with the warmup config Error err = generate(prompt, config); // Reset stats after warmup, not resetting the std::unique_ptr! reset(); return err;}void TextLLMRunner::stop() { if (is_loaded()) { text_token_generator_->stop(); } else { ET_LOG(Error, "Token generator is not loaded, cannot stop"); }}void TextLLMRunner::reset() { stats_->reset(); pos_ = 0;}} // namespace executorch::extension::llm No newline at end of file
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TEMPORARY comment indicates this is test code that should not be merged to production. Setting max_new_tokens to 1000000 when it's <= 0 bypasses important validation logic and could lead to excessive memory usage or infinite generation loops. This workaround should either be removed or properly documented with a plan for resolution before merging.

Copilot uses AI. Check for mistakes.
eviction_batch_size=eviction_batch_size,
)
return module
# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.# Components for supporting Attention Sink. See# https://arxiv.org/abs/2309.17453 for more details about Attention Sink.# This implementation is torch.export compatible using a ring buffer approach# for the sliding window portion while preserving the sink tokens.import typesfrom typing import Optional, Tupleimport torchimport torch.nn as nnfrom executorch.examples.models.llama.attention import ( _create_causal_mask_for_ring_buffer, AttentionMHA, CachePositionsManager, KVCache, RingKVCache,)from executorch.examples.models.llama.model_args import ModelArgsfrom executorch.examples.models.llama.rope import ( apply_rotary_emb_to_k, hf_apply_rotary_emb_to_k, Rope,)from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filterclass RopeWithAttentionSink(Rope): """ Rope that helps adjust position encoding when tokens are shifted in KVCache. For AttentionSink, when tokens are shifted in KVCache, we need to use positions in KVCache instead of positions in the actual text. For torch.export compatibility, this just passes through the position - the actual position adjustment is handled by the cache update logic. Note: This class uses the model's max_context_len (params.max_context_len) for RoPE frequency table size, which should be large enough to support generation beyond the sliding window. The actual KV cache size is sink_size + window_size * 2. """ def __init__( self, params: ModelArgs, window_size: int, sink_size: int, eviction_batch_size: int, ): super().__init__(params) if self.params.use_hf_rope: self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k else: self.apply_rotary_emb_to_k = apply_rotary_emb_to_k # The KV cache size is sink_size + window_size * 2 (ring buffer needs 2x) self.kv_cache_size = sink_size + window_size * 2 self.window_size = window_size self.sink_size = sink_size # max_context_len from params is used for RoPE frequencies (should be large) self.max_context_length = self.params.max_context_len self.eviction_batch_size = eviction_batch_size def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): """ Get rotary embedding frequencies. For attention sink, we use the original position - the sliding window is handled by the cache index management, not by position shifting. """ assert input_pos is not None return super().get_freqs(input_pos, seq_len) def rerotate_k( self, k: torch.Tensor, original_position: int, new_position: int, ): """ Rerotate k from original_position to new_position. The shape of k is (batch_size, seq_len, n_local_heads, head_dim) """ seq_len = k.shape[1] original_freqs_cos = self.freqs_cos.narrow(0, original_position, seq_len) original_freqs_sin = self.freqs_sin.narrow(0, original_position, seq_len) new_freqs_cos = self.freqs_cos.narrow(0, new_position, seq_len) new_freqs_sin = self.freqs_sin.narrow(0, new_position, seq_len) rerotation_cos = ( new_freqs_cos * original_freqs_cos + new_freqs_sin * original_freqs_sin ) rerotation_sin = ( new_freqs_sin * original_freqs_cos - new_freqs_cos * original_freqs_sin ) return self.apply_rotary_emb_to_k(k, rerotation_cos, rerotation_sin)def _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len): """ Create causal mask for attention sink. Unlike regular ring buffer mask, this mask: 1. ALWAYS allows attending to sink tokens (positions 0 to sink_size-1) 2. Uses sliding window for other tokens Args: cache_positions: Tensor of actual positions stored at each cache index window_size: Size of the sliding window sink_size: Number of sink tokens to always attend to start_pos: Starting position of the current query seq_len: Length of the current query sequence """ pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1) delta = pos_q - cache_positions # Valid if position is filled (>= 0) and causal (delta >= 0) is_valid = (cache_positions >= 0) & (delta >= 0) # Sink tokens (original positions 0 to sink_size-1) are always visible is_sink = cache_positions < sink_size # Window tokens must be within sliding window is_in_window = delta < window_size # Final mask: valid AND (is_sink OR is_in_window) attn_mask = is_valid & (is_sink | is_in_window) attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712 return attn_maskclass CachePositionsManagerWithSink(nn.Module): """ Manages cache positions for attention sink + sliding window. For sink_size=0: behaves exactly like original CachePositionsManager. For sink_size>0: sink tokens go to fixed positions, rest uses ring buffer. IMPORTANT: cache_size should be the actual cache dimension size (2x window for ring buffer). """ def __init__(self, cache_size: int): super().__init__() # cache_size is the actual size of the kv cache dimension self.max_context_length = cache_size # Use zeros like original CachePositionsManager self.register_buffer( "cache_positions", torch.zeros((self.max_context_length,), dtype=torch.long, device="cpu"), ) def calculate_positions_and_update_indices( self, input_pos: torch.Tensor, seq_len: int ) -> torch.Tensor: """ Calculate indices into k_cache, v_cache for placing k_val, v_val. This is identical to the original CachePositionsManager logic. """ start_pos = input_pos[0].item() torch._check_is_size(start_pos) orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos # Simple ring buffer: just mod by cache size indices = orig_indices % self.max_context_length # Update cache_positions exactly like original CachePositionsManager full_t = torch.full((self.max_context_length,), -1, dtype=torch.long) arange_tensor = torch.arange(self.max_context_length, dtype=torch.long) cache_positions = torch.where( arange_tensor < start_pos, self.cache_positions, full_t ) self.cache_positions.copy_(cache_positions) self.cache_positions.index_copy_(0, indices, orig_indices) return indicesclass KVCacheWithAttentionSink(KVCache): """ KV cache that supports attention sink with torch.export compatibility. Uses a ring buffer approach for the sliding window portion while keeping the first sink_size tokens fixed. This avoids dynamic shape operations. Cache layout: [sink: 0 to sink_size-1] [ring_buffer: sink_size to sink_size + window_size*2 - 1] """ def __init__( self, n_heads: int, head_dim: int, enable_dynamic_shape: bool, rope: RopeWithAttentionSink, window_size: int, sink_size: int, eviction_batch_size: int, max_batch_size: int = 1, dtype=torch.float32, ): # Total cache size is sink_size + window_size * 2 (ring buffer needs 2x) total_cache_size = sink_size + window_size * 2 super().__init__( max_batch_size=max_batch_size, max_context_length=total_cache_size, n_heads=n_heads, head_dim=head_dim, enable_dynamic_shape=enable_dynamic_shape, dtype=dtype, ) self.rope = rope self.window_size = window_size self.sink_size = sink_size self.eviction_batch_size = eviction_batch_size self.is_ring_buffer = True # Cache positions manager for determining write locations # Pass the total cache size (same as self.max_context_length after super().__init__) self.cache_positions_manager = CachePositionsManagerWithSink(total_cache_size) def create_causal_mask_for_ring_buffer( self, start_pos: torch.Tensor, seq_len: int ): """ Create causal mask for the attention with attention sink. Sink tokens are ALWAYS visible, plus recent tokens in the window. """ cache_positions = self.cache_positions_manager.cache_positions if self.sink_size > 0: # Use attention sink mask that always allows attending to sink tokens return _create_causal_mask_for_attention_sink( cache_positions, self.window_size, self.sink_size, start_pos, seq_len ) else: # Pure ring buffer mode - use original mask with window_size = actual window return _create_causal_mask_for_ring_buffer( cache_positions, self.window_size, start_pos, seq_len ) def update( self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Update KV cache with new key-value pairs. Uses ring buffer indexing for positions >= sink_size. """ seq_len = k_val.size(2) assert seq_len <= self.k_cache.size( 2 ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" # Calculate write indices indices = self.cache_positions_manager.calculate_positions_and_update_indices( input_pos, seq_len ) start_pos = input_pos[0].item() torch._check_is_size(start_pos) self.k_cache.index_copy_(2, indices, k_val) self.v_cache.index_copy_(2, indices, v_val) return self.k_cache, self.v_cache def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: """ For ring buffer implementation, no explicit eviction is needed. The ring buffer automatically overwrites old values. Returns 0 to indicate no position shift is needed. """ return 0def attention_sink_forward( self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, input_pos: Optional[torch.Tensor] = None,): """ Forward function for attention with attention sink KV cache. Uses ring buffer masking for proper attention patterns. """ assert self.use_kv_cache assert input_pos is not None bsz, seqlen, _ = x.shape # QKV q, k, v = self.wq(x), self.wk(x), self.wv(x) q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) # RoPE relative positional embeddings q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) # Transpose for attention: [B, H, S, D] q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) # Update KV cache k, v = self.kv_cache.update(input_pos, k, v) # Use ring buffer mask since we have is_ring_buffer=True start_pos = input_pos[0].item() torch._check_is_size(start_pos) attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(start_pos, seqlen) # SDPA output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) # Return tuple like original AttentionMHA.forward return self.wo(output), Nonedef _replace_rope( module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink): def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: return isinstance(child, Rope) def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: return rope_with_attention_sink _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)def _replace_attention( module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink, sink_size: int, window_size: int, eviction_batch_size: int,): for _, child_module in module._modules.items(): if len(list(child_module.children())) > 0: # pyre-ignore [16] _replace_attention( module=child_module, # pyre-ignore [6] rope_with_attention_sink=rope_with_attention_sink, sink_size=sink_size, window_size=window_size, eviction_batch_size=eviction_batch_size, ) if isinstance(child_module, AttentionMHA): kv_cache = child_module.kv_cache # Always use KVCacheWithAttentionSink, even for sink_size=0 # This ensures we don't get replaced by CustomKVCache when use_sdpa_with_kv_cache=True kv_cache_with_attention_sink = KVCacheWithAttentionSink( n_heads=kv_cache.n_heads, head_dim=kv_cache.head_dim, enable_dynamic_shape=kv_cache.enable_dynamic_shape, rope=rope_with_attention_sink, max_batch_size=kv_cache.max_batch_size, window_size=window_size, sink_size=sink_size, eviction_batch_size=eviction_batch_size, dtype=kv_cache.k_cache.dtype, ) child_module.kv_cache = kv_cache_with_attention_sink # If using SDPACustom (fused SDPA op), enable attention mask support # so it uses our ring buffer / attention sink mask instead of simple causal mask if "SDPACustom" in child_module.SDPA.__class__.__name__: child_module.SDPA.use_attention_mask = True # Don't replace forward - let the original AttentionMHA.forward handle it # since our KVCache has is_ring_buffer=True, it will use the ring buffer maskdef enable_attention_sink( module: torch.nn.Module, params: ModelArgs, sink_size: int, window_size: int, eviction_batch_size: int,) -> torch.nn.Module: """ Transform the model to be able to run inference with Attention Sink. There mainly three steps: - Replace Rope with RopeWithAttentionSink - Replace Attention's KVCache with KVCacheWithAttentionSink - Replace Attention's forward with attention_sink_forward """ rope_with_attention_sink = RopeWithAttentionSink( params=params, window_size=window_size, sink_size=sink_size, eviction_batch_size=eviction_batch_size, ) _replace_rope(module, rope_with_attention_sink) _replace_attention( module=module, rope_with_attention_sink=rope_with_attention_sink, sink_size=sink_size, window_size=window_size, eviction_batch_size=eviction_batch_size, ) return module No newline at end of file
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR title "[TEST] Attention sink" and empty PR description suggest this is experimental or work-in-progress code. The presence of TEMPORARY workarounds in the C++ code and significant implementation changes without proper documentation indicate this PR may not be ready for production merge. Consider adding a proper description explaining the purpose, approach, testing strategy, and any known limitations before merging.

Copilot uses AI. Check for mistakes.
@@ -0,0 +1 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree."""Tests for the ring-buffer based attention sink implementation.This tests the torch.export-compatible implementation that uses a ring bufferfor the sliding window rather than explicit token eviction.Usage: # Run with pytest python -m pytest examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py -v # Or run directly python examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py"""import unittestimport torchfrom executorch.examples.models.llama.model_args import ModelArgsfrom executorch.examples.models.llama.source_transformation.attention_sink import ( CachePositionsManagerWithSink, KVCacheWithAttentionSink, RopeWithAttentionSink, _create_causal_mask_for_attention_sink,)class CachePositionsManagerWithSinkTest(unittest.TestCase): """Test the cache positions manager for ring buffer indexing.""" def setUp(self): self.cache_size = 32 # Total cache size (e.g., sink_size + window_size * 2) self.manager = CachePositionsManagerWithSink(self.cache_size) def test_initial_positions_are_zero(self): """Cache positions should start as zeros.""" expected = torch.zeros(self.cache_size, dtype=torch.long) torch.testing.assert_close(self.manager.cache_positions, expected) def test_simple_update(self): """Test simple sequential update without wraparound.""" input_pos = torch.tensor([0], dtype=torch.long) seq_len = 5 indices = self.manager.calculate_positions_and_update_indices(input_pos, seq_len) # Should return indices 0, 1, 2, 3, 4 expected_indices = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long) torch.testing.assert_close(indices, expected_indices) # Cache positions at those indices should be the original positions for i in range(5): self.assertEqual(self.manager.cache_positions[i].item(), i) def test_wraparound(self): """Test ring buffer wraparound.""" # Fill cache to position 30 input_pos = torch.tensor([0], dtype=torch.long) self.manager.calculate_positions_and_update_indices(input_pos, 30) # Add 5 more tokens at position 30 - should wrap around input_pos = torch.tensor([30], dtype=torch.long) indices = self.manager.calculate_positions_and_update_indices(input_pos, 5) # Indices should wrap: 30 % 32 = 30, 31, 0, 1, 2 expected_indices = torch.tensor([30, 31, 0, 1, 2], dtype=torch.long) torch.testing.assert_close(indices, expected_indices) def test_cache_positions_track_original_positions(self): """Cache positions should track which original position is at each index.""" # Fill with positions 0-31 input_pos = torch.tensor([0], dtype=torch.long) self.manager.calculate_positions_and_update_indices(input_pos, 32) # Now add position 32 which wraps to index 0 input_pos = torch.tensor([32], dtype=torch.long) self.manager.calculate_positions_and_update_indices(input_pos, 1) # Index 0 should now contain original position 32 self.assertEqual(self.manager.cache_positions[0].item(), 32)class CausalMaskTest(unittest.TestCase): """Test the causal mask generation for attention sink.""" def test_mask_allows_sink_tokens(self): """Sink tokens should always be visible (mask = 0).""" cache_size = 32 sink_size = 4 window_size = 14 # cache_size = sink_size + window_size * 2 # Create cache positions where positions 0-3 are sink tokens cache_positions = torch.arange(cache_size, dtype=torch.long) start_pos = 20 # Query at position 20 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Sink tokens (indices 0-3, original positions 0-3) should have mask = 0 for i in range(sink_size): self.assertEqual(mask[0, i].item(), 0.0, f"Sink token at index {i} should be visible") def test_mask_blocks_future_tokens(self): """Future tokens should be masked (-inf).""" cache_size = 32 sink_size = 4 window_size = 14 cache_positions = torch.arange(cache_size, dtype=torch.long) start_pos = 10 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Future tokens (positions > 10) should have mask = -inf for i in range(11, cache_size): self.assertEqual(mask[0, i].item(), float('-inf'), f"Future token at position {i} should be masked") def test_mask_respects_window(self): """Tokens outside the window should be masked.""" cache_size = 32 sink_size = 4 window_size = 5 # Only allow 5 recent tokens cache_positions = torch.arange(cache_size, dtype=torch.long) start_pos = 20 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Positions 16-20 should be visible (within window of 5) for pos in range(16, 21): self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible (in window)") # Position 15 should be masked (outside window, not a sink) self.assertEqual(mask[0, 15].item(), float('-inf'), f"Position 15 should be masked (outside window)")class KVCacheWithAttentionSinkTest(unittest.TestCase): """Test the KV cache with attention sink.""" def setUp(self): torch.manual_seed(42) self.window_size = 14 self.sink_size = 4 self.n_heads = 8 self.head_dim = 64 self.max_batch_size = 1 # Create model args with enough context for RoPE self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, # Large enough for RoPE n_heads=self.n_heads, n_kv_heads=self.n_heads, dim=self.n_heads * self.head_dim, ) self.rope = RopeWithAttentionSink( params=self.params, window_size=self.window_size, sink_size=self.sink_size, eviction_batch_size=1, ) self.kv_cache = KVCacheWithAttentionSink( n_heads=self.n_heads, head_dim=self.head_dim, enable_dynamic_shape=True, rope=self.rope, max_batch_size=self.max_batch_size, window_size=self.window_size, sink_size=self.sink_size, eviction_batch_size=1, dtype=torch.float32, ) def test_cache_size(self): """Cache should be sink_size + window_size * 2.""" expected_size = self.sink_size + self.window_size * 2 # 4 + 28 = 32 self.assertEqual(self.kv_cache.k_cache.size(2), expected_size) self.assertEqual(self.kv_cache.v_cache.size(2), expected_size) def test_is_ring_buffer(self): """Cache should be marked as ring buffer.""" self.assertTrue(self.kv_cache.is_ring_buffer) def test_update_stores_kv(self): """Update should store key-value pairs.""" k = torch.randn(1, self.n_heads, 5, self.head_dim) v = torch.randn(1, self.n_heads, 5, self.head_dim) input_pos = torch.tensor([0], dtype=torch.long) k_out, v_out = self.kv_cache.update(input_pos, k, v) # First 5 positions should contain our values torch.testing.assert_close(k_out[:, :, :5, :], k) torch.testing.assert_close(v_out[:, :, :5, :], v) def test_evict_tokens_returns_zero(self): """Ring buffer implementation doesn't shift, so evict returns 0.""" input_pos = torch.tensor([100], dtype=torch.long) shift = self.kv_cache.evict_tokens(input_pos, 10) self.assertEqual(shift, 0) def test_extended_generation(self): """Test that cache works for positions beyond cache size.""" cache_size = self.kv_cache.k_cache.size(2) # Fill cache with initial tokens for pos in range(cache_size + 50): k = torch.randn(1, self.n_heads, 1, self.head_dim) v = torch.randn(1, self.n_heads, 1, self.head_dim) input_pos = torch.tensor([pos], dtype=torch.long) k_out, v_out = self.kv_cache.update(input_pos, k, v) # Should not raise any errors self.assertEqual(k_out.shape, self.kv_cache.k_cache.shape) self.assertEqual(v_out.shape, self.kv_cache.v_cache.shape)class RopeWithAttentionSinkTest(unittest.TestCase): """Test RoPE for attention sink.""" def setUp(self): torch.manual_seed(42) self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, n_heads=8, dim=512, ) self.rope = RopeWithAttentionSink( params=self.params, window_size=100, sink_size=4, eviction_batch_size=1, ) def test_get_freqs_uses_original_position(self): """RoPE frequencies should use the original position.""" input_pos = torch.tensor([50], dtype=torch.long) seq_len = 5 freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seq_len) # Should get frequencies for positions 50-54 expected_cos = self.rope.freqs_cos[50:55] expected_sin = self.rope.freqs_sin[50:55] torch.testing.assert_close(freqs_cos, expected_cos) torch.testing.assert_close(freqs_sin, expected_sin) def test_rerotate_k(self): """Test re-rotation of k from one position to another.""" batch_size = 1 seq_len = 8 n_heads = self.params.n_heads head_dim = self.params.dim // n_heads k = torch.randn(batch_size, seq_len, n_heads, head_dim) q = torch.randn(batch_size, seq_len, n_heads, head_dim) # Rotate k at position 100 original_pos = 100 freqs_cos, freqs_sin = self.rope.get_freqs( torch.tensor([original_pos], dtype=torch.long), seq_len ) _, rotated_k = self.rope.forward(q, k, freqs_cos, freqs_sin) # Re-rotate to position 50 new_pos = 50 rerotated_k = self.rope.rerotate_k(rotated_k, original_pos, new_pos) # This should be equivalent to directly rotating k at position 50 freqs_cos_new, freqs_sin_new = self.rope.get_freqs( torch.tensor([new_pos], dtype=torch.long), seq_len ) _, expected_k = self.rope.forward(q, k, freqs_cos_new, freqs_sin_new) torch.testing.assert_close(rerotated_k, expected_k, rtol=1e-4, atol=1e-4)class CausalMaskWithWraparoundTest(unittest.TestCase): """Test causal mask with ring buffer wraparound.""" def test_mask_after_wraparound(self): """Test mask after cache has wrapped around.""" cache_size = 16 sink_size = 4 window_size = 6 # cache_size = sink_size + window_size * 2 # Simulate cache after generating beyond cache_size: # The ring buffer wraps, so indices 0-15 contain positions that wrap # At position 50, with cache_size=16, the cache contains: # positions 50-15=35 to 49 at various indices cache_positions = torch.zeros(cache_size, dtype=torch.long) # Fill with positions that would exist after generating 50 tokens # idx = pos % cache_size, so: # pos 34-49 occupy indices 2-15 and 0-1 for pos in range(34, 50): idx = pos % cache_size cache_positions[idx] = pos start_pos = 49 # Query at position 49 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Positions within window (49-6+1=44 to 49) should be visible visible_count = 0 for i in range(cache_size): pos = cache_positions[i].item() if pos >= 44 and pos <= 49: # In window self.assertEqual(mask[0, i].item(), 0.0, f"Position {pos} at idx {i} should be visible (in window)") visible_count += 1 # Should have some visible tokens self.assertGreater(visible_count, 0, "Should have visible tokens in window") def test_mask_with_sink_size_zero(self): """Test pure sliding window (sink_size=0).""" cache_size = 16 sink_size = 0 window_size = 8 cache_positions = torch.arange(cache_size, dtype=torch.long) start_pos = 10 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Positions 3-10 should be visible (within window of 8) for pos in range(3, 11): self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible") # Positions 0-2 should be masked (outside window) for pos in range(0, 3): self.assertEqual(mask[0, pos].item(), float('-inf'), f"Position {pos} should be masked (outside window)")class PrefillTest(unittest.TestCase): """Test prefill scenarios.""" def setUp(self): torch.manual_seed(42) self.window_size = 14 self.sink_size = 4 self.n_heads = 8 self.head_dim = 64 self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, n_heads=self.n_heads, n_kv_heads=self.n_heads, dim=self.n_heads * self.head_dim, ) self.rope = RopeWithAttentionSink( params=self.params, window_size=self.window_size, sink_size=self.sink_size, eviction_batch_size=1, ) self.kv_cache = KVCacheWithAttentionSink( n_heads=self.n_heads, head_dim=self.head_dim, enable_dynamic_shape=True, rope=self.rope, max_batch_size=1, window_size=self.window_size, sink_size=self.sink_size, eviction_batch_size=1, dtype=torch.float32, ) def test_prefill_entire_cache(self): """Test prefill that fills entire cache.""" cache_size = self.kv_cache.k_cache.size(2) k = torch.randn(1, self.n_heads, cache_size, self.head_dim) v = torch.randn(1, self.n_heads, cache_size, self.head_dim) input_pos = torch.tensor([0], dtype=torch.long) k_out, v_out = self.kv_cache.update(input_pos, k, v) # All positions should be filled torch.testing.assert_close(k_out, k) torch.testing.assert_close(v_out, v) def test_prefill_larger_than_cache_raises_error(self): """Test that prefill larger than cache size raises an assertion error.""" cache_size = self.kv_cache.k_cache.size(2) seq_len = cache_size + 10 k = torch.randn(1, self.n_heads, seq_len, self.head_dim) v = torch.randn(1, self.n_heads, seq_len, self.head_dim) input_pos = torch.tensor([0], dtype=torch.long) # This should raise an assertion error since seq_len > cache_size with self.assertRaises(AssertionError): self.kv_cache.update(input_pos, k, v) def test_prefill_followed_by_decode(self): """Test prefill followed by decode steps.""" cache_size = self.kv_cache.k_cache.size(2) # Prefill with 20 tokens k_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) v_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) input_pos = torch.tensor([0], dtype=torch.long) self.kv_cache.update(input_pos, k_prefill, v_prefill) # Decode 5 more tokens for i in range(5): k_decode = torch.randn(1, self.n_heads, 1, self.head_dim) v_decode = torch.randn(1, self.n_heads, 1, self.head_dim) input_pos = torch.tensor([20 + i], dtype=torch.long) k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode) # Verify cache positions are updated expected_pos = 20 + i cache_idx = expected_pos % cache_size self.assertEqual( self.kv_cache.cache_positions_manager.cache_positions[cache_idx].item(), expected_pos )class EnableAttentionSinkTest(unittest.TestCase): """Test the enable_attention_sink transformation.""" def setUp(self): torch.manual_seed(42) self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, n_heads=8, n_kv_heads=8, dim=512, n_layers=2, vocab_size=100, ) def test_enable_attention_sink_transforms_model(self): """Test that enable_attention_sink properly transforms the model.""" from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.source_transformation.attention_sink import ( enable_attention_sink, ) # Create a simple transformer with torch.device("meta"): model = construct_transformer(self.params) model.to_empty(device="cpu") # Apply attention sink transformation model = enable_attention_sink( module=model, params=self.params, sink_size=4, window_size=100, eviction_batch_size=1, ) # Check that KV caches are replaced for layer in model.layers: kv_cache = layer.attention.kv_cache self.assertIsInstance(kv_cache, KVCacheWithAttentionSink) self.assertEqual(kv_cache.sink_size, 4) self.assertEqual(kv_cache.window_size, 100) self.assertTrue(kv_cache.is_ring_buffer) def test_enable_attention_sink_replaces_rope(self): """Test that RoPE is replaced with RopeWithAttentionSink.""" from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.source_transformation.attention_sink import ( enable_attention_sink, ) with torch.device("meta"): model = construct_transformer(self.params) model.to_empty(device="cpu") model = enable_attention_sink( module=model, params=self.params, sink_size=4, window_size=100, eviction_batch_size=1, ) # Check that rope is replaced for layer in model.layers: rope = layer.attention.rope self.assertIsInstance(rope, RopeWithAttentionSink)class IntegrationTest(unittest.TestCase): """Integration tests for end-to-end scenarios.""" def setUp(self): torch.manual_seed(42) def test_cache_positions_consistency(self): """Test that cache positions remain consistent during generation.""" cache_size = 32 sink_size = 4 window_size = 14 n_heads = 8 head_dim = 64 params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, n_heads=n_heads, n_kv_heads=n_heads, dim=n_heads * head_dim, ) rope = RopeWithAttentionSink( params=params, window_size=window_size, sink_size=sink_size, eviction_batch_size=1, ) kv_cache = KVCacheWithAttentionSink( n_heads=n_heads, head_dim=head_dim, enable_dynamic_shape=True, rope=rope, max_batch_size=1, window_size=window_size, sink_size=sink_size, eviction_batch_size=1, dtype=torch.float32, ) # Generate 100 tokens for pos in range(100): k = torch.randn(1, n_heads, 1, head_dim) v = torch.randn(1, n_heads, 1, head_dim) input_pos = torch.tensor([pos], dtype=torch.long) kv_cache.update(input_pos, k, v) # Create mask and verify it's valid mask = kv_cache.create_causal_mask_for_ring_buffer(pos, 1) # Mask should not be all -inf (would mean no tokens to attend to) non_inf_count = (mask != float('-inf')).sum().item() self.assertGreater(non_inf_count, 0, f"At pos {pos}, mask should have visible tokens") # For positions >= sink_size, sinks should always be visible if pos >= sink_size: for i in range(sink_size): cache_pos = kv_cache.cache_positions_manager.cache_positions[i].item() if cache_pos < sink_size: # This is actually a sink token self.assertEqual(mask[0, i].item(), 0.0, f"Sink at idx {i} should be visible at pos {pos}")if __name__ == '__main__': unittest.main() No newline at end of file
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The entire file content has been compressed onto a single line, removing all newlines. This makes the code completely unreadable and unmaintainable. The file should be properly formatted with appropriate line breaks, indentation, and spacing according to Python coding standards (PEP 8).

Suggested change
# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree."""Tests for the ring-buffer based attention sink implementation.This tests the torch.export-compatible implementation that uses a ring bufferfor the sliding window rather than explicit token eviction.Usage: # Run with pytest python -m pytest examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py -v # Or run directly python examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py"""import unittestimport torchfrom executorch.examples.models.llama.model_args import ModelArgsfrom executorch.examples.models.llama.source_transformation.attention_sink import ( CachePositionsManagerWithSink, KVCacheWithAttentionSink, RopeWithAttentionSink, _create_causal_mask_for_attention_sink,)class CachePositionsManagerWithSinkTest(unittest.TestCase): """Test the cache positions manager for ring buffer indexing.""" def setUp(self): self.cache_size = 32 # Total cache size (e.g., sink_size + window_size * 2) self.manager = CachePositionsManagerWithSink(self.cache_size) def test_initial_positions_are_zero(self): """Cache positions should start as zeros.""" expected = torch.zeros(self.cache_size, dtype=torch.long) torch.testing.assert_close(self.manager.cache_positions, expected) def test_simple_update(self): """Test simple sequential update without wraparound.""" input_pos = torch.tensor([0], dtype=torch.long) seq_len = 5 indices = self.manager.calculate_positions_and_update_indices(input_pos, seq_len) # Should return indices 0, 1, 2, 3, 4 expected_indices = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long) torch.testing.assert_close(indices, expected_indices) # Cache positions at those indices should be the original positions for i in range(5): self.assertEqual(self.manager.cache_positions[i].item(), i) def test_wraparound(self): """Test ring buffer wraparound.""" # Fill cache to position 30 input_pos = torch.tensor([0], dtype=torch.long) self.manager.calculate_positions_and_update_indices(input_pos, 30) # Add 5 more tokens at position 30 - should wrap around input_pos = torch.tensor([30], dtype=torch.long) indices = self.manager.calculate_positions_and_update_indices(input_pos, 5) # Indices should wrap: 30 % 32 = 30, 31, 0, 1, 2 expected_indices = torch.tensor([30, 31, 0, 1, 2], dtype=torch.long) torch.testing.assert_close(indices, expected_indices) def test_cache_positions_track_original_positions(self): """Cache positions should track which original position is at each index.""" # Fill with positions 0-31 input_pos = torch.tensor([0], dtype=torch.long) self.manager.calculate_positions_and_update_indices(input_pos, 32) # Now add position 32 which wraps to index 0 input_pos = torch.tensor([32], dtype=torch.long) self.manager.calculate_positions_and_update_indices(input_pos, 1) # Index 0 should now contain original position 32 self.assertEqual(self.manager.cache_positions[0].item(), 32)class CausalMaskTest(unittest.TestCase): """Test the causal mask generation for attention sink.""" def test_mask_allows_sink_tokens(self): """Sink tokens should always be visible (mask = 0).""" cache_size = 32 sink_size = 4 window_size = 14 # cache_size = sink_size + window_size * 2 # Create cache positions where positions 0-3 are sink tokens cache_positions = torch.arange(cache_size, dtype=torch.long) start_pos = 20 # Query at position 20 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Sink tokens (indices 0-3, original positions 0-3) should have mask = 0 for i in range(sink_size): self.assertEqual(mask[0, i].item(), 0.0, f"Sink token at index {i} should be visible") def test_mask_blocks_future_tokens(self): """Future tokens should be masked (-inf).""" cache_size = 32 sink_size = 4 window_size = 14 cache_positions = torch.arange(cache_size, dtype=torch.long) start_pos = 10 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Future tokens (positions > 10) should have mask = -inf for i in range(11, cache_size): self.assertEqual(mask[0, i].item(), float('-inf'), f"Future token at position {i} should be masked") def test_mask_respects_window(self): """Tokens outside the window should be masked.""" cache_size = 32 sink_size = 4 window_size = 5 # Only allow 5 recent tokens cache_positions = torch.arange(cache_size, dtype=torch.long) start_pos = 20 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Positions 16-20 should be visible (within window of 5) for pos in range(16, 21): self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible (in window)") # Position 15 should be masked (outside window, not a sink) self.assertEqual(mask[0, 15].item(), float('-inf'), f"Position 15 should be masked (outside window)")class KVCacheWithAttentionSinkTest(unittest.TestCase): """Test the KV cache with attention sink.""" def setUp(self): torch.manual_seed(42) self.window_size = 14 self.sink_size = 4 self.n_heads = 8 self.head_dim = 64 self.max_batch_size = 1 # Create model args with enough context for RoPE self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, # Large enough for RoPE n_heads=self.n_heads, n_kv_heads=self.n_heads, dim=self.n_heads * self.head_dim, ) self.rope = RopeWithAttentionSink( params=self.params, window_size=self.window_size, sink_size=self.sink_size, eviction_batch_size=1, ) self.kv_cache = KVCacheWithAttentionSink( n_heads=self.n_heads, head_dim=self.head_dim, enable_dynamic_shape=True, rope=self.rope, max_batch_size=self.max_batch_size, window_size=self.window_size, sink_size=self.sink_size, eviction_batch_size=1, dtype=torch.float32, ) def test_cache_size(self): """Cache should be sink_size + window_size * 2.""" expected_size = self.sink_size + self.window_size * 2 # 4 + 28 = 32 self.assertEqual(self.kv_cache.k_cache.size(2), expected_size) self.assertEqual(self.kv_cache.v_cache.size(2), expected_size) def test_is_ring_buffer(self): """Cache should be marked as ring buffer.""" self.assertTrue(self.kv_cache.is_ring_buffer) def test_update_stores_kv(self): """Update should store key-value pairs.""" k = torch.randn(1, self.n_heads, 5, self.head_dim) v = torch.randn(1, self.n_heads, 5, self.head_dim) input_pos = torch.tensor([0], dtype=torch.long) k_out, v_out = self.kv_cache.update(input_pos, k, v) # First 5 positions should contain our values torch.testing.assert_close(k_out[:, :, :5, :], k) torch.testing.assert_close(v_out[:, :, :5, :], v) def test_evict_tokens_returns_zero(self): """Ring buffer implementation doesn't shift, so evict returns 0.""" input_pos = torch.tensor([100], dtype=torch.long) shift = self.kv_cache.evict_tokens(input_pos, 10) self.assertEqual(shift, 0) def test_extended_generation(self): """Test that cache works for positions beyond cache size.""" cache_size = self.kv_cache.k_cache.size(2) # Fill cache with initial tokens for pos in range(cache_size + 50): k = torch.randn(1, self.n_heads, 1, self.head_dim) v = torch.randn(1, self.n_heads, 1, self.head_dim) input_pos = torch.tensor([pos], dtype=torch.long) k_out, v_out = self.kv_cache.update(input_pos, k, v) # Should not raise any errors self.assertEqual(k_out.shape, self.kv_cache.k_cache.shape) self.assertEqual(v_out.shape, self.kv_cache.v_cache.shape)class RopeWithAttentionSinkTest(unittest.TestCase): """Test RoPE for attention sink.""" def setUp(self): torch.manual_seed(42) self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, n_heads=8, dim=512, ) self.rope = RopeWithAttentionSink( params=self.params, window_size=100, sink_size=4, eviction_batch_size=1, ) def test_get_freqs_uses_original_position(self): """RoPE frequencies should use the original position.""" input_pos = torch.tensor([50], dtype=torch.long) seq_len = 5 freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seq_len) # Should get frequencies for positions 50-54 expected_cos = self.rope.freqs_cos[50:55] expected_sin = self.rope.freqs_sin[50:55] torch.testing.assert_close(freqs_cos, expected_cos) torch.testing.assert_close(freqs_sin, expected_sin) def test_rerotate_k(self): """Test re-rotation of k from one position to another.""" batch_size = 1 seq_len = 8 n_heads = self.params.n_heads head_dim = self.params.dim // n_heads k = torch.randn(batch_size, seq_len, n_heads, head_dim) q = torch.randn(batch_size, seq_len, n_heads, head_dim) # Rotate k at position 100 original_pos = 100 freqs_cos, freqs_sin = self.rope.get_freqs( torch.tensor([original_pos], dtype=torch.long), seq_len ) _, rotated_k = self.rope.forward(q, k, freqs_cos, freqs_sin) # Re-rotate to position 50 new_pos = 50 rerotated_k = self.rope.rerotate_k(rotated_k, original_pos, new_pos) # This should be equivalent to directly rotating k at position 50 freqs_cos_new, freqs_sin_new = self.rope.get_freqs( torch.tensor([new_pos], dtype=torch.long), seq_len ) _, expected_k = self.rope.forward(q, k, freqs_cos_new, freqs_sin_new) torch.testing.assert_close(rerotated_k, expected_k, rtol=1e-4, atol=1e-4)class CausalMaskWithWraparoundTest(unittest.TestCase): """Test causal mask with ring buffer wraparound.""" def test_mask_after_wraparound(self): """Test mask after cache has wrapped around.""" cache_size = 16 sink_size = 4 window_size = 6 # cache_size = sink_size + window_size * 2 # Simulate cache after generating beyond cache_size: # The ring buffer wraps, so indices 0-15 contain positions that wrap # At position 50, with cache_size=16, the cache contains: # positions 50-15=35 to 49 at various indices cache_positions = torch.zeros(cache_size, dtype=torch.long) # Fill with positions that would exist after generating 50 tokens # idx = pos % cache_size, so: # pos 34-49 occupy indices 2-15 and 0-1 for pos in range(34, 50): idx = pos % cache_size cache_positions[idx] = pos start_pos = 49 # Query at position 49 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Positions within window (49-6+1=44 to 49) should be visible visible_count = 0 for i in range(cache_size): pos = cache_positions[i].item() if pos >= 44 and pos <= 49: # In window self.assertEqual(mask[0, i].item(), 0.0, f"Position {pos} at idx {i} should be visible (in window)") visible_count += 1 # Should have some visible tokens self.assertGreater(visible_count, 0, "Should have visible tokens in window") def test_mask_with_sink_size_zero(self): """Test pure sliding window (sink_size=0).""" cache_size = 16 sink_size = 0 window_size = 8 cache_positions = torch.arange(cache_size, dtype=torch.long) start_pos = 10 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Positions 3-10 should be visible (within window of 8) for pos in range(3, 11): self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible") # Positions 0-2 should be masked (outside window) for pos in range(0, 3): self.assertEqual(mask[0, pos].item(), float('-inf'), f"Position {pos} should be masked (outside window)")class PrefillTest(unittest.TestCase): """Test prefill scenarios.""" def setUp(self): torch.manual_seed(42) self.window_size = 14 self.sink_size = 4 self.n_heads = 8 self.head_dim = 64 self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, n_heads=self.n_heads, n_kv_heads=self.n_heads, dim=self.n_heads * self.head_dim, ) self.rope = RopeWithAttentionSink( params=self.params, window_size=self.window_size, sink_size=self.sink_size, eviction_batch_size=1, ) self.kv_cache = KVCacheWithAttentionSink( n_heads=self.n_heads, head_dim=self.head_dim, enable_dynamic_shape=True, rope=self.rope, max_batch_size=1, window_size=self.window_size, sink_size=self.sink_size, eviction_batch_size=1, dtype=torch.float32, ) def test_prefill_entire_cache(self): """Test prefill that fills entire cache.""" cache_size = self.kv_cache.k_cache.size(2) k = torch.randn(1, self.n_heads, cache_size, self.head_dim) v = torch.randn(1, self.n_heads, cache_size, self.head_dim) input_pos = torch.tensor([0], dtype=torch.long) k_out, v_out = self.kv_cache.update(input_pos, k, v) # All positions should be filled torch.testing.assert_close(k_out, k) torch.testing.assert_close(v_out, v) def test_prefill_larger_than_cache_raises_error(self): """Test that prefill larger than cache size raises an assertion error.""" cache_size = self.kv_cache.k_cache.size(2) seq_len = cache_size + 10 k = torch.randn(1, self.n_heads, seq_len, self.head_dim) v = torch.randn(1, self.n_heads, seq_len, self.head_dim) input_pos = torch.tensor([0], dtype=torch.long) # This should raise an assertion error since seq_len > cache_size with self.assertRaises(AssertionError): self.kv_cache.update(input_pos, k, v) def test_prefill_followed_by_decode(self): """Test prefill followed by decode steps.""" cache_size = self.kv_cache.k_cache.size(2) # Prefill with 20 tokens k_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) v_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) input_pos = torch.tensor([0], dtype=torch.long) self.kv_cache.update(input_pos, k_prefill, v_prefill) # Decode 5 more tokens for i in range(5): k_decode = torch.randn(1, self.n_heads, 1, self.head_dim) v_decode = torch.randn(1, self.n_heads, 1, self.head_dim) input_pos = torch.tensor([20 + i], dtype=torch.long) k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode) # Verify cache positions are updated expected_pos = 20 + i cache_idx = expected_pos % cache_size self.assertEqual( self.kv_cache.cache_positions_manager.cache_positions[cache_idx].item(), expected_pos )class EnableAttentionSinkTest(unittest.TestCase): """Test the enable_attention_sink transformation.""" def setUp(self): torch.manual_seed(42) self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, n_heads=8, n_kv_heads=8, dim=512, n_layers=2, vocab_size=100, ) def test_enable_attention_sink_transforms_model(self): """Test that enable_attention_sink properly transforms the model.""" from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.source_transformation.attention_sink import ( enable_attention_sink, ) # Create a simple transformer with torch.device("meta"): model = construct_transformer(self.params) model.to_empty(device="cpu") # Apply attention sink transformation model = enable_attention_sink( module=model, params=self.params, sink_size=4, window_size=100, eviction_batch_size=1, ) # Check that KV caches are replaced for layer in model.layers: kv_cache = layer.attention.kv_cache self.assertIsInstance(kv_cache, KVCacheWithAttentionSink) self.assertEqual(kv_cache.sink_size, 4) self.assertEqual(kv_cache.window_size, 100) self.assertTrue(kv_cache.is_ring_buffer) def test_enable_attention_sink_replaces_rope(self): """Test that RoPE is replaced with RopeWithAttentionSink.""" from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.source_transformation.attention_sink import ( enable_attention_sink, ) with torch.device("meta"): model = construct_transformer(self.params) model.to_empty(device="cpu") model = enable_attention_sink( module=model, params=self.params, sink_size=4, window_size=100, eviction_batch_size=1, ) # Check that rope is replaced for layer in model.layers: rope = layer.attention.rope self.assertIsInstance(rope, RopeWithAttentionSink)class IntegrationTest(unittest.TestCase): """Integration tests for end-to-end scenarios.""" def setUp(self): torch.manual_seed(42) def test_cache_positions_consistency(self): """Test that cache positions remain consistent during generation.""" cache_size = 32 sink_size = 4 window_size = 14 n_heads = 8 head_dim = 64 params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, n_heads=n_heads, n_kv_heads=n_heads, dim=n_heads * head_dim, ) rope = RopeWithAttentionSink( params=params, window_size=window_size, sink_size=sink_size, eviction_batch_size=1, ) kv_cache = KVCacheWithAttentionSink( n_heads=n_heads, head_dim=head_dim, enable_dynamic_shape=True, rope=rope, max_batch_size=1, window_size=window_size, sink_size=sink_size, eviction_batch_size=1, dtype=torch.float32, ) # Generate 100 tokens for pos in range(100): k = torch.randn(1, n_heads, 1, head_dim) v = torch.randn(1, n_heads, 1, head_dim) input_pos = torch.tensor([pos], dtype=torch.long) kv_cache.update(input_pos, k, v) # Create mask and verify it's valid mask = kv_cache.create_causal_mask_for_ring_buffer(pos, 1) # Mask should not be all -inf (would mean no tokens to attend to) non_inf_count = (mask != float('-inf')).sum().item() self.assertGreater(non_inf_count, 0, f"At pos {pos}, mask should have visible tokens") # For positions >= sink_size, sinks should always be visible if pos >= sink_size: for i in range(sink_size): cache_pos = kv_cache.cache_positions_manager.cache_positions[i].item() if cache_pos < sink_size: # This is actually a sink token self.assertEqual(mask[0, i].item(), 0.0, f"Sink at idx {i} should be visible at pos {pos}")if __name__ == '__main__': unittest.main()
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Tests for the ring-buffer based attention sink implementation.
This tests the torch.export-compatible implementation that uses a ring buffer
for the sliding window rather than explicit token eviction.
Usage:
# Run with pytest
python -m pytest examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py -v
# Or run directly
python examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py
"""
import unittest
import torch
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.source_transformation.attention_sink import (
CachePositionsManagerWithSink,
KVCacheWithAttentionSink,
RopeWithAttentionSink,
_create_causal_mask_for_attention_sink,
)
class CachePositionsManagerWithSinkTest(unittest.TestCase):
"""Test the cache positions manager for ring buffer indexing."""
def setUp(self) -> None:
# Total cache size (e.g., sink_size + window_size * 2)
self.cache_size = 32
self.manager = CachePositionsManagerWithSink(self.cache_size)
def test_initial_positions_are_zero(self) -> None:
"""Cache positions should start as zeros."""
expected = torch.zeros(self.cache_size, dtype=torch.long)
torch.testing.assert_close(self.manager.cache_positions, expected)
def test_simple_update(self) -> None:
"""Test simple sequential update without wraparound."""
input_pos = torch.tensor([0], dtype=torch.long)
seq_len = 5
indices = self.manager.calculate_positions_and_update_indices(
input_pos, seq_len
)
# Should return indices 0, 1, 2, 3, 4
expected_indices = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long)
torch.testing.assert_close(indices, expected_indices)
# Cache positions at those indices should be the original positions
for i in range(5):
self.assertEqual(self.manager.cache_positions[i].item(), i)
def test_wraparound(self) -> None:
"""Test ring buffer wraparound."""
# Fill cache to position 30
input_pos = torch.tensor([0], dtype=torch.long)
self.manager.calculate_positions_and_update_indices(input_pos, 30)
# Add 5 more tokens at position 30 - should wrap around
input_pos = torch.tensor([30], dtype=torch.long)
indices = self.manager.calculate_positions_and_update_indices(
input_pos, 5
)
# Indices should wrap: 30 % 32 = 30, 31, 0, 1, 2
expected_indices = torch.tensor([30, 31, 0, 1, 2], dtype=torch.long)
torch.testing.assert_close(indices, expected_indices)
def test_cache_positions_track_original_positions(self) -> None:
"""Cache positions should track which original position is at each index."""
# Fill with positions 0-31
input_pos = torch.tensor([0], dtype=torch.long)
self.manager.calculate_positions_and_update_indices(input_pos, 32)
# Now add position 32 which wraps to index 0
input_pos = torch.tensor([32], dtype=torch.long)
self.manager.calculate_positions_and_update_indices(input_pos, 1)
# Index 0 should now contain original position 32
self.assertEqual(self.manager.cache_positions[0].item(), 32)
class CausalMaskTest(unittest.TestCase):
"""Test the causal mask generation for attention sink."""
def test_mask_allows_sink_tokens(self) -> None:
"""Sink tokens should always be visible (mask = 0)."""
cache_size = 32
sink_size = 4
# cache_size = sink_size + window_size * 2
window_size = 14
# Create cache positions where positions 0-3 are sink tokens
cache_positions = torch.arange(cache_size, dtype=torch.long)
start_pos = 20 # Query at position 20
seq_len = 1
mask = _create_causal_mask_for_attention_sink(
cache_positions, window_size, sink_size, start_pos, seq_len
)
# Sink tokens (indices 0-3, original positions 0-3) should have mask = 0
for i in range(sink_size):
self.assertEqual(
mask[0, i].item(),
0.0,
f"Sink token at index {i} should be visible",
)
def test_mask_blocks_future_tokens(self) -> None:
"""Future tokens should be masked (-inf)."""
cache_size = 32
sink_size = 4
window_size = 14
cache_positions = torch.arange(cache_size, dtype=torch.long)
start_pos = 10
seq_len = 1
mask = _create_causal_mask_for_attention_sink(
cache_positions, window_size, sink_size, start_pos, seq_len
)
# Future tokens (positions > 10) should have mask = -inf
for i in range(11, cache_size):
self.assertEqual(
mask[0, i].item(),
float("-inf"),
f"Future token at position {i} should be masked",
)
def test_mask_respects_window(self) -> None:
"""Tokens outside the window should be masked."""
cache_size = 32
sink_size = 4
# Only allow 5 recent tokens
window_size = 5
cache_positions = torch.arange(cache_size, dtype=torch.long)
start_pos = 20
seq_len = 1
mask = _create_causal_mask_for_attention_sink(
cache_positions, window_size, sink_size, start_pos, seq_len
)
# Positions 16-20 should be visible (within window of 5)
for pos in range(16, 21):
self.assertEqual(
mask[0, pos].item(),
0.0,
f"Position {pos} should be visible (in window)",
)
# Position 15 should be masked (outside window, not a sink)
self.assertEqual(
mask[0, 15].item(),
float("-inf"),
"Position 15 should be masked (outside window)",
)
class KVCacheWithAttentionSinkTest(unittest.TestCase):
"""Test the KV cache with attention sink."""
def setUp(self) -> None:
torch.manual_seed(42)
self.window_size = 14
self.sink_size = 4
self.n_heads = 8
self.head_dim = 64
self.max_batch_size = 1
# Create model args with enough context for RoPE
self.params = ModelArgs(
use_kv_cache=True,
enable_dynamic_shape=True,
max_context_len=1024, # Large enough for RoPE
n_heads=self.n_heads,
n_kv_heads=self.n_heads,
dim=self.n_heads * self.head_dim,
)
self.rope = RopeWithAttentionSink(
params=self.params,
window_size=self.window_size,
sink_size=self.sink_size,
eviction_batch_size=1,
)
self.kv_cache = KVCacheWithAttentionSink(
n_heads=self.n_heads,
head_dim=self.head_dim,
enable_dynamic_shape=True,
rope=self.rope,
max_batch_size=self.max_batch_size,
window_size=self.window_size,
sink_size=self.sink_size,
eviction_batch_size=1,
dtype=torch.float32,
)
def test_cache_size(self) -> None:
"""Cache should be sink_size + window_size * 2."""
# 4 + 28 = 32
expected_size = self.sink_size + self.window_size * 2
self.assertEqual(self.kv_cache.k_cache.size(2), expected_size)
self.assertEqual(self.kv_cache.v_cache.size(2), expected_size)
def test_is_ring_buffer(self) -> None:
"""Cache should be marked as ring buffer."""
self.assertTrue(self.kv_cache.is_ring_buffer)
def test_update_stores_kv(self) -> None:
"""Update should store key-value pairs."""
k = torch.randn(1, self.n_heads, 5, self.head_dim)
v = torch.randn(1, self.n_heads, 5, self.head_dim)
input_pos = torch.tensor([0], dtype=torch.long)
k_out, v_out = self.kv_cache.update(input_pos, k, v)
# First 5 positions should contain our values
torch.testing.assert_close(k_out[:, :, :5, :], k)
torch.testing.assert_close(v_out[:, :, :5, :], v)
def test_evict_tokens_returns_zero(self) -> None:
"""Ring buffer implementation doesn't shift, so evict returns 0."""
input_pos = torch.tensor([100], dtype=torch.long)
shift = self.kv_cache.evict_tokens(input_pos, 10)
self.assertEqual(shift, 0)
def test_extended_generation(self) -> None:
"""Test that cache works for positions beyond cache size."""
cache_size = self.kv_cache.k_cache.size(2)
# Fill cache with initial tokens
for pos in range(cache_size + 50):
k = torch.randn(1, self.n_heads, 1, self.head_dim)
v = torch.randn(1, self.n_heads, 1, self.head_dim)
input_pos = torch.tensor([pos], dtype=torch.long)
k_out, v_out = self.kv_cache.update(input_pos, k, v)
# Should not raise any errors
self.assertEqual(k_out.shape, self.kv_cache.k_cache.shape)
self.assertEqual(v_out.shape, self.kv_cache.v_cache.shape)
class RopeWithAttentionSinkTest(unittest.TestCase):
"""Test RoPE for attention sink."""
def setUp(self) -> None:
torch.manual_seed(42)
self.params = ModelArgs(
use_kv_cache=True,
enable_dynamic_shape=True,
max_context_len=1024,
n_heads=8,
dim=512,
)
self.rope = RopeWithAttentionSink(
params=self.params,
window_size=100,
sink_size=4,
eviction_batch_size=1,
)
def test_get_freqs_uses_original_position(self) -> None:
"""RoPE frequencies should use the original position."""
input_pos = torch.tensor([50], dtype=torch.long)
seq_len = 5
freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seq_len)
# Should get frequencies for positions 50-54
expected_cos = self.rope.freqs_cos[50:55]
expected_sin = self.rope.freqs_sin[50:55]
torch.testing.assert_close(freqs_cos, expected_cos)
torch.testing.assert_close(freqs_sin, expected_sin)
def test_rerotate_k(self) -> None:
"""Test re-rotation of k from one position to another."""
batch_size = 1
seq_len = 8
n_heads = self.params.n_heads
head_dim = self.params.dim // n_heads
k = torch.randn(batch_size, seq_len, n_heads, head_dim)
q = torch.randn(batch_size, seq_len, n_heads, head_dim)
# Rotate k at position 100
original_pos = 100
freqs_cos, freqs_sin = self.rope.get_freqs(
torch.tensor([original_pos], dtype=torch.long), seq_len
)
_, rotated_k = self.rope.forward(q, k, freqs_cos, freqs_sin)
# Re-rotate to position 50
new_pos = 50
rerotated_k = self.rope.rerotate_k(rotated_k, original_pos, new_pos)
# This should be equivalent to directly rotating k at position 50
freqs_cos_new, freqs_sin_new = self.rope.get_freqs(
torch.tensor([new_pos], dtype=torch.long), seq_len
)
_, expected_k = self.rope.forward(q, k, freqs_cos_new, freqs_sin_new)
torch.testing.assert_close(
rerotated_k,
expected_k,
rtol=1e-4,
atol=1e-4,
)
class CausalMaskWithWraparoundTest(unittest.TestCase):
"""Test causal mask with ring buffer wraparound."""
def test_mask_after_wraparound(self) -> None:
"""Test mask after cache has wrapped around."""
cache_size = 16
sink_size = 4
# cache_size = sink_size + window_size * 2
window_size = 6
# Simulate cache after generating beyond cache_size:
# The ring buffer wraps, so indices 0-15 contain positions that wrap
# At position 50, with cache_size=16, the cache contains:
# positions 50-15=35 to 49 at various indices
cache_positions = torch.zeros(cache_size, dtype=torch.long)
# Fill with positions that would exist after generating 50 tokens
# idx = pos % cache_size, so:
# pos 34-49 occupy indices 2-15 and 0-1
for pos in range(34, 50):
idx = pos % cache_size
cache_positions[idx] = pos
start_pos = 49 # Query at position 49
seq_len = 1
mask = _create_causal_mask_for_attention_sink(
cache_positions, window_size, sink_size, start_pos, seq_len
)
# Positions within window (49-6+1=44 to 49) should be visible
visible_count = 0
for i in range(cache_size):
pos = cache_positions[i].item()
if 44 <= pos <= 49: # In window
self.assertEqual(
mask[0, i].item(),
0.0,
f"Position {pos} at idx {i} should be visible (in window)",
)
visible_count += 1
# Should have some visible tokens
self.assertGreater(
visible_count, 0, "Should have visible tokens in window"
)
def test_mask_with_sink_size_zero(self) -> None:
"""Test pure sliding window (sink_size=0)."""
cache_size = 16
sink_size = 0
window_size = 8
cache_positions = torch.arange(cache_size, dtype=torch.long)
start_pos = 10
seq_len = 1
mask = _create_causal_mask_for_attention_sink(
cache_positions, window_size, sink_size, start_pos, seq_len
)
# Positions 3-10 should be visible (within window of 8)
for pos in range(3, 11):
self.assertEqual(
mask[0, pos].item(),
0.0,
f"Position {pos} should be visible",
)
# Positions 0-2 should be masked (outside window)
for pos in range(0, 3):
self.assertEqual(
mask[0, pos].item(),
float("-inf"),
f"Position {pos} should be masked (outside window)",
)
class PrefillTest(unittest.TestCase):
"""Test prefill scenarios."""
def setUp(self) -> None:
torch.manual_seed(42)
self.window_size = 14
self.sink_size = 4
self.n_heads = 8
self.head_dim = 64
self.params = ModelArgs(
use_kv_cache=True,
enable_dynamic_shape=True,
max_context_len=1024,
n_heads=self.n_heads,
n_kv_heads=self.n_heads,
dim=self.n_heads * self.head_dim,
)
self.rope = RopeWithAttentionSink(
params=self.params,
window_size=self.window_size,
sink_size=self.sink_size,
eviction_batch_size=1,
)
self.kv_cache = KVCacheWithAttentionSink(
n_heads=self.n_heads,
head_dim=self.head_dim,
enable_dynamic_shape=True,
rope=self.rope,
max_batch_size=1,
window_size=self.window_size,
sink_size=self.sink_size,
eviction_batch_size=1,
dtype=torch.float32,
)
def test_prefill_entire_cache(self) -> None:
"""Test prefill that fills entire cache."""
cache_size = self.kv_cache.k_cache.size(2)
k = torch.randn(1, self.n_heads, cache_size, self.head_dim)
v = torch.randn(1, self.n_heads, cache_size, self.head_dim)
input_pos = torch.tensor([0], dtype=torch.long)
k_out, v_out = self.kv_cache.update(input_pos, k, v)
# All positions should be filled
torch.testing.assert_close(k_out, k)
torch.testing.assert_close(v_out, v)
def test_prefill_larger_than_cache_raises_error(self) -> None:
"""Test that prefill larger than cache size raises an assertion error."""
cache_size = self.kv_cache.k_cache.size(2)
seq_len = cache_size + 10
k = torch.randn(1, self.n_heads, seq_len, self.head_dim)
v = torch.randn(1, self.n_heads, seq_len, self.head_dim)
input_pos = torch.tensor([0], dtype=torch.long)
# This should raise an assertion error since seq_len > cache_size
with self.assertRaises(AssertionError):
self.kv_cache.update(input_pos, k, v)
def test_prefill_followed_by_decode(self) -> None:
"""Test prefill followed by decode steps."""
cache_size = self.kv_cache.k_cache.size(2)
# Prefill with 20 tokens
k_prefill = torch.randn(1, self.n_heads, 20, self.head_dim)
v_prefill = torch.randn(1, self.n_heads, 20, self.head_dim)
input_pos = torch.tensor([0], dtype=torch.long)
self.kv_cache.update(input_pos, k_prefill, v_prefill)
# Decode 5 more tokens
for i in range(5):
k_decode = torch.randn(1, self.n_heads, 1, self.head_dim)
v_decode = torch.randn(1, self.n_heads, 1, self.head_dim)
input_pos = torch.tensor([20 + i], dtype=torch.long)
k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode)
# Verify cache positions are updated
expected_pos = 20 + i
cache_idx = expected_pos % cache_size
self.assertEqual(
self.kv_cache.cache_positions_manager.cache_positions[
cache_idx
].item(),
expected_pos,
)
class EnableAttentionSinkTest(unittest.TestCase):
"""Test the enable_attention_sink transformation."""
def setUp(self) -> None:
torch.manual_seed(42)
self.params = ModelArgs(
use_kv_cache=True,
enable_dynamic_shape=True,
max_context_len=1024,
n_heads=8,
n_kv_heads=8,
dim=512,
n_layers=2,
vocab_size=100,
)
def test_enable_attention_sink_transforms_model(self) -> None:
"""Test that enable_attention_sink properly transforms the model."""
from executorch.examples.models.llama.llama_transformer import (
construct_transformer,
)
from executorch.examples.models.llama.source_transformation.attention_sink import (
enable_attention_sink,
)
# Create a simple transformer
with torch.device("meta"):
model = construct_transformer(self.params)
model.to_empty(device="cpu")
# Apply attention sink transformation
model = enable_attention_sink(
module=model,
params=self.params,
sink_size=4,
window_size=100,
eviction_batch_size=1,
)
# Check that KV caches are replaced
for layer in model.layers:
kv_cache = layer.attention.kv_cache
self.assertIsInstance(kv_cache, KVCacheWithAttentionSink)
self.assertEqual(kv_cache.sink_size, 4)
self.assertEqual(kv_cache.window_size, 100)
self.assertTrue(kv_cache.is_ring_buffer)
def test_enable_attention_sink_replaces_rope(self) -> None:
"""Test that RoPE is replaced with RopeWithAttentionSink."""
from executorch.examples.models.llama.llama_transformer import (
construct_transformer,
)
from executorch.examples.models.llama.source_transformation.attention_sink import (
enable_attention_sink,
)
with torch.device("meta"):
model = construct_transformer(self.params)
model.to_empty(device="cpu")
model = enable_attention_sink(
module=model,
params=self.params,
sink_size=4,
window_size=100,
eviction_batch_size=1,
)
# Check that rope is replaced
for layer in model.layers:
rope = layer.attention.rope
self.assertIsInstance(rope, RopeWithAttentionSink)
class IntegrationTest(unittest.TestCase):
"""Integration tests for end-to-end scenarios."""
def setUp(self) -> None:
torch.manual_seed(42)
def test_cache_positions_consistency(self) -> None:
"""Test that cache positions remain consistent during generation."""
cache_size = 32
sink_size = 4
window_size = 14
n_heads = 8
head_dim = 64
params = ModelArgs(
use_kv_cache=True,
enable_dynamic_shape=True,
max_context_len=1024,
n_heads=n_heads,
n_kv_heads=n_heads,
dim=n_heads * head_dim,
)
rope = RopeWithAttentionSink(
params=params,
window_size=window_size,
sink_size=sink_size,
eviction_batch_size=1,
)
kv_cache = KVCacheWithAttentionSink(
n_heads=n_heads,
head_dim=head_dim,
enable_dynamic_shape=True,
rope=rope,
max_batch_size=1,
window_size=window_size,
sink_size=sink_size,
eviction_batch_size=1,
dtype=torch.float32,
)
# Generate 100 tokens
for pos in range(100):
k = torch.randn(1, n_heads, 1, head_dim)
v = torch.randn(1, n_heads, 1, head_dim)
input_pos = torch.tensor([pos], dtype=torch.long)
kv_cache.update(input_pos, k, v)
# Create mask and verify it's valid
mask = kv_cache.create_causal_mask_for_ring_buffer(pos, 1)
# Mask should not be all -inf (would mean no tokens to attend to)
non_inf_count = (mask != float("-inf")).sum().item()
self.assertGreater(
non_inf_count,
0,
f"At pos {pos}, mask should have visible tokens",
)
# For positions >= sink_size, sinks should always be visible
if pos >= sink_size:
for i in range(sink_size):
cache_pos = (
kv_cache.cache_positions_manager.cache_positions[i].item()
)
# This is actually a sink token
if cache_pos < sink_size:
self.assertEqual(
mask[0, i].item(),
0.0,
f"Sink at idx {i} should be visible at pos {pos}",
)
if __name__ == "__main__":
unittest.main()

Copilot uses AI. Check for mistakes.
eviction_batch_size=eviction_batch_size,
)
return module
# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.# Components for supporting Attention Sink. See# https://arxiv.org/abs/2309.17453 for more details about Attention Sink.# This implementation is torch.export compatible using a ring buffer approach# for the sliding window portion while preserving the sink tokens.import typesfrom typing import Optional, Tupleimport torchimport torch.nn as nnfrom executorch.examples.models.llama.attention import ( _create_causal_mask_for_ring_buffer, AttentionMHA, CachePositionsManager, KVCache, RingKVCache,)from executorch.examples.models.llama.model_args import ModelArgsfrom executorch.examples.models.llama.rope import ( apply_rotary_emb_to_k, hf_apply_rotary_emb_to_k, Rope,)from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filterclass RopeWithAttentionSink(Rope): """ Rope that helps adjust position encoding when tokens are shifted in KVCache. For AttentionSink, when tokens are shifted in KVCache, we need to use positions in KVCache instead of positions in the actual text. For torch.export compatibility, this just passes through the position - the actual position adjustment is handled by the cache update logic. Note: This class uses the model's max_context_len (params.max_context_len) for RoPE frequency table size, which should be large enough to support generation beyond the sliding window. The actual KV cache size is sink_size + window_size * 2. """ def __init__( self, params: ModelArgs, window_size: int, sink_size: int, eviction_batch_size: int, ): super().__init__(params) if self.params.use_hf_rope: self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k else: self.apply_rotary_emb_to_k = apply_rotary_emb_to_k # The KV cache size is sink_size + window_size * 2 (ring buffer needs 2x) self.kv_cache_size = sink_size + window_size * 2 self.window_size = window_size self.sink_size = sink_size # max_context_len from params is used for RoPE frequencies (should be large) self.max_context_length = self.params.max_context_len self.eviction_batch_size = eviction_batch_size def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): """ Get rotary embedding frequencies. For attention sink, we use the original position - the sliding window is handled by the cache index management, not by position shifting. """ assert input_pos is not None return super().get_freqs(input_pos, seq_len) def rerotate_k( self, k: torch.Tensor, original_position: int, new_position: int, ): """ Rerotate k from original_position to new_position. The shape of k is (batch_size, seq_len, n_local_heads, head_dim) """ seq_len = k.shape[1] original_freqs_cos = self.freqs_cos.narrow(0, original_position, seq_len) original_freqs_sin = self.freqs_sin.narrow(0, original_position, seq_len) new_freqs_cos = self.freqs_cos.narrow(0, new_position, seq_len) new_freqs_sin = self.freqs_sin.narrow(0, new_position, seq_len) rerotation_cos = ( new_freqs_cos * original_freqs_cos + new_freqs_sin * original_freqs_sin ) rerotation_sin = ( new_freqs_sin * original_freqs_cos - new_freqs_cos * original_freqs_sin ) return self.apply_rotary_emb_to_k(k, rerotation_cos, rerotation_sin)def _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len): """ Create causal mask for attention sink. Unlike regular ring buffer mask, this mask: 1. ALWAYS allows attending to sink tokens (positions 0 to sink_size-1) 2. Uses sliding window for other tokens Args: cache_positions: Tensor of actual positions stored at each cache index window_size: Size of the sliding window sink_size: Number of sink tokens to always attend to start_pos: Starting position of the current query seq_len: Length of the current query sequence """ pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1) delta = pos_q - cache_positions # Valid if position is filled (>= 0) and causal (delta >= 0) is_valid = (cache_positions >= 0) & (delta >= 0) # Sink tokens (original positions 0 to sink_size-1) are always visible is_sink = cache_positions < sink_size # Window tokens must be within sliding window is_in_window = delta < window_size # Final mask: valid AND (is_sink OR is_in_window) attn_mask = is_valid & (is_sink | is_in_window) attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712 return attn_maskclass CachePositionsManagerWithSink(nn.Module): """ Manages cache positions for attention sink + sliding window. For sink_size=0: behaves exactly like original CachePositionsManager. For sink_size>0: sink tokens go to fixed positions, rest uses ring buffer. IMPORTANT: cache_size should be the actual cache dimension size (2x window for ring buffer). """ def __init__(self, cache_size: int): super().__init__() # cache_size is the actual size of the kv cache dimension self.max_context_length = cache_size # Use zeros like original CachePositionsManager self.register_buffer( "cache_positions", torch.zeros((self.max_context_length,), dtype=torch.long, device="cpu"), ) def calculate_positions_and_update_indices( self, input_pos: torch.Tensor, seq_len: int ) -> torch.Tensor: """ Calculate indices into k_cache, v_cache for placing k_val, v_val. This is identical to the original CachePositionsManager logic. """ start_pos = input_pos[0].item() torch._check_is_size(start_pos) orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos # Simple ring buffer: just mod by cache size indices = orig_indices % self.max_context_length # Update cache_positions exactly like original CachePositionsManager full_t = torch.full((self.max_context_length,), -1, dtype=torch.long) arange_tensor = torch.arange(self.max_context_length, dtype=torch.long) cache_positions = torch.where( arange_tensor < start_pos, self.cache_positions, full_t ) self.cache_positions.copy_(cache_positions) self.cache_positions.index_copy_(0, indices, orig_indices) return indicesclass KVCacheWithAttentionSink(KVCache): """ KV cache that supports attention sink with torch.export compatibility. Uses a ring buffer approach for the sliding window portion while keeping the first sink_size tokens fixed. This avoids dynamic shape operations. Cache layout: [sink: 0 to sink_size-1] [ring_buffer: sink_size to sink_size + window_size*2 - 1] """ def __init__( self, n_heads: int, head_dim: int, enable_dynamic_shape: bool, rope: RopeWithAttentionSink, window_size: int, sink_size: int, eviction_batch_size: int, max_batch_size: int = 1, dtype=torch.float32, ): # Total cache size is sink_size + window_size * 2 (ring buffer needs 2x) total_cache_size = sink_size + window_size * 2 super().__init__( max_batch_size=max_batch_size, max_context_length=total_cache_size, n_heads=n_heads, head_dim=head_dim, enable_dynamic_shape=enable_dynamic_shape, dtype=dtype, ) self.rope = rope self.window_size = window_size self.sink_size = sink_size self.eviction_batch_size = eviction_batch_size self.is_ring_buffer = True # Cache positions manager for determining write locations # Pass the total cache size (same as self.max_context_length after super().__init__) self.cache_positions_manager = CachePositionsManagerWithSink(total_cache_size) def create_causal_mask_for_ring_buffer( self, start_pos: torch.Tensor, seq_len: int ): """ Create causal mask for the attention with attention sink. Sink tokens are ALWAYS visible, plus recent tokens in the window. """ cache_positions = self.cache_positions_manager.cache_positions if self.sink_size > 0: # Use attention sink mask that always allows attending to sink tokens return _create_causal_mask_for_attention_sink( cache_positions, self.window_size, self.sink_size, start_pos, seq_len ) else: # Pure ring buffer mode - use original mask with window_size = actual window return _create_causal_mask_for_ring_buffer( cache_positions, self.window_size, start_pos, seq_len ) def update( self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Update KV cache with new key-value pairs. Uses ring buffer indexing for positions >= sink_size. """ seq_len = k_val.size(2) assert seq_len <= self.k_cache.size( 2 ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" # Calculate write indices indices = self.cache_positions_manager.calculate_positions_and_update_indices( input_pos, seq_len ) start_pos = input_pos[0].item() torch._check_is_size(start_pos) self.k_cache.index_copy_(2, indices, k_val) self.v_cache.index_copy_(2, indices, v_val) return self.k_cache, self.v_cache def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: """ For ring buffer implementation, no explicit eviction is needed. The ring buffer automatically overwrites old values. Returns 0 to indicate no position shift is needed. """ return 0def attention_sink_forward( self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, input_pos: Optional[torch.Tensor] = None,): """ Forward function for attention with attention sink KV cache. Uses ring buffer masking for proper attention patterns. """ assert self.use_kv_cache assert input_pos is not None bsz, seqlen, _ = x.shape # QKV q, k, v = self.wq(x), self.wk(x), self.wv(x) q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) # RoPE relative positional embeddings q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) # Transpose for attention: [B, H, S, D] q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) # Update KV cache k, v = self.kv_cache.update(input_pos, k, v) # Use ring buffer mask since we have is_ring_buffer=True start_pos = input_pos[0].item() torch._check_is_size(start_pos) attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(start_pos, seqlen) # SDPA output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) # Return tuple like original AttentionMHA.forward return self.wo(output), Nonedef _replace_rope( module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink): def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: return isinstance(child, Rope) def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: return rope_with_attention_sink _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)def _replace_attention( module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink, sink_size: int, window_size: int, eviction_batch_size: int,): for _, child_module in module._modules.items(): if len(list(child_module.children())) > 0: # pyre-ignore [16] _replace_attention( module=child_module, # pyre-ignore [6] rope_with_attention_sink=rope_with_attention_sink, sink_size=sink_size, window_size=window_size, eviction_batch_size=eviction_batch_size, ) if isinstance(child_module, AttentionMHA): kv_cache = child_module.kv_cache # Always use KVCacheWithAttentionSink, even for sink_size=0 # This ensures we don't get replaced by CustomKVCache when use_sdpa_with_kv_cache=True kv_cache_with_attention_sink = KVCacheWithAttentionSink( n_heads=kv_cache.n_heads, head_dim=kv_cache.head_dim, enable_dynamic_shape=kv_cache.enable_dynamic_shape, rope=rope_with_attention_sink, max_batch_size=kv_cache.max_batch_size, window_size=window_size, sink_size=sink_size, eviction_batch_size=eviction_batch_size, dtype=kv_cache.k_cache.dtype, ) child_module.kv_cache = kv_cache_with_attention_sink # If using SDPACustom (fused SDPA op), enable attention mask support # so it uses our ring buffer / attention sink mask instead of simple causal mask if "SDPACustom" in child_module.SDPA.__class__.__name__: child_module.SDPA.use_attention_mask = True # Don't replace forward - let the original AttentionMHA.forward handle it # since our KVCache has is_ring_buffer=True, it will use the ring buffer maskdef enable_attention_sink( module: torch.nn.Module, params: ModelArgs, sink_size: int, window_size: int, eviction_batch_size: int,) -> torch.nn.Module: """ Transform the model to be able to run inference with Attention Sink. There mainly three steps: - Replace Rope with RopeWithAttentionSink - Replace Attention's KVCache with KVCacheWithAttentionSink - Replace Attention's forward with attention_sink_forward """ rope_with_attention_sink = RopeWithAttentionSink( params=params, window_size=window_size, sink_size=sink_size, eviction_batch_size=eviction_batch_size, ) _replace_rope(module, rope_with_attention_sink) _replace_attention( module=module, rope_with_attention_sink=rope_with_attention_sink, sink_size=sink_size, window_size=window_size, eviction_batch_size=eviction_batch_size, ) return module No newline at end of file
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The entire file content has been compressed onto a single line, removing all newlines. This makes the code completely unreadable and unmaintainable. The file should be properly formatted with appropriate line breaks, indentation, and spacing according to Python coding standards (PEP 8).

Copilot uses AI. Check for mistakes.
ppl = torch.exp(torch.cat(nlls).mean())
print(f"Perplexity: {ppl.item()}")
return ppl.item()
# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.import argparsefrom typing import Optional, Unionimport torchfrom datasets import load_datasetfrom executorch.examples.models.llama.export_llama_lib import ( get_quantizer_and_quant_params,)from executorch.extension.llm.export.builder import LLMEdgeManagerfrom lm_eval.evaluator import simple_evaluatefrom pytorch_tokenizers import get_tokenizerfrom pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizerfrom pytorch_tokenizers.tiktoken import TiktokenTokenizer as Tiktokenfrom torch.nn import CrossEntropyLossfrom tqdm import tqdmfrom .evaluate.eager_eval import EagerEvalWrapperfrom .export_llama_lib import ( _prepare_for_llama_export, build_args_parser as _build_args_parser,)class GraphModuleEvalWrapper(EagerEvalWrapper): """ A wrapper class for ExecuTorch py-binded integration with the lm-evaluation-harness library. """ def __init__( self, model: torch.fx.GraphModule, tokenizer: Union[SentencePieceTokenizer, Tiktoken], max_seq_length: Optional[int] = None, use_kv_cache: bool = False, generate_full_logits: bool = False, enable_dynamic_shape: bool = True, ): super().__init__( model=model, tokenizer=tokenizer, max_seq_length=max_seq_length ) self._model = model.to(self.device) self._use_kv_cache = use_kv_cache self._generate_full_logits = generate_full_logits self._enable_dynamic_shape = enable_dynamic_shape def _model_call(self, inps): if self._use_kv_cache: if not self._enable_dynamic_shape: # graph module exported without dynamic shape won't work with a different shape. # And we have to do single token prefill here. result_logits = [] for pos in range(inps.shape[-1]): pos_tensor = torch.tensor([pos], dtype=torch.int64) logits = self._model( inps[:, pos : pos + 1], {"input_pos": pos_tensor} ) result_logits.append(logits) if self._generate_full_logits: return torch.cat(result_logits, dim=1) else: return torch.stack(result_logits, dim=1) else: pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) # Batch process the whole sequence. logits = self._model( inps[:, : self._max_seq_length], {"input_pos": pos_tensor} ) return logits else: return self._model(inps) def _model_generate(self, context, max_length, eos_token_id): raise Exception("unimplemented")class ETPybindEvalWrapper(EagerEvalWrapper): """ A wrapper class for ExecuTorch py-binded integration with the lm-evaluation-harness library. """ def __init__( self, model: str, tokenizer: Union[SentencePieceTokenizer, Tiktoken], max_seq_length: Optional[int] = None, ): super().__init__(None, tokenizer, max_seq_length) # pyre-ignore self._model = model # Expects model to be path to a .pte file from executorch.extension.pybindings.portable_lib import _load_for_executorch # Load custom ops and quantized ops. from executorch.extension.pybindings import portable_lib # noqa # usort: skip # Note: import this after portable_lib from executorch.extension.llm.custom_ops import ( # noqa custom_ops, # usort: skip ) from executorch.kernels import quantized # noqa self._et_model = _load_for_executorch(self._model) self._use_kv_cache = self._et_model.run_method("use_kv_cache")[0] # pyre-ignore def _model_call(self, inps): # Given inps (tokens), return the logits from a single forward call # inps: Tensor of shape (1, max_seq_len - 1) # logits: Tensor of shape (1, max_seq_len - 1, vocab_size) result = [] if self._use_kv_cache: pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) result = self._et_model.forward( (inps[:, : self._max_seq_length], pos_tensor) ) else: result = self._et_model.forward((inps,)) if result[0].dim() != 3: raise ValueError( f"Dim of logits must be 3 for evaluation. Got {result[0].dim()} here. Add --generate_full_logits in export_llama to generate a pte file with full logits." ) return result[0]class ETRunnerEvalWrapper(EagerEvalWrapper): """ A wrapper class for ExecuTorch Runtime integration with the lm-evaluation-harness library. """ def __init__( self, model: str, tokenizer: Union[SentencePieceTokenizer, Tiktoken], tokenizer_bin: str, max_seq_length: Optional[int] = None, ): super().__init__(None, tokenizer, max_seq_length) # pyre-ignore self._model = model self._tokenizer_bin = tokenizer_bin def _model_call(self, inps): # Given inps (tokens), return the logits from a single # forward call # Example: # inps: Tensor of shape (1, N) # logits: Tensor of shape (1, N, vocab_size) passdef gen_eval_wrapper( model_name: str, args: argparse.ArgumentParser, llm_config=None,): """ Generates a wrapper interface around the provided model and tokenizer for the lm-evaluation-harness library. Returns: eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library. """ # If llm_config is not provided, convert args to llm_config if llm_config is None: from executorch.extension.llm.export.config.llm_config import LlmConfig llm_config = LlmConfig.from_args(args) tokenizer = get_tokenizer(llm_config.base.tokenizer_path) # ExecuTorch Binary Evaluation if (model := args.pte) is not None: # pyre-ignore if (tokenizer_bin := args.tokenizer_bin) is not None: # pyre-ignore # ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime return ETRunnerEvalWrapper( model=model, tokenizer=tokenizer, tokenizer_bin=tokenizer_bin, max_seq_length=llm_config.export.max_seq_length, ) # ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings return ETPybindEvalWrapper( model=model, tokenizer=tokenizer, # Exported model takes at most (max_seq_length - 1) tokens. # Note that the eager model takes at most max_seq_length tokens. max_seq_length=llm_config.export.max_seq_length - 1, ) pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( llm_config ) # GPTFastEvalWrapper: Create a wrapper around a pre-exported model manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) if len(quantizers) != 0: manager = manager.export().pt2e_quantize(quantizers) model = ( manager.pre_autograd_graph_module.to(device="cuda") # pyre-ignore if torch.cuda.is_available() else manager.pre_autograd_graph_module.to(device="cpu") ) return GraphModuleEvalWrapper( model=model, tokenizer=tokenizer, max_seq_length=llm_config.export.max_seq_length, use_kv_cache=llm_config.model.use_kv_cache, enable_dynamic_shape=llm_config.model.enable_dynamic_shape, ) else: # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch # for quantizers. Currently export only works with --kv_cache, but # fails without the kv_cache mode model = ( manager.model.eval().to(device="cuda") if torch.cuda.is_available() else manager.model.eval().to(device="cpu") ) # Save the checkpoint after the eager model preparation is done. # The reason for this option is that the checkpoint can be used # to do evaluations in other evaluation platforms, or with data # that is not available in this eval_llama. We save the checkpoint # here for consistency with eval_llama. The accuracy results we # get from eval_llama can be used as a reference to other evaluations. if args.output_eager_checkpoint_file is not None: # pyre-ignore torch.save(model, args.output_eager_checkpoint_file) return EagerEvalWrapper( model=model, tokenizer=tokenizer, max_seq_length=llm_config.export.max_seq_length, use_kv_cache=llm_config.model.use_kv_cache, )def build_args_parser() -> argparse.ArgumentParser: # Start with arg parser from export_llama_lib parser = _build_args_parser() # Add additional args specific to eval parser.add_argument( "--tasks", nargs="+", type=str, default=["wikitext"], help="list of lm-eluther tasks to evaluate usage: --tasks task1 task2", ) parser.add_argument( "--limit", type=int, default=None, help="number of samples to evalulate. If not set, evaluate all samples", ) parser.add_argument( "-f", "--num_fewshot", type=int, default=None, metavar="N", help="Number of examples in few-shot context", ) # Add additional args specific to eval via an ET Runner # Note: For initial integration, the tokenizer.model is also required parser.add_argument( "--pte", type=str, default=None, help="[For ExecuTorch] Path to the ExecuTorch model being evaluated. If provided, don't go through the export flow", ) parser.add_argument( "--tokenizer_bin", type=str, default=None, help="[For ExecuTorch] Path to the Tokenizer binary for evaluating ExecuTorch models via runtime", ) parser.add_argument( "--output_eager_checkpoint_file", type=str, default=None, help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.", ) # Set of parameters secpific to AttentionSink. parser.add_argument("--attention_sink_eval_tokens", type=int, default=0) return parserdef eval_llama( model_name: str, args: argparse.ArgumentParser,) -> None: # Convert args to LlmConfig from executorch.extension.llm.export.config.llm_config import LlmConfig llm_config = LlmConfig.from_args(args) # Generate the eval wrapper eval_wrapper = gen_eval_wrapper(model_name, args, llm_config) # Needed for loading mmlu dataset. # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files if args.tasks and "mmlu" in args.tasks: import datasets datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True # Evaluate the model with torch.no_grad(): eval_results = simple_evaluate( model=eval_wrapper, tasks=args.tasks, num_fewshot=args.num_fewshot, limit=args.limit, ) for task, res in eval_results["results"].items(): print(f"{task}: {res}")def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParser): """ Evaluate the model's perplexity when AttentionSink is enabled. This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py Updated for the ring-buffer based attention sink implementation. """ # Convert args to LlmConfig from executorch.extension.llm.export.config.llm_config import LlmConfig llm_config = LlmConfig.from_args(args) assert llm_config.model.use_attention_sink is not None assert args.attention_sink_eval_tokens > 0 attention_sink_params = llm_config.model.use_attention_sink.split(",") assert len(attention_sink_params) == 3 sink_size = int(attention_sink_params[0]) window_size = int(attention_sink_params[1]) # For the ring buffer implementation, the cache size is sink_size + window_size * 2 # max_context_length should be >= sink_size + window_size (for RoPE frequencies) # but can be larger to support extended generation assert llm_config.export.max_context_length >= sink_size + window_size, ( f"max_context_length ({llm_config.export.max_context_length}) must be >= " f"sink_size + window_size ({sink_size + window_size})" ) device = "cuda" if torch.cuda.is_available() else "cpu" manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) model = manager.model.eval().to(device=device) tokenizer = get_tokenizer(llm_config.base.tokenizer_path) eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") nlls = [] loss_fn = CrossEntropyLoss(reduction="none") progress_bar = tqdm(total=args.attention_sink_eval_tokens) input_pos = 0 while input_pos < args.attention_sink_eval_tokens: for text in eval_data["text"]: tokens = tokenizer.encode(text, bos=False, eos=False) if len(tokens) <= 0: continue with torch.no_grad(): num_tokens = min( len(tokens) - 1, args.attention_sink_eval_tokens - input_pos ) logits = model( torch.tensor( [tokens[:num_tokens]], dtype=torch.int64, device=device ), torch.tensor([input_pos], dtype=torch.int64, device=device), ).squeeze(dim=0) neg_log_likelihood = loss_fn( logits, torch.tensor( [tokens[1 : num_tokens + 1]], dtype=torch.int64, device=device ).view(-1), ) nlls.append(neg_log_likelihood) input_pos += num_tokens progress_bar.update(num_tokens) if input_pos >= args.attention_sink_eval_tokens: break ppl = torch.exp(torch.cat(nlls).mean()) print(f"Perplexity: {ppl.item()}") return ppl.item() No newline at end of file
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The entire file content has been compressed onto a single line, removing all newlines. This makes the code completely unreadable and unmaintainable. The file should be properly formatted with appropriate line breaks, indentation, and spacing according to Python coding standards (PEP 8).

Copilot uses AI. Check for mistakes.
Comment on lines +234 to +235
# Use 131072 (128k) as default for Llama 3.2 models or the original model max if larger.
default_rope_length = max(131072, model_args.max_context_len)
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hardcoded value of 131072 (128k) seems arbitrary and may not be appropriate for all model configurations. This could cause excessive memory usage for RoPE frequency tables in models that don't need such large context. Consider making this configurable through the llm_config or deriving it from model-specific parameters rather than using a hardcoded default.

Suggested change
# Use 131072 (128k) as default for Llama 3.2 models or the original model max if larger.
default_rope_length = max(131072, model_args.max_context_len)
# Derive the RoPE max context length from configuration rather than using a fixed
# default. Use the larger of export.max_context_length and the original
# model_args.max_context_len.
default_rope_length = max(
self.llm_config.export.max_context_length,
model_args.max_context_len,
)

Copilot uses AI. Check for mistakes.
ppl = torch.exp(torch.cat(nlls).mean())
print(f"Perplexity: {ppl.item()}")
return ppl.item()
# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.import argparsefrom typing import Optional, Unionimport torchfrom datasets import load_datasetfrom executorch.examples.models.llama.export_llama_lib import ( get_quantizer_and_quant_params,)from executorch.extension.llm.export.builder import LLMEdgeManagerfrom lm_eval.evaluator import simple_evaluatefrom pytorch_tokenizers import get_tokenizerfrom pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizerfrom pytorch_tokenizers.tiktoken import TiktokenTokenizer as Tiktokenfrom torch.nn import CrossEntropyLossfrom tqdm import tqdmfrom .evaluate.eager_eval import EagerEvalWrapperfrom .export_llama_lib import ( _prepare_for_llama_export, build_args_parser as _build_args_parser,)class GraphModuleEvalWrapper(EagerEvalWrapper): """ A wrapper class for ExecuTorch py-binded integration with the lm-evaluation-harness library. """ def __init__( self, model: torch.fx.GraphModule, tokenizer: Union[SentencePieceTokenizer, Tiktoken], max_seq_length: Optional[int] = None, use_kv_cache: bool = False, generate_full_logits: bool = False, enable_dynamic_shape: bool = True, ): super().__init__( model=model, tokenizer=tokenizer, max_seq_length=max_seq_length ) self._model = model.to(self.device) self._use_kv_cache = use_kv_cache self._generate_full_logits = generate_full_logits self._enable_dynamic_shape = enable_dynamic_shape def _model_call(self, inps): if self._use_kv_cache: if not self._enable_dynamic_shape: # graph module exported without dynamic shape won't work with a different shape. # And we have to do single token prefill here. result_logits = [] for pos in range(inps.shape[-1]): pos_tensor = torch.tensor([pos], dtype=torch.int64) logits = self._model( inps[:, pos : pos + 1], {"input_pos": pos_tensor} ) result_logits.append(logits) if self._generate_full_logits: return torch.cat(result_logits, dim=1) else: return torch.stack(result_logits, dim=1) else: pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) # Batch process the whole sequence. logits = self._model( inps[:, : self._max_seq_length], {"input_pos": pos_tensor} ) return logits else: return self._model(inps) def _model_generate(self, context, max_length, eos_token_id): raise Exception("unimplemented")class ETPybindEvalWrapper(EagerEvalWrapper): """ A wrapper class for ExecuTorch py-binded integration with the lm-evaluation-harness library. """ def __init__( self, model: str, tokenizer: Union[SentencePieceTokenizer, Tiktoken], max_seq_length: Optional[int] = None, ): super().__init__(None, tokenizer, max_seq_length) # pyre-ignore self._model = model # Expects model to be path to a .pte file from executorch.extension.pybindings.portable_lib import _load_for_executorch # Load custom ops and quantized ops. from executorch.extension.pybindings import portable_lib # noqa # usort: skip # Note: import this after portable_lib from executorch.extension.llm.custom_ops import ( # noqa custom_ops, # usort: skip ) from executorch.kernels import quantized # noqa self._et_model = _load_for_executorch(self._model) self._use_kv_cache = self._et_model.run_method("use_kv_cache")[0] # pyre-ignore def _model_call(self, inps): # Given inps (tokens), return the logits from a single forward call # inps: Tensor of shape (1, max_seq_len - 1) # logits: Tensor of shape (1, max_seq_len - 1, vocab_size) result = [] if self._use_kv_cache: pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) result = self._et_model.forward( (inps[:, : self._max_seq_length], pos_tensor) ) else: result = self._et_model.forward((inps,)) if result[0].dim() != 3: raise ValueError( f"Dim of logits must be 3 for evaluation. Got {result[0].dim()} here. Add --generate_full_logits in export_llama to generate a pte file with full logits." ) return result[0]class ETRunnerEvalWrapper(EagerEvalWrapper): """ A wrapper class for ExecuTorch Runtime integration with the lm-evaluation-harness library. """ def __init__( self, model: str, tokenizer: Union[SentencePieceTokenizer, Tiktoken], tokenizer_bin: str, max_seq_length: Optional[int] = None, ): super().__init__(None, tokenizer, max_seq_length) # pyre-ignore self._model = model self._tokenizer_bin = tokenizer_bin def _model_call(self, inps): # Given inps (tokens), return the logits from a single # forward call # Example: # inps: Tensor of shape (1, N) # logits: Tensor of shape (1, N, vocab_size) passdef gen_eval_wrapper( model_name: str, args: argparse.ArgumentParser, llm_config=None,): """ Generates a wrapper interface around the provided model and tokenizer for the lm-evaluation-harness library. Returns: eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library. """ # If llm_config is not provided, convert args to llm_config if llm_config is None: from executorch.extension.llm.export.config.llm_config import LlmConfig llm_config = LlmConfig.from_args(args) tokenizer = get_tokenizer(llm_config.base.tokenizer_path) # ExecuTorch Binary Evaluation if (model := args.pte) is not None: # pyre-ignore if (tokenizer_bin := args.tokenizer_bin) is not None: # pyre-ignore # ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime return ETRunnerEvalWrapper( model=model, tokenizer=tokenizer, tokenizer_bin=tokenizer_bin, max_seq_length=llm_config.export.max_seq_length, ) # ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings return ETPybindEvalWrapper( model=model, tokenizer=tokenizer, # Exported model takes at most (max_seq_length - 1) tokens. # Note that the eager model takes at most max_seq_length tokens. max_seq_length=llm_config.export.max_seq_length - 1, ) pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( llm_config ) # GPTFastEvalWrapper: Create a wrapper around a pre-exported model manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) if len(quantizers) != 0: manager = manager.export().pt2e_quantize(quantizers) model = ( manager.pre_autograd_graph_module.to(device="cuda") # pyre-ignore if torch.cuda.is_available() else manager.pre_autograd_graph_module.to(device="cpu") ) return GraphModuleEvalWrapper( model=model, tokenizer=tokenizer, max_seq_length=llm_config.export.max_seq_length, use_kv_cache=llm_config.model.use_kv_cache, enable_dynamic_shape=llm_config.model.enable_dynamic_shape, ) else: # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch # for quantizers. Currently export only works with --kv_cache, but # fails without the kv_cache mode model = ( manager.model.eval().to(device="cuda") if torch.cuda.is_available() else manager.model.eval().to(device="cpu") ) # Save the checkpoint after the eager model preparation is done. # The reason for this option is that the checkpoint can be used # to do evaluations in other evaluation platforms, or with data # that is not available in this eval_llama. We save the checkpoint # here for consistency with eval_llama. The accuracy results we # get from eval_llama can be used as a reference to other evaluations. if args.output_eager_checkpoint_file is not None: # pyre-ignore torch.save(model, args.output_eager_checkpoint_file) return EagerEvalWrapper( model=model, tokenizer=tokenizer, max_seq_length=llm_config.export.max_seq_length, use_kv_cache=llm_config.model.use_kv_cache, )def build_args_parser() -> argparse.ArgumentParser: # Start with arg parser from export_llama_lib parser = _build_args_parser() # Add additional args specific to eval parser.add_argument( "--tasks", nargs="+", type=str, default=["wikitext"], help="list of lm-eluther tasks to evaluate usage: --tasks task1 task2", ) parser.add_argument( "--limit", type=int, default=None, help="number of samples to evalulate. If not set, evaluate all samples", ) parser.add_argument( "-f", "--num_fewshot", type=int, default=None, metavar="N", help="Number of examples in few-shot context", ) # Add additional args specific to eval via an ET Runner # Note: For initial integration, the tokenizer.model is also required parser.add_argument( "--pte", type=str, default=None, help="[For ExecuTorch] Path to the ExecuTorch model being evaluated. If provided, don't go through the export flow", ) parser.add_argument( "--tokenizer_bin", type=str, default=None, help="[For ExecuTorch] Path to the Tokenizer binary for evaluating ExecuTorch models via runtime", ) parser.add_argument( "--output_eager_checkpoint_file", type=str, default=None, help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.", ) # Set of parameters secpific to AttentionSink. parser.add_argument("--attention_sink_eval_tokens", type=int, default=0) return parserdef eval_llama( model_name: str, args: argparse.ArgumentParser,) -> None: # Convert args to LlmConfig from executorch.extension.llm.export.config.llm_config import LlmConfig llm_config = LlmConfig.from_args(args) # Generate the eval wrapper eval_wrapper = gen_eval_wrapper(model_name, args, llm_config) # Needed for loading mmlu dataset. # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files if args.tasks and "mmlu" in args.tasks: import datasets datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True # Evaluate the model with torch.no_grad(): eval_results = simple_evaluate( model=eval_wrapper, tasks=args.tasks, num_fewshot=args.num_fewshot, limit=args.limit, ) for task, res in eval_results["results"].items(): print(f"{task}: {res}")def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParser): """ Evaluate the model's perplexity when AttentionSink is enabled. This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py Updated for the ring-buffer based attention sink implementation. """ # Convert args to LlmConfig from executorch.extension.llm.export.config.llm_config import LlmConfig llm_config = LlmConfig.from_args(args) assert llm_config.model.use_attention_sink is not None assert args.attention_sink_eval_tokens > 0 attention_sink_params = llm_config.model.use_attention_sink.split(",") assert len(attention_sink_params) == 3 sink_size = int(attention_sink_params[0]) window_size = int(attention_sink_params[1]) # For the ring buffer implementation, the cache size is sink_size + window_size * 2 # max_context_length should be >= sink_size + window_size (for RoPE frequencies) # but can be larger to support extended generation assert llm_config.export.max_context_length >= sink_size + window_size, ( f"max_context_length ({llm_config.export.max_context_length}) must be >= " f"sink_size + window_size ({sink_size + window_size})" ) device = "cuda" if torch.cuda.is_available() else "cpu" manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) model = manager.model.eval().to(device=device) tokenizer = get_tokenizer(llm_config.base.tokenizer_path) eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") nlls = [] loss_fn = CrossEntropyLoss(reduction="none") progress_bar = tqdm(total=args.attention_sink_eval_tokens) input_pos = 0 while input_pos < args.attention_sink_eval_tokens: for text in eval_data["text"]: tokens = tokenizer.encode(text, bos=False, eos=False) if len(tokens) <= 0: continue with torch.no_grad(): num_tokens = min( len(tokens) - 1, args.attention_sink_eval_tokens - input_pos ) logits = model( torch.tensor( [tokens[:num_tokens]], dtype=torch.int64, device=device ), torch.tensor([input_pos], dtype=torch.int64, device=device), ).squeeze(dim=0) neg_log_likelihood = loss_fn( logits, torch.tensor( [tokens[1 : num_tokens + 1]], dtype=torch.int64, device=device ).view(-1), ) nlls.append(neg_log_likelihood) input_pos += num_tokens progress_bar.update(num_tokens) if input_pos >= args.attention_sink_eval_tokens: break ppl = torch.exp(torch.cat(nlls).mean()) print(f"Perplexity: {ppl.item()}") return ppl.item() No newline at end of file
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in the assertion message: "secpific" should be "specific".

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings February 5, 2026 22:41
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 9 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +158 to +167

// TEMPORARY: For sliding window / infinite context testing,
// override max_new_tokens to allow unlimited generation
if (max_new_tokens <= 0) {
max_new_tokens = 1000000; // Effectively unlimited
}
// If user specified seq_len, use that instead
if (config.seq_len > 0 && config.seq_len > max_new_tokens) {
max_new_tokens = config.seq_len;
}
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TEMPORARY override logic (lines 159-167) should not be merged to production. This code allows "effectively unlimited" generation (max_new_tokens = 1000000) which could cause memory issues or infinite loops. If this is needed for testing, it should be guarded by a flag or removed before merging to production branches.

Suggested change
// TEMPORARY: For sliding window / infinite context testing,
// override max_new_tokens to allow unlimited generation
if (max_new_tokens <= 0) {
max_new_tokens = 1000000; // Effectively unlimited
}
// If user specified seq_len, use that instead
if (config.seq_len > 0 && config.seq_len > max_new_tokens) {
max_new_tokens = config.seq_len;
}
// If user specified seq_len, use that instead.
if (config.seq_len > 0) {
max_new_tokens = static_cast<int>(config.seq_len);
}

Copilot uses AI. Check for mistakes.
Comment on lines +11 to +12
# window_size=124: 滑动窗口大小
# eviction_batch_size=1: 每次驱逐 1 个 token
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chinese characters are used in comments. Comments should be in English for consistency with the rest of the codebase. Please translate:

  • Line 11: "# window_size=124: 滑动窗口大小" should be "# window_size=124: Sliding window size"
  • Line 12: "# eviction_batch_size=1: 每次驱逐 1 个 token" should be "# eviction_batch_size=1: Evict 1 token at a time"
Suggested change
# window_size=124: 滑动窗口大小
# eviction_batch_size=1: 每次驱逐 1 token
# window_size=124: Sliding window size
# eviction_batch_size=1: Evict 1 token at a time

Copilot uses AI. Check for mistakes.
Comment on lines +158 to +167

// TEMPORARY: For sliding window / infinite context testing,
// override max_new_tokens to allow unlimited generation
if (max_new_tokens <= 0) {
max_new_tokens = 1000000; // Effectively unlimited
}
// If user specified seq_len, use that instead
if (config.seq_len > 0 && config.seq_len > max_new_tokens) {
max_new_tokens = config.seq_len;
}
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TEMPORARY override logic has problematic behavior. The condition on line 165 checks if config.seq_len > max_new_tokens, but if max_new_tokens was just set to 1000000 on line 162, this condition will rarely be true unless seq_len is also extremely large. This means the seq_len override may not work as intended when max_new_tokens was initially <= 0.

Consider revising the logic to check config.seq_len before applying the "unlimited" override, or reconsider the order of these checks. For example:

  1. First apply seq_len override if config.seq_len > 0
  2. Then apply the unlimited fallback only if the result is still <= 0
Suggested change
// TEMPORARY: For sliding window / infinite context testing,
// override max_new_tokens to allow unlimited generation
if (max_new_tokens <= 0) {
max_new_tokens = 1000000; // Effectively unlimited
}
// If user specified seq_len, use that instead
if (config.seq_len > 0 && config.seq_len > max_new_tokens) {
max_new_tokens = config.seq_len;
}
// If user specified seq_len, use that instead
if (config.seq_len > 0 && config.seq_len > max_new_tokens) {
max_new_tokens = config.seq_len;
}
// TEMPORARY: For sliding window / infinite context testing,
// override max_new_tokens to allow unlimited generation
if (max_new_tokens <= 0) {
max_new_tokens = 1000000; // Effectively unlimited
}

Copilot uses AI. Check for mistakes.
Comment on lines 307 to 349
@@ -212,26 +311,42 @@ def attention_sink_forward(
freqs_sin: torch.Tensor,
input_pos: Optional[torch.Tensor] = None,
):
"""
Forward function for attention with attention sink KV cache.
Uses ring buffer masking for proper attention patterns.
"""
assert self.use_kv_cache
assert input_pos is not None

bsz, seqlen, _ = x.shape

# QKV
q, k, v = self.wq(x), self.wk(x), self.wv(x)
# We need view_copy elimination
q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim)
k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

# Prepare for space in KV cache and get position shift
position_shift = self.kv_cache.evict_tokens(input_pos, seqlen)

# RoPE relative positional embeddings with shifted position in KV cache
# RoPE relative positional embeddings
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)

output = self.SDPA(input_pos + position_shift, q, k, v, bsz, seqlen, self.mask)
return self.wo(output)
# Transpose for attention: [B, H, S, D]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# Update KV cache
k, v = self.kv_cache.update(input_pos, k, v)

# Use ring buffer mask since we have is_ring_buffer=True
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(start_pos, seqlen)

# SDPA
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask)

# Return tuple like original AttentionMHA.forward
return self.wo(output), None
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function attention_sink_forward is defined but never used. The comment on line 403-404 states "Don't replace forward - let the original AttentionMHA.forward handle it", but the docstring on line 419 still says "Replace Attention's forward with attention_sink_forward". This is inconsistent and creates confusion.

Either remove the unused attention_sink_forward function (lines 307-349) if it's truly not needed, or use it if it's meant to replace the forward method. If keeping it for future use, add a comment explaining why it's currently unused.

Copilot uses AI. Check for mistakes.
- Replace Rope with RopeWithAttentionSink
- Replace Attention's KVCache with KVCacheWithAttentionSink, forward with attention_sink_forward
- Replace Attention's KVCache with KVCacheWithAttentionSink
- Replace Attention's forward with attention_sink_forward
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring is outdated and inconsistent with the actual implementation. Line 419 states "Replace Attention's forward with attention_sink_forward", but the code on lines 403-404 explicitly does NOT replace the forward method. Update the docstring to accurately reflect what the function does.

Suggested change
- Replace Attention's forward with attention_sink_forward
- Configure Attention/SDPA to use attention masks compatible with the ring buffer and attention sink

Copilot uses AI. Check for mistakes.
Comment on lines +21 to +23
CachePositionsManager,
KVCache,
RingKVCache,
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two imports from executorch.examples.models.llama.attention are unused: CachePositionsManager and RingKVCache (lines 21 and 23). These should be removed to keep the imports clean.

Suggested change
CachePositionsManager,
KVCache,
RingKVCache,
KVCache,

Copilot uses AI. Check for mistakes.
k_decode = torch.randn(1, self.n_heads, 1, self.head_dim)
v_decode = torch.randn(1, self.n_heads, 1, self.head_dim)
input_pos = torch.tensor([20 + i], dtype=torch.long)
k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode)
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable k_out is not used.

Suggested change
k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode)
self.kv_cache.update(input_pos, k_decode, v_decode)

Copilot uses AI. Check for mistakes.
v_decode = torch.randn(1, self.n_heads, 1, self.head_dim)
input_pos = torch.tensor([20 + i], dtype=torch.long)
k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode)

Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable v_out is not used.

Suggested change
# Verify that the returned k/v match the decode inputs
self.assertEqual(k_out.shape, k_decode.shape)
self.assertEqual(v_out.shape, v_decode.shape)
torch.testing.assert_close(k_out, k_decode)
torch.testing.assert_close(v_out, v_decode)

Copilot uses AI. Check for mistakes.

def test_cache_positions_consistency(self):
"""Test that cache positions remain consistent during generation."""
cache_size = 32
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable cache_size is not used.

Suggested change
cache_size = 32

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings February 5, 2026 23:20
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 11 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 402 to +405
There mainly three steps:
- Replace Rope with RopeWithAttentionSink
- Replace Attention's KVCache with KVCacheWithAttentionSink, forward with attention_sink_forward
- Replace Attention's KVCache with KVCacheWithAttentionSink
- Replace Attention's forward with attention_sink_forward
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring states "There mainly three steps" with a grammatical error. It should be "There are mainly three steps". Additionally, step 3 mentions "Replace Attention's forward with attention_sink_forward", but the implementation at lines 389-390 explicitly does NOT replace the forward method. The docstring should be updated to accurately reflect what the function does.

Copilot uses AI. Check for mistakes.
Comment on lines 154 to 157
# Use zeros like original CachePositionsManager
self.register_buffer(
"cache_positions",
torch.zeros((self.max_context_length,), dtype=torch.long, device="cpu"),
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CachePositionsManagerWithSink initializes cache_positions to all zeros, which differs from the original CachePositionsManager that uses sentinel values (-1) for invalid/unfilled positions (see attention.py lines 219-225). This means before the cache is fully populated, positions that haven't been written will have a value of 0, which could be confused with actual position 0. The mask creation function checks cache_positions >= 0 to determine validity, so unfilled positions with value 0 will be incorrectly treated as valid position 0. This should use -1 for unfilled positions, or implement the same sentinel logic as the base CachePositionsManager.

Suggested change
# Use zeros like original CachePositionsManager
self.register_buffer(
"cache_positions",
torch.zeros((self.max_context_length,), dtype=torch.long, device="cpu"),
# Use -1 as sentinel for unfilled positions, matching CachePositionsManager
self.register_buffer(
"cache_positions",
torch.full(
(self.max_context_length,),
-1,
dtype=torch.long,
device="cpu",
),

Copilot uses AI. Check for mistakes.
Comment on lines +160 to +193
def calculate_positions_and_update_indices(
self, input_pos: torch.Tensor, seq_len: int
) -> torch.Tensor:
"""
Calculate indices into k_cache, v_cache for placing k_val, v_val.

Index calculation:
- Position < sink_size: index = position (sink tokens at fixed indices)
- Position >= sink_size: index = sink_size + (position - sink_size) % ring_size

This ensures sink tokens (indices 0 to sink_size-1) are NEVER overwritten.
"""
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)

# Original positions for the sequence
orig_positions = torch.arange(seq_len, dtype=torch.long) + start_pos

if self.sink_size == 0:
# Simple ring buffer: just mod by cache size
indices = orig_positions % self.max_context_length
else:
# Shifted ring buffer: sink tokens at fixed indices, rest in ring buffer
# For position >= sink_size: index = sink_size + (position - sink_size) % ring_size
shifted = orig_positions - self.sink_size
ring_indices = self.sink_size + (shifted % self.ring_size)
# For position < sink_size: use position directly
indices = torch.where(orig_positions < self.sink_size, orig_positions, ring_indices)

# Update cache_positions to track what position is at each index
# Only update the indices we're writing to
self.cache_positions.index_copy_(0, indices, orig_positions)

return indices
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The calculate_positions_and_update_indices method doesn't implement the sentinel value logic that the original CachePositionsManager has (see attention.py lines 219-225). Before the cache wraps around, positions that haven't been written yet should be marked as invalid with -1, not left as 0. The current implementation doesn't clear previously unfilled positions to -1 when updating, which can lead to stale or incorrect values in cache_positions being treated as valid. This should implement the same sentinel logic as the base class where positions that haven't been written yet are set to -1.

Copilot uses AI. Check for mistakes.
Comment on lines +238 to 255
def create_causal_mask_for_ring_buffer(
self, start_pos: torch.Tensor, seq_len: int
):
"""
input_pos_item = input_pos.item()
torch._check_is_size(input_pos_item)
if input_pos_item + self.position_shift + seq_len > self.max_context_length:
# There are not enough spaces in the cache to store the new tokens.
# We need to evict some old tokens and shift some recent tokens.
num_to_evict = max(
input_pos_item
+ self.position_shift
- self.max_context_length
+ seq_len,
self.eviction_batch_size,
)
num_to_keep = (
input_pos_item + self.position_shift - self.sink_size - num_to_evict
)
num_empty_space = self.window_size - num_to_keep
dim_to_slice = 2
k_to_keep = self.k_cache.narrow(
dim_to_slice,
self.sink_size + num_to_evict, # pyre-ignore [6]
num_to_keep, # pyre-ignore [6]
)
k_to_keep = self.rope.rerotate_k(
k=k_to_keep.transpose(1, 2),
original_position=(self.sink_size + num_to_evict), # pyre-ignore [6]
new_position=self.sink_size,
).transpose(1, 2)
self.k_cache = torch.cat(
[
self.k_cache.narrow(dim_to_slice, 0, self.sink_size),
k_to_keep,
torch.zeros_like(
self.k_cache.narrow(
dim_to_slice, 0, num_empty_space # pyre-ignore [6]
)
),
],
dim=dim_to_slice,
Create causal mask for the attention with attention sink.
Sink tokens are ALWAYS visible, plus recent tokens in the window.
"""
cache_positions = self.cache_positions_manager.cache_positions
if self.sink_size > 0:
# Use attention sink mask that always allows attending to sink tokens
return _create_causal_mask_for_attention_sink(
cache_positions, self.window_size, self.sink_size, start_pos, seq_len
)
self.v_cache = torch.cat(
[
self.v_cache.narrow(dim_to_slice, 0, self.sink_size),
self.v_cache.narrow(
dim_to_slice,
self.sink_size + num_to_evict, # pyre-ignore [6]
num_to_keep, # pyre-ignore [6]
),
torch.zeros_like(
self.v_cache.narrow(
dim_to_slice, 0, num_empty_space # pyre-ignore [6]
)
),
],
dim=dim_to_slice,
else:
# Pure ring buffer mode - use original mask with window_size = actual window
return _create_causal_mask_for_ring_buffer(
cache_positions, self.window_size, start_pos, seq_len
)
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type mismatch in create_causal_mask_for_ring_buffer: the parameter start_pos is declared as torch.Tensor (line 239), but it's passed directly to _create_causal_mask_for_attention_sink (line 249) which expects an int according to its signature (line 103). This will cause a type error. The start_pos should either be extracted to an int with .item() before passing, or the function signature should be updated to accept torch.Tensor consistently.

Copilot uses AI. Check for mistakes.
Comment on lines 293 to 335
@@ -212,26 +297,42 @@ def attention_sink_forward(
freqs_sin: torch.Tensor,
input_pos: Optional[torch.Tensor] = None,
):
"""
Forward function for attention with attention sink KV cache.
Uses ring buffer masking for proper attention patterns.
"""
assert self.use_kv_cache
assert input_pos is not None

bsz, seqlen, _ = x.shape

# QKV
q, k, v = self.wq(x), self.wk(x), self.wv(x)
# We need view_copy elimination
q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim)
k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

# Prepare for space in KV cache and get position shift
position_shift = self.kv_cache.evict_tokens(input_pos, seqlen)

# RoPE relative positional embeddings with shifted position in KV cache
# RoPE relative positional embeddings
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)

output = self.SDPA(input_pos + position_shift, q, k, v, bsz, seqlen, self.mask)
return self.wo(output)
# Transpose for attention: [B, H, S, D]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# Update KV cache
k, v = self.kv_cache.update(input_pos, k, v)

# Use ring buffer mask since we have is_ring_buffer=True
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(start_pos, seqlen)

# SDPA
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask)

# Return tuple like original AttentionMHA.forward
return self.wo(output), None
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function attention_sink_forward is defined but appears to be unused based on the comment at lines 389-390 which states "Don't replace forward". This creates dead code in the codebase. Either this function should be used to replace the forward method, or it should be removed if it's no longer needed. The docstring at line 405 also still mentions "Replace Attention's forward with attention_sink_forward", which is inconsistent with the actual implementation.

Copilot uses AI. Check for mistakes.
Comment on lines +277 to +280
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
self.k_cache.index_copy_(2, indices, k_val)
self.v_cache.index_copy_(2, indices, v_val)
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The update method in KVCacheWithAttentionSink always uses index_copy_ without checking the enable_dynamic_shape flag that was passed during initialization. The base KVCache class (in attention.py lines 88-118) has different behavior depending on this flag. For consistency with the parent class and to support both dynamic and static shape modes, this method should check self.enable_dynamic_shape and implement both code paths like the base class does.

Suggested change
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
self.k_cache.index_copy_(2, indices, k_val)
self.v_cache.index_copy_(2, indices, v_val)
if self.enable_dynamic_shape:
# Dynamic-shape-friendly bulk update using index_copy_
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
self.k_cache.index_copy_(2, indices, k_val)
self.v_cache.index_copy_(2, indices, v_val)
else:
# Static-shape-friendly path: per-position assignment using the same indices.
# This mirrors the base KVCache behavior of avoiding dynamic index_copy_
# in static-shape mode while preserving ring buffer semantics.
for i in range(seq_len):
cache_idx = indices[i].item()
torch._check_is_size(cache_idx)
# k_cache/v_cache shape is [batch, heads, context, head_dim]
# k_val/v_val shape is [batch, heads, seq_len, head_dim]
self.k_cache[:, :, cache_idx, :] = k_val[:, :, i, :]
self.v_cache[:, :, cache_idx, :] = v_val[:, :, i, :]

Copilot uses AI. Check for mistakes.
k_decode = torch.randn(1, self.n_heads, 1, self.head_dim)
v_decode = torch.randn(1, self.n_heads, 1, self.head_dim)
input_pos = torch.tensor([20 + i], dtype=torch.long)
k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode)
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable k_out is not used.

Suggested change
k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode)
self.kv_cache.update(input_pos, k_decode, v_decode)

Copilot uses AI. Check for mistakes.
k_decode = torch.randn(1, self.n_heads, 1, self.head_dim)
v_decode = torch.randn(1, self.n_heads, 1, self.head_dim)
input_pos = torch.tensor([20 + i], dtype=torch.long)
k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode)
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable v_out is not used.

Suggested change
k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode)
self.kv_cache.update(input_pos, k_decode, v_decode)

Copilot uses AI. Check for mistakes.

def test_cache_positions_consistency(self):
"""Test that cache positions remain consistent during generation."""
cache_size = 32
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable cache_size is not used.

Suggested change
cache_size = 32

Copilot uses AI. Check for mistakes.
Comment on lines +21 to +23
CachePositionsManager,
KVCache,
RingKVCache,
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'CachePositionsManager' is not used.
Import of 'RingKVCache' is not used.

Suggested change
CachePositionsManager,
KVCache,
RingKVCache,
KVCache,

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings February 6, 2026 00:06
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated 7 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Please also take a look at examples/models/llama/runner to make sure it can emit many tokens, exceeding context size.

Please check whether the output makes sense or not
```
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing opening backtick for the code block. The code block should start with triple backticks before the shell command.

Copilot uses AI. Check for mistakes.
Comment on lines +158 to +163

// TEMPORARY: For sliding window / infinite context testing,
// override max_new_tokens to allow unlimited generation
if (max_new_tokens <= 0) {
max_new_tokens = 1000000; // Effectively unlimited
}
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TEMPORARY max_new_tokens override with a hardcoded value of 1000000 is concerning for production code. This effectively removes any generation length limits and could cause memory issues or runaway generation. Consider either removing this temporary code if testing is complete, or adding proper configuration for unlimited generation with appropriate safeguards.

Suggested change
// TEMPORARY: For sliding window / infinite context testing,
// override max_new_tokens to allow unlimited generation
if (max_new_tokens <= 0) {
max_new_tokens = 1000000; // Effectively unlimited
}

Copilot uses AI. Check for mistakes.
Comment on lines 273 to 284
fbcode_target(_kind = runtime.python_test,
name = "attention_sink_test",
srcs = [
"source_transformation/test_attention_sink.py",
],
supports_static_listing = False,
deps = [
"fbsource//third-party/pypi/parameterized:parameterized",
"//caffe2:torch",
":export_library",
],
)
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The BUCK file still references "source_transformation/test_attention_sink.py" at line 276, but this file is being deleted in this PR. This will cause build failures. The test target "attention_sink_test" should either be removed or updated to reference the correct test file.

Copilot uses AI. Check for mistakes.
- Replace Rope with RopeWithAttentionSink
- Replace Attention's KVCache with KVCacheWithAttentionSink, forward with attention_sink_forward
- Replace Attention's KVCache with KVCacheWithAttentionSink
- Replace Attention's forward with attention_sink_forward
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment mentions that forward should be replaced with attention_sink_forward (line 410), but the actual code at line 394 explicitly states "Don't replace forward". This is a discrepancy between the docstring and the implementation. The docstring should be updated to reflect that the forward method is NOT being replaced.

Suggested change
- Replace Attention's forward with attention_sink_forward
- Keep Attention's original forward; KVCacheWithAttentionSink and masks provide Attention Sink behavior.

Copilot uses AI. Check for mistakes.
Comment on lines 298 to 340
@@ -212,26 +302,42 @@ def attention_sink_forward(
freqs_sin: torch.Tensor,
input_pos: Optional[torch.Tensor] = None,
):
"""
Forward function for attention with attention sink KV cache.
Uses ring buffer masking for proper attention patterns.
"""
assert self.use_kv_cache
assert input_pos is not None

bsz, seqlen, _ = x.shape

# QKV
q, k, v = self.wq(x), self.wk(x), self.wv(x)
# We need view_copy elimination
q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim)
k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

# Prepare for space in KV cache and get position shift
position_shift = self.kv_cache.evict_tokens(input_pos, seqlen)

# RoPE relative positional embeddings with shifted position in KV cache
# RoPE relative positional embeddings
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)

output = self.SDPA(input_pos + position_shift, q, k, v, bsz, seqlen, self.mask)
return self.wo(output)
# Transpose for attention: [B, H, S, D]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# Update KV cache
k, v = self.kv_cache.update(input_pos, k, v)

# Use ring buffer mask since we have is_ring_buffer=True
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(start_pos, seqlen)

# SDPA
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask)

# Return tuple like original AttentionMHA.forward
return self.wo(output), None
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function attention_sink_forward is defined but never used. According to the comment at line 394, the forward method is not being replaced, and instead the original AttentionMHA.forward handles the ring buffer case. This unused function adds confusion and maintenance burden. Consider either removing this function or adding a comment explaining why it's kept (e.g., for potential future use or documentation purposes).

Copilot uses AI. Check for mistakes.
@@ -293,7 +406,8 @@ def enable_attention_sink(
Transform the model to be able to run inference with Attention Sink.
There mainly three steps:
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in comment: "There mainly three steps" should be "There are mainly three steps".

Suggested change
There mainly three steps:
There are mainly three steps:

Copilot uses AI. Check for mistakes.
+base.model_class="llama3_2" \
+base.checkpoint="consolidated.00.pth" \
+base.params="params.json"
```
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing closing backtick for the code block. The code block starting at line 23 is not properly closed before this line, which will cause the entire section to be rendered incorrectly.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants