diff --git a/examples/models/llama/BUCK b/examples/models/llama/BUCK index 9d9897a2819..57ed846dd86 100644 --- a/examples/models/llama/BUCK +++ b/examples/models/llama/BUCK @@ -283,6 +283,18 @@ fbcode_target(_kind = runtime.python_test, ], ) +fbcode_target(_kind = runtime.python_test, + name = "attention_sink_ring_buffer_test", + srcs = [ + "source_transformation/test_attention_sink_ring_buffer.py", + ], + supports_static_listing = False, + deps = [ + "//caffe2:torch", + ":export_library", + ], +) + fbcode_target(_kind = runtime.python_test, name = "quantized_sdpa_source_transform_test", srcs = [ diff --git a/examples/models/llama/config/llama_attention_sink.yaml b/examples/models/llama/config/llama_attention_sink.yaml new file mode 100644 index 00000000000..1d859035d74 --- /dev/null +++ b/examples/models/llama/config/llama_attention_sink.yaml @@ -0,0 +1,31 @@ +base: + metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' + +model: + use_sdpa_with_kv_cache: True # Now supported! We set use_attention_mask=True on SDPACustom + use_kv_cache: True + dtype_override: fp32 + enable_dynamic_shape: True + # Attention Sink: "sink_size,window_size,eviction_batch_size" + # sink_size=4: Keep first 4 tokens (e.g., BOS + system prompt) + # window_size=124: 滑动窗口大小 + # eviction_batch_size=1: 每次驱逐 1 个 token + # KV cache size = sink_size + window_size * 2 = 4 + 124*2 = 252 + use_attention_sink: "4,124,1" + +export: + # max_context_length controls the RoPE frequency table size. + # It must be >= sink_size + window_size (128), but larger values are + # recommended to support generation beyond the sliding window. + # The model default (e.g., 8192 or 131072) is typically used if not specified. + # For testing, we use the model's default by not setting this explicitly. + +quantization: + qmode: 8da4w + group_size: 128 + embedding_quantize: 4,32 + +backend: + xnnpack: + enabled: True + extended_ops: True diff --git a/examples/models/llama/eval_llama_lib.py b/examples/models/llama/eval_llama_lib.py index 03f8f5cd759..e13a5299e61 100644 --- a/examples/models/llama/eval_llama_lib.py +++ b/examples/models/llama/eval_llama_lib.py @@ -338,6 +338,8 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse Evaluate the model's perplexity when AttentionSink is enabled. This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py + + Updated for the ring-buffer based attention sink implementation. """ # Convert args to LlmConfig from executorch.extension.llm.export.config.llm_config import LlmConfig @@ -351,7 +353,13 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse sink_size = int(attention_sink_params[0]) window_size = int(attention_sink_params[1]) - assert llm_config.export.max_seq_length == sink_size + window_size + # For the ring buffer implementation, the cache size is sink_size + window_size * 2 + # max_context_length should be >= sink_size + window_size (for RoPE frequencies) + # but can be larger to support extended generation + assert llm_config.export.max_context_length >= sink_size + window_size, ( + f"max_context_length ({llm_config.export.max_context_length}) must be >= " + f"sink_size + window_size ({sink_size + window_size})" + ) device = "cuda" if torch.cuda.is_available() else "cpu" manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 1ec85936f7a..3be00d78711 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -218,7 +218,22 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): window_size = int(attention_sink_params[1]) eviction_batch_size = int(attention_sink_params[2]) - assert self.llm_config.export.max_context_length == sink_size + window_size + # max_context_length must be >= sink_size + window_size to have enough RoPE frequencies + # A larger max_context_length is allowed (and recommended) to support generation beyond + # the sliding window size. + assert self.llm_config.export.max_context_length >= sink_size + window_size, ( + f"max_context_length ({self.llm_config.export.max_context_length}) must be >= " + f"sink_size + window_size ({sink_size + window_size})" + ) + + # IMPORTANT: For attention sink, we need RoPE frequencies for all possible generation + # positions, not just the cache size. Override the model's max_context_len to use + # a larger value that supports extended generation. + # We use model_args.max_context_len which was set from export.max_context_length + # but for RoPE we need the full generation length capability. + # Use 131072 (128k) as default for Llama 3.2 models or the original model max if larger. + default_rope_length = max(131072, model_args.max_context_len) + model_args.max_context_len = default_rope_length self.model_ = enable_attention_sink( module=self.model_, diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index 22bd8a3e228..3a69d85e61b 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -7,12 +7,21 @@ # Components for supporting Attention Sink. See # https://arxiv.org/abs/2309.17453 for more details about Attention Sink. +# This implementation is torch.export compatible using a ring buffer approach +# for the sliding window portion while preserving the sink tokens. + import types -from typing import Optional +from typing import Optional, Tuple import torch - -from executorch.examples.models.llama.attention import AttentionMHA, KVCache +import torch.nn as nn +from executorch.examples.models.llama.attention import ( + _create_causal_mask_for_ring_buffer, + AttentionMHA, + CachePositionsManager, + KVCache, + RingKVCache, +) from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.rope import ( apply_rotary_emb_to_k, @@ -27,6 +36,13 @@ class RopeWithAttentionSink(Rope): Rope that helps adjust position encoding when tokens are shifted in KVCache. For AttentionSink, when tokens are shifted in KVCache, we need to use positions in KVCache instead of positions in the actual text. + + For torch.export compatibility, this just passes through the position - the + actual position adjustment is handled by the cache update logic. + + Note: This class uses the model's max_context_len (params.max_context_len) for + RoPE frequency table size, which should be large enough to support generation + beyond the sliding window. The actual KV cache size is sink_size + window_size * 2. """ def __init__( @@ -41,28 +57,22 @@ def __init__( self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k else: self.apply_rotary_emb_to_k = apply_rotary_emb_to_k - self.max_context_length = window_size + sink_size - assert self.max_context_length == self.params.max_context_len + # The KV cache size is sink_size + window_size * 2 (ring buffer needs 2x) + self.kv_cache_size = sink_size + window_size * 2 + self.window_size = window_size + self.sink_size = sink_size + # max_context_len from params is used for RoPE frequencies (should be large) + self.max_context_length = self.params.max_context_len self.eviction_batch_size = eviction_batch_size - self.position_shift = 0 def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): + """ + Get rotary embedding frequencies. + For attention sink, we use the original position - the sliding window + is handled by the cache index management, not by position shifting. + """ assert input_pos is not None - - input_pos_item = input_pos.item() - torch._check_is_size(input_pos_item) - if input_pos_item + self.position_shift + seq_len > self.max_context_length: - # There are not enough spaces in the cache to store the new tokens. - # We need to evict some old tokens and shift some recent tokens. - num_to_evict = max( - input_pos_item - + self.position_shift - - self.max_context_length - + seq_len, - self.eviction_batch_size, - ) - self.position_shift -= num_to_evict # pyre-ignore [8] - return super().get_freqs(input_pos + self.position_shift, seq_len) + return super().get_freqs(input_pos, seq_len) def rerotate_k( self, @@ -71,15 +81,8 @@ def rerotate_k( new_position: int, ): """ - Rerotate k from original_position to new_position. This is done by rerotating - k with (new_position * theta - original_position * theta) with the following matrix: - (cos(delta), -sin(delta) - sin(delta), cos(delta)) - where delta = new_position * theta - original_position * theta - - The shape of k is (batch_size, seq_len, n_local_heads, head_dim) - - Based on https://github.com/huggingface/transformers/blame/main/src/transformers/cache_utils.py#L961 + Rerotate k from original_position to new_position. + The shape of k is (batch_size, seq_len, n_local_heads, head_dim) """ seq_len = k.shape[1] original_freqs_cos = self.freqs_cos.narrow(0, original_position, seq_len) @@ -96,15 +99,113 @@ def rerotate_k( return self.apply_rotary_emb_to_k(k, rerotation_cos, rerotation_sin) +def _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len +): + """ + Create causal mask for attention sink. + + Unlike regular ring buffer mask, this mask: + 1. ALWAYS allows attending to sink tokens (positions 0 to sink_size-1) + 2. Uses sliding window for other tokens + + Args: + cache_positions: Tensor of actual positions stored at each cache index + window_size: Size of the sliding window + sink_size: Number of sink tokens to always attend to + start_pos: Starting position of the current query + seq_len: Length of the current query sequence + """ + pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1) + delta = pos_q - cache_positions + + # Valid if position is filled (>= 0) and causal (delta >= 0) + is_valid = (cache_positions >= 0) & (delta >= 0) + + # Sink tokens (original positions 0 to sink_size-1) are always visible + is_sink = cache_positions < sink_size + + # Window tokens must be within sliding window + # Use <= to include the boundary token. For window_size=124, we want to attend + # to the last 124 tokens BEFORE the current position (delta 1 to 124), plus + # position 4 (first non-sink token) which has delta exactly = window_size. + # This ensures sink_size + window_size tokens are visible when cache is full. + is_in_window = delta <= window_size + + # Final mask: valid AND (is_sink OR is_in_window) + attn_mask = is_valid & (is_sink | is_in_window) + attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712 + return attn_mask + + +class CachePositionsManagerWithSink(nn.Module): + """ + Manages cache positions for attention sink + sliding window. + + For sink_size=0: behaves exactly like original CachePositionsManager (simple ring buffer). + For sink_size>0: sink tokens (indices 0 to sink_size-1) are NEVER overwritten. + Ring buffer only cycles through indices sink_size to cache_size-1. + + IMPORTANT: cache_size should be the actual cache dimension size (sink_size + 2*window_size). + """ + + def __init__(self, cache_size: int, sink_size: int = 0): + super().__init__() + self.max_context_length = cache_size + self.sink_size = sink_size + # Ring buffer size = cache_size - sink_size + self.ring_size = cache_size - sink_size + # Initialize to -1 to mark unwritten positions + # The mask uses (cache_positions >= 0) to check if a position is valid + self.register_buffer( + "cache_positions", + torch.full((self.max_context_length,), -1, dtype=torch.long, device="cpu"), + ) + + def calculate_positions_and_update_indices( + self, input_pos: torch.Tensor, seq_len: int + ) -> torch.Tensor: + """ + Calculate indices into k_cache, v_cache for placing k_val, v_val. + + Index calculation: + - Position < sink_size: index = position (sink tokens at fixed indices) + - Position >= sink_size: index = sink_size + (position - sink_size) % ring_size + + This ensures sink tokens (indices 0 to sink_size-1) are NEVER overwritten. + """ + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + + # Original positions for the sequence + orig_positions = torch.arange(seq_len, dtype=torch.long) + start_pos + + if self.sink_size == 0: + # Simple ring buffer: just mod by cache size + indices = orig_positions % self.max_context_length + else: + # Shifted ring buffer: sink tokens at fixed indices, rest in ring buffer + # For position >= sink_size: index = sink_size + (position - sink_size) % ring_size + shifted = orig_positions - self.sink_size + ring_indices = self.sink_size + (shifted % self.ring_size) + # For position < sink_size: use position directly + indices = torch.where(orig_positions < self.sink_size, orig_positions, ring_indices) + + # Update cache_positions to track what position is at each index + # Only update the indices we're writing to + self.cache_positions.index_copy_(0, indices, orig_positions) + + return indices + + class KVCacheWithAttentionSink(KVCache): """ - KV cache that supports attention sink. It keeps the initial few tokens as attention sink. - For other tokens, it uses a sliding window to keep the most recent tokens. + KV cache that supports attention sink with torch.export compatibility. + + Uses a ring buffer approach for the sliding window portion while keeping + the first sink_size tokens fixed. This avoids dynamic shape operations. - Parameters: - window_size: the size of the sliding window - sink_size: the number of initial tokens to keep as attention sink - eviction_batch_size: the number of tokens to evict in batch when there is not enough space in the KV cache + Cache layout: [sink: 0 to sink_size-1] [ring_buffer: sink_size to sink_size + window_size*2 - 1] """ def __init__( @@ -119,9 +220,11 @@ def __init__( max_batch_size: int = 1, dtype=torch.float32, ): + # Total cache size is sink_size + window_size * 2 (ring buffer needs 2x) + total_cache_size = sink_size + window_size * 2 super().__init__( max_batch_size=max_batch_size, - max_context_length=window_size + sink_size, + max_context_length=total_cache_size, n_heads=n_heads, head_dim=head_dim, enable_dynamic_shape=enable_dynamic_shape, @@ -131,78 +234,65 @@ def __init__( self.window_size = window_size self.sink_size = sink_size self.eviction_batch_size = eviction_batch_size - self.position_shift = 0 + self.is_ring_buffer = True - def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: - """ - Evict old tokens from the cache to make rooms for new tokens. - - Parameters: - input_pos: the start position of the incoming token in the actual sequence - seq_len: the length of the incoming sequence - rope: the rope object to use for rerotating k + # Cache positions manager for determining write locations + # Pass the total cache size (same as self.max_context_length after super().__init__) + self.cache_positions_manager = CachePositionsManagerWithSink(total_cache_size, sink_size) - Returns: - the number of tokens to evict from the cache which is also the number of - positions to shift for incoming tokens + def create_causal_mask_for_ring_buffer( + self, start_pos: torch.Tensor, seq_len: int + ): """ - input_pos_item = input_pos.item() - torch._check_is_size(input_pos_item) - if input_pos_item + self.position_shift + seq_len > self.max_context_length: - # There are not enough spaces in the cache to store the new tokens. - # We need to evict some old tokens and shift some recent tokens. - num_to_evict = max( - input_pos_item - + self.position_shift - - self.max_context_length - + seq_len, - self.eviction_batch_size, - ) - num_to_keep = ( - input_pos_item + self.position_shift - self.sink_size - num_to_evict - ) - num_empty_space = self.window_size - num_to_keep - dim_to_slice = 2 - k_to_keep = self.k_cache.narrow( - dim_to_slice, - self.sink_size + num_to_evict, # pyre-ignore [6] - num_to_keep, # pyre-ignore [6] - ) - k_to_keep = self.rope.rerotate_k( - k=k_to_keep.transpose(1, 2), - original_position=(self.sink_size + num_to_evict), # pyre-ignore [6] - new_position=self.sink_size, - ).transpose(1, 2) - self.k_cache = torch.cat( - [ - self.k_cache.narrow(dim_to_slice, 0, self.sink_size), - k_to_keep, - torch.zeros_like( - self.k_cache.narrow( - dim_to_slice, 0, num_empty_space # pyre-ignore [6] - ) - ), - ], - dim=dim_to_slice, + Create causal mask for the attention with attention sink. + Sink tokens are ALWAYS visible, plus recent tokens in the window. + """ + cache_positions = self.cache_positions_manager.cache_positions + if self.sink_size > 0: + # Use attention sink mask that always allows attending to sink tokens + return _create_causal_mask_for_attention_sink( + cache_positions, self.window_size, self.sink_size, start_pos, seq_len ) - self.v_cache = torch.cat( - [ - self.v_cache.narrow(dim_to_slice, 0, self.sink_size), - self.v_cache.narrow( - dim_to_slice, - self.sink_size + num_to_evict, # pyre-ignore [6] - num_to_keep, # pyre-ignore [6] - ), - torch.zeros_like( - self.v_cache.narrow( - dim_to_slice, 0, num_empty_space # pyre-ignore [6] - ) - ), - ], - dim=dim_to_slice, + else: + # Pure ring buffer mode - use original mask with window_size = actual window + return _create_causal_mask_for_ring_buffer( + cache_positions, self.window_size, start_pos, seq_len ) - self.position_shift -= num_to_evict # pyre-ignore [8] - return self.position_shift + + def update( + self, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Update KV cache with new key-value pairs. + Uses ring buffer indexing for positions >= sink_size. + """ + seq_len = k_val.size(2) + assert seq_len <= self.k_cache.size( + 2 + ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" + + # Calculate write indices + indices = self.cache_positions_manager.calculate_positions_and_update_indices( + input_pos, seq_len + ) + + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + self.k_cache.index_copy_(2, indices, k_val) + self.v_cache.index_copy_(2, indices, v_val) + + return self.k_cache, self.v_cache + + def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: + """ + For ring buffer implementation, no explicit eviction is needed. + The ring buffer automatically overwrites old values. + Returns 0 to indicate no position shift is needed. + """ + return 0 def attention_sink_forward( @@ -212,6 +302,10 @@ def attention_sink_forward( freqs_sin: torch.Tensor, input_pos: Optional[torch.Tensor] = None, ): + """ + Forward function for attention with attention sink KV cache. + Uses ring buffer masking for proper attention patterns. + """ assert self.use_kv_cache assert input_pos is not None @@ -219,19 +313,31 @@ def attention_sink_forward( # QKV q, k, v = self.wq(x), self.wk(x), self.wv(x) - # We need view_copy elimination q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - # Prepare for space in KV cache and get position shift - position_shift = self.kv_cache.evict_tokens(input_pos, seqlen) - - # RoPE relative positional embeddings with shifted position in KV cache + # RoPE relative positional embeddings q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) - output = self.SDPA(input_pos + position_shift, q, k, v, bsz, seqlen, self.mask) - return self.wo(output) + # Transpose for attention: [B, H, S, D] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Update KV cache + k, v = self.kv_cache.update(input_pos, k, v) + + # Use ring buffer mask since we have is_ring_buffer=True + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(start_pos, seqlen) + + # SDPA + output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) + + # Return tuple like original AttentionMHA.forward + return self.wo(output), None def _replace_rope( @@ -265,6 +371,8 @@ def _replace_attention( if isinstance(child_module, AttentionMHA): kv_cache = child_module.kv_cache + # Always use KVCacheWithAttentionSink, even for sink_size=0 + # This ensures we don't get replaced by CustomKVCache when use_sdpa_with_kv_cache=True kv_cache_with_attention_sink = KVCacheWithAttentionSink( n_heads=kv_cache.n_heads, head_dim=kv_cache.head_dim, @@ -277,9 +385,14 @@ def _replace_attention( dtype=kv_cache.k_cache.dtype, ) child_module.kv_cache = kv_cache_with_attention_sink - child_module.forward = types.MethodType( # pyre-ignore - attention_sink_forward, child_module - ) + + # If using SDPACustom (fused SDPA op), enable attention mask support + # so it uses our ring buffer / attention sink mask instead of simple causal mask + if "SDPACustom" in child_module.SDPA.__class__.__name__: + child_module.SDPA.use_attention_mask = True + + # Don't replace forward - let the original AttentionMHA.forward handle it + # since our KVCache has is_ring_buffer=True, it will use the ring buffer mask def enable_attention_sink( @@ -293,7 +406,8 @@ def enable_attention_sink( Transform the model to be able to run inference with Attention Sink. There mainly three steps: - Replace Rope with RopeWithAttentionSink - - Replace Attention's KVCache with KVCacheWithAttentionSink, forward with attention_sink_forward + - Replace Attention's KVCache with KVCacheWithAttentionSink + - Replace Attention's forward with attention_sink_forward """ rope_with_attention_sink = RopeWithAttentionSink( params=params, diff --git a/examples/models/llama/source_transformation/custom_kv_cache.py b/examples/models/llama/source_transformation/custom_kv_cache.py index 8d4d37e0e93..f8a268183b5 100644 --- a/examples/models/llama/source_transformation/custom_kv_cache.py +++ b/examples/models/llama/source_transformation/custom_kv_cache.py @@ -372,7 +372,16 @@ def replace_kv_cache_with_custom_kv_cache(module): def _replace_kv_cache_with_custom_kv_cache(module): for name, child in module.named_children(): - if isinstance(child, KVCache): + # Skip KVCacheWithAttentionSink as it has special evict_tokens logic + # that is not compatible with CustomKVCache. + # Check by class name because the class might come from different module paths + # (e.g., 'examples.models...' vs 'executorch.examples.models...') + child_class_name = type(child).__name__ + if child_class_name == "KVCacheWithAttentionSink": + logging.info(f"Skipping KVCacheWithAttentionSink at {name}") + _replace_kv_cache_with_custom_kv_cache(child) + elif isinstance(child, KVCache): + logging.info(f"Replacing KVCache at {name} (type={child_class_name})") cache_shape = child.k_cache.shape cache_dtype = child.k_cache.dtype max_batch_size, n_heads, max_context_length, head_dim = cache_shape diff --git a/examples/models/llama/source_transformation/test_attention_sink.py b/examples/models/llama/source_transformation/test_attention_sink.py deleted file mode 100644 index fc882ebf4ab..00000000000 --- a/examples/models/llama/source_transformation/test_attention_sink.py +++ /dev/null @@ -1,514 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -import torch -from executorch.examples.models.llama.model_args import ModelArgs - -from executorch.examples.models.llama.source_transformation.attention_sink import ( - KVCacheWithAttentionSink, - RopeWithAttentionSink, -) -from parameterized import parameterized - - -class RopeWithAttentionSinkTest(unittest.TestCase): - - def _init_rope(self, params: ModelArgs, eviction_batch_size: int): - return RopeWithAttentionSink( - params=params, - window_size=252, - sink_size=4, - eviction_batch_size=eviction_batch_size, - ) - - def setUp(self): - torch.manual_seed(42) - self.params = ModelArgs( - use_kv_cache=True, enable_dynamic_shape=True, max_context_len=256 - ) - self.rope_with_attention_sink = self._init_rope( - params=self.params, eviction_batch_size=1 - ) - - @parameterized.expand( - [ - [0, 10, 1, 0], # No shift - [250, 10, 1, 246], # Some shift - [256, 10, 1, 246], # All shift - [0, 10, 30, 0], # No shift with batch eviction - [250, 10, 30, 220], # Some shift with batch eviction - [256, 10, 30, 226], # All shift with batch eviction - ] - ) - def test_get_freqs( - self, input_pos, seq_len, eviction_batch_size, expected_result_pos - ): - self.rope_with_attention_sink = self._init_rope( - params=self.params, eviction_batch_size=eviction_batch_size - ) - - freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( - input_pos=torch.tensor([input_pos], dtype=torch.int32), - seq_len=seq_len, - ) - - torch.testing.assert_close( - freqs_cos, - self.rope_with_attention_sink.freqs_cos.narrow( - 0, expected_result_pos, seq_len - ), - ) - torch.testing.assert_close( - freqs_sin, - self.rope_with_attention_sink.freqs_sin.narrow( - 0, expected_result_pos, seq_len - ), - ) - - @parameterized.expand( - [ - [128, 127], # Rotate left - [128, 128], # No rotation - [128, 129], # Rotate right - ] - ) - def test_rotate(self, original_position, new_position): - seq_len = 32 - - size = (1, seq_len, self.params.n_heads, self.params.head_dim) - q = torch.rand(*size, dtype=torch.float32) - k = torch.rand( - *size, - dtype=torch.float32, - ) - freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( - input_pos=torch.tensor([original_position], dtype=torch.int32), - seq_len=seq_len, - ) - _, pre_rotated_k = self.rope_with_attention_sink.forward( - q=q, - k=k, - freqs_cos=freqs_cos, - freqs_sin=freqs_sin, - ) - - rerotated_k = self.rope_with_attention_sink.rerotate_k( - k=pre_rotated_k, - original_position=original_position, - new_position=new_position, - ) - - freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( - input_pos=torch.tensor([new_position], dtype=torch.int32), - seq_len=seq_len, - ) - _, expected_k = self.rope_with_attention_sink.forward( - q=q, - k=k, - freqs_cos=freqs_cos, - freqs_sin=freqs_sin, - ) - - torch.testing.assert_close(rerotated_k, expected_k) - - -class KVCacheWithAttentionSinkTest(unittest.TestCase): - - _single_evict_test_cases = [ - [4, 1], - ] - - _batch_evict_test_cases = [ - [4, 8], - ] - - _sliding_window_test_cases = [ - [0, 1], - ] - - def _init_cache(self, sink_size, eviction_batch_size): - self.params = ModelArgs( - use_kv_cache=True, - enable_dynamic_shape=True, - max_context_len=self.window_size + sink_size, - ) - self.rope_with_attention_sink = RopeWithAttentionSink( - params=self.params, - window_size=self.window_size, - sink_size=sink_size, - eviction_batch_size=eviction_batch_size, - ) - self.kv_cache = KVCacheWithAttentionSink( - n_heads=self.params.n_heads, - head_dim=self.params.head_dim, - enable_dynamic_shape=self.params.enable_dynamic_shape, - rope=self.rope_with_attention_sink, - max_batch_size=self.max_batch_size, - window_size=self.window_size, - sink_size=sink_size, - eviction_batch_size=eviction_batch_size, - dtype=self.dtype, - ) - - def _rand_kv_with_length(self, seq_len): - size = ( - self.max_batch_size, - self.params.n_heads, - seq_len, - self.params.head_dim, - ) - k = torch.rand( - *size, - dtype=self.dtype, - ) - v = torch.rand( - *size, - dtype=self.dtype, - ) - return k, v - - def _zero_kv_with_length(self, seq_len): - size = ( - self.max_batch_size, - self.params.n_heads, - seq_len, - self.params.head_dim, - ) - k = torch.zeros( - *size, - dtype=self.dtype, - ) - v = torch.zeros( - *size, - dtype=self.dtype, - ) - return k, v - - def _get_dim_to_slice(self): - return 2 - - def _get_expected_rotated_k(self, k, original_position, new_position): - return self.rope_with_attention_sink.rerotate_k( - k=k.transpose(1, 2), - original_position=original_position, - new_position=new_position, - ).transpose(1, 2) - - def setUp(self): - torch.manual_seed(42) - self.max_batch_size = 1 - self.window_size = 28 - self.dtype = torch.float32 - - @parameterized.expand( - _single_evict_test_cases + _batch_evict_test_cases + _sliding_window_test_cases - ) - def test_evict_empty_cache(self, sink_size, eviction_batch_size): - self._init_cache(sink_size, eviction_batch_size) - - # KV cache is empty, evict does nothing - input_pos = torch.tensor([0], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 1) == 0 - - expected_k, expected_v = self._zero_kv_with_length(self.window_size + sink_size) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand( - _single_evict_test_cases + _batch_evict_test_cases + _sliding_window_test_cases - ) - def test_evict_without_shift(self, sink_size, eviction_batch_size): - dimension_to_slice = 2 - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has enough spaces for new tokens, no shift - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(10) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([10], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 1) == 0 - - zero_k, zero_v = self._zero_kv_with_length(self.window_size + sink_size - 10) - - expected_k = torch.cat( - [ - k, - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v, - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_single_evict_test_cases) - def test_evict_with_some_shift(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has some spaces for new tokens but not all, shift some tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([10], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 24) == -2 - - zero_k, zero_v = self._zero_kv_with_length(24) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k(k1.narrow(dimension_to_slice, 1, 4), 6, 4), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 1, 4), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_single_evict_test_cases) - def test_evict_with_all_shift(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has no spaces for new tokens, shift all tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(27) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([32], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 6) == -6 - - zero_k, zero_v = self._zero_kv_with_length(6) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 5, 22), 10, 4 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 5, 22), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_sliding_window_test_cases) - def test_evict_with_some_shift_for_sliding_window( - self, sink_size, eviction_batch_size - ): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has some spaces for new tokens but not all, shift some tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([10], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 20) == -2 - - zero_k, zero_v = self._zero_kv_with_length(20) - expected_k = torch.cat( - [ - self._get_expected_rotated_k(k.narrow(dimension_to_slice, 2, 3), 2, 0), - self._get_expected_rotated_k(k1, 5, 3), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 2, 3), - v1, - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_sliding_window_test_cases) - def test_evict_with_all_shift_for_sliding_window( - self, sink_size, eviction_batch_size - ): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has no spaces for new tokens, shift all tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(23) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([28], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 6) == -6 - - zero_k, zero_v = self._zero_kv_with_length(6) - expected_k = torch.cat( - [ - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 1, 22), 6, 0 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v1.narrow(dimension_to_slice, 1, 22), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_batch_evict_test_cases) - def test_batch_evict_with_seq_len(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has some spaces for new tokens but not all, shift some tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(25) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([30], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 12) == -10 - - zero_k, zero_v = self._zero_kv_with_length(12) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 9, 16), 14, 4 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 9, 16), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_batch_evict_test_cases) - def test_batch_evict_with_batch_size(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has no spaces for new tokens, shift all tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(25) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([30], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 6) == -8 - - zero_k, zero_v = self._zero_kv_with_length(10) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 7, 18), 12, 4 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 7, 18), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) diff --git a/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py new file mode 100644 index 00000000000..953033f5cd8 --- /dev/null +++ b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py @@ -0,0 +1,644 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for the ring-buffer based attention sink implementation. + +This tests the torch.export-compatible implementation that uses a ring buffer +for the sliding window rather than explicit token eviction. + +Usage: + # Run with pytest + python -m pytest examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py -v + + # Or run directly + python examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py +""" + +import unittest + +import torch +from executorch.examples.models.llama.model_args import ModelArgs + +from executorch.examples.models.llama.source_transformation.attention_sink import ( + CachePositionsManagerWithSink, + KVCacheWithAttentionSink, + RopeWithAttentionSink, + _create_causal_mask_for_attention_sink, +) + + +class CachePositionsManagerWithSinkTest(unittest.TestCase): + """Test the cache positions manager for ring buffer indexing.""" + + def setUp(self): + self.cache_size = 32 # Total cache size (e.g., sink_size + window_size * 2) + # Default: no sink (simple ring buffer) + self.manager = CachePositionsManagerWithSink(self.cache_size, sink_size=0) + + def test_initial_positions_are_minus_one(self): + """Cache positions should start as -1 (unwritten).""" + expected = torch.full((self.cache_size,), -1, dtype=torch.long) + torch.testing.assert_close(self.manager.cache_positions, expected) + + def test_simple_update(self): + """Test simple sequential update without wraparound.""" + input_pos = torch.tensor([0], dtype=torch.long) + seq_len = 5 + indices = self.manager.calculate_positions_and_update_indices(input_pos, seq_len) + + # Should return indices 0, 1, 2, 3, 4 + expected_indices = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long) + torch.testing.assert_close(indices, expected_indices) + + # Cache positions at those indices should be the original positions + for i in range(5): + self.assertEqual(self.manager.cache_positions[i].item(), i) + + def test_wraparound_no_sink(self): + """Test ring buffer wraparound with sink_size=0.""" + # Fill cache to position 30 + input_pos = torch.tensor([0], dtype=torch.long) + self.manager.calculate_positions_and_update_indices(input_pos, 30) + + # Add 5 more tokens at position 30 - should wrap around + input_pos = torch.tensor([30], dtype=torch.long) + indices = self.manager.calculate_positions_and_update_indices(input_pos, 5) + + # Indices should wrap: 30 % 32 = 30, 31, 0, 1, 2 + expected_indices = torch.tensor([30, 31, 0, 1, 2], dtype=torch.long) + torch.testing.assert_close(indices, expected_indices) + + def test_wraparound_with_sink(self): + """Test ring buffer wraparound with sink_size > 0.""" + sink_size = 4 + cache_size = 32 + manager = CachePositionsManagerWithSink(cache_size, sink_size) + + # Fill cache to position 30 + input_pos = torch.tensor([0], dtype=torch.long) + manager.calculate_positions_and_update_indices(input_pos, 30) + + # Add 5 more tokens at position 30 + input_pos = torch.tensor([30], dtype=torch.long) + indices = manager.calculate_positions_and_update_indices(input_pos, 5) + + # Ring size = 32 - 4 = 28 + # pos 30 -> idx = 4 + (30-4)%28 = 4 + 26 = 30 + # pos 31 -> idx = 4 + (31-4)%28 = 4 + 27 = 31 + # pos 32 -> idx = 4 + (32-4)%28 = 4 + 0 = 4 (WRAPS TO SINK_SIZE=4, not 0!) + # pos 33 -> idx = 4 + (33-4)%28 = 4 + 1 = 5 + # pos 34 -> idx = 4 + (34-4)%28 = 4 + 2 = 6 + expected_indices = torch.tensor([30, 31, 4, 5, 6], dtype=torch.long) + torch.testing.assert_close(indices, expected_indices) + + def test_cache_positions_track_original_positions_no_sink(self): + """Cache positions should track which original position is at each index (no sink).""" + # Fill with positions 0-31 + input_pos = torch.tensor([0], dtype=torch.long) + self.manager.calculate_positions_and_update_indices(input_pos, 32) + + # Now add position 32 which wraps to index 0 + input_pos = torch.tensor([32], dtype=torch.long) + self.manager.calculate_positions_and_update_indices(input_pos, 1) + + # Index 0 should now contain original position 32 + self.assertEqual(self.manager.cache_positions[0].item(), 32) + + def test_cache_positions_track_original_positions_with_sink(self): + """Cache positions should track positions, and sink tokens are never overwritten.""" + sink_size = 4 + cache_size = 32 + manager = CachePositionsManagerWithSink(cache_size, sink_size) + + # Fill with positions 0-31 + input_pos = torch.tensor([0], dtype=torch.long) + manager.calculate_positions_and_update_indices(input_pos, 32) + + # Indices 0-3 should have pos 0-3 (Sink tokens) + for i in range(4): + self.assertEqual(manager.cache_positions[i].item(), i) + + # Now add position 32. + # (32-4)%28 = 0. So index = 4 + 0 = 4. + input_pos = torch.tensor([32], dtype=torch.long) + manager.calculate_positions_and_update_indices(input_pos, 1) + + # Index 4 should now contain original position 32 + self.assertEqual(manager.cache_positions[4].item(), 32) + + # Index 0-3 (sink) should STILL contain positions 0-3 (unchanged) + for i in range(4): + self.assertEqual(manager.cache_positions[i].item(), i) + + +class CausalMaskTest(unittest.TestCase): + """Test the causal mask generation for attention sink.""" + + def test_mask_allows_sink_tokens(self): + """Sink tokens should always be visible (mask = 0).""" + cache_size = 32 + sink_size = 4 + window_size = 14 # cache_size = sink_size + window_size * 2 + + # Create cache positions where positions 0-3 are sink tokens + cache_positions = torch.arange(cache_size, dtype=torch.long) + + start_pos = 20 # Query at position 20 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Sink tokens (indices 0-3, original positions 0-3) should have mask = 0 + for i in range(sink_size): + self.assertEqual(mask[0, i].item(), 0.0, f"Sink token at index {i} should be visible") + + def test_mask_blocks_future_tokens(self): + """Future tokens should be masked (-inf).""" + cache_size = 32 + sink_size = 4 + window_size = 14 + + cache_positions = torch.arange(cache_size, dtype=torch.long) + + start_pos = 10 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Future tokens (positions > 10) should have mask = -inf + for i in range(11, cache_size): + self.assertEqual(mask[0, i].item(), float('-inf'), f"Future token at position {i} should be masked") + + def test_mask_respects_window(self): + """Tokens outside the window should be masked.""" + cache_size = 32 + sink_size = 4 + window_size = 5 # Only allow 5 recent tokens + + cache_positions = torch.arange(cache_size, dtype=torch.long) + + start_pos = 20 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Positions 16-20 should be visible (within window of 5) + for pos in range(16, 21): + self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible (in window)") + + # Position 15 should be masked (outside window, not a sink) + self.assertEqual(mask[0, 15].item(), float('-inf'), f"Position 15 should be masked (outside window)") + + +class KVCacheWithAttentionSinkTest(unittest.TestCase): + """Test the KV cache with attention sink.""" + + def setUp(self): + torch.manual_seed(42) + self.window_size = 14 + self.sink_size = 4 + self.n_heads = 8 + self.head_dim = 64 + self.max_batch_size = 1 + + # Create model args with enough context for RoPE + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, # Large enough for RoPE + n_heads=self.n_heads, + n_kv_heads=self.n_heads, + dim=self.n_heads * self.head_dim, + ) + + self.rope = RopeWithAttentionSink( + params=self.params, + window_size=self.window_size, + sink_size=self.sink_size, + eviction_batch_size=1, + ) + + self.kv_cache = KVCacheWithAttentionSink( + n_heads=self.n_heads, + head_dim=self.head_dim, + enable_dynamic_shape=True, + rope=self.rope, + max_batch_size=self.max_batch_size, + window_size=self.window_size, + sink_size=self.sink_size, + eviction_batch_size=1, + dtype=torch.float32, + ) + + def test_cache_size(self): + """Cache should be sink_size + window_size * 2.""" + expected_size = self.sink_size + self.window_size * 2 # 4 + 28 = 32 + self.assertEqual(self.kv_cache.k_cache.size(2), expected_size) + self.assertEqual(self.kv_cache.v_cache.size(2), expected_size) + + def test_is_ring_buffer(self): + """Cache should be marked as ring buffer.""" + self.assertTrue(self.kv_cache.is_ring_buffer) + + def test_update_stores_kv(self): + """Update should store key-value pairs.""" + k = torch.randn(1, self.n_heads, 5, self.head_dim) + v = torch.randn(1, self.n_heads, 5, self.head_dim) + input_pos = torch.tensor([0], dtype=torch.long) + + k_out, v_out = self.kv_cache.update(input_pos, k, v) + + # First 5 positions should contain our values + torch.testing.assert_close(k_out[:, :, :5, :], k) + torch.testing.assert_close(v_out[:, :, :5, :], v) + + def test_evict_tokens_returns_zero(self): + """Ring buffer implementation doesn't shift, so evict returns 0.""" + input_pos = torch.tensor([100], dtype=torch.long) + shift = self.kv_cache.evict_tokens(input_pos, 10) + self.assertEqual(shift, 0) + + def test_extended_generation(self): + """Test that cache works for positions beyond cache size.""" + cache_size = self.kv_cache.k_cache.size(2) + + # Fill cache with initial tokens + for pos in range(cache_size + 50): + k = torch.randn(1, self.n_heads, 1, self.head_dim) + v = torch.randn(1, self.n_heads, 1, self.head_dim) + input_pos = torch.tensor([pos], dtype=torch.long) + + k_out, v_out = self.kv_cache.update(input_pos, k, v) + + # Should not raise any errors + self.assertEqual(k_out.shape, self.kv_cache.k_cache.shape) + self.assertEqual(v_out.shape, self.kv_cache.v_cache.shape) + + +class RopeWithAttentionSinkTest(unittest.TestCase): + """Test RoPE for attention sink.""" + + def setUp(self): + torch.manual_seed(42) + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, + n_heads=8, + dim=512, + ) + + self.rope = RopeWithAttentionSink( + params=self.params, + window_size=100, + sink_size=4, + eviction_batch_size=1, + ) + + def test_get_freqs_uses_original_position(self): + """RoPE frequencies should use the original position.""" + input_pos = torch.tensor([50], dtype=torch.long) + seq_len = 5 + + freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seq_len) + + # Should get frequencies for positions 50-54 + expected_cos = self.rope.freqs_cos[50:55] + expected_sin = self.rope.freqs_sin[50:55] + + torch.testing.assert_close(freqs_cos, expected_cos) + torch.testing.assert_close(freqs_sin, expected_sin) + + def test_rerotate_k(self): + """Test re-rotation of k from one position to another.""" + batch_size = 1 + seq_len = 8 + n_heads = self.params.n_heads + head_dim = self.params.dim // n_heads + + k = torch.randn(batch_size, seq_len, n_heads, head_dim) + q = torch.randn(batch_size, seq_len, n_heads, head_dim) + + # Rotate k at position 100 + original_pos = 100 + freqs_cos, freqs_sin = self.rope.get_freqs( + torch.tensor([original_pos], dtype=torch.long), seq_len + ) + _, rotated_k = self.rope.forward(q, k, freqs_cos, freqs_sin) + + # Re-rotate to position 50 + new_pos = 50 + rerotated_k = self.rope.rerotate_k(rotated_k, original_pos, new_pos) + + # This should be equivalent to directly rotating k at position 50 + freqs_cos_new, freqs_sin_new = self.rope.get_freqs( + torch.tensor([new_pos], dtype=torch.long), seq_len + ) + _, expected_k = self.rope.forward(q, k, freqs_cos_new, freqs_sin_new) + + torch.testing.assert_close(rerotated_k, expected_k, rtol=1e-4, atol=1e-4) + + +class CausalMaskWithWraparoundTest(unittest.TestCase): + """Test causal mask with ring buffer wraparound.""" + + def test_mask_after_wraparound(self): + """Test mask after cache has wrapped around.""" + cache_size = 16 + sink_size = 4 + window_size = 6 # cache_size = sink_size + window_size * 2 + + # Simulate cache after generating beyond cache_size: + # The ring buffer wraps, so indices 0-15 contain positions that wrap + # At position 50, with cache_size=16, the cache contains: + # positions 50-15=35 to 49 at various indices + cache_positions = torch.zeros(cache_size, dtype=torch.long) + # Fill with positions that would exist after generating 50 tokens + # idx = pos % cache_size, so: + # pos 34-49 occupy indices 2-15 and 0-1 + for pos in range(34, 50): + idx = pos % cache_size + cache_positions[idx] = pos + + start_pos = 49 # Query at position 49 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Positions within window (49-6+1=44 to 49) should be visible + visible_count = 0 + for i in range(cache_size): + pos = cache_positions[i].item() + if pos >= 44 and pos <= 49: # In window + self.assertEqual(mask[0, i].item(), 0.0, + f"Position {pos} at idx {i} should be visible (in window)") + visible_count += 1 + + # Should have some visible tokens + self.assertGreater(visible_count, 0, "Should have visible tokens in window") + + def test_mask_with_sink_size_zero(self): + """Test pure sliding window (sink_size=0).""" + cache_size = 16 + sink_size = 0 + window_size = 8 + + cache_positions = torch.arange(cache_size, dtype=torch.long) + start_pos = 10 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Positions 3-10 should be visible (within window of 8) + for pos in range(3, 11): + self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible") + + # Positions 0-2 should be masked (outside window) + for pos in range(0, 3): + self.assertEqual(mask[0, pos].item(), float('-inf'), + f"Position {pos} should be masked (outside window)") + + +class PrefillTest(unittest.TestCase): + """Test prefill scenarios.""" + + def setUp(self): + torch.manual_seed(42) + self.window_size = 14 + self.sink_size = 4 + self.n_heads = 8 + self.head_dim = 64 + + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, + n_heads=self.n_heads, + n_kv_heads=self.n_heads, + dim=self.n_heads * self.head_dim, + ) + + self.rope = RopeWithAttentionSink( + params=self.params, + window_size=self.window_size, + sink_size=self.sink_size, + eviction_batch_size=1, + ) + + self.kv_cache = KVCacheWithAttentionSink( + n_heads=self.n_heads, + head_dim=self.head_dim, + enable_dynamic_shape=True, + rope=self.rope, + max_batch_size=1, + window_size=self.window_size, + sink_size=self.sink_size, + eviction_batch_size=1, + dtype=torch.float32, + ) + + def test_prefill_entire_cache(self): + """Test prefill that fills entire cache.""" + cache_size = self.kv_cache.k_cache.size(2) + + k = torch.randn(1, self.n_heads, cache_size, self.head_dim) + v = torch.randn(1, self.n_heads, cache_size, self.head_dim) + input_pos = torch.tensor([0], dtype=torch.long) + + k_out, v_out = self.kv_cache.update(input_pos, k, v) + + # All positions should be filled + torch.testing.assert_close(k_out, k) + torch.testing.assert_close(v_out, v) + + def test_prefill_larger_than_cache_raises_error(self): + """Test that prefill larger than cache size raises an assertion error.""" + cache_size = self.kv_cache.k_cache.size(2) + seq_len = cache_size + 10 + + k = torch.randn(1, self.n_heads, seq_len, self.head_dim) + v = torch.randn(1, self.n_heads, seq_len, self.head_dim) + input_pos = torch.tensor([0], dtype=torch.long) + + # This should raise an assertion error since seq_len > cache_size + with self.assertRaises(AssertionError): + self.kv_cache.update(input_pos, k, v) + + def test_prefill_followed_by_decode(self): + """Test prefill followed by decode steps.""" + cache_size = self.kv_cache.k_cache.size(2) + + # Prefill with 20 tokens + k_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) + v_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) + input_pos = torch.tensor([0], dtype=torch.long) + self.kv_cache.update(input_pos, k_prefill, v_prefill) + + # Decode 5 more tokens + for i in range(5): + k_decode = torch.randn(1, self.n_heads, 1, self.head_dim) + v_decode = torch.randn(1, self.n_heads, 1, self.head_dim) + input_pos = torch.tensor([20 + i], dtype=torch.long) + k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode) + + # Verify cache positions are updated + expected_pos = 20 + i + cache_idx = expected_pos % cache_size + self.assertEqual( + self.kv_cache.cache_positions_manager.cache_positions[cache_idx].item(), + expected_pos + ) + + +class EnableAttentionSinkTest(unittest.TestCase): + """Test the enable_attention_sink transformation.""" + + def setUp(self): + torch.manual_seed(42) + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, + n_heads=8, + n_kv_heads=8, + dim=512, + n_layers=2, + vocab_size=100, + ) + + def test_enable_attention_sink_transforms_model(self): + """Test that enable_attention_sink properly transforms the model.""" + from executorch.examples.models.llama.llama_transformer import construct_transformer + from executorch.examples.models.llama.source_transformation.attention_sink import ( + enable_attention_sink, + ) + + # Create a simple transformer + with torch.device("meta"): + model = construct_transformer(self.params) + model.to_empty(device="cpu") + + # Apply attention sink transformation + model = enable_attention_sink( + module=model, + params=self.params, + sink_size=4, + window_size=100, + eviction_batch_size=1, + ) + + # Check that KV caches are replaced + for layer in model.layers: + kv_cache = layer.attention.kv_cache + self.assertIsInstance(kv_cache, KVCacheWithAttentionSink) + self.assertEqual(kv_cache.sink_size, 4) + self.assertEqual(kv_cache.window_size, 100) + self.assertTrue(kv_cache.is_ring_buffer) + + def test_enable_attention_sink_replaces_rope(self): + """Test that RoPE is replaced with RopeWithAttentionSink.""" + from executorch.examples.models.llama.llama_transformer import construct_transformer + from executorch.examples.models.llama.source_transformation.attention_sink import ( + enable_attention_sink, + ) + + with torch.device("meta"): + model = construct_transformer(self.params) + model.to_empty(device="cpu") + + model = enable_attention_sink( + module=model, + params=self.params, + sink_size=4, + window_size=100, + eviction_batch_size=1, + ) + + # Check that rope is replaced + for layer in model.layers: + rope = layer.attention.rope + self.assertIsInstance(rope, RopeWithAttentionSink) + + +class IntegrationTest(unittest.TestCase): + """Integration tests for end-to-end scenarios.""" + + def setUp(self): + torch.manual_seed(42) + + def test_cache_positions_consistency(self): + """Test that cache positions remain consistent during generation.""" + cache_size = 32 + sink_size = 4 + window_size = 14 + n_heads = 8 + head_dim = 64 + + params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, + n_heads=n_heads, + n_kv_heads=n_heads, + dim=n_heads * head_dim, + ) + + rope = RopeWithAttentionSink( + params=params, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=1, + ) + + kv_cache = KVCacheWithAttentionSink( + n_heads=n_heads, + head_dim=head_dim, + enable_dynamic_shape=True, + rope=rope, + max_batch_size=1, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=1, + dtype=torch.float32, + ) + + # Generate 100 tokens + for pos in range(100): + k = torch.randn(1, n_heads, 1, head_dim) + v = torch.randn(1, n_heads, 1, head_dim) + input_pos = torch.tensor([pos], dtype=torch.long) + + kv_cache.update(input_pos, k, v) + + # Create mask and verify it's valid + mask = kv_cache.create_causal_mask_for_ring_buffer(pos, 1) + + # Mask should not be all -inf (would mean no tokens to attend to) + non_inf_count = (mask != float('-inf')).sum().item() + self.assertGreater(non_inf_count, 0, f"At pos {pos}, mask should have visible tokens") + + # For positions >= sink_size, sinks should always be visible + if pos >= sink_size: + for i in range(sink_size): + cache_pos = kv_cache.cache_positions_manager.cache_positions[i].item() + if cache_pos < sink_size: # This is actually a sink token + self.assertEqual(mask[0, i].item(), 0.0, + f"Sink at idx {i} should be visible at pos {pos}") + + +if __name__ == '__main__': + unittest.main() diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index 373f57d3a8e..11d77f321bf 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -129,24 +129,42 @@ Error TextLLMRunner::generate( std::vector prompt_tokens = encode_res.get(); int num_prompt_tokens = prompt_tokens.size(); - // Reduce max_context_len by pos_ - int64_t max_context_len = metadata_.at(kMaxContextLen) - pos_; + // Get max_seq_len for single prefill chunk limit + int64_t max_seq_len = metadata_.at(kMaxSeqLen); + int64_t max_context_len = metadata_.at(kMaxContextLen); + ET_CHECK_OR_RETURN_ERROR( num_prompt_tokens >= 1, InvalidArgument, "Expected at least 1 prompt token"); + + // For models with sliding window (Ring Buffer / Attention Sink), + // we allow pos_ to exceed max_context_len. The model handles this + // internally via ring buffer indexing or token eviction. + // We only check that a single prefill chunk doesn't exceed max_seq_len. ET_CHECK_OR_RETURN_ERROR( - num_prompt_tokens < max_context_len, + num_prompt_tokens <= max_seq_len, InvalidArgument, - "num_prompt_tokens %d >= max_context_len %" PRId64 - ", Max seq length exceeded - please increase max seq len value in your export script", + "num_prompt_tokens %d > max_seq_len %" PRId64 + ", Single prefill chunk too large - please reduce prompt size or increase max_seq_len", num_prompt_tokens, - max_context_len); + max_seq_len); - // Determine max_new_tokens using the GenerationConfig's resolve method, - // then subtract pos_ for max_new_tokens. + // Determine max_new_tokens using the GenerationConfig's resolve method. + // For sliding window models, we use max_context_len directly (not reduced by pos_) + // because the model handles position wrapping internally via ring buffer. int max_new_tokens = config.resolve_max_new_tokens(max_context_len, num_prompt_tokens); + + // TEMPORARY: For sliding window / infinite context testing, + // override max_new_tokens to allow unlimited generation + if (max_new_tokens <= 0) { + max_new_tokens = 1000000; // Effectively unlimited + } + // If user specified seq_len, use that instead + if (config.seq_len > 0 && config.seq_len > max_new_tokens) { + max_new_tokens = config.seq_len; + } ET_LOG( Info, diff --git a/test_attention_sink.md b/test_attention_sink.md new file mode 100644 index 00000000000..0e70723cc20 --- /dev/null +++ b/test_attention_sink.md @@ -0,0 +1,37 @@ +# Build the runtime + +Do this after modifying runtime code (cpp) +```sh +cmake --workflow llm-debug +pushd examples/models/llama +cmake --workflow --preset llama-debug +popd +``` + +# Export model +Take a look at examples/models/llama/README.md + +Check point is in ~/executorch/ + +Make sure you are in conda executorch env + +# No quantization +# Set these paths to point to the downloaded files +LLAMA_CHECKPOINT=path/to/consolidated.00.pth +LLAMA_PARAMS=path/to/params.json + +python -m extension.llm.export.export_llm \ + --config examples/models/llama/config/llama_bf16.yaml \ + +base.model_class="llama3_2" \ + +base.checkpoint="consolidated.00.pth" \ + +base.params="params.json" +``` + +# Run + +Please also take a look at examples/models/llama/runner to make sure it can emit many tokens, exceeding context size. + +Please check whether the output makes sense or not +``` +cmake-out/examples/models/llama/llama_main --model_path= --tokenizer_path= --prompt= +```