From c954f27e543f57d1bdd547b28d9de0872d0effaf Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Feb 2026 01:41:34 -0800 Subject: [PATCH 01/17] attention sink --- .../llama/config/llama_attention_sink.yaml | 31 ++ .../source_transformation/attention_sink.py | 316 ++++++++++++------ .../source_transformation/custom_kv_cache.py | 11 +- 3 files changed, 247 insertions(+), 111 deletions(-) create mode 100644 examples/models/llama/config/llama_attention_sink.yaml diff --git a/examples/models/llama/config/llama_attention_sink.yaml b/examples/models/llama/config/llama_attention_sink.yaml new file mode 100644 index 00000000000..8d3fc01b8f7 --- /dev/null +++ b/examples/models/llama/config/llama_attention_sink.yaml @@ -0,0 +1,31 @@ +base: + metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' + +model: + use_sdpa_with_kv_cache: False # attention_sink 需要用标准 SDPA + use_kv_cache: True + dtype_override: fp32 + enable_dynamic_shape: True + # Attention Sink: "sink_size,window_size,eviction_batch_size" + # sink_size=4: 保留前 4 个 token (BOS + system prompt) + # window_size=126: 滑动窗口大小 + # eviction_batch_size=1: 每次驱逐 1 个 token + # 实际 cache 大小 = sink_size + window_size = 4 + 126 = 130 + # 但 ring buffer 需要 2x window, 所以实际是 4 + 126*2 = 256 + use_attention_sink: "4,126,1" + +export: + # max_context_length = sink_size + window_size = 4 + 126 = 130 + # 但 ring buffer 内部会创建 sink_size + window_size * 2 = 256 的缓存 + max_context_length: 130 + max_seq_length: 130 + +quantization: + qmode: 8da4w + group_size: 128 + embedding_quantize: 4,32 + +backend: + xnnpack: + enabled: True + extended_ops: True diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index 22bd8a3e228..126bbcf919c 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -7,12 +7,19 @@ # Components for supporting Attention Sink. See # https://arxiv.org/abs/2309.17453 for more details about Attention Sink. +# This implementation is torch.export compatible using a ring buffer approach +# for the sliding window portion while preserving the sink tokens. + import types -from typing import Optional +from typing import Optional, Tuple import torch - -from executorch.examples.models.llama.attention import AttentionMHA, KVCache +import torch.nn as nn +from executorch.examples.models.llama.attention import ( + _create_causal_mask_for_ring_buffer, + AttentionMHA, + KVCache, +) from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.rope import ( apply_rotary_emb_to_k, @@ -27,6 +34,9 @@ class RopeWithAttentionSink(Rope): Rope that helps adjust position encoding when tokens are shifted in KVCache. For AttentionSink, when tokens are shifted in KVCache, we need to use positions in KVCache instead of positions in the actual text. + + For torch.export compatibility, this just passes through the position - the + actual position adjustment is handled by the cache update logic. """ def __init__( @@ -42,27 +52,19 @@ def __init__( else: self.apply_rotary_emb_to_k = apply_rotary_emb_to_k self.max_context_length = window_size + sink_size + self.window_size = window_size + self.sink_size = sink_size assert self.max_context_length == self.params.max_context_len self.eviction_batch_size = eviction_batch_size - self.position_shift = 0 def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): + """ + Get rotary embedding frequencies. + For attention sink, we use the original position - the sliding window + is handled by the cache index management, not by position shifting. + """ assert input_pos is not None - - input_pos_item = input_pos.item() - torch._check_is_size(input_pos_item) - if input_pos_item + self.position_shift + seq_len > self.max_context_length: - # There are not enough spaces in the cache to store the new tokens. - # We need to evict some old tokens and shift some recent tokens. - num_to_evict = max( - input_pos_item - + self.position_shift - - self.max_context_length - + seq_len, - self.eviction_batch_size, - ) - self.position_shift -= num_to_evict # pyre-ignore [8] - return super().get_freqs(input_pos + self.position_shift, seq_len) + return super().get_freqs(input_pos, seq_len) def rerotate_k( self, @@ -71,15 +73,8 @@ def rerotate_k( new_position: int, ): """ - Rerotate k from original_position to new_position. This is done by rerotating - k with (new_position * theta - original_position * theta) with the following matrix: - (cos(delta), -sin(delta) - sin(delta), cos(delta)) - where delta = new_position * theta - original_position * theta - - The shape of k is (batch_size, seq_len, n_local_heads, head_dim) - - Based on https://github.com/huggingface/transformers/blame/main/src/transformers/cache_utils.py#L961 + Rerotate k from original_position to new_position. + The shape of k is (batch_size, seq_len, n_local_heads, head_dim) """ seq_len = k.shape[1] original_freqs_cos = self.freqs_cos.narrow(0, original_position, seq_len) @@ -96,15 +91,114 @@ def rerotate_k( return self.apply_rotary_emb_to_k(k, rerotation_cos, rerotation_sin) +def _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len +): + """ + Create causal mask for attention sink. + + Unlike regular ring buffer mask, this mask: + 1. ALWAYS allows attending to sink tokens (positions 0 to sink_size-1) + 2. Uses sliding window for other tokens + + Args: + cache_positions: Tensor of actual positions stored at each cache index + window_size: Size of the sliding window + sink_size: Number of sink tokens to always attend to + start_pos: Starting position of the current query + seq_len: Length of the current query sequence + """ + pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1) + delta = pos_q - cache_positions + + # Valid if position is filled (>= 0) and causal (delta >= 0) + is_valid = (cache_positions >= 0) & (delta >= 0) + + # Sink tokens (original positions 0 to sink_size-1) are always visible + is_sink = cache_positions < sink_size + + # Window tokens must be within sliding window + is_in_window = delta < window_size + + # Final mask: valid AND (is_sink OR is_in_window) + attn_mask = is_valid & (is_sink | is_in_window) + attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712 + return attn_mask + + +class CachePositionsManagerWithSink(nn.Module): + """ + Manages cache positions for attention sink + sliding window. + Similar to CachePositionsManager but handles sink tokens separately. + + Layout: [sink_tokens (fixed)] [ring_buffer_window (rotating)] + + For sink_size=4 and window_size=8: + - Positions 0-3 in the sequence go to cache indices 0-3 (fixed) + - Positions 4+ go to cache indices 4-19 using ring buffer (window_size * 2) + """ + + def __init__(self, window_size: int, sink_size: int): + super().__init__() + # Total cache size is sink + window * 2 (ring buffer needs 2x for proper masking) + self.max_context_length = sink_size + window_size * 2 + self.sink_size = sink_size + self.window_size = window_size + self.register_buffer( + "cache_positions", + torch.full((self.max_context_length,), -1, dtype=torch.long, device="cpu"), + ) + # Initialize sink positions (these are fixed) + if sink_size > 0: + self.cache_positions[:sink_size] = torch.arange(sink_size) + + def calculate_positions_and_update_indices( + self, input_pos: torch.Tensor, seq_len: int + ) -> torch.Tensor: + """ + Calculate indices into k_cache, v_cache for placing k_val, v_val. + + For positions < sink_size: index = position (fixed) + For positions >= sink_size: index = sink_size + (pos - sink_size) % (window_size * 2) + """ + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + + orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos + + # Calculate cache indices based on whether position is sink or window + sink_part = torch.minimum(orig_indices, torch.tensor(self.sink_size)) + window_part = torch.maximum( + orig_indices - self.sink_size, torch.tensor(0) + ) % (self.window_size * 2) + is_sink = orig_indices < self.sink_size + indices = torch.where(is_sink, sink_part, self.sink_size + window_part) + + # Update cache_positions: clear old positions and set new ones + full_t = torch.full((self.max_context_length,), -1, dtype=torch.long) + arange_tensor = torch.arange(self.max_context_length, dtype=torch.long) + # Keep sink positions (0 to sink_size-1) and clear window positions that will be overwritten + cache_positions = torch.where( + arange_tensor < self.sink_size, self.cache_positions, full_t + ) + # For non-sink positions, check if they should be cleared + cache_positions = torch.where( + arange_tensor < start_pos, self.cache_positions, cache_positions + ) + self.cache_positions.copy_(cache_positions) + self.cache_positions.index_copy_(0, indices, orig_indices) + + return indices + + class KVCacheWithAttentionSink(KVCache): """ - KV cache that supports attention sink. It keeps the initial few tokens as attention sink. - For other tokens, it uses a sliding window to keep the most recent tokens. + KV cache that supports attention sink with torch.export compatibility. - Parameters: - window_size: the size of the sliding window - sink_size: the number of initial tokens to keep as attention sink - eviction_batch_size: the number of tokens to evict in batch when there is not enough space in the KV cache + Uses a ring buffer approach for the sliding window portion while keeping + the first sink_size tokens fixed. This avoids dynamic shape operations. + + Cache layout: [sink: 0 to sink_size-1] [ring_buffer: sink_size to sink_size + window_size*2 - 1] """ def __init__( @@ -119,9 +213,11 @@ def __init__( max_batch_size: int = 1, dtype=torch.float32, ): + # Total cache size is sink_size + window_size * 2 (ring buffer needs 2x) + total_cache_size = sink_size + window_size * 2 super().__init__( max_batch_size=max_batch_size, - max_context_length=window_size + sink_size, + max_context_length=total_cache_size, n_heads=n_heads, head_dim=head_dim, enable_dynamic_shape=enable_dynamic_shape, @@ -131,78 +227,61 @@ def __init__( self.window_size = window_size self.sink_size = sink_size self.eviction_batch_size = eviction_batch_size - self.position_shift = 0 + self.is_ring_buffer = True - def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: + # Cache positions manager for determining write locations + self.cache_positions_manager = CachePositionsManagerWithSink( + window_size=window_size, + sink_size=sink_size, + ) + + def create_causal_mask_for_ring_buffer( + self, start_pos: torch.Tensor, seq_len: int + ): + """ + Create causal mask for the attention with attention sink. + Sink tokens are ALWAYS visible, plus recent tokens in the window. """ - Evict old tokens from the cache to make rooms for new tokens. + cache_positions = self.cache_positions_manager.cache_positions + # Use attention sink mask that always allows attending to sink tokens + return _create_causal_mask_for_attention_sink( + cache_positions, self.window_size, self.sink_size, start_pos, seq_len + ) + + def update( + self, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Update KV cache with new key-value pairs. + Uses ring buffer indexing for positions >= sink_size. + """ + seq_len = k_val.size(2) + assert seq_len <= self.k_cache.size( + 2 + ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" + + # Calculate write indices + indices = self.cache_positions_manager.calculate_positions_and_update_indices( + input_pos, seq_len + ) + + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + self.k_cache.index_copy_(2, indices, k_val) + self.v_cache.index_copy_(2, indices, v_val) - Parameters: - input_pos: the start position of the incoming token in the actual sequence - seq_len: the length of the incoming sequence - rope: the rope object to use for rerotating k + return self.k_cache, self.v_cache - Returns: - the number of tokens to evict from the cache which is also the number of - positions to shift for incoming tokens + def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: """ - input_pos_item = input_pos.item() - torch._check_is_size(input_pos_item) - if input_pos_item + self.position_shift + seq_len > self.max_context_length: - # There are not enough spaces in the cache to store the new tokens. - # We need to evict some old tokens and shift some recent tokens. - num_to_evict = max( - input_pos_item - + self.position_shift - - self.max_context_length - + seq_len, - self.eviction_batch_size, - ) - num_to_keep = ( - input_pos_item + self.position_shift - self.sink_size - num_to_evict - ) - num_empty_space = self.window_size - num_to_keep - dim_to_slice = 2 - k_to_keep = self.k_cache.narrow( - dim_to_slice, - self.sink_size + num_to_evict, # pyre-ignore [6] - num_to_keep, # pyre-ignore [6] - ) - k_to_keep = self.rope.rerotate_k( - k=k_to_keep.transpose(1, 2), - original_position=(self.sink_size + num_to_evict), # pyre-ignore [6] - new_position=self.sink_size, - ).transpose(1, 2) - self.k_cache = torch.cat( - [ - self.k_cache.narrow(dim_to_slice, 0, self.sink_size), - k_to_keep, - torch.zeros_like( - self.k_cache.narrow( - dim_to_slice, 0, num_empty_space # pyre-ignore [6] - ) - ), - ], - dim=dim_to_slice, - ) - self.v_cache = torch.cat( - [ - self.v_cache.narrow(dim_to_slice, 0, self.sink_size), - self.v_cache.narrow( - dim_to_slice, - self.sink_size + num_to_evict, # pyre-ignore [6] - num_to_keep, # pyre-ignore [6] - ), - torch.zeros_like( - self.v_cache.narrow( - dim_to_slice, 0, num_empty_space # pyre-ignore [6] - ) - ), - ], - dim=dim_to_slice, - ) - self.position_shift -= num_to_evict # pyre-ignore [8] - return self.position_shift + For ring buffer implementation, no explicit eviction is needed. + The ring buffer automatically overwrites old values. + Returns 0 to indicate no position shift is needed. + """ + return 0 def attention_sink_forward( @@ -212,6 +291,10 @@ def attention_sink_forward( freqs_sin: torch.Tensor, input_pos: Optional[torch.Tensor] = None, ): + """ + Forward function for attention with attention sink KV cache. + Uses ring buffer masking for proper attention patterns. + """ assert self.use_kv_cache assert input_pos is not None @@ -219,19 +302,31 @@ def attention_sink_forward( # QKV q, k, v = self.wq(x), self.wk(x), self.wv(x) - # We need view_copy elimination q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - # Prepare for space in KV cache and get position shift - position_shift = self.kv_cache.evict_tokens(input_pos, seqlen) - - # RoPE relative positional embeddings with shifted position in KV cache + # RoPE relative positional embeddings q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) - output = self.SDPA(input_pos + position_shift, q, k, v, bsz, seqlen, self.mask) - return self.wo(output) + # Transpose for attention: [B, H, S, D] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Update KV cache + k, v = self.kv_cache.update(input_pos, k, v) + + # Use ring buffer mask since we have is_ring_buffer=True + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(start_pos, seqlen) + + # SDPA + output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) + + # Return tuple like original AttentionMHA.forward + return self.wo(output), None def _replace_rope( @@ -293,7 +388,8 @@ def enable_attention_sink( Transform the model to be able to run inference with Attention Sink. There mainly three steps: - Replace Rope with RopeWithAttentionSink - - Replace Attention's KVCache with KVCacheWithAttentionSink, forward with attention_sink_forward + - Replace Attention's KVCache with KVCacheWithAttentionSink + - Replace Attention's forward with attention_sink_forward """ rope_with_attention_sink = RopeWithAttentionSink( params=params, diff --git a/examples/models/llama/source_transformation/custom_kv_cache.py b/examples/models/llama/source_transformation/custom_kv_cache.py index 8d4d37e0e93..373f4fcc4a0 100644 --- a/examples/models/llama/source_transformation/custom_kv_cache.py +++ b/examples/models/llama/source_transformation/custom_kv_cache.py @@ -371,8 +371,17 @@ def replace_kv_cache_with_custom_kv_cache(module): def _replace_kv_cache_with_custom_kv_cache(module): + # Import here to avoid circular imports + from executorch.examples.models.llama.source_transformation.attention_sink import ( + KVCacheWithAttentionSink, + ) + for name, child in module.named_children(): - if isinstance(child, KVCache): + # Skip KVCacheWithAttentionSink as it has special evict_tokens logic + # that is not compatible with CustomKVCache + if isinstance(child, KVCacheWithAttentionSink): + _replace_kv_cache_with_custom_kv_cache(child) + elif isinstance(child, KVCache): cache_shape = child.k_cache.shape cache_dtype = child.k_cache.dtype max_batch_size, n_heads, max_context_length, head_dim = cache_shape From 34a937f4eb5c111cfa9c89af4a7be557eb0d0835 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Feb 2026 11:05:46 -0800 Subject: [PATCH 02/17] Test this --- .../llama/config/llama_attention_sink.yaml | 18 +-- examples/models/llama/model.py | 8 +- .../source_transformation/attention_sink.py | 124 +++++++++--------- extension/llm/runner/text_llm_runner.cpp | 34 +++-- 4 files changed, 107 insertions(+), 77 deletions(-) diff --git a/examples/models/llama/config/llama_attention_sink.yaml b/examples/models/llama/config/llama_attention_sink.yaml index 8d3fc01b8f7..2e8cfb18cde 100644 --- a/examples/models/llama/config/llama_attention_sink.yaml +++ b/examples/models/llama/config/llama_attention_sink.yaml @@ -7,18 +7,18 @@ model: dtype_override: fp32 enable_dynamic_shape: True # Attention Sink: "sink_size,window_size,eviction_batch_size" - # sink_size=4: 保留前 4 个 token (BOS + system prompt) - # window_size=126: 滑动窗口大小 + # sink_size=4: Keep first 4 tokens (e.g., BOS + system prompt) + # window_size=124: 滑动窗口大小 # eviction_batch_size=1: 每次驱逐 1 个 token - # 实际 cache 大小 = sink_size + window_size = 4 + 126 = 130 - # 但 ring buffer 需要 2x window, 所以实际是 4 + 126*2 = 256 - use_attention_sink: "4,126,1" + # KV cache size = sink_size + window_size * 2 = 4 + 124*2 = 252 + use_attention_sink: "4,124,1" export: - # max_context_length = sink_size + window_size = 4 + 126 = 130 - # 但 ring buffer 内部会创建 sink_size + window_size * 2 = 256 的缓存 - max_context_length: 130 - max_seq_length: 130 + # max_context_length controls the RoPE frequency table size. + # It must be >= sink_size + window_size (128), but larger values are + # recommended to support generation beyond the sliding window. + # The model default (e.g., 8192 or 131072) is typically used if not specified. + # For testing, we use the model's default by not setting this explicitly. quantization: qmode: 8da4w diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 1ec85936f7a..a2c8d73207b 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -218,7 +218,13 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): window_size = int(attention_sink_params[1]) eviction_batch_size = int(attention_sink_params[2]) - assert self.llm_config.export.max_context_length == sink_size + window_size + # max_context_length must be >= sink_size + window_size to have enough RoPE frequencies + # A larger max_context_length is allowed (and recommended) to support generation beyond + # the sliding window size. + assert self.llm_config.export.max_context_length >= sink_size + window_size, ( + f"max_context_length ({self.llm_config.export.max_context_length}) must be >= " + f"sink_size + window_size ({sink_size + window_size})" + ) self.model_ = enable_attention_sink( module=self.model_, diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index 126bbcf919c..068ac7bdc68 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -18,7 +18,9 @@ from executorch.examples.models.llama.attention import ( _create_causal_mask_for_ring_buffer, AttentionMHA, + CachePositionsManager, KVCache, + RingKVCache, ) from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.rope import ( @@ -37,6 +39,10 @@ class RopeWithAttentionSink(Rope): For torch.export compatibility, this just passes through the position - the actual position adjustment is handled by the cache update logic. + + Note: This class uses the model's max_context_len (params.max_context_len) for + RoPE frequency table size, which should be large enough to support generation + beyond the sliding window. The actual KV cache size is sink_size + window_size * 2. """ def __init__( @@ -51,10 +57,12 @@ def __init__( self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k else: self.apply_rotary_emb_to_k = apply_rotary_emb_to_k - self.max_context_length = window_size + sink_size + # The KV cache size is sink_size + window_size * 2 (ring buffer needs 2x) + self.kv_cache_size = sink_size + window_size * 2 self.window_size = window_size self.sink_size = sink_size - assert self.max_context_length == self.params.max_context_len + # max_context_len from params is used for RoPE frequencies (should be large) + self.max_context_length = self.params.max_context_len self.eviction_batch_size = eviction_batch_size def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): @@ -129,61 +137,44 @@ def _create_causal_mask_for_attention_sink( class CachePositionsManagerWithSink(nn.Module): """ Manages cache positions for attention sink + sliding window. - Similar to CachePositionsManager but handles sink tokens separately. - - Layout: [sink_tokens (fixed)] [ring_buffer_window (rotating)] - - For sink_size=4 and window_size=8: - - Positions 0-3 in the sequence go to cache indices 0-3 (fixed) - - Positions 4+ go to cache indices 4-19 using ring buffer (window_size * 2) + + For sink_size=0: behaves exactly like original CachePositionsManager. + For sink_size>0: sink tokens go to fixed positions, rest uses ring buffer. + + IMPORTANT: cache_size should be the actual cache dimension size (2x window for ring buffer). """ - def __init__(self, window_size: int, sink_size: int): + def __init__(self, cache_size: int): super().__init__() - # Total cache size is sink + window * 2 (ring buffer needs 2x for proper masking) - self.max_context_length = sink_size + window_size * 2 - self.sink_size = sink_size - self.window_size = window_size + # cache_size is the actual size of the kv cache dimension + self.max_context_length = cache_size + # Use zeros like original CachePositionsManager self.register_buffer( "cache_positions", - torch.full((self.max_context_length,), -1, dtype=torch.long, device="cpu"), + torch.zeros((self.max_context_length,), dtype=torch.long, device="cpu"), ) - # Initialize sink positions (these are fixed) - if sink_size > 0: - self.cache_positions[:sink_size] = torch.arange(sink_size) def calculate_positions_and_update_indices( self, input_pos: torch.Tensor, seq_len: int ) -> torch.Tensor: """ Calculate indices into k_cache, v_cache for placing k_val, v_val. - - For positions < sink_size: index = position (fixed) - For positions >= sink_size: index = sink_size + (pos - sink_size) % (window_size * 2) + + This is identical to the original CachePositionsManager logic. """ start_pos = input_pos[0].item() torch._check_is_size(start_pos) orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos + + # Simple ring buffer: just mod by cache size + indices = orig_indices % self.max_context_length - # Calculate cache indices based on whether position is sink or window - sink_part = torch.minimum(orig_indices, torch.tensor(self.sink_size)) - window_part = torch.maximum( - orig_indices - self.sink_size, torch.tensor(0) - ) % (self.window_size * 2) - is_sink = orig_indices < self.sink_size - indices = torch.where(is_sink, sink_part, self.sink_size + window_part) - - # Update cache_positions: clear old positions and set new ones + # Update cache_positions exactly like original CachePositionsManager full_t = torch.full((self.max_context_length,), -1, dtype=torch.long) arange_tensor = torch.arange(self.max_context_length, dtype=torch.long) - # Keep sink positions (0 to sink_size-1) and clear window positions that will be overwritten - cache_positions = torch.where( - arange_tensor < self.sink_size, self.cache_positions, full_t - ) - # For non-sink positions, check if they should be cleared cache_positions = torch.where( - arange_tensor < start_pos, self.cache_positions, cache_positions + arange_tensor < start_pos, self.cache_positions, full_t ) self.cache_positions.copy_(cache_positions) self.cache_positions.index_copy_(0, indices, orig_indices) @@ -230,10 +221,8 @@ def __init__( self.is_ring_buffer = True # Cache positions manager for determining write locations - self.cache_positions_manager = CachePositionsManagerWithSink( - window_size=window_size, - sink_size=sink_size, - ) + # Pass the total cache size (same as self.max_context_length after super().__init__) + self.cache_positions_manager = CachePositionsManagerWithSink(total_cache_size) def create_causal_mask_for_ring_buffer( self, start_pos: torch.Tensor, seq_len: int @@ -243,10 +232,16 @@ def create_causal_mask_for_ring_buffer( Sink tokens are ALWAYS visible, plus recent tokens in the window. """ cache_positions = self.cache_positions_manager.cache_positions - # Use attention sink mask that always allows attending to sink tokens - return _create_causal_mask_for_attention_sink( - cache_positions, self.window_size, self.sink_size, start_pos, seq_len - ) + if self.sink_size > 0: + # Use attention sink mask that always allows attending to sink tokens + return _create_causal_mask_for_attention_sink( + cache_positions, self.window_size, self.sink_size, start_pos, seq_len + ) + else: + # Pure ring buffer mode - use original mask with window_size = actual window + return _create_causal_mask_for_ring_buffer( + cache_positions, self.window_size, start_pos, seq_len + ) def update( self, @@ -360,21 +355,32 @@ def _replace_attention( if isinstance(child_module, AttentionMHA): kv_cache = child_module.kv_cache - kv_cache_with_attention_sink = KVCacheWithAttentionSink( - n_heads=kv_cache.n_heads, - head_dim=kv_cache.head_dim, - enable_dynamic_shape=kv_cache.enable_dynamic_shape, - rope=rope_with_attention_sink, - max_batch_size=kv_cache.max_batch_size, - window_size=window_size, - sink_size=sink_size, - eviction_batch_size=eviction_batch_size, - dtype=kv_cache.k_cache.dtype, - ) - child_module.kv_cache = kv_cache_with_attention_sink - child_module.forward = types.MethodType( # pyre-ignore - attention_sink_forward, child_module - ) + if sink_size == 0: + # For sink_size=0, use the exact same RingKVCache that works + # This is a test to ensure parity with the working implementation + child_module.kv_cache = RingKVCache( + kv_cache.max_batch_size, + window_size, # RingKVCache expects user-provided window size + kv_cache.n_heads, + kv_cache.head_dim, + kv_cache.enable_dynamic_shape, + kv_cache.k_cache.dtype, + ) + else: + kv_cache_with_attention_sink = KVCacheWithAttentionSink( + n_heads=kv_cache.n_heads, + head_dim=kv_cache.head_dim, + enable_dynamic_shape=kv_cache.enable_dynamic_shape, + rope=rope_with_attention_sink, + max_batch_size=kv_cache.max_batch_size, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=eviction_batch_size, + dtype=kv_cache.k_cache.dtype, + ) + child_module.kv_cache = kv_cache_with_attention_sink + # Don't replace forward - let the original AttentionMHA.forward handle it + # since our KVCache has is_ring_buffer=True, it will use the ring buffer mask def enable_attention_sink( diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index 373f57d3a8e..11d77f321bf 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -129,24 +129,42 @@ Error TextLLMRunner::generate( std::vector prompt_tokens = encode_res.get(); int num_prompt_tokens = prompt_tokens.size(); - // Reduce max_context_len by pos_ - int64_t max_context_len = metadata_.at(kMaxContextLen) - pos_; + // Get max_seq_len for single prefill chunk limit + int64_t max_seq_len = metadata_.at(kMaxSeqLen); + int64_t max_context_len = metadata_.at(kMaxContextLen); + ET_CHECK_OR_RETURN_ERROR( num_prompt_tokens >= 1, InvalidArgument, "Expected at least 1 prompt token"); + + // For models with sliding window (Ring Buffer / Attention Sink), + // we allow pos_ to exceed max_context_len. The model handles this + // internally via ring buffer indexing or token eviction. + // We only check that a single prefill chunk doesn't exceed max_seq_len. ET_CHECK_OR_RETURN_ERROR( - num_prompt_tokens < max_context_len, + num_prompt_tokens <= max_seq_len, InvalidArgument, - "num_prompt_tokens %d >= max_context_len %" PRId64 - ", Max seq length exceeded - please increase max seq len value in your export script", + "num_prompt_tokens %d > max_seq_len %" PRId64 + ", Single prefill chunk too large - please reduce prompt size or increase max_seq_len", num_prompt_tokens, - max_context_len); + max_seq_len); - // Determine max_new_tokens using the GenerationConfig's resolve method, - // then subtract pos_ for max_new_tokens. + // Determine max_new_tokens using the GenerationConfig's resolve method. + // For sliding window models, we use max_context_len directly (not reduced by pos_) + // because the model handles position wrapping internally via ring buffer. int max_new_tokens = config.resolve_max_new_tokens(max_context_len, num_prompt_tokens); + + // TEMPORARY: For sliding window / infinite context testing, + // override max_new_tokens to allow unlimited generation + if (max_new_tokens <= 0) { + max_new_tokens = 1000000; // Effectively unlimited + } + // If user specified seq_len, use that instead + if (config.seq_len > 0 && config.seq_len > max_new_tokens) { + max_new_tokens = config.seq_len; + } ET_LOG( Info, From 83f437accd9259dad175087d8af6bbdd5ff798aa Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Feb 2026 12:12:03 -0800 Subject: [PATCH 03/17] kv cache sdpa --- .../llama/config/llama_attention_sink.yaml | 2 +- examples/models/llama/model.py | 9 ++++ .../source_transformation/attention_sink.py | 44 +++++++++---------- .../source_transformation/custom_kv_cache.py | 14 +++--- 4 files changed, 37 insertions(+), 32 deletions(-) diff --git a/examples/models/llama/config/llama_attention_sink.yaml b/examples/models/llama/config/llama_attention_sink.yaml index 2e8cfb18cde..1d859035d74 100644 --- a/examples/models/llama/config/llama_attention_sink.yaml +++ b/examples/models/llama/config/llama_attention_sink.yaml @@ -2,7 +2,7 @@ base: metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' model: - use_sdpa_with_kv_cache: False # attention_sink 需要用标准 SDPA + use_sdpa_with_kv_cache: True # Now supported! We set use_attention_mask=True on SDPACustom use_kv_cache: True dtype_override: fp32 enable_dynamic_shape: True diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index a2c8d73207b..3be00d78711 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -226,6 +226,15 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): f"sink_size + window_size ({sink_size + window_size})" ) + # IMPORTANT: For attention sink, we need RoPE frequencies for all possible generation + # positions, not just the cache size. Override the model's max_context_len to use + # a larger value that supports extended generation. + # We use model_args.max_context_len which was set from export.max_context_length + # but for RoPE we need the full generation length capability. + # Use 131072 (128k) as default for Llama 3.2 models or the original model max if larger. + default_rope_length = max(131072, model_args.max_context_len) + model_args.max_context_len = default_rope_length + self.model_ = enable_attention_sink( module=self.model_, params=model_args, diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index 068ac7bdc68..c2e0f6606f5 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -355,30 +355,26 @@ def _replace_attention( if isinstance(child_module, AttentionMHA): kv_cache = child_module.kv_cache - if sink_size == 0: - # For sink_size=0, use the exact same RingKVCache that works - # This is a test to ensure parity with the working implementation - child_module.kv_cache = RingKVCache( - kv_cache.max_batch_size, - window_size, # RingKVCache expects user-provided window size - kv_cache.n_heads, - kv_cache.head_dim, - kv_cache.enable_dynamic_shape, - kv_cache.k_cache.dtype, - ) - else: - kv_cache_with_attention_sink = KVCacheWithAttentionSink( - n_heads=kv_cache.n_heads, - head_dim=kv_cache.head_dim, - enable_dynamic_shape=kv_cache.enable_dynamic_shape, - rope=rope_with_attention_sink, - max_batch_size=kv_cache.max_batch_size, - window_size=window_size, - sink_size=sink_size, - eviction_batch_size=eviction_batch_size, - dtype=kv_cache.k_cache.dtype, - ) - child_module.kv_cache = kv_cache_with_attention_sink + # Always use KVCacheWithAttentionSink, even for sink_size=0 + # This ensures we don't get replaced by CustomKVCache when use_sdpa_with_kv_cache=True + kv_cache_with_attention_sink = KVCacheWithAttentionSink( + n_heads=kv_cache.n_heads, + head_dim=kv_cache.head_dim, + enable_dynamic_shape=kv_cache.enable_dynamic_shape, + rope=rope_with_attention_sink, + max_batch_size=kv_cache.max_batch_size, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=eviction_batch_size, + dtype=kv_cache.k_cache.dtype, + ) + child_module.kv_cache = kv_cache_with_attention_sink + + # If using SDPACustom (fused SDPA op), enable attention mask support + # so it uses our ring buffer / attention sink mask instead of simple causal mask + if "SDPACustom" in child_module.SDPA.__class__.__name__: + child_module.SDPA.use_attention_mask = True + # Don't replace forward - let the original AttentionMHA.forward handle it # since our KVCache has is_ring_buffer=True, it will use the ring buffer mask diff --git a/examples/models/llama/source_transformation/custom_kv_cache.py b/examples/models/llama/source_transformation/custom_kv_cache.py index 373f4fcc4a0..f8a268183b5 100644 --- a/examples/models/llama/source_transformation/custom_kv_cache.py +++ b/examples/models/llama/source_transformation/custom_kv_cache.py @@ -371,17 +371,17 @@ def replace_kv_cache_with_custom_kv_cache(module): def _replace_kv_cache_with_custom_kv_cache(module): - # Import here to avoid circular imports - from executorch.examples.models.llama.source_transformation.attention_sink import ( - KVCacheWithAttentionSink, - ) - for name, child in module.named_children(): # Skip KVCacheWithAttentionSink as it has special evict_tokens logic - # that is not compatible with CustomKVCache - if isinstance(child, KVCacheWithAttentionSink): + # that is not compatible with CustomKVCache. + # Check by class name because the class might come from different module paths + # (e.g., 'examples.models...' vs 'executorch.examples.models...') + child_class_name = type(child).__name__ + if child_class_name == "KVCacheWithAttentionSink": + logging.info(f"Skipping KVCacheWithAttentionSink at {name}") _replace_kv_cache_with_custom_kv_cache(child) elif isinstance(child, KVCache): + logging.info(f"Replacing KVCache at {name} (type={child_class_name})") cache_shape = child.k_cache.shape cache_dtype = child.k_cache.dtype max_batch_size, n_heads, max_context_length, head_dim = cache_shape From 97f9910eabf5a118d27d7c6c3adb6804739bc6d2 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Feb 2026 12:28:57 -0800 Subject: [PATCH 04/17] Test --- examples/models/llama/BUCK | 12 + examples/models/llama/eval_llama_lib.py | 10 +- .../test_attention_sink_ring_buffer.py | 594 ++++++++++++++++++ 3 files changed, 615 insertions(+), 1 deletion(-) create mode 100644 examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py diff --git a/examples/models/llama/BUCK b/examples/models/llama/BUCK index 9d9897a2819..57ed846dd86 100644 --- a/examples/models/llama/BUCK +++ b/examples/models/llama/BUCK @@ -283,6 +283,18 @@ fbcode_target(_kind = runtime.python_test, ], ) +fbcode_target(_kind = runtime.python_test, + name = "attention_sink_ring_buffer_test", + srcs = [ + "source_transformation/test_attention_sink_ring_buffer.py", + ], + supports_static_listing = False, + deps = [ + "//caffe2:torch", + ":export_library", + ], +) + fbcode_target(_kind = runtime.python_test, name = "quantized_sdpa_source_transform_test", srcs = [ diff --git a/examples/models/llama/eval_llama_lib.py b/examples/models/llama/eval_llama_lib.py index 03f8f5cd759..e13a5299e61 100644 --- a/examples/models/llama/eval_llama_lib.py +++ b/examples/models/llama/eval_llama_lib.py @@ -338,6 +338,8 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse Evaluate the model's perplexity when AttentionSink is enabled. This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py + + Updated for the ring-buffer based attention sink implementation. """ # Convert args to LlmConfig from executorch.extension.llm.export.config.llm_config import LlmConfig @@ -351,7 +353,13 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse sink_size = int(attention_sink_params[0]) window_size = int(attention_sink_params[1]) - assert llm_config.export.max_seq_length == sink_size + window_size + # For the ring buffer implementation, the cache size is sink_size + window_size * 2 + # max_context_length should be >= sink_size + window_size (for RoPE frequencies) + # but can be larger to support extended generation + assert llm_config.export.max_context_length >= sink_size + window_size, ( + f"max_context_length ({llm_config.export.max_context_length}) must be >= " + f"sink_size + window_size ({sink_size + window_size})" + ) device = "cuda" if torch.cuda.is_available() else "cpu" manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) diff --git a/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py new file mode 100644 index 00000000000..7109129caca --- /dev/null +++ b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py @@ -0,0 +1,594 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for the ring-buffer based attention sink implementation. + +This tests the torch.export-compatible implementation that uses a ring buffer +for the sliding window rather than explicit token eviction. + +Usage: + # Run with pytest + python -m pytest examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py -v + + # Or run directly + python examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py +""" + +import unittest + +import torch +from executorch.examples.models.llama.model_args import ModelArgs + +from executorch.examples.models.llama.source_transformation.attention_sink import ( + CachePositionsManagerWithSink, + KVCacheWithAttentionSink, + RopeWithAttentionSink, + _create_causal_mask_for_attention_sink, +) + + +class CachePositionsManagerWithSinkTest(unittest.TestCase): + """Test the cache positions manager for ring buffer indexing.""" + + def setUp(self): + self.cache_size = 32 # Total cache size (e.g., sink_size + window_size * 2) + self.manager = CachePositionsManagerWithSink(self.cache_size) + + def test_initial_positions_are_zero(self): + """Cache positions should start as zeros.""" + expected = torch.zeros(self.cache_size, dtype=torch.long) + torch.testing.assert_close(self.manager.cache_positions, expected) + + def test_simple_update(self): + """Test simple sequential update without wraparound.""" + input_pos = torch.tensor([0], dtype=torch.long) + seq_len = 5 + indices = self.manager.calculate_positions_and_update_indices(input_pos, seq_len) + + # Should return indices 0, 1, 2, 3, 4 + expected_indices = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long) + torch.testing.assert_close(indices, expected_indices) + + # Cache positions at those indices should be the original positions + for i in range(5): + self.assertEqual(self.manager.cache_positions[i].item(), i) + + def test_wraparound(self): + """Test ring buffer wraparound.""" + # Fill cache to position 30 + input_pos = torch.tensor([0], dtype=torch.long) + self.manager.calculate_positions_and_update_indices(input_pos, 30) + + # Add 5 more tokens at position 30 - should wrap around + input_pos = torch.tensor([30], dtype=torch.long) + indices = self.manager.calculate_positions_and_update_indices(input_pos, 5) + + # Indices should wrap: 30 % 32 = 30, 31, 0, 1, 2 + expected_indices = torch.tensor([30, 31, 0, 1, 2], dtype=torch.long) + torch.testing.assert_close(indices, expected_indices) + + def test_cache_positions_track_original_positions(self): + """Cache positions should track which original position is at each index.""" + # Fill with positions 0-31 + input_pos = torch.tensor([0], dtype=torch.long) + self.manager.calculate_positions_and_update_indices(input_pos, 32) + + # Now add position 32 which wraps to index 0 + input_pos = torch.tensor([32], dtype=torch.long) + self.manager.calculate_positions_and_update_indices(input_pos, 1) + + # Index 0 should now contain original position 32 + self.assertEqual(self.manager.cache_positions[0].item(), 32) + + +class CausalMaskTest(unittest.TestCase): + """Test the causal mask generation for attention sink.""" + + def test_mask_allows_sink_tokens(self): + """Sink tokens should always be visible (mask = 0).""" + cache_size = 32 + sink_size = 4 + window_size = 14 # cache_size = sink_size + window_size * 2 + + # Create cache positions where positions 0-3 are sink tokens + cache_positions = torch.arange(cache_size, dtype=torch.long) + + start_pos = 20 # Query at position 20 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Sink tokens (indices 0-3, original positions 0-3) should have mask = 0 + for i in range(sink_size): + self.assertEqual(mask[0, i].item(), 0.0, f"Sink token at index {i} should be visible") + + def test_mask_blocks_future_tokens(self): + """Future tokens should be masked (-inf).""" + cache_size = 32 + sink_size = 4 + window_size = 14 + + cache_positions = torch.arange(cache_size, dtype=torch.long) + + start_pos = 10 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Future tokens (positions > 10) should have mask = -inf + for i in range(11, cache_size): + self.assertEqual(mask[0, i].item(), float('-inf'), f"Future token at position {i} should be masked") + + def test_mask_respects_window(self): + """Tokens outside the window should be masked.""" + cache_size = 32 + sink_size = 4 + window_size = 5 # Only allow 5 recent tokens + + cache_positions = torch.arange(cache_size, dtype=torch.long) + + start_pos = 20 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Positions 16-20 should be visible (within window of 5) + for pos in range(16, 21): + self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible (in window)") + + # Position 15 should be masked (outside window, not a sink) + self.assertEqual(mask[0, 15].item(), float('-inf'), f"Position 15 should be masked (outside window)") + + +class KVCacheWithAttentionSinkTest(unittest.TestCase): + """Test the KV cache with attention sink.""" + + def setUp(self): + torch.manual_seed(42) + self.window_size = 14 + self.sink_size = 4 + self.n_heads = 8 + self.head_dim = 64 + self.max_batch_size = 1 + + # Create model args with enough context for RoPE + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, # Large enough for RoPE + n_heads=self.n_heads, + n_kv_heads=self.n_heads, + dim=self.n_heads * self.head_dim, + ) + + self.rope = RopeWithAttentionSink( + params=self.params, + window_size=self.window_size, + sink_size=self.sink_size, + eviction_batch_size=1, + ) + + self.kv_cache = KVCacheWithAttentionSink( + n_heads=self.n_heads, + head_dim=self.head_dim, + enable_dynamic_shape=True, + rope=self.rope, + max_batch_size=self.max_batch_size, + window_size=self.window_size, + sink_size=self.sink_size, + eviction_batch_size=1, + dtype=torch.float32, + ) + + def test_cache_size(self): + """Cache should be sink_size + window_size * 2.""" + expected_size = self.sink_size + self.window_size * 2 # 4 + 28 = 32 + self.assertEqual(self.kv_cache.k_cache.size(2), expected_size) + self.assertEqual(self.kv_cache.v_cache.size(2), expected_size) + + def test_is_ring_buffer(self): + """Cache should be marked as ring buffer.""" + self.assertTrue(self.kv_cache.is_ring_buffer) + + def test_update_stores_kv(self): + """Update should store key-value pairs.""" + k = torch.randn(1, self.n_heads, 5, self.head_dim) + v = torch.randn(1, self.n_heads, 5, self.head_dim) + input_pos = torch.tensor([0], dtype=torch.long) + + k_out, v_out = self.kv_cache.update(input_pos, k, v) + + # First 5 positions should contain our values + torch.testing.assert_close(k_out[:, :, :5, :], k) + torch.testing.assert_close(v_out[:, :, :5, :], v) + + def test_evict_tokens_returns_zero(self): + """Ring buffer implementation doesn't shift, so evict returns 0.""" + input_pos = torch.tensor([100], dtype=torch.long) + shift = self.kv_cache.evict_tokens(input_pos, 10) + self.assertEqual(shift, 0) + + def test_extended_generation(self): + """Test that cache works for positions beyond cache size.""" + cache_size = self.kv_cache.k_cache.size(2) + + # Fill cache with initial tokens + for pos in range(cache_size + 50): + k = torch.randn(1, self.n_heads, 1, self.head_dim) + v = torch.randn(1, self.n_heads, 1, self.head_dim) + input_pos = torch.tensor([pos], dtype=torch.long) + + k_out, v_out = self.kv_cache.update(input_pos, k, v) + + # Should not raise any errors + self.assertEqual(k_out.shape, self.kv_cache.k_cache.shape) + self.assertEqual(v_out.shape, self.kv_cache.v_cache.shape) + + +class RopeWithAttentionSinkTest(unittest.TestCase): + """Test RoPE for attention sink.""" + + def setUp(self): + torch.manual_seed(42) + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, + n_heads=8, + dim=512, + ) + + self.rope = RopeWithAttentionSink( + params=self.params, + window_size=100, + sink_size=4, + eviction_batch_size=1, + ) + + def test_get_freqs_uses_original_position(self): + """RoPE frequencies should use the original position.""" + input_pos = torch.tensor([50], dtype=torch.long) + seq_len = 5 + + freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seq_len) + + # Should get frequencies for positions 50-54 + expected_cos = self.rope.freqs_cos[50:55] + expected_sin = self.rope.freqs_sin[50:55] + + torch.testing.assert_close(freqs_cos, expected_cos) + torch.testing.assert_close(freqs_sin, expected_sin) + + def test_rerotate_k(self): + """Test re-rotation of k from one position to another.""" + batch_size = 1 + seq_len = 8 + n_heads = self.params.n_heads + head_dim = self.params.dim // n_heads + + k = torch.randn(batch_size, seq_len, n_heads, head_dim) + q = torch.randn(batch_size, seq_len, n_heads, head_dim) + + # Rotate k at position 100 + original_pos = 100 + freqs_cos, freqs_sin = self.rope.get_freqs( + torch.tensor([original_pos], dtype=torch.long), seq_len + ) + _, rotated_k = self.rope.forward(q, k, freqs_cos, freqs_sin) + + # Re-rotate to position 50 + new_pos = 50 + rerotated_k = self.rope.rerotate_k(rotated_k, original_pos, new_pos) + + # This should be equivalent to directly rotating k at position 50 + freqs_cos_new, freqs_sin_new = self.rope.get_freqs( + torch.tensor([new_pos], dtype=torch.long), seq_len + ) + _, expected_k = self.rope.forward(q, k, freqs_cos_new, freqs_sin_new) + + torch.testing.assert_close(rerotated_k, expected_k, rtol=1e-4, atol=1e-4) + + +class CausalMaskWithWraparoundTest(unittest.TestCase): + """Test causal mask with ring buffer wraparound.""" + + def test_mask_after_wraparound(self): + """Test mask after cache has wrapped around.""" + cache_size = 16 + sink_size = 4 + window_size = 6 # cache_size = sink_size + window_size * 2 + + # Simulate cache after generating beyond cache_size: + # The ring buffer wraps, so indices 0-15 contain positions that wrap + # At position 50, with cache_size=16, the cache contains: + # positions 50-15=35 to 49 at various indices + cache_positions = torch.zeros(cache_size, dtype=torch.long) + # Fill with positions that would exist after generating 50 tokens + # idx = pos % cache_size, so: + # pos 34-49 occupy indices 2-15 and 0-1 + for pos in range(34, 50): + idx = pos % cache_size + cache_positions[idx] = pos + + start_pos = 49 # Query at position 49 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Positions within window (49-6+1=44 to 49) should be visible + visible_count = 0 + for i in range(cache_size): + pos = cache_positions[i].item() + if pos >= 44 and pos <= 49: # In window + self.assertEqual(mask[0, i].item(), 0.0, + f"Position {pos} at idx {i} should be visible (in window)") + visible_count += 1 + + # Should have some visible tokens + self.assertGreater(visible_count, 0, "Should have visible tokens in window") + + def test_mask_with_sink_size_zero(self): + """Test pure sliding window (sink_size=0).""" + cache_size = 16 + sink_size = 0 + window_size = 8 + + cache_positions = torch.arange(cache_size, dtype=torch.long) + start_pos = 10 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Positions 3-10 should be visible (within window of 8) + for pos in range(3, 11): + self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible") + + # Positions 0-2 should be masked (outside window) + for pos in range(0, 3): + self.assertEqual(mask[0, pos].item(), float('-inf'), + f"Position {pos} should be masked (outside window)") + + +class PrefillTest(unittest.TestCase): + """Test prefill scenarios.""" + + def setUp(self): + torch.manual_seed(42) + self.window_size = 14 + self.sink_size = 4 + self.n_heads = 8 + self.head_dim = 64 + + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, + n_heads=self.n_heads, + n_kv_heads=self.n_heads, + dim=self.n_heads * self.head_dim, + ) + + self.rope = RopeWithAttentionSink( + params=self.params, + window_size=self.window_size, + sink_size=self.sink_size, + eviction_batch_size=1, + ) + + self.kv_cache = KVCacheWithAttentionSink( + n_heads=self.n_heads, + head_dim=self.head_dim, + enable_dynamic_shape=True, + rope=self.rope, + max_batch_size=1, + window_size=self.window_size, + sink_size=self.sink_size, + eviction_batch_size=1, + dtype=torch.float32, + ) + + def test_prefill_entire_cache(self): + """Test prefill that fills entire cache.""" + cache_size = self.kv_cache.k_cache.size(2) + + k = torch.randn(1, self.n_heads, cache_size, self.head_dim) + v = torch.randn(1, self.n_heads, cache_size, self.head_dim) + input_pos = torch.tensor([0], dtype=torch.long) + + k_out, v_out = self.kv_cache.update(input_pos, k, v) + + # All positions should be filled + torch.testing.assert_close(k_out, k) + torch.testing.assert_close(v_out, v) + + def test_prefill_larger_than_cache_raises_error(self): + """Test that prefill larger than cache size raises an assertion error.""" + cache_size = self.kv_cache.k_cache.size(2) + seq_len = cache_size + 10 + + k = torch.randn(1, self.n_heads, seq_len, self.head_dim) + v = torch.randn(1, self.n_heads, seq_len, self.head_dim) + input_pos = torch.tensor([0], dtype=torch.long) + + # This should raise an assertion error since seq_len > cache_size + with self.assertRaises(AssertionError): + self.kv_cache.update(input_pos, k, v) + + def test_prefill_followed_by_decode(self): + """Test prefill followed by decode steps.""" + cache_size = self.kv_cache.k_cache.size(2) + + # Prefill with 20 tokens + k_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) + v_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) + input_pos = torch.tensor([0], dtype=torch.long) + self.kv_cache.update(input_pos, k_prefill, v_prefill) + + # Decode 5 more tokens + for i in range(5): + k_decode = torch.randn(1, self.n_heads, 1, self.head_dim) + v_decode = torch.randn(1, self.n_heads, 1, self.head_dim) + input_pos = torch.tensor([20 + i], dtype=torch.long) + k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode) + + # Verify cache positions are updated + expected_pos = 20 + i + cache_idx = expected_pos % cache_size + self.assertEqual( + self.kv_cache.cache_positions_manager.cache_positions[cache_idx].item(), + expected_pos + ) + + +class EnableAttentionSinkTest(unittest.TestCase): + """Test the enable_attention_sink transformation.""" + + def setUp(self): + torch.manual_seed(42) + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, + n_heads=8, + n_kv_heads=8, + dim=512, + n_layers=2, + vocab_size=100, + ) + + def test_enable_attention_sink_transforms_model(self): + """Test that enable_attention_sink properly transforms the model.""" + from executorch.examples.models.llama.llama_transformer import construct_transformer + from executorch.examples.models.llama.source_transformation.attention_sink import ( + enable_attention_sink, + ) + + # Create a simple transformer + with torch.device("meta"): + model = construct_transformer(self.params) + model.to_empty(device="cpu") + + # Apply attention sink transformation + model = enable_attention_sink( + module=model, + params=self.params, + sink_size=4, + window_size=100, + eviction_batch_size=1, + ) + + # Check that KV caches are replaced + for layer in model.layers: + kv_cache = layer.attention.kv_cache + self.assertIsInstance(kv_cache, KVCacheWithAttentionSink) + self.assertEqual(kv_cache.sink_size, 4) + self.assertEqual(kv_cache.window_size, 100) + self.assertTrue(kv_cache.is_ring_buffer) + + def test_enable_attention_sink_replaces_rope(self): + """Test that RoPE is replaced with RopeWithAttentionSink.""" + from executorch.examples.models.llama.llama_transformer import construct_transformer + from executorch.examples.models.llama.source_transformation.attention_sink import ( + enable_attention_sink, + ) + + with torch.device("meta"): + model = construct_transformer(self.params) + model.to_empty(device="cpu") + + model = enable_attention_sink( + module=model, + params=self.params, + sink_size=4, + window_size=100, + eviction_batch_size=1, + ) + + # Check that rope is replaced + for layer in model.layers: + rope = layer.attention.rope + self.assertIsInstance(rope, RopeWithAttentionSink) + + +class IntegrationTest(unittest.TestCase): + """Integration tests for end-to-end scenarios.""" + + def setUp(self): + torch.manual_seed(42) + + def test_cache_positions_consistency(self): + """Test that cache positions remain consistent during generation.""" + cache_size = 32 + sink_size = 4 + window_size = 14 + n_heads = 8 + head_dim = 64 + + params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, + n_heads=n_heads, + n_kv_heads=n_heads, + dim=n_heads * head_dim, + ) + + rope = RopeWithAttentionSink( + params=params, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=1, + ) + + kv_cache = KVCacheWithAttentionSink( + n_heads=n_heads, + head_dim=head_dim, + enable_dynamic_shape=True, + rope=rope, + max_batch_size=1, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=1, + dtype=torch.float32, + ) + + # Generate 100 tokens + for pos in range(100): + k = torch.randn(1, n_heads, 1, head_dim) + v = torch.randn(1, n_heads, 1, head_dim) + input_pos = torch.tensor([pos], dtype=torch.long) + + kv_cache.update(input_pos, k, v) + + # Create mask and verify it's valid + mask = kv_cache.create_causal_mask_for_ring_buffer(pos, 1) + + # Mask should not be all -inf (would mean no tokens to attend to) + non_inf_count = (mask != float('-inf')).sum().item() + self.assertGreater(non_inf_count, 0, f"At pos {pos}, mask should have visible tokens") + + # For positions >= sink_size, sinks should always be visible + if pos >= sink_size: + for i in range(sink_size): + cache_pos = kv_cache.cache_positions_manager.cache_positions[i].item() + if cache_pos < sink_size: # This is actually a sink token + self.assertEqual(mask[0, i].item(), 0.0, + f"Sink at idx {i} should be visible at pos {pos}") + + +if __name__ == '__main__': + unittest.main() From 075d521b2960945aedd4f4a908b3f69107c5f18d Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Feb 2026 12:31:38 -0800 Subject: [PATCH 05/17] Remove trailing whitespace Co-authored-by: Claude --- examples/models/llama/eval_llama_lib.py | 404 +----------- .../source_transformation/attention_sink.py | 411 +----------- .../test_attention_sink_ring_buffer.py | 595 +----------------- extension/llm/runner/text_llm_runner.cpp | 311 +-------- 4 files changed, 4 insertions(+), 1717 deletions(-) diff --git a/examples/models/llama/eval_llama_lib.py b/examples/models/llama/eval_llama_lib.py index e13a5299e61..d8e107f74ef 100644 --- a/examples/models/llama/eval_llama_lib.py +++ b/examples/models/llama/eval_llama_lib.py @@ -1,403 +1 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -import argparse - -from typing import Optional, Union - -import torch - -from datasets import load_dataset -from executorch.examples.models.llama.export_llama_lib import ( - get_quantizer_and_quant_params, -) - -from executorch.extension.llm.export.builder import LLMEdgeManager -from lm_eval.evaluator import simple_evaluate -from pytorch_tokenizers import get_tokenizer -from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer -from pytorch_tokenizers.tiktoken import TiktokenTokenizer as Tiktoken -from torch.nn import CrossEntropyLoss -from tqdm import tqdm - -from .evaluate.eager_eval import EagerEvalWrapper - -from .export_llama_lib import ( - _prepare_for_llama_export, - build_args_parser as _build_args_parser, -) - - -class GraphModuleEvalWrapper(EagerEvalWrapper): - """ - A wrapper class for ExecuTorch py-binded integration with the - lm-evaluation-harness library. - """ - - def __init__( - self, - model: torch.fx.GraphModule, - tokenizer: Union[SentencePieceTokenizer, Tiktoken], - max_seq_length: Optional[int] = None, - use_kv_cache: bool = False, - generate_full_logits: bool = False, - enable_dynamic_shape: bool = True, - ): - super().__init__( - model=model, tokenizer=tokenizer, max_seq_length=max_seq_length - ) - self._model = model.to(self.device) - self._use_kv_cache = use_kv_cache - self._generate_full_logits = generate_full_logits - self._enable_dynamic_shape = enable_dynamic_shape - - def _model_call(self, inps): - if self._use_kv_cache: - if not self._enable_dynamic_shape: - # graph module exported without dynamic shape won't work with a different shape. - # And we have to do single token prefill here. - result_logits = [] - for pos in range(inps.shape[-1]): - pos_tensor = torch.tensor([pos], dtype=torch.int64) - logits = self._model( - inps[:, pos : pos + 1], {"input_pos": pos_tensor} - ) - result_logits.append(logits) - if self._generate_full_logits: - return torch.cat(result_logits, dim=1) - else: - return torch.stack(result_logits, dim=1) - else: - pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) - # Batch process the whole sequence. - logits = self._model( - inps[:, : self._max_seq_length], {"input_pos": pos_tensor} - ) - return logits - - else: - return self._model(inps) - - def _model_generate(self, context, max_length, eos_token_id): - raise Exception("unimplemented") - - -class ETPybindEvalWrapper(EagerEvalWrapper): - """ - A wrapper class for ExecuTorch py-binded integration with the - lm-evaluation-harness library. - """ - - def __init__( - self, - model: str, - tokenizer: Union[SentencePieceTokenizer, Tiktoken], - max_seq_length: Optional[int] = None, - ): - super().__init__(None, tokenizer, max_seq_length) # pyre-ignore - self._model = model # Expects model to be path to a .pte file - - from executorch.extension.pybindings.portable_lib import _load_for_executorch - - # Load custom ops and quantized ops. - from executorch.extension.pybindings import portable_lib # noqa # usort: skip - - # Note: import this after portable_lib - from executorch.extension.llm.custom_ops import ( # noqa - custom_ops, # usort: skip - ) - from executorch.kernels import quantized # noqa - - self._et_model = _load_for_executorch(self._model) - self._use_kv_cache = self._et_model.run_method("use_kv_cache")[0] # pyre-ignore - - def _model_call(self, inps): - # Given inps (tokens), return the logits from a single forward call - # inps: Tensor of shape (1, max_seq_len - 1) - # logits: Tensor of shape (1, max_seq_len - 1, vocab_size) - result = [] - if self._use_kv_cache: - pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) - result = self._et_model.forward( - (inps[:, : self._max_seq_length], pos_tensor) - ) - else: - result = self._et_model.forward((inps,)) - if result[0].dim() != 3: - raise ValueError( - f"Dim of logits must be 3 for evaluation. Got {result[0].dim()} here. Add --generate_full_logits in export_llama to generate a pte file with full logits." - ) - return result[0] - - -class ETRunnerEvalWrapper(EagerEvalWrapper): - """ - A wrapper class for ExecuTorch Runtime integration with the - lm-evaluation-harness library. - """ - - def __init__( - self, - model: str, - tokenizer: Union[SentencePieceTokenizer, Tiktoken], - tokenizer_bin: str, - max_seq_length: Optional[int] = None, - ): - super().__init__(None, tokenizer, max_seq_length) # pyre-ignore - self._model = model - self._tokenizer_bin = tokenizer_bin - - def _model_call(self, inps): - # Given inps (tokens), return the logits from a single - # forward call - - # Example: - # inps: Tensor of shape (1, N) - # logits: Tensor of shape (1, N, vocab_size) - pass - - -def gen_eval_wrapper( - model_name: str, - args: argparse.ArgumentParser, - llm_config=None, -): - """ - Generates a wrapper interface around the provided model and tokenizer for - the lm-evaluation-harness library. - - Returns: - eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library. - """ - # If llm_config is not provided, convert args to llm_config - if llm_config is None: - from executorch.extension.llm.export.config.llm_config import LlmConfig - - llm_config = LlmConfig.from_args(args) - - tokenizer = get_tokenizer(llm_config.base.tokenizer_path) - - # ExecuTorch Binary Evaluation - if (model := args.pte) is not None: # pyre-ignore - if (tokenizer_bin := args.tokenizer_bin) is not None: # pyre-ignore - # ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime - return ETRunnerEvalWrapper( - model=model, - tokenizer=tokenizer, - tokenizer_bin=tokenizer_bin, - max_seq_length=llm_config.export.max_seq_length, - ) - - # ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings - return ETPybindEvalWrapper( - model=model, - tokenizer=tokenizer, - # Exported model takes at most (max_seq_length - 1) tokens. - # Note that the eager model takes at most max_seq_length tokens. - max_seq_length=llm_config.export.max_seq_length - 1, - ) - - pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( - llm_config - ) - # GPTFastEvalWrapper: Create a wrapper around a pre-exported model - manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) - - if len(quantizers) != 0: - manager = manager.export().pt2e_quantize(quantizers) - model = ( - manager.pre_autograd_graph_module.to(device="cuda") # pyre-ignore - if torch.cuda.is_available() - else manager.pre_autograd_graph_module.to(device="cpu") - ) - return GraphModuleEvalWrapper( - model=model, - tokenizer=tokenizer, - max_seq_length=llm_config.export.max_seq_length, - use_kv_cache=llm_config.model.use_kv_cache, - enable_dynamic_shape=llm_config.model.enable_dynamic_shape, - ) - else: - # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch - # for quantizers. Currently export only works with --kv_cache, but - # fails without the kv_cache mode - model = ( - manager.model.eval().to(device="cuda") - if torch.cuda.is_available() - else manager.model.eval().to(device="cpu") - ) - - # Save the checkpoint after the eager model preparation is done. - # The reason for this option is that the checkpoint can be used - # to do evaluations in other evaluation platforms, or with data - # that is not available in this eval_llama. We save the checkpoint - # here for consistency with eval_llama. The accuracy results we - # get from eval_llama can be used as a reference to other evaluations. - if args.output_eager_checkpoint_file is not None: # pyre-ignore - torch.save(model, args.output_eager_checkpoint_file) - - return EagerEvalWrapper( - model=model, - tokenizer=tokenizer, - max_seq_length=llm_config.export.max_seq_length, - use_kv_cache=llm_config.model.use_kv_cache, - ) - - -def build_args_parser() -> argparse.ArgumentParser: - # Start with arg parser from export_llama_lib - parser = _build_args_parser() - - # Add additional args specific to eval - parser.add_argument( - "--tasks", - nargs="+", - type=str, - default=["wikitext"], - help="list of lm-eluther tasks to evaluate usage: --tasks task1 task2", - ) - parser.add_argument( - "--limit", - type=int, - default=None, - help="number of samples to evalulate. If not set, evaluate all samples", - ) - parser.add_argument( - "-f", - "--num_fewshot", - type=int, - default=None, - metavar="N", - help="Number of examples in few-shot context", - ) - # Add additional args specific to eval via an ET Runner - # Note: For initial integration, the tokenizer.model is also required - parser.add_argument( - "--pte", - type=str, - default=None, - help="[For ExecuTorch] Path to the ExecuTorch model being evaluated. If provided, don't go through the export flow", - ) - parser.add_argument( - "--tokenizer_bin", - type=str, - default=None, - help="[For ExecuTorch] Path to the Tokenizer binary for evaluating ExecuTorch models via runtime", - ) - parser.add_argument( - "--output_eager_checkpoint_file", - type=str, - default=None, - help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.", - ) - - # Set of parameters secpific to AttentionSink. - parser.add_argument("--attention_sink_eval_tokens", type=int, default=0) - - return parser - - -def eval_llama( - model_name: str, - args: argparse.ArgumentParser, -) -> None: - # Convert args to LlmConfig - from executorch.extension.llm.export.config.llm_config import LlmConfig - - llm_config = LlmConfig.from_args(args) - - # Generate the eval wrapper - eval_wrapper = gen_eval_wrapper(model_name, args, llm_config) - - # Needed for loading mmlu dataset. - # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files - if args.tasks and "mmlu" in args.tasks: - import datasets - - datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True - - # Evaluate the model - with torch.no_grad(): - eval_results = simple_evaluate( - model=eval_wrapper, - tasks=args.tasks, - num_fewshot=args.num_fewshot, - limit=args.limit, - ) - - for task, res in eval_results["results"].items(): - print(f"{task}: {res}") - - -def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParser): - """ - Evaluate the model's perplexity when AttentionSink is enabled. - - This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py - - Updated for the ring-buffer based attention sink implementation. - """ - # Convert args to LlmConfig - from executorch.extension.llm.export.config.llm_config import LlmConfig - - llm_config = LlmConfig.from_args(args) - - assert llm_config.model.use_attention_sink is not None - assert args.attention_sink_eval_tokens > 0 - attention_sink_params = llm_config.model.use_attention_sink.split(",") - assert len(attention_sink_params) == 3 - sink_size = int(attention_sink_params[0]) - window_size = int(attention_sink_params[1]) - - # For the ring buffer implementation, the cache size is sink_size + window_size * 2 - # max_context_length should be >= sink_size + window_size (for RoPE frequencies) - # but can be larger to support extended generation - assert llm_config.export.max_context_length >= sink_size + window_size, ( - f"max_context_length ({llm_config.export.max_context_length}) must be >= " - f"sink_size + window_size ({sink_size + window_size})" - ) - - device = "cuda" if torch.cuda.is_available() else "cpu" - manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) - model = manager.model.eval().to(device=device) - tokenizer = get_tokenizer(llm_config.base.tokenizer_path) - - eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") - - nlls = [] - loss_fn = CrossEntropyLoss(reduction="none") - progress_bar = tqdm(total=args.attention_sink_eval_tokens) - input_pos = 0 - while input_pos < args.attention_sink_eval_tokens: - for text in eval_data["text"]: - tokens = tokenizer.encode(text, bos=False, eos=False) - if len(tokens) <= 0: - continue - with torch.no_grad(): - num_tokens = min( - len(tokens) - 1, args.attention_sink_eval_tokens - input_pos - ) - logits = model( - torch.tensor( - [tokens[:num_tokens]], dtype=torch.int64, device=device - ), - torch.tensor([input_pos], dtype=torch.int64, device=device), - ).squeeze(dim=0) - neg_log_likelihood = loss_fn( - logits, - torch.tensor( - [tokens[1 : num_tokens + 1]], dtype=torch.int64, device=device - ).view(-1), - ) - nlls.append(neg_log_likelihood) - input_pos += num_tokens - progress_bar.update(num_tokens) - if input_pos >= args.attention_sink_eval_tokens: - break - ppl = torch.exp(torch.cat(nlls).mean()) - print(f"Perplexity: {ppl.item()}") - return ppl.item() +# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.import argparsefrom typing import Optional, Unionimport torchfrom datasets import load_datasetfrom executorch.examples.models.llama.export_llama_lib import ( get_quantizer_and_quant_params,)from executorch.extension.llm.export.builder import LLMEdgeManagerfrom lm_eval.evaluator import simple_evaluatefrom pytorch_tokenizers import get_tokenizerfrom pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizerfrom pytorch_tokenizers.tiktoken import TiktokenTokenizer as Tiktokenfrom torch.nn import CrossEntropyLossfrom tqdm import tqdmfrom .evaluate.eager_eval import EagerEvalWrapperfrom .export_llama_lib import ( _prepare_for_llama_export, build_args_parser as _build_args_parser,)class GraphModuleEvalWrapper(EagerEvalWrapper): """ A wrapper class for ExecuTorch py-binded integration with the lm-evaluation-harness library. """ def __init__( self, model: torch.fx.GraphModule, tokenizer: Union[SentencePieceTokenizer, Tiktoken], max_seq_length: Optional[int] = None, use_kv_cache: bool = False, generate_full_logits: bool = False, enable_dynamic_shape: bool = True, ): super().__init__( model=model, tokenizer=tokenizer, max_seq_length=max_seq_length ) self._model = model.to(self.device) self._use_kv_cache = use_kv_cache self._generate_full_logits = generate_full_logits self._enable_dynamic_shape = enable_dynamic_shape def _model_call(self, inps): if self._use_kv_cache: if not self._enable_dynamic_shape: # graph module exported without dynamic shape won't work with a different shape. # And we have to do single token prefill here. result_logits = [] for pos in range(inps.shape[-1]): pos_tensor = torch.tensor([pos], dtype=torch.int64) logits = self._model( inps[:, pos : pos + 1], {"input_pos": pos_tensor} ) result_logits.append(logits) if self._generate_full_logits: return torch.cat(result_logits, dim=1) else: return torch.stack(result_logits, dim=1) else: pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) # Batch process the whole sequence. logits = self._model( inps[:, : self._max_seq_length], {"input_pos": pos_tensor} ) return logits else: return self._model(inps) def _model_generate(self, context, max_length, eos_token_id): raise Exception("unimplemented")class ETPybindEvalWrapper(EagerEvalWrapper): """ A wrapper class for ExecuTorch py-binded integration with the lm-evaluation-harness library. """ def __init__( self, model: str, tokenizer: Union[SentencePieceTokenizer, Tiktoken], max_seq_length: Optional[int] = None, ): super().__init__(None, tokenizer, max_seq_length) # pyre-ignore self._model = model # Expects model to be path to a .pte file from executorch.extension.pybindings.portable_lib import _load_for_executorch # Load custom ops and quantized ops. from executorch.extension.pybindings import portable_lib # noqa # usort: skip # Note: import this after portable_lib from executorch.extension.llm.custom_ops import ( # noqa custom_ops, # usort: skip ) from executorch.kernels import quantized # noqa self._et_model = _load_for_executorch(self._model) self._use_kv_cache = self._et_model.run_method("use_kv_cache")[0] # pyre-ignore def _model_call(self, inps): # Given inps (tokens), return the logits from a single forward call # inps: Tensor of shape (1, max_seq_len - 1) # logits: Tensor of shape (1, max_seq_len - 1, vocab_size) result = [] if self._use_kv_cache: pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) result = self._et_model.forward( (inps[:, : self._max_seq_length], pos_tensor) ) else: result = self._et_model.forward((inps,)) if result[0].dim() != 3: raise ValueError( f"Dim of logits must be 3 for evaluation. Got {result[0].dim()} here. Add --generate_full_logits in export_llama to generate a pte file with full logits." ) return result[0]class ETRunnerEvalWrapper(EagerEvalWrapper): """ A wrapper class for ExecuTorch Runtime integration with the lm-evaluation-harness library. """ def __init__( self, model: str, tokenizer: Union[SentencePieceTokenizer, Tiktoken], tokenizer_bin: str, max_seq_length: Optional[int] = None, ): super().__init__(None, tokenizer, max_seq_length) # pyre-ignore self._model = model self._tokenizer_bin = tokenizer_bin def _model_call(self, inps): # Given inps (tokens), return the logits from a single # forward call # Example: # inps: Tensor of shape (1, N) # logits: Tensor of shape (1, N, vocab_size) passdef gen_eval_wrapper( model_name: str, args: argparse.ArgumentParser, llm_config=None,): """ Generates a wrapper interface around the provided model and tokenizer for the lm-evaluation-harness library. Returns: eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library. """ # If llm_config is not provided, convert args to llm_config if llm_config is None: from executorch.extension.llm.export.config.llm_config import LlmConfig llm_config = LlmConfig.from_args(args) tokenizer = get_tokenizer(llm_config.base.tokenizer_path) # ExecuTorch Binary Evaluation if (model := args.pte) is not None: # pyre-ignore if (tokenizer_bin := args.tokenizer_bin) is not None: # pyre-ignore # ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime return ETRunnerEvalWrapper( model=model, tokenizer=tokenizer, tokenizer_bin=tokenizer_bin, max_seq_length=llm_config.export.max_seq_length, ) # ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings return ETPybindEvalWrapper( model=model, tokenizer=tokenizer, # Exported model takes at most (max_seq_length - 1) tokens. # Note that the eager model takes at most max_seq_length tokens. max_seq_length=llm_config.export.max_seq_length - 1, ) pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( llm_config ) # GPTFastEvalWrapper: Create a wrapper around a pre-exported model manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) if len(quantizers) != 0: manager = manager.export().pt2e_quantize(quantizers) model = ( manager.pre_autograd_graph_module.to(device="cuda") # pyre-ignore if torch.cuda.is_available() else manager.pre_autograd_graph_module.to(device="cpu") ) return GraphModuleEvalWrapper( model=model, tokenizer=tokenizer, max_seq_length=llm_config.export.max_seq_length, use_kv_cache=llm_config.model.use_kv_cache, enable_dynamic_shape=llm_config.model.enable_dynamic_shape, ) else: # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch # for quantizers. Currently export only works with --kv_cache, but # fails without the kv_cache mode model = ( manager.model.eval().to(device="cuda") if torch.cuda.is_available() else manager.model.eval().to(device="cpu") ) # Save the checkpoint after the eager model preparation is done. # The reason for this option is that the checkpoint can be used # to do evaluations in other evaluation platforms, or with data # that is not available in this eval_llama. We save the checkpoint # here for consistency with eval_llama. The accuracy results we # get from eval_llama can be used as a reference to other evaluations. if args.output_eager_checkpoint_file is not None: # pyre-ignore torch.save(model, args.output_eager_checkpoint_file) return EagerEvalWrapper( model=model, tokenizer=tokenizer, max_seq_length=llm_config.export.max_seq_length, use_kv_cache=llm_config.model.use_kv_cache, )def build_args_parser() -> argparse.ArgumentParser: # Start with arg parser from export_llama_lib parser = _build_args_parser() # Add additional args specific to eval parser.add_argument( "--tasks", nargs="+", type=str, default=["wikitext"], help="list of lm-eluther tasks to evaluate usage: --tasks task1 task2", ) parser.add_argument( "--limit", type=int, default=None, help="number of samples to evalulate. If not set, evaluate all samples", ) parser.add_argument( "-f", "--num_fewshot", type=int, default=None, metavar="N", help="Number of examples in few-shot context", ) # Add additional args specific to eval via an ET Runner # Note: For initial integration, the tokenizer.model is also required parser.add_argument( "--pte", type=str, default=None, help="[For ExecuTorch] Path to the ExecuTorch model being evaluated. If provided, don't go through the export flow", ) parser.add_argument( "--tokenizer_bin", type=str, default=None, help="[For ExecuTorch] Path to the Tokenizer binary for evaluating ExecuTorch models via runtime", ) parser.add_argument( "--output_eager_checkpoint_file", type=str, default=None, help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.", ) # Set of parameters secpific to AttentionSink. parser.add_argument("--attention_sink_eval_tokens", type=int, default=0) return parserdef eval_llama( model_name: str, args: argparse.ArgumentParser,) -> None: # Convert args to LlmConfig from executorch.extension.llm.export.config.llm_config import LlmConfig llm_config = LlmConfig.from_args(args) # Generate the eval wrapper eval_wrapper = gen_eval_wrapper(model_name, args, llm_config) # Needed for loading mmlu dataset. # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files if args.tasks and "mmlu" in args.tasks: import datasets datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True # Evaluate the model with torch.no_grad(): eval_results = simple_evaluate( model=eval_wrapper, tasks=args.tasks, num_fewshot=args.num_fewshot, limit=args.limit, ) for task, res in eval_results["results"].items(): print(f"{task}: {res}")def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParser): """ Evaluate the model's perplexity when AttentionSink is enabled. This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py Updated for the ring-buffer based attention sink implementation. """ # Convert args to LlmConfig from executorch.extension.llm.export.config.llm_config import LlmConfig llm_config = LlmConfig.from_args(args) assert llm_config.model.use_attention_sink is not None assert args.attention_sink_eval_tokens > 0 attention_sink_params = llm_config.model.use_attention_sink.split(",") assert len(attention_sink_params) == 3 sink_size = int(attention_sink_params[0]) window_size = int(attention_sink_params[1]) # For the ring buffer implementation, the cache size is sink_size + window_size * 2 # max_context_length should be >= sink_size + window_size (for RoPE frequencies) # but can be larger to support extended generation assert llm_config.export.max_context_length >= sink_size + window_size, ( f"max_context_length ({llm_config.export.max_context_length}) must be >= " f"sink_size + window_size ({sink_size + window_size})" ) device = "cuda" if torch.cuda.is_available() else "cpu" manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) model = manager.model.eval().to(device=device) tokenizer = get_tokenizer(llm_config.base.tokenizer_path) eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") nlls = [] loss_fn = CrossEntropyLoss(reduction="none") progress_bar = tqdm(total=args.attention_sink_eval_tokens) input_pos = 0 while input_pos < args.attention_sink_eval_tokens: for text in eval_data["text"]: tokens = tokenizer.encode(text, bos=False, eos=False) if len(tokens) <= 0: continue with torch.no_grad(): num_tokens = min( len(tokens) - 1, args.attention_sink_eval_tokens - input_pos ) logits = model( torch.tensor( [tokens[:num_tokens]], dtype=torch.int64, device=device ), torch.tensor([input_pos], dtype=torch.int64, device=device), ).squeeze(dim=0) neg_log_likelihood = loss_fn( logits, torch.tensor( [tokens[1 : num_tokens + 1]], dtype=torch.int64, device=device ).view(-1), ) nlls.append(neg_log_likelihood) input_pos += num_tokens progress_bar.update(num_tokens) if input_pos >= args.attention_sink_eval_tokens: break ppl = torch.exp(torch.cat(nlls).mean()) print(f"Perplexity: {ppl.item()}") return ppl.item() \ No newline at end of file diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index c2e0f6606f5..99b0231f2c3 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -1,410 +1 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# Components for supporting Attention Sink. See -# https://arxiv.org/abs/2309.17453 for more details about Attention Sink. - -# This implementation is torch.export compatible using a ring buffer approach -# for the sliding window portion while preserving the sink tokens. - -import types -from typing import Optional, Tuple - -import torch -import torch.nn as nn -from executorch.examples.models.llama.attention import ( - _create_causal_mask_for_ring_buffer, - AttentionMHA, - CachePositionsManager, - KVCache, - RingKVCache, -) -from executorch.examples.models.llama.model_args import ModelArgs -from executorch.examples.models.llama.rope import ( - apply_rotary_emb_to_k, - hf_apply_rotary_emb_to_k, - Rope, -) -from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter - - -class RopeWithAttentionSink(Rope): - """ - Rope that helps adjust position encoding when tokens are shifted in KVCache. - For AttentionSink, when tokens are shifted in KVCache, we need to use positions - in KVCache instead of positions in the actual text. - - For torch.export compatibility, this just passes through the position - the - actual position adjustment is handled by the cache update logic. - - Note: This class uses the model's max_context_len (params.max_context_len) for - RoPE frequency table size, which should be large enough to support generation - beyond the sliding window. The actual KV cache size is sink_size + window_size * 2. - """ - - def __init__( - self, - params: ModelArgs, - window_size: int, - sink_size: int, - eviction_batch_size: int, - ): - super().__init__(params) - if self.params.use_hf_rope: - self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k - else: - self.apply_rotary_emb_to_k = apply_rotary_emb_to_k - # The KV cache size is sink_size + window_size * 2 (ring buffer needs 2x) - self.kv_cache_size = sink_size + window_size * 2 - self.window_size = window_size - self.sink_size = sink_size - # max_context_len from params is used for RoPE frequencies (should be large) - self.max_context_length = self.params.max_context_len - self.eviction_batch_size = eviction_batch_size - - def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): - """ - Get rotary embedding frequencies. - For attention sink, we use the original position - the sliding window - is handled by the cache index management, not by position shifting. - """ - assert input_pos is not None - return super().get_freqs(input_pos, seq_len) - - def rerotate_k( - self, - k: torch.Tensor, - original_position: int, - new_position: int, - ): - """ - Rerotate k from original_position to new_position. - The shape of k is (batch_size, seq_len, n_local_heads, head_dim) - """ - seq_len = k.shape[1] - original_freqs_cos = self.freqs_cos.narrow(0, original_position, seq_len) - original_freqs_sin = self.freqs_sin.narrow(0, original_position, seq_len) - new_freqs_cos = self.freqs_cos.narrow(0, new_position, seq_len) - new_freqs_sin = self.freqs_sin.narrow(0, new_position, seq_len) - rerotation_cos = ( - new_freqs_cos * original_freqs_cos + new_freqs_sin * original_freqs_sin - ) - rerotation_sin = ( - new_freqs_sin * original_freqs_cos - new_freqs_cos * original_freqs_sin - ) - - return self.apply_rotary_emb_to_k(k, rerotation_cos, rerotation_sin) - - -def _create_causal_mask_for_attention_sink( - cache_positions, window_size, sink_size, start_pos, seq_len -): - """ - Create causal mask for attention sink. - - Unlike regular ring buffer mask, this mask: - 1. ALWAYS allows attending to sink tokens (positions 0 to sink_size-1) - 2. Uses sliding window for other tokens - - Args: - cache_positions: Tensor of actual positions stored at each cache index - window_size: Size of the sliding window - sink_size: Number of sink tokens to always attend to - start_pos: Starting position of the current query - seq_len: Length of the current query sequence - """ - pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1) - delta = pos_q - cache_positions - - # Valid if position is filled (>= 0) and causal (delta >= 0) - is_valid = (cache_positions >= 0) & (delta >= 0) - - # Sink tokens (original positions 0 to sink_size-1) are always visible - is_sink = cache_positions < sink_size - - # Window tokens must be within sliding window - is_in_window = delta < window_size - - # Final mask: valid AND (is_sink OR is_in_window) - attn_mask = is_valid & (is_sink | is_in_window) - attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712 - return attn_mask - - -class CachePositionsManagerWithSink(nn.Module): - """ - Manages cache positions for attention sink + sliding window. - - For sink_size=0: behaves exactly like original CachePositionsManager. - For sink_size>0: sink tokens go to fixed positions, rest uses ring buffer. - - IMPORTANT: cache_size should be the actual cache dimension size (2x window for ring buffer). - """ - - def __init__(self, cache_size: int): - super().__init__() - # cache_size is the actual size of the kv cache dimension - self.max_context_length = cache_size - # Use zeros like original CachePositionsManager - self.register_buffer( - "cache_positions", - torch.zeros((self.max_context_length,), dtype=torch.long, device="cpu"), - ) - - def calculate_positions_and_update_indices( - self, input_pos: torch.Tensor, seq_len: int - ) -> torch.Tensor: - """ - Calculate indices into k_cache, v_cache for placing k_val, v_val. - - This is identical to the original CachePositionsManager logic. - """ - start_pos = input_pos[0].item() - torch._check_is_size(start_pos) - - orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos - - # Simple ring buffer: just mod by cache size - indices = orig_indices % self.max_context_length - - # Update cache_positions exactly like original CachePositionsManager - full_t = torch.full((self.max_context_length,), -1, dtype=torch.long) - arange_tensor = torch.arange(self.max_context_length, dtype=torch.long) - cache_positions = torch.where( - arange_tensor < start_pos, self.cache_positions, full_t - ) - self.cache_positions.copy_(cache_positions) - self.cache_positions.index_copy_(0, indices, orig_indices) - - return indices - - -class KVCacheWithAttentionSink(KVCache): - """ - KV cache that supports attention sink with torch.export compatibility. - - Uses a ring buffer approach for the sliding window portion while keeping - the first sink_size tokens fixed. This avoids dynamic shape operations. - - Cache layout: [sink: 0 to sink_size-1] [ring_buffer: sink_size to sink_size + window_size*2 - 1] - """ - - def __init__( - self, - n_heads: int, - head_dim: int, - enable_dynamic_shape: bool, - rope: RopeWithAttentionSink, - window_size: int, - sink_size: int, - eviction_batch_size: int, - max_batch_size: int = 1, - dtype=torch.float32, - ): - # Total cache size is sink_size + window_size * 2 (ring buffer needs 2x) - total_cache_size = sink_size + window_size * 2 - super().__init__( - max_batch_size=max_batch_size, - max_context_length=total_cache_size, - n_heads=n_heads, - head_dim=head_dim, - enable_dynamic_shape=enable_dynamic_shape, - dtype=dtype, - ) - self.rope = rope - self.window_size = window_size - self.sink_size = sink_size - self.eviction_batch_size = eviction_batch_size - self.is_ring_buffer = True - - # Cache positions manager for determining write locations - # Pass the total cache size (same as self.max_context_length after super().__init__) - self.cache_positions_manager = CachePositionsManagerWithSink(total_cache_size) - - def create_causal_mask_for_ring_buffer( - self, start_pos: torch.Tensor, seq_len: int - ): - """ - Create causal mask for the attention with attention sink. - Sink tokens are ALWAYS visible, plus recent tokens in the window. - """ - cache_positions = self.cache_positions_manager.cache_positions - if self.sink_size > 0: - # Use attention sink mask that always allows attending to sink tokens - return _create_causal_mask_for_attention_sink( - cache_positions, self.window_size, self.sink_size, start_pos, seq_len - ) - else: - # Pure ring buffer mode - use original mask with window_size = actual window - return _create_causal_mask_for_ring_buffer( - cache_positions, self.window_size, start_pos, seq_len - ) - - def update( - self, - input_pos: torch.Tensor, - k_val: torch.Tensor, - v_val: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Update KV cache with new key-value pairs. - Uses ring buffer indexing for positions >= sink_size. - """ - seq_len = k_val.size(2) - assert seq_len <= self.k_cache.size( - 2 - ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" - - # Calculate write indices - indices = self.cache_positions_manager.calculate_positions_and_update_indices( - input_pos, seq_len - ) - - start_pos = input_pos[0].item() - torch._check_is_size(start_pos) - self.k_cache.index_copy_(2, indices, k_val) - self.v_cache.index_copy_(2, indices, v_val) - - return self.k_cache, self.v_cache - - def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: - """ - For ring buffer implementation, no explicit eviction is needed. - The ring buffer automatically overwrites old values. - Returns 0 to indicate no position shift is needed. - """ - return 0 - - -def attention_sink_forward( - self, - x: torch.Tensor, - freqs_cos: torch.Tensor, - freqs_sin: torch.Tensor, - input_pos: Optional[torch.Tensor] = None, -): - """ - Forward function for attention with attention sink KV cache. - Uses ring buffer masking for proper attention patterns. - """ - assert self.use_kv_cache - assert input_pos is not None - - bsz, seqlen, _ = x.shape - - # QKV - q, k, v = self.wq(x), self.wk(x), self.wv(x) - q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) - k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - - # RoPE relative positional embeddings - q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) - - # Transpose for attention: [B, H, S, D] - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - # Update KV cache - k, v = self.kv_cache.update(input_pos, k, v) - - # Use ring buffer mask since we have is_ring_buffer=True - start_pos = input_pos[0].item() - torch._check_is_size(start_pos) - attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(start_pos, seqlen) - - # SDPA - output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) - - # Return tuple like original AttentionMHA.forward - return self.wo(output), None - - -def _replace_rope( - module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink -): - def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: - return isinstance(child, Rope) - - def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: - return rope_with_attention_sink - - _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) - - -def _replace_attention( - module: torch.nn.Module, - rope_with_attention_sink: RopeWithAttentionSink, - sink_size: int, - window_size: int, - eviction_batch_size: int, -): - for _, child_module in module._modules.items(): - if len(list(child_module.children())) > 0: # pyre-ignore [16] - _replace_attention( - module=child_module, # pyre-ignore [6] - rope_with_attention_sink=rope_with_attention_sink, - sink_size=sink_size, - window_size=window_size, - eviction_batch_size=eviction_batch_size, - ) - - if isinstance(child_module, AttentionMHA): - kv_cache = child_module.kv_cache - # Always use KVCacheWithAttentionSink, even for sink_size=0 - # This ensures we don't get replaced by CustomKVCache when use_sdpa_with_kv_cache=True - kv_cache_with_attention_sink = KVCacheWithAttentionSink( - n_heads=kv_cache.n_heads, - head_dim=kv_cache.head_dim, - enable_dynamic_shape=kv_cache.enable_dynamic_shape, - rope=rope_with_attention_sink, - max_batch_size=kv_cache.max_batch_size, - window_size=window_size, - sink_size=sink_size, - eviction_batch_size=eviction_batch_size, - dtype=kv_cache.k_cache.dtype, - ) - child_module.kv_cache = kv_cache_with_attention_sink - - # If using SDPACustom (fused SDPA op), enable attention mask support - # so it uses our ring buffer / attention sink mask instead of simple causal mask - if "SDPACustom" in child_module.SDPA.__class__.__name__: - child_module.SDPA.use_attention_mask = True - - # Don't replace forward - let the original AttentionMHA.forward handle it - # since our KVCache has is_ring_buffer=True, it will use the ring buffer mask - - -def enable_attention_sink( - module: torch.nn.Module, - params: ModelArgs, - sink_size: int, - window_size: int, - eviction_batch_size: int, -) -> torch.nn.Module: - """ - Transform the model to be able to run inference with Attention Sink. - There mainly three steps: - - Replace Rope with RopeWithAttentionSink - - Replace Attention's KVCache with KVCacheWithAttentionSink - - Replace Attention's forward with attention_sink_forward - """ - rope_with_attention_sink = RopeWithAttentionSink( - params=params, - window_size=window_size, - sink_size=sink_size, - eviction_batch_size=eviction_batch_size, - ) - _replace_rope(module, rope_with_attention_sink) - _replace_attention( - module=module, - rope_with_attention_sink=rope_with_attention_sink, - sink_size=sink_size, - window_size=window_size, - eviction_batch_size=eviction_batch_size, - ) - return module +# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.# Components for supporting Attention Sink. See# https://arxiv.org/abs/2309.17453 for more details about Attention Sink.# This implementation is torch.export compatible using a ring buffer approach# for the sliding window portion while preserving the sink tokens.import typesfrom typing import Optional, Tupleimport torchimport torch.nn as nnfrom executorch.examples.models.llama.attention import ( _create_causal_mask_for_ring_buffer, AttentionMHA, CachePositionsManager, KVCache, RingKVCache,)from executorch.examples.models.llama.model_args import ModelArgsfrom executorch.examples.models.llama.rope import ( apply_rotary_emb_to_k, hf_apply_rotary_emb_to_k, Rope,)from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filterclass RopeWithAttentionSink(Rope): """ Rope that helps adjust position encoding when tokens are shifted in KVCache. For AttentionSink, when tokens are shifted in KVCache, we need to use positions in KVCache instead of positions in the actual text. For torch.export compatibility, this just passes through the position - the actual position adjustment is handled by the cache update logic. Note: This class uses the model's max_context_len (params.max_context_len) for RoPE frequency table size, which should be large enough to support generation beyond the sliding window. The actual KV cache size is sink_size + window_size * 2. """ def __init__( self, params: ModelArgs, window_size: int, sink_size: int, eviction_batch_size: int, ): super().__init__(params) if self.params.use_hf_rope: self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k else: self.apply_rotary_emb_to_k = apply_rotary_emb_to_k # The KV cache size is sink_size + window_size * 2 (ring buffer needs 2x) self.kv_cache_size = sink_size + window_size * 2 self.window_size = window_size self.sink_size = sink_size # max_context_len from params is used for RoPE frequencies (should be large) self.max_context_length = self.params.max_context_len self.eviction_batch_size = eviction_batch_size def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): """ Get rotary embedding frequencies. For attention sink, we use the original position - the sliding window is handled by the cache index management, not by position shifting. """ assert input_pos is not None return super().get_freqs(input_pos, seq_len) def rerotate_k( self, k: torch.Tensor, original_position: int, new_position: int, ): """ Rerotate k from original_position to new_position. The shape of k is (batch_size, seq_len, n_local_heads, head_dim) """ seq_len = k.shape[1] original_freqs_cos = self.freqs_cos.narrow(0, original_position, seq_len) original_freqs_sin = self.freqs_sin.narrow(0, original_position, seq_len) new_freqs_cos = self.freqs_cos.narrow(0, new_position, seq_len) new_freqs_sin = self.freqs_sin.narrow(0, new_position, seq_len) rerotation_cos = ( new_freqs_cos * original_freqs_cos + new_freqs_sin * original_freqs_sin ) rerotation_sin = ( new_freqs_sin * original_freqs_cos - new_freqs_cos * original_freqs_sin ) return self.apply_rotary_emb_to_k(k, rerotation_cos, rerotation_sin)def _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len): """ Create causal mask for attention sink. Unlike regular ring buffer mask, this mask: 1. ALWAYS allows attending to sink tokens (positions 0 to sink_size-1) 2. Uses sliding window for other tokens Args: cache_positions: Tensor of actual positions stored at each cache index window_size: Size of the sliding window sink_size: Number of sink tokens to always attend to start_pos: Starting position of the current query seq_len: Length of the current query sequence """ pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1) delta = pos_q - cache_positions # Valid if position is filled (>= 0) and causal (delta >= 0) is_valid = (cache_positions >= 0) & (delta >= 0) # Sink tokens (original positions 0 to sink_size-1) are always visible is_sink = cache_positions < sink_size # Window tokens must be within sliding window is_in_window = delta < window_size # Final mask: valid AND (is_sink OR is_in_window) attn_mask = is_valid & (is_sink | is_in_window) attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712 return attn_maskclass CachePositionsManagerWithSink(nn.Module): """ Manages cache positions for attention sink + sliding window. For sink_size=0: behaves exactly like original CachePositionsManager. For sink_size>0: sink tokens go to fixed positions, rest uses ring buffer. IMPORTANT: cache_size should be the actual cache dimension size (2x window for ring buffer). """ def __init__(self, cache_size: int): super().__init__() # cache_size is the actual size of the kv cache dimension self.max_context_length = cache_size # Use zeros like original CachePositionsManager self.register_buffer( "cache_positions", torch.zeros((self.max_context_length,), dtype=torch.long, device="cpu"), ) def calculate_positions_and_update_indices( self, input_pos: torch.Tensor, seq_len: int ) -> torch.Tensor: """ Calculate indices into k_cache, v_cache for placing k_val, v_val. This is identical to the original CachePositionsManager logic. """ start_pos = input_pos[0].item() torch._check_is_size(start_pos) orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos # Simple ring buffer: just mod by cache size indices = orig_indices % self.max_context_length # Update cache_positions exactly like original CachePositionsManager full_t = torch.full((self.max_context_length,), -1, dtype=torch.long) arange_tensor = torch.arange(self.max_context_length, dtype=torch.long) cache_positions = torch.where( arange_tensor < start_pos, self.cache_positions, full_t ) self.cache_positions.copy_(cache_positions) self.cache_positions.index_copy_(0, indices, orig_indices) return indicesclass KVCacheWithAttentionSink(KVCache): """ KV cache that supports attention sink with torch.export compatibility. Uses a ring buffer approach for the sliding window portion while keeping the first sink_size tokens fixed. This avoids dynamic shape operations. Cache layout: [sink: 0 to sink_size-1] [ring_buffer: sink_size to sink_size + window_size*2 - 1] """ def __init__( self, n_heads: int, head_dim: int, enable_dynamic_shape: bool, rope: RopeWithAttentionSink, window_size: int, sink_size: int, eviction_batch_size: int, max_batch_size: int = 1, dtype=torch.float32, ): # Total cache size is sink_size + window_size * 2 (ring buffer needs 2x) total_cache_size = sink_size + window_size * 2 super().__init__( max_batch_size=max_batch_size, max_context_length=total_cache_size, n_heads=n_heads, head_dim=head_dim, enable_dynamic_shape=enable_dynamic_shape, dtype=dtype, ) self.rope = rope self.window_size = window_size self.sink_size = sink_size self.eviction_batch_size = eviction_batch_size self.is_ring_buffer = True # Cache positions manager for determining write locations # Pass the total cache size (same as self.max_context_length after super().__init__) self.cache_positions_manager = CachePositionsManagerWithSink(total_cache_size) def create_causal_mask_for_ring_buffer( self, start_pos: torch.Tensor, seq_len: int ): """ Create causal mask for the attention with attention sink. Sink tokens are ALWAYS visible, plus recent tokens in the window. """ cache_positions = self.cache_positions_manager.cache_positions if self.sink_size > 0: # Use attention sink mask that always allows attending to sink tokens return _create_causal_mask_for_attention_sink( cache_positions, self.window_size, self.sink_size, start_pos, seq_len ) else: # Pure ring buffer mode - use original mask with window_size = actual window return _create_causal_mask_for_ring_buffer( cache_positions, self.window_size, start_pos, seq_len ) def update( self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Update KV cache with new key-value pairs. Uses ring buffer indexing for positions >= sink_size. """ seq_len = k_val.size(2) assert seq_len <= self.k_cache.size( 2 ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" # Calculate write indices indices = self.cache_positions_manager.calculate_positions_and_update_indices( input_pos, seq_len ) start_pos = input_pos[0].item() torch._check_is_size(start_pos) self.k_cache.index_copy_(2, indices, k_val) self.v_cache.index_copy_(2, indices, v_val) return self.k_cache, self.v_cache def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: """ For ring buffer implementation, no explicit eviction is needed. The ring buffer automatically overwrites old values. Returns 0 to indicate no position shift is needed. """ return 0def attention_sink_forward( self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, input_pos: Optional[torch.Tensor] = None,): """ Forward function for attention with attention sink KV cache. Uses ring buffer masking for proper attention patterns. """ assert self.use_kv_cache assert input_pos is not None bsz, seqlen, _ = x.shape # QKV q, k, v = self.wq(x), self.wk(x), self.wv(x) q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) # RoPE relative positional embeddings q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) # Transpose for attention: [B, H, S, D] q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) # Update KV cache k, v = self.kv_cache.update(input_pos, k, v) # Use ring buffer mask since we have is_ring_buffer=True start_pos = input_pos[0].item() torch._check_is_size(start_pos) attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(start_pos, seqlen) # SDPA output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) # Return tuple like original AttentionMHA.forward return self.wo(output), Nonedef _replace_rope( module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink): def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: return isinstance(child, Rope) def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: return rope_with_attention_sink _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)def _replace_attention( module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink, sink_size: int, window_size: int, eviction_batch_size: int,): for _, child_module in module._modules.items(): if len(list(child_module.children())) > 0: # pyre-ignore [16] _replace_attention( module=child_module, # pyre-ignore [6] rope_with_attention_sink=rope_with_attention_sink, sink_size=sink_size, window_size=window_size, eviction_batch_size=eviction_batch_size, ) if isinstance(child_module, AttentionMHA): kv_cache = child_module.kv_cache # Always use KVCacheWithAttentionSink, even for sink_size=0 # This ensures we don't get replaced by CustomKVCache when use_sdpa_with_kv_cache=True kv_cache_with_attention_sink = KVCacheWithAttentionSink( n_heads=kv_cache.n_heads, head_dim=kv_cache.head_dim, enable_dynamic_shape=kv_cache.enable_dynamic_shape, rope=rope_with_attention_sink, max_batch_size=kv_cache.max_batch_size, window_size=window_size, sink_size=sink_size, eviction_batch_size=eviction_batch_size, dtype=kv_cache.k_cache.dtype, ) child_module.kv_cache = kv_cache_with_attention_sink # If using SDPACustom (fused SDPA op), enable attention mask support # so it uses our ring buffer / attention sink mask instead of simple causal mask if "SDPACustom" in child_module.SDPA.__class__.__name__: child_module.SDPA.use_attention_mask = True # Don't replace forward - let the original AttentionMHA.forward handle it # since our KVCache has is_ring_buffer=True, it will use the ring buffer maskdef enable_attention_sink( module: torch.nn.Module, params: ModelArgs, sink_size: int, window_size: int, eviction_batch_size: int,) -> torch.nn.Module: """ Transform the model to be able to run inference with Attention Sink. There mainly three steps: - Replace Rope with RopeWithAttentionSink - Replace Attention's KVCache with KVCacheWithAttentionSink - Replace Attention's forward with attention_sink_forward """ rope_with_attention_sink = RopeWithAttentionSink( params=params, window_size=window_size, sink_size=sink_size, eviction_batch_size=eviction_batch_size, ) _replace_rope(module, rope_with_attention_sink) _replace_attention( module=module, rope_with_attention_sink=rope_with_attention_sink, sink_size=sink_size, window_size=window_size, eviction_batch_size=eviction_batch_size, ) return module \ No newline at end of file diff --git a/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py index 7109129caca..7ccad6aadbf 100644 --- a/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py +++ b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py @@ -1,594 +1 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Tests for the ring-buffer based attention sink implementation. - -This tests the torch.export-compatible implementation that uses a ring buffer -for the sliding window rather than explicit token eviction. - -Usage: - # Run with pytest - python -m pytest examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py -v - - # Or run directly - python examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py -""" - -import unittest - -import torch -from executorch.examples.models.llama.model_args import ModelArgs - -from executorch.examples.models.llama.source_transformation.attention_sink import ( - CachePositionsManagerWithSink, - KVCacheWithAttentionSink, - RopeWithAttentionSink, - _create_causal_mask_for_attention_sink, -) - - -class CachePositionsManagerWithSinkTest(unittest.TestCase): - """Test the cache positions manager for ring buffer indexing.""" - - def setUp(self): - self.cache_size = 32 # Total cache size (e.g., sink_size + window_size * 2) - self.manager = CachePositionsManagerWithSink(self.cache_size) - - def test_initial_positions_are_zero(self): - """Cache positions should start as zeros.""" - expected = torch.zeros(self.cache_size, dtype=torch.long) - torch.testing.assert_close(self.manager.cache_positions, expected) - - def test_simple_update(self): - """Test simple sequential update without wraparound.""" - input_pos = torch.tensor([0], dtype=torch.long) - seq_len = 5 - indices = self.manager.calculate_positions_and_update_indices(input_pos, seq_len) - - # Should return indices 0, 1, 2, 3, 4 - expected_indices = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long) - torch.testing.assert_close(indices, expected_indices) - - # Cache positions at those indices should be the original positions - for i in range(5): - self.assertEqual(self.manager.cache_positions[i].item(), i) - - def test_wraparound(self): - """Test ring buffer wraparound.""" - # Fill cache to position 30 - input_pos = torch.tensor([0], dtype=torch.long) - self.manager.calculate_positions_and_update_indices(input_pos, 30) - - # Add 5 more tokens at position 30 - should wrap around - input_pos = torch.tensor([30], dtype=torch.long) - indices = self.manager.calculate_positions_and_update_indices(input_pos, 5) - - # Indices should wrap: 30 % 32 = 30, 31, 0, 1, 2 - expected_indices = torch.tensor([30, 31, 0, 1, 2], dtype=torch.long) - torch.testing.assert_close(indices, expected_indices) - - def test_cache_positions_track_original_positions(self): - """Cache positions should track which original position is at each index.""" - # Fill with positions 0-31 - input_pos = torch.tensor([0], dtype=torch.long) - self.manager.calculate_positions_and_update_indices(input_pos, 32) - - # Now add position 32 which wraps to index 0 - input_pos = torch.tensor([32], dtype=torch.long) - self.manager.calculate_positions_and_update_indices(input_pos, 1) - - # Index 0 should now contain original position 32 - self.assertEqual(self.manager.cache_positions[0].item(), 32) - - -class CausalMaskTest(unittest.TestCase): - """Test the causal mask generation for attention sink.""" - - def test_mask_allows_sink_tokens(self): - """Sink tokens should always be visible (mask = 0).""" - cache_size = 32 - sink_size = 4 - window_size = 14 # cache_size = sink_size + window_size * 2 - - # Create cache positions where positions 0-3 are sink tokens - cache_positions = torch.arange(cache_size, dtype=torch.long) - - start_pos = 20 # Query at position 20 - seq_len = 1 - - mask = _create_causal_mask_for_attention_sink( - cache_positions, window_size, sink_size, start_pos, seq_len - ) - - # Sink tokens (indices 0-3, original positions 0-3) should have mask = 0 - for i in range(sink_size): - self.assertEqual(mask[0, i].item(), 0.0, f"Sink token at index {i} should be visible") - - def test_mask_blocks_future_tokens(self): - """Future tokens should be masked (-inf).""" - cache_size = 32 - sink_size = 4 - window_size = 14 - - cache_positions = torch.arange(cache_size, dtype=torch.long) - - start_pos = 10 - seq_len = 1 - - mask = _create_causal_mask_for_attention_sink( - cache_positions, window_size, sink_size, start_pos, seq_len - ) - - # Future tokens (positions > 10) should have mask = -inf - for i in range(11, cache_size): - self.assertEqual(mask[0, i].item(), float('-inf'), f"Future token at position {i} should be masked") - - def test_mask_respects_window(self): - """Tokens outside the window should be masked.""" - cache_size = 32 - sink_size = 4 - window_size = 5 # Only allow 5 recent tokens - - cache_positions = torch.arange(cache_size, dtype=torch.long) - - start_pos = 20 - seq_len = 1 - - mask = _create_causal_mask_for_attention_sink( - cache_positions, window_size, sink_size, start_pos, seq_len - ) - - # Positions 16-20 should be visible (within window of 5) - for pos in range(16, 21): - self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible (in window)") - - # Position 15 should be masked (outside window, not a sink) - self.assertEqual(mask[0, 15].item(), float('-inf'), f"Position 15 should be masked (outside window)") - - -class KVCacheWithAttentionSinkTest(unittest.TestCase): - """Test the KV cache with attention sink.""" - - def setUp(self): - torch.manual_seed(42) - self.window_size = 14 - self.sink_size = 4 - self.n_heads = 8 - self.head_dim = 64 - self.max_batch_size = 1 - - # Create model args with enough context for RoPE - self.params = ModelArgs( - use_kv_cache=True, - enable_dynamic_shape=True, - max_context_len=1024, # Large enough for RoPE - n_heads=self.n_heads, - n_kv_heads=self.n_heads, - dim=self.n_heads * self.head_dim, - ) - - self.rope = RopeWithAttentionSink( - params=self.params, - window_size=self.window_size, - sink_size=self.sink_size, - eviction_batch_size=1, - ) - - self.kv_cache = KVCacheWithAttentionSink( - n_heads=self.n_heads, - head_dim=self.head_dim, - enable_dynamic_shape=True, - rope=self.rope, - max_batch_size=self.max_batch_size, - window_size=self.window_size, - sink_size=self.sink_size, - eviction_batch_size=1, - dtype=torch.float32, - ) - - def test_cache_size(self): - """Cache should be sink_size + window_size * 2.""" - expected_size = self.sink_size + self.window_size * 2 # 4 + 28 = 32 - self.assertEqual(self.kv_cache.k_cache.size(2), expected_size) - self.assertEqual(self.kv_cache.v_cache.size(2), expected_size) - - def test_is_ring_buffer(self): - """Cache should be marked as ring buffer.""" - self.assertTrue(self.kv_cache.is_ring_buffer) - - def test_update_stores_kv(self): - """Update should store key-value pairs.""" - k = torch.randn(1, self.n_heads, 5, self.head_dim) - v = torch.randn(1, self.n_heads, 5, self.head_dim) - input_pos = torch.tensor([0], dtype=torch.long) - - k_out, v_out = self.kv_cache.update(input_pos, k, v) - - # First 5 positions should contain our values - torch.testing.assert_close(k_out[:, :, :5, :], k) - torch.testing.assert_close(v_out[:, :, :5, :], v) - - def test_evict_tokens_returns_zero(self): - """Ring buffer implementation doesn't shift, so evict returns 0.""" - input_pos = torch.tensor([100], dtype=torch.long) - shift = self.kv_cache.evict_tokens(input_pos, 10) - self.assertEqual(shift, 0) - - def test_extended_generation(self): - """Test that cache works for positions beyond cache size.""" - cache_size = self.kv_cache.k_cache.size(2) - - # Fill cache with initial tokens - for pos in range(cache_size + 50): - k = torch.randn(1, self.n_heads, 1, self.head_dim) - v = torch.randn(1, self.n_heads, 1, self.head_dim) - input_pos = torch.tensor([pos], dtype=torch.long) - - k_out, v_out = self.kv_cache.update(input_pos, k, v) - - # Should not raise any errors - self.assertEqual(k_out.shape, self.kv_cache.k_cache.shape) - self.assertEqual(v_out.shape, self.kv_cache.v_cache.shape) - - -class RopeWithAttentionSinkTest(unittest.TestCase): - """Test RoPE for attention sink.""" - - def setUp(self): - torch.manual_seed(42) - self.params = ModelArgs( - use_kv_cache=True, - enable_dynamic_shape=True, - max_context_len=1024, - n_heads=8, - dim=512, - ) - - self.rope = RopeWithAttentionSink( - params=self.params, - window_size=100, - sink_size=4, - eviction_batch_size=1, - ) - - def test_get_freqs_uses_original_position(self): - """RoPE frequencies should use the original position.""" - input_pos = torch.tensor([50], dtype=torch.long) - seq_len = 5 - - freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seq_len) - - # Should get frequencies for positions 50-54 - expected_cos = self.rope.freqs_cos[50:55] - expected_sin = self.rope.freqs_sin[50:55] - - torch.testing.assert_close(freqs_cos, expected_cos) - torch.testing.assert_close(freqs_sin, expected_sin) - - def test_rerotate_k(self): - """Test re-rotation of k from one position to another.""" - batch_size = 1 - seq_len = 8 - n_heads = self.params.n_heads - head_dim = self.params.dim // n_heads - - k = torch.randn(batch_size, seq_len, n_heads, head_dim) - q = torch.randn(batch_size, seq_len, n_heads, head_dim) - - # Rotate k at position 100 - original_pos = 100 - freqs_cos, freqs_sin = self.rope.get_freqs( - torch.tensor([original_pos], dtype=torch.long), seq_len - ) - _, rotated_k = self.rope.forward(q, k, freqs_cos, freqs_sin) - - # Re-rotate to position 50 - new_pos = 50 - rerotated_k = self.rope.rerotate_k(rotated_k, original_pos, new_pos) - - # This should be equivalent to directly rotating k at position 50 - freqs_cos_new, freqs_sin_new = self.rope.get_freqs( - torch.tensor([new_pos], dtype=torch.long), seq_len - ) - _, expected_k = self.rope.forward(q, k, freqs_cos_new, freqs_sin_new) - - torch.testing.assert_close(rerotated_k, expected_k, rtol=1e-4, atol=1e-4) - - -class CausalMaskWithWraparoundTest(unittest.TestCase): - """Test causal mask with ring buffer wraparound.""" - - def test_mask_after_wraparound(self): - """Test mask after cache has wrapped around.""" - cache_size = 16 - sink_size = 4 - window_size = 6 # cache_size = sink_size + window_size * 2 - - # Simulate cache after generating beyond cache_size: - # The ring buffer wraps, so indices 0-15 contain positions that wrap - # At position 50, with cache_size=16, the cache contains: - # positions 50-15=35 to 49 at various indices - cache_positions = torch.zeros(cache_size, dtype=torch.long) - # Fill with positions that would exist after generating 50 tokens - # idx = pos % cache_size, so: - # pos 34-49 occupy indices 2-15 and 0-1 - for pos in range(34, 50): - idx = pos % cache_size - cache_positions[idx] = pos - - start_pos = 49 # Query at position 49 - seq_len = 1 - - mask = _create_causal_mask_for_attention_sink( - cache_positions, window_size, sink_size, start_pos, seq_len - ) - - # Positions within window (49-6+1=44 to 49) should be visible - visible_count = 0 - for i in range(cache_size): - pos = cache_positions[i].item() - if pos >= 44 and pos <= 49: # In window - self.assertEqual(mask[0, i].item(), 0.0, - f"Position {pos} at idx {i} should be visible (in window)") - visible_count += 1 - - # Should have some visible tokens - self.assertGreater(visible_count, 0, "Should have visible tokens in window") - - def test_mask_with_sink_size_zero(self): - """Test pure sliding window (sink_size=0).""" - cache_size = 16 - sink_size = 0 - window_size = 8 - - cache_positions = torch.arange(cache_size, dtype=torch.long) - start_pos = 10 - seq_len = 1 - - mask = _create_causal_mask_for_attention_sink( - cache_positions, window_size, sink_size, start_pos, seq_len - ) - - # Positions 3-10 should be visible (within window of 8) - for pos in range(3, 11): - self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible") - - # Positions 0-2 should be masked (outside window) - for pos in range(0, 3): - self.assertEqual(mask[0, pos].item(), float('-inf'), - f"Position {pos} should be masked (outside window)") - - -class PrefillTest(unittest.TestCase): - """Test prefill scenarios.""" - - def setUp(self): - torch.manual_seed(42) - self.window_size = 14 - self.sink_size = 4 - self.n_heads = 8 - self.head_dim = 64 - - self.params = ModelArgs( - use_kv_cache=True, - enable_dynamic_shape=True, - max_context_len=1024, - n_heads=self.n_heads, - n_kv_heads=self.n_heads, - dim=self.n_heads * self.head_dim, - ) - - self.rope = RopeWithAttentionSink( - params=self.params, - window_size=self.window_size, - sink_size=self.sink_size, - eviction_batch_size=1, - ) - - self.kv_cache = KVCacheWithAttentionSink( - n_heads=self.n_heads, - head_dim=self.head_dim, - enable_dynamic_shape=True, - rope=self.rope, - max_batch_size=1, - window_size=self.window_size, - sink_size=self.sink_size, - eviction_batch_size=1, - dtype=torch.float32, - ) - - def test_prefill_entire_cache(self): - """Test prefill that fills entire cache.""" - cache_size = self.kv_cache.k_cache.size(2) - - k = torch.randn(1, self.n_heads, cache_size, self.head_dim) - v = torch.randn(1, self.n_heads, cache_size, self.head_dim) - input_pos = torch.tensor([0], dtype=torch.long) - - k_out, v_out = self.kv_cache.update(input_pos, k, v) - - # All positions should be filled - torch.testing.assert_close(k_out, k) - torch.testing.assert_close(v_out, v) - - def test_prefill_larger_than_cache_raises_error(self): - """Test that prefill larger than cache size raises an assertion error.""" - cache_size = self.kv_cache.k_cache.size(2) - seq_len = cache_size + 10 - - k = torch.randn(1, self.n_heads, seq_len, self.head_dim) - v = torch.randn(1, self.n_heads, seq_len, self.head_dim) - input_pos = torch.tensor([0], dtype=torch.long) - - # This should raise an assertion error since seq_len > cache_size - with self.assertRaises(AssertionError): - self.kv_cache.update(input_pos, k, v) - - def test_prefill_followed_by_decode(self): - """Test prefill followed by decode steps.""" - cache_size = self.kv_cache.k_cache.size(2) - - # Prefill with 20 tokens - k_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) - v_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) - input_pos = torch.tensor([0], dtype=torch.long) - self.kv_cache.update(input_pos, k_prefill, v_prefill) - - # Decode 5 more tokens - for i in range(5): - k_decode = torch.randn(1, self.n_heads, 1, self.head_dim) - v_decode = torch.randn(1, self.n_heads, 1, self.head_dim) - input_pos = torch.tensor([20 + i], dtype=torch.long) - k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode) - - # Verify cache positions are updated - expected_pos = 20 + i - cache_idx = expected_pos % cache_size - self.assertEqual( - self.kv_cache.cache_positions_manager.cache_positions[cache_idx].item(), - expected_pos - ) - - -class EnableAttentionSinkTest(unittest.TestCase): - """Test the enable_attention_sink transformation.""" - - def setUp(self): - torch.manual_seed(42) - self.params = ModelArgs( - use_kv_cache=True, - enable_dynamic_shape=True, - max_context_len=1024, - n_heads=8, - n_kv_heads=8, - dim=512, - n_layers=2, - vocab_size=100, - ) - - def test_enable_attention_sink_transforms_model(self): - """Test that enable_attention_sink properly transforms the model.""" - from executorch.examples.models.llama.llama_transformer import construct_transformer - from executorch.examples.models.llama.source_transformation.attention_sink import ( - enable_attention_sink, - ) - - # Create a simple transformer - with torch.device("meta"): - model = construct_transformer(self.params) - model.to_empty(device="cpu") - - # Apply attention sink transformation - model = enable_attention_sink( - module=model, - params=self.params, - sink_size=4, - window_size=100, - eviction_batch_size=1, - ) - - # Check that KV caches are replaced - for layer in model.layers: - kv_cache = layer.attention.kv_cache - self.assertIsInstance(kv_cache, KVCacheWithAttentionSink) - self.assertEqual(kv_cache.sink_size, 4) - self.assertEqual(kv_cache.window_size, 100) - self.assertTrue(kv_cache.is_ring_buffer) - - def test_enable_attention_sink_replaces_rope(self): - """Test that RoPE is replaced with RopeWithAttentionSink.""" - from executorch.examples.models.llama.llama_transformer import construct_transformer - from executorch.examples.models.llama.source_transformation.attention_sink import ( - enable_attention_sink, - ) - - with torch.device("meta"): - model = construct_transformer(self.params) - model.to_empty(device="cpu") - - model = enable_attention_sink( - module=model, - params=self.params, - sink_size=4, - window_size=100, - eviction_batch_size=1, - ) - - # Check that rope is replaced - for layer in model.layers: - rope = layer.attention.rope - self.assertIsInstance(rope, RopeWithAttentionSink) - - -class IntegrationTest(unittest.TestCase): - """Integration tests for end-to-end scenarios.""" - - def setUp(self): - torch.manual_seed(42) - - def test_cache_positions_consistency(self): - """Test that cache positions remain consistent during generation.""" - cache_size = 32 - sink_size = 4 - window_size = 14 - n_heads = 8 - head_dim = 64 - - params = ModelArgs( - use_kv_cache=True, - enable_dynamic_shape=True, - max_context_len=1024, - n_heads=n_heads, - n_kv_heads=n_heads, - dim=n_heads * head_dim, - ) - - rope = RopeWithAttentionSink( - params=params, - window_size=window_size, - sink_size=sink_size, - eviction_batch_size=1, - ) - - kv_cache = KVCacheWithAttentionSink( - n_heads=n_heads, - head_dim=head_dim, - enable_dynamic_shape=True, - rope=rope, - max_batch_size=1, - window_size=window_size, - sink_size=sink_size, - eviction_batch_size=1, - dtype=torch.float32, - ) - - # Generate 100 tokens - for pos in range(100): - k = torch.randn(1, n_heads, 1, head_dim) - v = torch.randn(1, n_heads, 1, head_dim) - input_pos = torch.tensor([pos], dtype=torch.long) - - kv_cache.update(input_pos, k, v) - - # Create mask and verify it's valid - mask = kv_cache.create_causal_mask_for_ring_buffer(pos, 1) - - # Mask should not be all -inf (would mean no tokens to attend to) - non_inf_count = (mask != float('-inf')).sum().item() - self.assertGreater(non_inf_count, 0, f"At pos {pos}, mask should have visible tokens") - - # For positions >= sink_size, sinks should always be visible - if pos >= sink_size: - for i in range(sink_size): - cache_pos = kv_cache.cache_positions_manager.cache_positions[i].item() - if cache_pos < sink_size: # This is actually a sink token - self.assertEqual(mask[0, i].item(), 0.0, - f"Sink at idx {i} should be visible at pos {pos}") - - -if __name__ == '__main__': - unittest.main() +# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree."""Tests for the ring-buffer based attention sink implementation.This tests the torch.export-compatible implementation that uses a ring bufferfor the sliding window rather than explicit token eviction.Usage: # Run with pytest python -m pytest examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py -v # Or run directly python examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py"""import unittestimport torchfrom executorch.examples.models.llama.model_args import ModelArgsfrom executorch.examples.models.llama.source_transformation.attention_sink import ( CachePositionsManagerWithSink, KVCacheWithAttentionSink, RopeWithAttentionSink, _create_causal_mask_for_attention_sink,)class CachePositionsManagerWithSinkTest(unittest.TestCase): """Test the cache positions manager for ring buffer indexing.""" def setUp(self): self.cache_size = 32 # Total cache size (e.g., sink_size + window_size * 2) self.manager = CachePositionsManagerWithSink(self.cache_size) def test_initial_positions_are_zero(self): """Cache positions should start as zeros.""" expected = torch.zeros(self.cache_size, dtype=torch.long) torch.testing.assert_close(self.manager.cache_positions, expected) def test_simple_update(self): """Test simple sequential update without wraparound.""" input_pos = torch.tensor([0], dtype=torch.long) seq_len = 5 indices = self.manager.calculate_positions_and_update_indices(input_pos, seq_len) # Should return indices 0, 1, 2, 3, 4 expected_indices = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long) torch.testing.assert_close(indices, expected_indices) # Cache positions at those indices should be the original positions for i in range(5): self.assertEqual(self.manager.cache_positions[i].item(), i) def test_wraparound(self): """Test ring buffer wraparound.""" # Fill cache to position 30 input_pos = torch.tensor([0], dtype=torch.long) self.manager.calculate_positions_and_update_indices(input_pos, 30) # Add 5 more tokens at position 30 - should wrap around input_pos = torch.tensor([30], dtype=torch.long) indices = self.manager.calculate_positions_and_update_indices(input_pos, 5) # Indices should wrap: 30 % 32 = 30, 31, 0, 1, 2 expected_indices = torch.tensor([30, 31, 0, 1, 2], dtype=torch.long) torch.testing.assert_close(indices, expected_indices) def test_cache_positions_track_original_positions(self): """Cache positions should track which original position is at each index.""" # Fill with positions 0-31 input_pos = torch.tensor([0], dtype=torch.long) self.manager.calculate_positions_and_update_indices(input_pos, 32) # Now add position 32 which wraps to index 0 input_pos = torch.tensor([32], dtype=torch.long) self.manager.calculate_positions_and_update_indices(input_pos, 1) # Index 0 should now contain original position 32 self.assertEqual(self.manager.cache_positions[0].item(), 32)class CausalMaskTest(unittest.TestCase): """Test the causal mask generation for attention sink.""" def test_mask_allows_sink_tokens(self): """Sink tokens should always be visible (mask = 0).""" cache_size = 32 sink_size = 4 window_size = 14 # cache_size = sink_size + window_size * 2 # Create cache positions where positions 0-3 are sink tokens cache_positions = torch.arange(cache_size, dtype=torch.long) start_pos = 20 # Query at position 20 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Sink tokens (indices 0-3, original positions 0-3) should have mask = 0 for i in range(sink_size): self.assertEqual(mask[0, i].item(), 0.0, f"Sink token at index {i} should be visible") def test_mask_blocks_future_tokens(self): """Future tokens should be masked (-inf).""" cache_size = 32 sink_size = 4 window_size = 14 cache_positions = torch.arange(cache_size, dtype=torch.long) start_pos = 10 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Future tokens (positions > 10) should have mask = -inf for i in range(11, cache_size): self.assertEqual(mask[0, i].item(), float('-inf'), f"Future token at position {i} should be masked") def test_mask_respects_window(self): """Tokens outside the window should be masked.""" cache_size = 32 sink_size = 4 window_size = 5 # Only allow 5 recent tokens cache_positions = torch.arange(cache_size, dtype=torch.long) start_pos = 20 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Positions 16-20 should be visible (within window of 5) for pos in range(16, 21): self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible (in window)") # Position 15 should be masked (outside window, not a sink) self.assertEqual(mask[0, 15].item(), float('-inf'), f"Position 15 should be masked (outside window)")class KVCacheWithAttentionSinkTest(unittest.TestCase): """Test the KV cache with attention sink.""" def setUp(self): torch.manual_seed(42) self.window_size = 14 self.sink_size = 4 self.n_heads = 8 self.head_dim = 64 self.max_batch_size = 1 # Create model args with enough context for RoPE self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, # Large enough for RoPE n_heads=self.n_heads, n_kv_heads=self.n_heads, dim=self.n_heads * self.head_dim, ) self.rope = RopeWithAttentionSink( params=self.params, window_size=self.window_size, sink_size=self.sink_size, eviction_batch_size=1, ) self.kv_cache = KVCacheWithAttentionSink( n_heads=self.n_heads, head_dim=self.head_dim, enable_dynamic_shape=True, rope=self.rope, max_batch_size=self.max_batch_size, window_size=self.window_size, sink_size=self.sink_size, eviction_batch_size=1, dtype=torch.float32, ) def test_cache_size(self): """Cache should be sink_size + window_size * 2.""" expected_size = self.sink_size + self.window_size * 2 # 4 + 28 = 32 self.assertEqual(self.kv_cache.k_cache.size(2), expected_size) self.assertEqual(self.kv_cache.v_cache.size(2), expected_size) def test_is_ring_buffer(self): """Cache should be marked as ring buffer.""" self.assertTrue(self.kv_cache.is_ring_buffer) def test_update_stores_kv(self): """Update should store key-value pairs.""" k = torch.randn(1, self.n_heads, 5, self.head_dim) v = torch.randn(1, self.n_heads, 5, self.head_dim) input_pos = torch.tensor([0], dtype=torch.long) k_out, v_out = self.kv_cache.update(input_pos, k, v) # First 5 positions should contain our values torch.testing.assert_close(k_out[:, :, :5, :], k) torch.testing.assert_close(v_out[:, :, :5, :], v) def test_evict_tokens_returns_zero(self): """Ring buffer implementation doesn't shift, so evict returns 0.""" input_pos = torch.tensor([100], dtype=torch.long) shift = self.kv_cache.evict_tokens(input_pos, 10) self.assertEqual(shift, 0) def test_extended_generation(self): """Test that cache works for positions beyond cache size.""" cache_size = self.kv_cache.k_cache.size(2) # Fill cache with initial tokens for pos in range(cache_size + 50): k = torch.randn(1, self.n_heads, 1, self.head_dim) v = torch.randn(1, self.n_heads, 1, self.head_dim) input_pos = torch.tensor([pos], dtype=torch.long) k_out, v_out = self.kv_cache.update(input_pos, k, v) # Should not raise any errors self.assertEqual(k_out.shape, self.kv_cache.k_cache.shape) self.assertEqual(v_out.shape, self.kv_cache.v_cache.shape)class RopeWithAttentionSinkTest(unittest.TestCase): """Test RoPE for attention sink.""" def setUp(self): torch.manual_seed(42) self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, n_heads=8, dim=512, ) self.rope = RopeWithAttentionSink( params=self.params, window_size=100, sink_size=4, eviction_batch_size=1, ) def test_get_freqs_uses_original_position(self): """RoPE frequencies should use the original position.""" input_pos = torch.tensor([50], dtype=torch.long) seq_len = 5 freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seq_len) # Should get frequencies for positions 50-54 expected_cos = self.rope.freqs_cos[50:55] expected_sin = self.rope.freqs_sin[50:55] torch.testing.assert_close(freqs_cos, expected_cos) torch.testing.assert_close(freqs_sin, expected_sin) def test_rerotate_k(self): """Test re-rotation of k from one position to another.""" batch_size = 1 seq_len = 8 n_heads = self.params.n_heads head_dim = self.params.dim // n_heads k = torch.randn(batch_size, seq_len, n_heads, head_dim) q = torch.randn(batch_size, seq_len, n_heads, head_dim) # Rotate k at position 100 original_pos = 100 freqs_cos, freqs_sin = self.rope.get_freqs( torch.tensor([original_pos], dtype=torch.long), seq_len ) _, rotated_k = self.rope.forward(q, k, freqs_cos, freqs_sin) # Re-rotate to position 50 new_pos = 50 rerotated_k = self.rope.rerotate_k(rotated_k, original_pos, new_pos) # This should be equivalent to directly rotating k at position 50 freqs_cos_new, freqs_sin_new = self.rope.get_freqs( torch.tensor([new_pos], dtype=torch.long), seq_len ) _, expected_k = self.rope.forward(q, k, freqs_cos_new, freqs_sin_new) torch.testing.assert_close(rerotated_k, expected_k, rtol=1e-4, atol=1e-4)class CausalMaskWithWraparoundTest(unittest.TestCase): """Test causal mask with ring buffer wraparound.""" def test_mask_after_wraparound(self): """Test mask after cache has wrapped around.""" cache_size = 16 sink_size = 4 window_size = 6 # cache_size = sink_size + window_size * 2 # Simulate cache after generating beyond cache_size: # The ring buffer wraps, so indices 0-15 contain positions that wrap # At position 50, with cache_size=16, the cache contains: # positions 50-15=35 to 49 at various indices cache_positions = torch.zeros(cache_size, dtype=torch.long) # Fill with positions that would exist after generating 50 tokens # idx = pos % cache_size, so: # pos 34-49 occupy indices 2-15 and 0-1 for pos in range(34, 50): idx = pos % cache_size cache_positions[idx] = pos start_pos = 49 # Query at position 49 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Positions within window (49-6+1=44 to 49) should be visible visible_count = 0 for i in range(cache_size): pos = cache_positions[i].item() if pos >= 44 and pos <= 49: # In window self.assertEqual(mask[0, i].item(), 0.0, f"Position {pos} at idx {i} should be visible (in window)") visible_count += 1 # Should have some visible tokens self.assertGreater(visible_count, 0, "Should have visible tokens in window") def test_mask_with_sink_size_zero(self): """Test pure sliding window (sink_size=0).""" cache_size = 16 sink_size = 0 window_size = 8 cache_positions = torch.arange(cache_size, dtype=torch.long) start_pos = 10 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Positions 3-10 should be visible (within window of 8) for pos in range(3, 11): self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible") # Positions 0-2 should be masked (outside window) for pos in range(0, 3): self.assertEqual(mask[0, pos].item(), float('-inf'), f"Position {pos} should be masked (outside window)")class PrefillTest(unittest.TestCase): """Test prefill scenarios.""" def setUp(self): torch.manual_seed(42) self.window_size = 14 self.sink_size = 4 self.n_heads = 8 self.head_dim = 64 self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, n_heads=self.n_heads, n_kv_heads=self.n_heads, dim=self.n_heads * self.head_dim, ) self.rope = RopeWithAttentionSink( params=self.params, window_size=self.window_size, sink_size=self.sink_size, eviction_batch_size=1, ) self.kv_cache = KVCacheWithAttentionSink( n_heads=self.n_heads, head_dim=self.head_dim, enable_dynamic_shape=True, rope=self.rope, max_batch_size=1, window_size=self.window_size, sink_size=self.sink_size, eviction_batch_size=1, dtype=torch.float32, ) def test_prefill_entire_cache(self): """Test prefill that fills entire cache.""" cache_size = self.kv_cache.k_cache.size(2) k = torch.randn(1, self.n_heads, cache_size, self.head_dim) v = torch.randn(1, self.n_heads, cache_size, self.head_dim) input_pos = torch.tensor([0], dtype=torch.long) k_out, v_out = self.kv_cache.update(input_pos, k, v) # All positions should be filled torch.testing.assert_close(k_out, k) torch.testing.assert_close(v_out, v) def test_prefill_larger_than_cache_raises_error(self): """Test that prefill larger than cache size raises an assertion error.""" cache_size = self.kv_cache.k_cache.size(2) seq_len = cache_size + 10 k = torch.randn(1, self.n_heads, seq_len, self.head_dim) v = torch.randn(1, self.n_heads, seq_len, self.head_dim) input_pos = torch.tensor([0], dtype=torch.long) # This should raise an assertion error since seq_len > cache_size with self.assertRaises(AssertionError): self.kv_cache.update(input_pos, k, v) def test_prefill_followed_by_decode(self): """Test prefill followed by decode steps.""" cache_size = self.kv_cache.k_cache.size(2) # Prefill with 20 tokens k_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) v_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) input_pos = torch.tensor([0], dtype=torch.long) self.kv_cache.update(input_pos, k_prefill, v_prefill) # Decode 5 more tokens for i in range(5): k_decode = torch.randn(1, self.n_heads, 1, self.head_dim) v_decode = torch.randn(1, self.n_heads, 1, self.head_dim) input_pos = torch.tensor([20 + i], dtype=torch.long) k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode) # Verify cache positions are updated expected_pos = 20 + i cache_idx = expected_pos % cache_size self.assertEqual( self.kv_cache.cache_positions_manager.cache_positions[cache_idx].item(), expected_pos )class EnableAttentionSinkTest(unittest.TestCase): """Test the enable_attention_sink transformation.""" def setUp(self): torch.manual_seed(42) self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, n_heads=8, n_kv_heads=8, dim=512, n_layers=2, vocab_size=100, ) def test_enable_attention_sink_transforms_model(self): """Test that enable_attention_sink properly transforms the model.""" from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.source_transformation.attention_sink import ( enable_attention_sink, ) # Create a simple transformer with torch.device("meta"): model = construct_transformer(self.params) model.to_empty(device="cpu") # Apply attention sink transformation model = enable_attention_sink( module=model, params=self.params, sink_size=4, window_size=100, eviction_batch_size=1, ) # Check that KV caches are replaced for layer in model.layers: kv_cache = layer.attention.kv_cache self.assertIsInstance(kv_cache, KVCacheWithAttentionSink) self.assertEqual(kv_cache.sink_size, 4) self.assertEqual(kv_cache.window_size, 100) self.assertTrue(kv_cache.is_ring_buffer) def test_enable_attention_sink_replaces_rope(self): """Test that RoPE is replaced with RopeWithAttentionSink.""" from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.source_transformation.attention_sink import ( enable_attention_sink, ) with torch.device("meta"): model = construct_transformer(self.params) model.to_empty(device="cpu") model = enable_attention_sink( module=model, params=self.params, sink_size=4, window_size=100, eviction_batch_size=1, ) # Check that rope is replaced for layer in model.layers: rope = layer.attention.rope self.assertIsInstance(rope, RopeWithAttentionSink)class IntegrationTest(unittest.TestCase): """Integration tests for end-to-end scenarios.""" def setUp(self): torch.manual_seed(42) def test_cache_positions_consistency(self): """Test that cache positions remain consistent during generation.""" cache_size = 32 sink_size = 4 window_size = 14 n_heads = 8 head_dim = 64 params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, n_heads=n_heads, n_kv_heads=n_heads, dim=n_heads * head_dim, ) rope = RopeWithAttentionSink( params=params, window_size=window_size, sink_size=sink_size, eviction_batch_size=1, ) kv_cache = KVCacheWithAttentionSink( n_heads=n_heads, head_dim=head_dim, enable_dynamic_shape=True, rope=rope, max_batch_size=1, window_size=window_size, sink_size=sink_size, eviction_batch_size=1, dtype=torch.float32, ) # Generate 100 tokens for pos in range(100): k = torch.randn(1, n_heads, 1, head_dim) v = torch.randn(1, n_heads, 1, head_dim) input_pos = torch.tensor([pos], dtype=torch.long) kv_cache.update(input_pos, k, v) # Create mask and verify it's valid mask = kv_cache.create_causal_mask_for_ring_buffer(pos, 1) # Mask should not be all -inf (would mean no tokens to attend to) non_inf_count = (mask != float('-inf')).sum().item() self.assertGreater(non_inf_count, 0, f"At pos {pos}, mask should have visible tokens") # For positions >= sink_size, sinks should always be visible if pos >= sink_size: for i in range(sink_size): cache_pos = kv_cache.cache_positions_manager.cache_positions[i].item() if cache_pos < sink_size: # This is actually a sink token self.assertEqual(mask[0, i].item(), 0.0, f"Sink at idx {i} should be visible at pos {pos}")if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index 11d77f321bf..7924d24082f 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -1,310 +1 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated - */ - -// A simple llama2 runner that includes preprocessing and post processing logic. -// The module takes in a string as input and emits a string as output. - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace executorch::extension::llm { - -using ::executorch::extension::Module; -using ::executorch::runtime::Error; -using ::executorch::runtime::Result; - -TextLLMRunner::TextLLMRunner( - std::unordered_map metadata, - std::unique_ptr<::tokenizers::Tokenizer> tokenizer, - std::unique_ptr<::executorch::extension::Module> module, - std::unique_ptr text_decoder_runner, - std::unique_ptr text_prefiller, - std::unique_ptr io_manager, - std::unique_ptr text_token_generator, - std::unique_ptr stats, - float temperature) - : tokenizer_(std::move(tokenizer)), - metadata_(std::move(metadata)), - module_(std::move(module)), - text_decoder_runner_(std::move(text_decoder_runner)), - text_prefiller_(std::move(text_prefiller)), - io_manager_(std::move(io_manager)), - text_token_generator_(std::move(text_token_generator)), - stats_(std::move(stats)), - temperature_(temperature), - pos_(0) { - // Note: This constructor assumes that text_prefiller and text_token_generator - // already have references to the Module and TextDecoderRunner they need -} - -bool TextLLMRunner::is_loaded() const { - return text_prefiller_->is_loaded() && text_token_generator_->is_loaded(); -} - -Error TextLLMRunner::load() { - if (is_loaded()) { - return Error::Ok; - } - ET_CHECK_OK_OR_RETURN_ERROR(text_prefiller_->load()); - ET_CHECK_OK_OR_RETURN_ERROR(io_manager_->load()); - ET_CHECK_OK_OR_RETURN_ERROR(text_token_generator_->load()); - return Error::Ok; -} - -// Don't print with the same priority during warmup -#define RUNNER_ET_LOG(warmup, format, ...) \ - if (warmup) { \ - ET_LOG(Debug, format, __VA_ARGS__); \ - } else { \ - ET_LOG(Info, format, __VA_ARGS__); \ - } - -Error TextLLMRunner::generate( - const std::string& prompt, - const GenerationConfig& config, - std::function token_callback, - std::function stats_callback) { - // Prepare the inputs. - // Use ones-initialized inputs. - ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); - if (!is_loaded()) { - stats_->model_load_start_ms = time_in_ms(); - ET_CHECK_OK_OR_RETURN_ERROR(load()); - stats_->model_load_end_ms = time_in_ms(); - } - - if (config.warming) { - ET_LOG(Info, "Doing a warmup run..."); - } - - RUNNER_ET_LOG( - config.warming, - "RSS after loading model: %f MiB (0 if unsupported)", - get_rss_bytes() / 1024.0 / 1024.0); - - // Wrap the token_callback with print function - std::function wrapped_callback = - [token_callback, config](const std::string& piece) { - if (!config.warming) { - llm::safe_printf(piece.c_str()); - fflush(stdout); - } - if (token_callback) { - token_callback(piece); - } - }; - // First token time only measures the time it takes to encode the prompt and - // return a response token. - - stats_->inference_start_ms = time_in_ms(); - shouldStop_ = false; - - ::tokenizers::Result> encode_res = tokenizer_->encode( - prompt, - /*bos=*/config.num_bos, - /*eos=*/config.num_eos); - - if (!encode_res.ok()) { - ET_LOG( - Error, - "Failed to encode prompt %s. Tokenizers error code %d", - prompt.c_str(), - static_cast(encode_res.error())); - return Error::InvalidArgument; - } - - // encode the (string) prompt into tokens sequence - std::vector prompt_tokens = encode_res.get(); - int num_prompt_tokens = prompt_tokens.size(); - - // Get max_seq_len for single prefill chunk limit - int64_t max_seq_len = metadata_.at(kMaxSeqLen); - int64_t max_context_len = metadata_.at(kMaxContextLen); - - ET_CHECK_OR_RETURN_ERROR( - num_prompt_tokens >= 1, - InvalidArgument, - "Expected at least 1 prompt token"); - - // For models with sliding window (Ring Buffer / Attention Sink), - // we allow pos_ to exceed max_context_len. The model handles this - // internally via ring buffer indexing or token eviction. - // We only check that a single prefill chunk doesn't exceed max_seq_len. - ET_CHECK_OR_RETURN_ERROR( - num_prompt_tokens <= max_seq_len, - InvalidArgument, - "num_prompt_tokens %d > max_seq_len %" PRId64 - ", Single prefill chunk too large - please reduce prompt size or increase max_seq_len", - num_prompt_tokens, - max_seq_len); - - // Determine max_new_tokens using the GenerationConfig's resolve method. - // For sliding window models, we use max_context_len directly (not reduced by pos_) - // because the model handles position wrapping internally via ring buffer. - int max_new_tokens = - config.resolve_max_new_tokens(max_context_len, num_prompt_tokens); - - // TEMPORARY: For sliding window / infinite context testing, - // override max_new_tokens to allow unlimited generation - if (max_new_tokens <= 0) { - max_new_tokens = 1000000; // Effectively unlimited - } - // If user specified seq_len, use that instead - if (config.seq_len > 0 && config.seq_len > max_new_tokens) { - max_new_tokens = config.seq_len; - } - - ET_LOG( - Info, - "Max new tokens resolved: %d, given pos_ %" PRId64 - ", num_prompt_tokens %zu, max_context_len %" PRId64, - max_new_tokens, - pos_, - prompt_tokens.size(), - max_context_len); - ET_CHECK_OR_RETURN_ERROR( - max_new_tokens > 0, - InvalidArgument, - "Max new tokens %d is less than or equal to 0", - max_new_tokens); - // Prefill first - // Here feed all tokens to the model and get the next predicted token - // after the prompt. After that we will enter generate loop. - - // print prompts - if (config.echo) { - wrapped_callback(prompt); - } - auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); - ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); - uint64_t cur_token = prefill_res.get(); - stats_->first_token_ms = time_in_ms(); - stats_->prompt_eval_end_ms = time_in_ms(); - - // print the first token from prefill. No prev_token so use cur_token for it. - auto decode_result = tokenizer_->decode(cur_token, cur_token); - if (!decode_result.ok()) { - ET_LOG( - Error, - "Tokenizers error code %d", - static_cast(decode_result.error())); - return ::executorch::runtime::Error::InvalidArgument; - } - wrapped_callback(std::move(*decode_result)); - RUNNER_ET_LOG( - config.warming, - "RSS after prompt prefill: %f MiB (0 if unsupported)", - get_rss_bytes() / 1024.0 / 1024.0); - - // start the main loop - prompt_tokens.push_back(cur_token); - - // Set ignore_eos based on config - text_token_generator_->set_ignore_eos(config.ignore_eos); - - // Generate max_new_tokens - 1 because prefill already generated 1 token. - auto generate_result = text_token_generator_->generate( - prompt_tokens, - pos_, - max_new_tokens - 1, - temperature_ == -1.0f ? config.temperature : temperature_, - wrapped_callback); - if (!generate_result.ok()) { - return generate_result.error(); - } - int64_t num_generated_tokens = generate_result.get(); - - pos_ += num_generated_tokens; - - stats_->inference_end_ms = time_in_ms(); - if (!config.warming) { - printf("\n"); - } - RUNNER_ET_LOG( - config.warming, - "RSS after finishing text generation: %f MiB (0 if unsupported)", - get_rss_bytes() / 1024.0 / 1024.0); - - if (num_generated_tokens == max_new_tokens) { - RUNNER_ET_LOG(config.warming, "Max new tokens %i reached!", max_new_tokens); - } - - stats_->num_prompt_tokens = num_prompt_tokens; - stats_->num_generated_tokens = num_generated_tokens; - - if (config.warming) { - ET_LOG(Info, "Warmup run finished!"); - } else { - // Do not print report during warmup - print_report(*stats_); - } - if (stats_callback) { - stats_callback(*stats_); - } - - return Error::Ok; -} - -Error TextLLMRunner::prefill( - const std::string& prompt, - const GenerationConfig& config) { - if (!is_loaded()) { - ET_CHECK_OK_OR_RETURN_ERROR(load()); - } - - ::tokenizers::Result> encode_res = tokenizer_->encode( - prompt, - /*bos=*/config.num_bos, - /*eos=*/config.num_eos); - - ET_CHECK_TK_OK_OR_RETURN_ERROR( - encode_res.error(), "Failed to encode prompt %s", prompt.c_str()); - - // encode the (string) prompt into tokens sequence - std::vector prompt_tokens = encode_res.get(); - auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); - ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); - return Error::Ok; -} - -Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) { - // Create a GenerationConfig for warmup - GenerationConfig config; - config.echo = false; - config.max_new_tokens = max_new_tokens; - config.warming = true; - - // Call generate with the warmup config - Error err = generate(prompt, config); - - // Reset stats after warmup, not resetting the std::unique_ptr! - reset(); - return err; -} - -void TextLLMRunner::stop() { - if (is_loaded()) { - text_token_generator_->stop(); - } else { - ET_LOG(Error, "Token generator is not loaded, cannot stop"); - } -} - -void TextLLMRunner::reset() { - stats_->reset(); - pos_ = 0; -} - -} // namespace executorch::extension::llm +/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated */// A simple llama2 runner that includes preprocessing and post processing logic.// The module takes in a string as input and emits a string as output.#include #include #include #include #include #include #include #include namespace executorch::extension::llm {using ::executorch::extension::Module;using ::executorch::runtime::Error;using ::executorch::runtime::Result;TextLLMRunner::TextLLMRunner( std::unordered_map metadata, std::unique_ptr<::tokenizers::Tokenizer> tokenizer, std::unique_ptr<::executorch::extension::Module> module, std::unique_ptr text_decoder_runner, std::unique_ptr text_prefiller, std::unique_ptr io_manager, std::unique_ptr text_token_generator, std::unique_ptr stats, float temperature) : tokenizer_(std::move(tokenizer)), metadata_(std::move(metadata)), module_(std::move(module)), text_decoder_runner_(std::move(text_decoder_runner)), text_prefiller_(std::move(text_prefiller)), io_manager_(std::move(io_manager)), text_token_generator_(std::move(text_token_generator)), stats_(std::move(stats)), temperature_(temperature), pos_(0) { // Note: This constructor assumes that text_prefiller and text_token_generator // already have references to the Module and TextDecoderRunner they need}bool TextLLMRunner::is_loaded() const { return text_prefiller_->is_loaded() && text_token_generator_->is_loaded();}Error TextLLMRunner::load() { if (is_loaded()) { return Error::Ok; } ET_CHECK_OK_OR_RETURN_ERROR(text_prefiller_->load()); ET_CHECK_OK_OR_RETURN_ERROR(io_manager_->load()); ET_CHECK_OK_OR_RETURN_ERROR(text_token_generator_->load()); return Error::Ok;}// Don't print with the same priority during warmup#define RUNNER_ET_LOG(warmup, format, ...) \ if (warmup) { \ ET_LOG(Debug, format, __VA_ARGS__); \ } else { \ ET_LOG(Info, format, __VA_ARGS__); \ }Error TextLLMRunner::generate( const std::string& prompt, const GenerationConfig& config, std::function token_callback, std::function stats_callback) { // Prepare the inputs. // Use ones-initialized inputs. ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); if (!is_loaded()) { stats_->model_load_start_ms = time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(load()); stats_->model_load_end_ms = time_in_ms(); } if (config.warming) { ET_LOG(Info, "Doing a warmup run..."); } RUNNER_ET_LOG( config.warming, "RSS after loading model: %f MiB (0 if unsupported)", get_rss_bytes() / 1024.0 / 1024.0); // Wrap the token_callback with print function std::function wrapped_callback = [token_callback, config](const std::string& piece) { if (!config.warming) { llm::safe_printf(piece.c_str()); fflush(stdout); } if (token_callback) { token_callback(piece); } }; // First token time only measures the time it takes to encode the prompt and // return a response token. stats_->inference_start_ms = time_in_ms(); shouldStop_ = false; ::tokenizers::Result> encode_res = tokenizer_->encode( prompt, /*bos=*/config.num_bos, /*eos=*/config.num_eos); if (!encode_res.ok()) { ET_LOG( Error, "Failed to encode prompt %s. Tokenizers error code %d", prompt.c_str(), static_cast(encode_res.error())); return Error::InvalidArgument; } // encode the (string) prompt into tokens sequence std::vector prompt_tokens = encode_res.get(); int num_prompt_tokens = prompt_tokens.size(); // Get max_seq_len for single prefill chunk limit int64_t max_seq_len = metadata_.at(kMaxSeqLen); int64_t max_context_len = metadata_.at(kMaxContextLen); ET_CHECK_OR_RETURN_ERROR( num_prompt_tokens >= 1, InvalidArgument, "Expected at least 1 prompt token"); // For models with sliding window (Ring Buffer / Attention Sink), // we allow pos_ to exceed max_context_len. The model handles this // internally via ring buffer indexing or token eviction. // We only check that a single prefill chunk doesn't exceed max_seq_len. ET_CHECK_OR_RETURN_ERROR( num_prompt_tokens <= max_seq_len, InvalidArgument, "num_prompt_tokens %d > max_seq_len %" PRId64 ", Single prefill chunk too large - please reduce prompt size or increase max_seq_len", num_prompt_tokens, max_seq_len); // Determine max_new_tokens using the GenerationConfig's resolve method. // For sliding window models, we use max_context_len directly (not reduced by pos_) // because the model handles position wrapping internally via ring buffer. int max_new_tokens = config.resolve_max_new_tokens(max_context_len, num_prompt_tokens); // TEMPORARY: For sliding window / infinite context testing, // override max_new_tokens to allow unlimited generation if (max_new_tokens <= 0) { max_new_tokens = 1000000; // Effectively unlimited } // If user specified seq_len, use that instead if (config.seq_len > 0 && config.seq_len > max_new_tokens) { max_new_tokens = config.seq_len; } ET_LOG( Info, "Max new tokens resolved: %d, given pos_ %" PRId64 ", num_prompt_tokens %zu, max_context_len %" PRId64, max_new_tokens, pos_, prompt_tokens.size(), max_context_len); ET_CHECK_OR_RETURN_ERROR( max_new_tokens > 0, InvalidArgument, "Max new tokens %d is less than or equal to 0", max_new_tokens); // Prefill first // Here feed all tokens to the model and get the next predicted token // after the prompt. After that we will enter generate loop. // print prompts if (config.echo) { wrapped_callback(prompt); } auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); uint64_t cur_token = prefill_res.get(); stats_->first_token_ms = time_in_ms(); stats_->prompt_eval_end_ms = time_in_ms(); // print the first token from prefill. No prev_token so use cur_token for it. auto decode_result = tokenizer_->decode(cur_token, cur_token); if (!decode_result.ok()) { ET_LOG( Error, "Tokenizers error code %d", static_cast(decode_result.error())); return ::executorch::runtime::Error::InvalidArgument; } wrapped_callback(std::move(*decode_result)); RUNNER_ET_LOG( config.warming, "RSS after prompt prefill: %f MiB (0 if unsupported)", get_rss_bytes() / 1024.0 / 1024.0); // start the main loop prompt_tokens.push_back(cur_token); // Set ignore_eos based on config text_token_generator_->set_ignore_eos(config.ignore_eos); // Generate max_new_tokens - 1 because prefill already generated 1 token. auto generate_result = text_token_generator_->generate( prompt_tokens, pos_, max_new_tokens - 1, temperature_ == -1.0f ? config.temperature : temperature_, wrapped_callback); if (!generate_result.ok()) { return generate_result.error(); } int64_t num_generated_tokens = generate_result.get(); pos_ += num_generated_tokens; stats_->inference_end_ms = time_in_ms(); if (!config.warming) { printf("\n"); } RUNNER_ET_LOG( config.warming, "RSS after finishing text generation: %f MiB (0 if unsupported)", get_rss_bytes() / 1024.0 / 1024.0); if (num_generated_tokens == max_new_tokens) { RUNNER_ET_LOG(config.warming, "Max new tokens %i reached!", max_new_tokens); } stats_->num_prompt_tokens = num_prompt_tokens; stats_->num_generated_tokens = num_generated_tokens; if (config.warming) { ET_LOG(Info, "Warmup run finished!"); } else { // Do not print report during warmup print_report(*stats_); } if (stats_callback) { stats_callback(*stats_); } return Error::Ok;}Error TextLLMRunner::prefill( const std::string& prompt, const GenerationConfig& config) { if (!is_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(load()); } ::tokenizers::Result> encode_res = tokenizer_->encode( prompt, /*bos=*/config.num_bos, /*eos=*/config.num_eos); ET_CHECK_TK_OK_OR_RETURN_ERROR( encode_res.error(), "Failed to encode prompt %s", prompt.c_str()); // encode the (string) prompt into tokens sequence std::vector prompt_tokens = encode_res.get(); auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); return Error::Ok;}Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) { // Create a GenerationConfig for warmup GenerationConfig config; config.echo = false; config.max_new_tokens = max_new_tokens; config.warming = true; // Call generate with the warmup config Error err = generate(prompt, config); // Reset stats after warmup, not resetting the std::unique_ptr! reset(); return err;}void TextLLMRunner::stop() { if (is_loaded()) { text_token_generator_->stop(); } else { ET_LOG(Error, "Token generator is not loaded, cannot stop"); }}void TextLLMRunner::reset() { stats_->reset(); pos_ = 0;}} // namespace executorch::extension::llm \ No newline at end of file From fe66e7443b0778c4b5bff3a531601262a796b6b7 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Feb 2026 14:11:02 -0800 Subject: [PATCH 06/17] Revert "Remove trailing whitespace" This reverts commit 075d521b2960945aedd4f4a908b3f69107c5f18d. --- examples/models/llama/eval_llama_lib.py | 404 +++++++++++- .../source_transformation/attention_sink.py | 411 +++++++++++- .../test_attention_sink_ring_buffer.py | 595 +++++++++++++++++- extension/llm/runner/text_llm_runner.cpp | 311 ++++++++- 4 files changed, 1717 insertions(+), 4 deletions(-) diff --git a/examples/models/llama/eval_llama_lib.py b/examples/models/llama/eval_llama_lib.py index d8e107f74ef..e13a5299e61 100644 --- a/examples/models/llama/eval_llama_lib.py +++ b/examples/models/llama/eval_llama_lib.py @@ -1 +1,403 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.import argparsefrom typing import Optional, Unionimport torchfrom datasets import load_datasetfrom executorch.examples.models.llama.export_llama_lib import ( get_quantizer_and_quant_params,)from executorch.extension.llm.export.builder import LLMEdgeManagerfrom lm_eval.evaluator import simple_evaluatefrom pytorch_tokenizers import get_tokenizerfrom pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizerfrom pytorch_tokenizers.tiktoken import TiktokenTokenizer as Tiktokenfrom torch.nn import CrossEntropyLossfrom tqdm import tqdmfrom .evaluate.eager_eval import EagerEvalWrapperfrom .export_llama_lib import ( _prepare_for_llama_export, build_args_parser as _build_args_parser,)class GraphModuleEvalWrapper(EagerEvalWrapper): """ A wrapper class for ExecuTorch py-binded integration with the lm-evaluation-harness library. """ def __init__( self, model: torch.fx.GraphModule, tokenizer: Union[SentencePieceTokenizer, Tiktoken], max_seq_length: Optional[int] = None, use_kv_cache: bool = False, generate_full_logits: bool = False, enable_dynamic_shape: bool = True, ): super().__init__( model=model, tokenizer=tokenizer, max_seq_length=max_seq_length ) self._model = model.to(self.device) self._use_kv_cache = use_kv_cache self._generate_full_logits = generate_full_logits self._enable_dynamic_shape = enable_dynamic_shape def _model_call(self, inps): if self._use_kv_cache: if not self._enable_dynamic_shape: # graph module exported without dynamic shape won't work with a different shape. # And we have to do single token prefill here. result_logits = [] for pos in range(inps.shape[-1]): pos_tensor = torch.tensor([pos], dtype=torch.int64) logits = self._model( inps[:, pos : pos + 1], {"input_pos": pos_tensor} ) result_logits.append(logits) if self._generate_full_logits: return torch.cat(result_logits, dim=1) else: return torch.stack(result_logits, dim=1) else: pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) # Batch process the whole sequence. logits = self._model( inps[:, : self._max_seq_length], {"input_pos": pos_tensor} ) return logits else: return self._model(inps) def _model_generate(self, context, max_length, eos_token_id): raise Exception("unimplemented")class ETPybindEvalWrapper(EagerEvalWrapper): """ A wrapper class for ExecuTorch py-binded integration with the lm-evaluation-harness library. """ def __init__( self, model: str, tokenizer: Union[SentencePieceTokenizer, Tiktoken], max_seq_length: Optional[int] = None, ): super().__init__(None, tokenizer, max_seq_length) # pyre-ignore self._model = model # Expects model to be path to a .pte file from executorch.extension.pybindings.portable_lib import _load_for_executorch # Load custom ops and quantized ops. from executorch.extension.pybindings import portable_lib # noqa # usort: skip # Note: import this after portable_lib from executorch.extension.llm.custom_ops import ( # noqa custom_ops, # usort: skip ) from executorch.kernels import quantized # noqa self._et_model = _load_for_executorch(self._model) self._use_kv_cache = self._et_model.run_method("use_kv_cache")[0] # pyre-ignore def _model_call(self, inps): # Given inps (tokens), return the logits from a single forward call # inps: Tensor of shape (1, max_seq_len - 1) # logits: Tensor of shape (1, max_seq_len - 1, vocab_size) result = [] if self._use_kv_cache: pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) result = self._et_model.forward( (inps[:, : self._max_seq_length], pos_tensor) ) else: result = self._et_model.forward((inps,)) if result[0].dim() != 3: raise ValueError( f"Dim of logits must be 3 for evaluation. Got {result[0].dim()} here. Add --generate_full_logits in export_llama to generate a pte file with full logits." ) return result[0]class ETRunnerEvalWrapper(EagerEvalWrapper): """ A wrapper class for ExecuTorch Runtime integration with the lm-evaluation-harness library. """ def __init__( self, model: str, tokenizer: Union[SentencePieceTokenizer, Tiktoken], tokenizer_bin: str, max_seq_length: Optional[int] = None, ): super().__init__(None, tokenizer, max_seq_length) # pyre-ignore self._model = model self._tokenizer_bin = tokenizer_bin def _model_call(self, inps): # Given inps (tokens), return the logits from a single # forward call # Example: # inps: Tensor of shape (1, N) # logits: Tensor of shape (1, N, vocab_size) passdef gen_eval_wrapper( model_name: str, args: argparse.ArgumentParser, llm_config=None,): """ Generates a wrapper interface around the provided model and tokenizer for the lm-evaluation-harness library. Returns: eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library. """ # If llm_config is not provided, convert args to llm_config if llm_config is None: from executorch.extension.llm.export.config.llm_config import LlmConfig llm_config = LlmConfig.from_args(args) tokenizer = get_tokenizer(llm_config.base.tokenizer_path) # ExecuTorch Binary Evaluation if (model := args.pte) is not None: # pyre-ignore if (tokenizer_bin := args.tokenizer_bin) is not None: # pyre-ignore # ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime return ETRunnerEvalWrapper( model=model, tokenizer=tokenizer, tokenizer_bin=tokenizer_bin, max_seq_length=llm_config.export.max_seq_length, ) # ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings return ETPybindEvalWrapper( model=model, tokenizer=tokenizer, # Exported model takes at most (max_seq_length - 1) tokens. # Note that the eager model takes at most max_seq_length tokens. max_seq_length=llm_config.export.max_seq_length - 1, ) pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( llm_config ) # GPTFastEvalWrapper: Create a wrapper around a pre-exported model manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) if len(quantizers) != 0: manager = manager.export().pt2e_quantize(quantizers) model = ( manager.pre_autograd_graph_module.to(device="cuda") # pyre-ignore if torch.cuda.is_available() else manager.pre_autograd_graph_module.to(device="cpu") ) return GraphModuleEvalWrapper( model=model, tokenizer=tokenizer, max_seq_length=llm_config.export.max_seq_length, use_kv_cache=llm_config.model.use_kv_cache, enable_dynamic_shape=llm_config.model.enable_dynamic_shape, ) else: # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch # for quantizers. Currently export only works with --kv_cache, but # fails without the kv_cache mode model = ( manager.model.eval().to(device="cuda") if torch.cuda.is_available() else manager.model.eval().to(device="cpu") ) # Save the checkpoint after the eager model preparation is done. # The reason for this option is that the checkpoint can be used # to do evaluations in other evaluation platforms, or with data # that is not available in this eval_llama. We save the checkpoint # here for consistency with eval_llama. The accuracy results we # get from eval_llama can be used as a reference to other evaluations. if args.output_eager_checkpoint_file is not None: # pyre-ignore torch.save(model, args.output_eager_checkpoint_file) return EagerEvalWrapper( model=model, tokenizer=tokenizer, max_seq_length=llm_config.export.max_seq_length, use_kv_cache=llm_config.model.use_kv_cache, )def build_args_parser() -> argparse.ArgumentParser: # Start with arg parser from export_llama_lib parser = _build_args_parser() # Add additional args specific to eval parser.add_argument( "--tasks", nargs="+", type=str, default=["wikitext"], help="list of lm-eluther tasks to evaluate usage: --tasks task1 task2", ) parser.add_argument( "--limit", type=int, default=None, help="number of samples to evalulate. If not set, evaluate all samples", ) parser.add_argument( "-f", "--num_fewshot", type=int, default=None, metavar="N", help="Number of examples in few-shot context", ) # Add additional args specific to eval via an ET Runner # Note: For initial integration, the tokenizer.model is also required parser.add_argument( "--pte", type=str, default=None, help="[For ExecuTorch] Path to the ExecuTorch model being evaluated. If provided, don't go through the export flow", ) parser.add_argument( "--tokenizer_bin", type=str, default=None, help="[For ExecuTorch] Path to the Tokenizer binary for evaluating ExecuTorch models via runtime", ) parser.add_argument( "--output_eager_checkpoint_file", type=str, default=None, help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.", ) # Set of parameters secpific to AttentionSink. parser.add_argument("--attention_sink_eval_tokens", type=int, default=0) return parserdef eval_llama( model_name: str, args: argparse.ArgumentParser,) -> None: # Convert args to LlmConfig from executorch.extension.llm.export.config.llm_config import LlmConfig llm_config = LlmConfig.from_args(args) # Generate the eval wrapper eval_wrapper = gen_eval_wrapper(model_name, args, llm_config) # Needed for loading mmlu dataset. # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files if args.tasks and "mmlu" in args.tasks: import datasets datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True # Evaluate the model with torch.no_grad(): eval_results = simple_evaluate( model=eval_wrapper, tasks=args.tasks, num_fewshot=args.num_fewshot, limit=args.limit, ) for task, res in eval_results["results"].items(): print(f"{task}: {res}")def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParser): """ Evaluate the model's perplexity when AttentionSink is enabled. This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py Updated for the ring-buffer based attention sink implementation. """ # Convert args to LlmConfig from executorch.extension.llm.export.config.llm_config import LlmConfig llm_config = LlmConfig.from_args(args) assert llm_config.model.use_attention_sink is not None assert args.attention_sink_eval_tokens > 0 attention_sink_params = llm_config.model.use_attention_sink.split(",") assert len(attention_sink_params) == 3 sink_size = int(attention_sink_params[0]) window_size = int(attention_sink_params[1]) # For the ring buffer implementation, the cache size is sink_size + window_size * 2 # max_context_length should be >= sink_size + window_size (for RoPE frequencies) # but can be larger to support extended generation assert llm_config.export.max_context_length >= sink_size + window_size, ( f"max_context_length ({llm_config.export.max_context_length}) must be >= " f"sink_size + window_size ({sink_size + window_size})" ) device = "cuda" if torch.cuda.is_available() else "cpu" manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) model = manager.model.eval().to(device=device) tokenizer = get_tokenizer(llm_config.base.tokenizer_path) eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") nlls = [] loss_fn = CrossEntropyLoss(reduction="none") progress_bar = tqdm(total=args.attention_sink_eval_tokens) input_pos = 0 while input_pos < args.attention_sink_eval_tokens: for text in eval_data["text"]: tokens = tokenizer.encode(text, bos=False, eos=False) if len(tokens) <= 0: continue with torch.no_grad(): num_tokens = min( len(tokens) - 1, args.attention_sink_eval_tokens - input_pos ) logits = model( torch.tensor( [tokens[:num_tokens]], dtype=torch.int64, device=device ), torch.tensor([input_pos], dtype=torch.int64, device=device), ).squeeze(dim=0) neg_log_likelihood = loss_fn( logits, torch.tensor( [tokens[1 : num_tokens + 1]], dtype=torch.int64, device=device ).view(-1), ) nlls.append(neg_log_likelihood) input_pos += num_tokens progress_bar.update(num_tokens) if input_pos >= args.attention_sink_eval_tokens: break ppl = torch.exp(torch.cat(nlls).mean()) print(f"Perplexity: {ppl.item()}") return ppl.item() \ No newline at end of file +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse + +from typing import Optional, Union + +import torch + +from datasets import load_dataset +from executorch.examples.models.llama.export_llama_lib import ( + get_quantizer_and_quant_params, +) + +from executorch.extension.llm.export.builder import LLMEdgeManager +from lm_eval.evaluator import simple_evaluate +from pytorch_tokenizers import get_tokenizer +from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer +from pytorch_tokenizers.tiktoken import TiktokenTokenizer as Tiktoken +from torch.nn import CrossEntropyLoss +from tqdm import tqdm + +from .evaluate.eager_eval import EagerEvalWrapper + +from .export_llama_lib import ( + _prepare_for_llama_export, + build_args_parser as _build_args_parser, +) + + +class GraphModuleEvalWrapper(EagerEvalWrapper): + """ + A wrapper class for ExecuTorch py-binded integration with the + lm-evaluation-harness library. + """ + + def __init__( + self, + model: torch.fx.GraphModule, + tokenizer: Union[SentencePieceTokenizer, Tiktoken], + max_seq_length: Optional[int] = None, + use_kv_cache: bool = False, + generate_full_logits: bool = False, + enable_dynamic_shape: bool = True, + ): + super().__init__( + model=model, tokenizer=tokenizer, max_seq_length=max_seq_length + ) + self._model = model.to(self.device) + self._use_kv_cache = use_kv_cache + self._generate_full_logits = generate_full_logits + self._enable_dynamic_shape = enable_dynamic_shape + + def _model_call(self, inps): + if self._use_kv_cache: + if not self._enable_dynamic_shape: + # graph module exported without dynamic shape won't work with a different shape. + # And we have to do single token prefill here. + result_logits = [] + for pos in range(inps.shape[-1]): + pos_tensor = torch.tensor([pos], dtype=torch.int64) + logits = self._model( + inps[:, pos : pos + 1], {"input_pos": pos_tensor} + ) + result_logits.append(logits) + if self._generate_full_logits: + return torch.cat(result_logits, dim=1) + else: + return torch.stack(result_logits, dim=1) + else: + pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) + # Batch process the whole sequence. + logits = self._model( + inps[:, : self._max_seq_length], {"input_pos": pos_tensor} + ) + return logits + + else: + return self._model(inps) + + def _model_generate(self, context, max_length, eos_token_id): + raise Exception("unimplemented") + + +class ETPybindEvalWrapper(EagerEvalWrapper): + """ + A wrapper class for ExecuTorch py-binded integration with the + lm-evaluation-harness library. + """ + + def __init__( + self, + model: str, + tokenizer: Union[SentencePieceTokenizer, Tiktoken], + max_seq_length: Optional[int] = None, + ): + super().__init__(None, tokenizer, max_seq_length) # pyre-ignore + self._model = model # Expects model to be path to a .pte file + + from executorch.extension.pybindings.portable_lib import _load_for_executorch + + # Load custom ops and quantized ops. + from executorch.extension.pybindings import portable_lib # noqa # usort: skip + + # Note: import this after portable_lib + from executorch.extension.llm.custom_ops import ( # noqa + custom_ops, # usort: skip + ) + from executorch.kernels import quantized # noqa + + self._et_model = _load_for_executorch(self._model) + self._use_kv_cache = self._et_model.run_method("use_kv_cache")[0] # pyre-ignore + + def _model_call(self, inps): + # Given inps (tokens), return the logits from a single forward call + # inps: Tensor of shape (1, max_seq_len - 1) + # logits: Tensor of shape (1, max_seq_len - 1, vocab_size) + result = [] + if self._use_kv_cache: + pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) + result = self._et_model.forward( + (inps[:, : self._max_seq_length], pos_tensor) + ) + else: + result = self._et_model.forward((inps,)) + if result[0].dim() != 3: + raise ValueError( + f"Dim of logits must be 3 for evaluation. Got {result[0].dim()} here. Add --generate_full_logits in export_llama to generate a pte file with full logits." + ) + return result[0] + + +class ETRunnerEvalWrapper(EagerEvalWrapper): + """ + A wrapper class for ExecuTorch Runtime integration with the + lm-evaluation-harness library. + """ + + def __init__( + self, + model: str, + tokenizer: Union[SentencePieceTokenizer, Tiktoken], + tokenizer_bin: str, + max_seq_length: Optional[int] = None, + ): + super().__init__(None, tokenizer, max_seq_length) # pyre-ignore + self._model = model + self._tokenizer_bin = tokenizer_bin + + def _model_call(self, inps): + # Given inps (tokens), return the logits from a single + # forward call + + # Example: + # inps: Tensor of shape (1, N) + # logits: Tensor of shape (1, N, vocab_size) + pass + + +def gen_eval_wrapper( + model_name: str, + args: argparse.ArgumentParser, + llm_config=None, +): + """ + Generates a wrapper interface around the provided model and tokenizer for + the lm-evaluation-harness library. + + Returns: + eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library. + """ + # If llm_config is not provided, convert args to llm_config + if llm_config is None: + from executorch.extension.llm.export.config.llm_config import LlmConfig + + llm_config = LlmConfig.from_args(args) + + tokenizer = get_tokenizer(llm_config.base.tokenizer_path) + + # ExecuTorch Binary Evaluation + if (model := args.pte) is not None: # pyre-ignore + if (tokenizer_bin := args.tokenizer_bin) is not None: # pyre-ignore + # ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime + return ETRunnerEvalWrapper( + model=model, + tokenizer=tokenizer, + tokenizer_bin=tokenizer_bin, + max_seq_length=llm_config.export.max_seq_length, + ) + + # ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings + return ETPybindEvalWrapper( + model=model, + tokenizer=tokenizer, + # Exported model takes at most (max_seq_length - 1) tokens. + # Note that the eager model takes at most max_seq_length tokens. + max_seq_length=llm_config.export.max_seq_length - 1, + ) + + pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( + llm_config + ) + # GPTFastEvalWrapper: Create a wrapper around a pre-exported model + manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) + + if len(quantizers) != 0: + manager = manager.export().pt2e_quantize(quantizers) + model = ( + manager.pre_autograd_graph_module.to(device="cuda") # pyre-ignore + if torch.cuda.is_available() + else manager.pre_autograd_graph_module.to(device="cpu") + ) + return GraphModuleEvalWrapper( + model=model, + tokenizer=tokenizer, + max_seq_length=llm_config.export.max_seq_length, + use_kv_cache=llm_config.model.use_kv_cache, + enable_dynamic_shape=llm_config.model.enable_dynamic_shape, + ) + else: + # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch + # for quantizers. Currently export only works with --kv_cache, but + # fails without the kv_cache mode + model = ( + manager.model.eval().to(device="cuda") + if torch.cuda.is_available() + else manager.model.eval().to(device="cpu") + ) + + # Save the checkpoint after the eager model preparation is done. + # The reason for this option is that the checkpoint can be used + # to do evaluations in other evaluation platforms, or with data + # that is not available in this eval_llama. We save the checkpoint + # here for consistency with eval_llama. The accuracy results we + # get from eval_llama can be used as a reference to other evaluations. + if args.output_eager_checkpoint_file is not None: # pyre-ignore + torch.save(model, args.output_eager_checkpoint_file) + + return EagerEvalWrapper( + model=model, + tokenizer=tokenizer, + max_seq_length=llm_config.export.max_seq_length, + use_kv_cache=llm_config.model.use_kv_cache, + ) + + +def build_args_parser() -> argparse.ArgumentParser: + # Start with arg parser from export_llama_lib + parser = _build_args_parser() + + # Add additional args specific to eval + parser.add_argument( + "--tasks", + nargs="+", + type=str, + default=["wikitext"], + help="list of lm-eluther tasks to evaluate usage: --tasks task1 task2", + ) + parser.add_argument( + "--limit", + type=int, + default=None, + help="number of samples to evalulate. If not set, evaluate all samples", + ) + parser.add_argument( + "-f", + "--num_fewshot", + type=int, + default=None, + metavar="N", + help="Number of examples in few-shot context", + ) + # Add additional args specific to eval via an ET Runner + # Note: For initial integration, the tokenizer.model is also required + parser.add_argument( + "--pte", + type=str, + default=None, + help="[For ExecuTorch] Path to the ExecuTorch model being evaluated. If provided, don't go through the export flow", + ) + parser.add_argument( + "--tokenizer_bin", + type=str, + default=None, + help="[For ExecuTorch] Path to the Tokenizer binary for evaluating ExecuTorch models via runtime", + ) + parser.add_argument( + "--output_eager_checkpoint_file", + type=str, + default=None, + help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.", + ) + + # Set of parameters secpific to AttentionSink. + parser.add_argument("--attention_sink_eval_tokens", type=int, default=0) + + return parser + + +def eval_llama( + model_name: str, + args: argparse.ArgumentParser, +) -> None: + # Convert args to LlmConfig + from executorch.extension.llm.export.config.llm_config import LlmConfig + + llm_config = LlmConfig.from_args(args) + + # Generate the eval wrapper + eval_wrapper = gen_eval_wrapper(model_name, args, llm_config) + + # Needed for loading mmlu dataset. + # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files + if args.tasks and "mmlu" in args.tasks: + import datasets + + datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True + + # Evaluate the model + with torch.no_grad(): + eval_results = simple_evaluate( + model=eval_wrapper, + tasks=args.tasks, + num_fewshot=args.num_fewshot, + limit=args.limit, + ) + + for task, res in eval_results["results"].items(): + print(f"{task}: {res}") + + +def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParser): + """ + Evaluate the model's perplexity when AttentionSink is enabled. + + This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py + + Updated for the ring-buffer based attention sink implementation. + """ + # Convert args to LlmConfig + from executorch.extension.llm.export.config.llm_config import LlmConfig + + llm_config = LlmConfig.from_args(args) + + assert llm_config.model.use_attention_sink is not None + assert args.attention_sink_eval_tokens > 0 + attention_sink_params = llm_config.model.use_attention_sink.split(",") + assert len(attention_sink_params) == 3 + sink_size = int(attention_sink_params[0]) + window_size = int(attention_sink_params[1]) + + # For the ring buffer implementation, the cache size is sink_size + window_size * 2 + # max_context_length should be >= sink_size + window_size (for RoPE frequencies) + # but can be larger to support extended generation + assert llm_config.export.max_context_length >= sink_size + window_size, ( + f"max_context_length ({llm_config.export.max_context_length}) must be >= " + f"sink_size + window_size ({sink_size + window_size})" + ) + + device = "cuda" if torch.cuda.is_available() else "cpu" + manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) + model = manager.model.eval().to(device=device) + tokenizer = get_tokenizer(llm_config.base.tokenizer_path) + + eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + + nlls = [] + loss_fn = CrossEntropyLoss(reduction="none") + progress_bar = tqdm(total=args.attention_sink_eval_tokens) + input_pos = 0 + while input_pos < args.attention_sink_eval_tokens: + for text in eval_data["text"]: + tokens = tokenizer.encode(text, bos=False, eos=False) + if len(tokens) <= 0: + continue + with torch.no_grad(): + num_tokens = min( + len(tokens) - 1, args.attention_sink_eval_tokens - input_pos + ) + logits = model( + torch.tensor( + [tokens[:num_tokens]], dtype=torch.int64, device=device + ), + torch.tensor([input_pos], dtype=torch.int64, device=device), + ).squeeze(dim=0) + neg_log_likelihood = loss_fn( + logits, + torch.tensor( + [tokens[1 : num_tokens + 1]], dtype=torch.int64, device=device + ).view(-1), + ) + nlls.append(neg_log_likelihood) + input_pos += num_tokens + progress_bar.update(num_tokens) + if input_pos >= args.attention_sink_eval_tokens: + break + ppl = torch.exp(torch.cat(nlls).mean()) + print(f"Perplexity: {ppl.item()}") + return ppl.item() diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index 99b0231f2c3..c2e0f6606f5 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -1 +1,410 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.# Components for supporting Attention Sink. See# https://arxiv.org/abs/2309.17453 for more details about Attention Sink.# This implementation is torch.export compatible using a ring buffer approach# for the sliding window portion while preserving the sink tokens.import typesfrom typing import Optional, Tupleimport torchimport torch.nn as nnfrom executorch.examples.models.llama.attention import ( _create_causal_mask_for_ring_buffer, AttentionMHA, CachePositionsManager, KVCache, RingKVCache,)from executorch.examples.models.llama.model_args import ModelArgsfrom executorch.examples.models.llama.rope import ( apply_rotary_emb_to_k, hf_apply_rotary_emb_to_k, Rope,)from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filterclass RopeWithAttentionSink(Rope): """ Rope that helps adjust position encoding when tokens are shifted in KVCache. For AttentionSink, when tokens are shifted in KVCache, we need to use positions in KVCache instead of positions in the actual text. For torch.export compatibility, this just passes through the position - the actual position adjustment is handled by the cache update logic. Note: This class uses the model's max_context_len (params.max_context_len) for RoPE frequency table size, which should be large enough to support generation beyond the sliding window. The actual KV cache size is sink_size + window_size * 2. """ def __init__( self, params: ModelArgs, window_size: int, sink_size: int, eviction_batch_size: int, ): super().__init__(params) if self.params.use_hf_rope: self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k else: self.apply_rotary_emb_to_k = apply_rotary_emb_to_k # The KV cache size is sink_size + window_size * 2 (ring buffer needs 2x) self.kv_cache_size = sink_size + window_size * 2 self.window_size = window_size self.sink_size = sink_size # max_context_len from params is used for RoPE frequencies (should be large) self.max_context_length = self.params.max_context_len self.eviction_batch_size = eviction_batch_size def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): """ Get rotary embedding frequencies. For attention sink, we use the original position - the sliding window is handled by the cache index management, not by position shifting. """ assert input_pos is not None return super().get_freqs(input_pos, seq_len) def rerotate_k( self, k: torch.Tensor, original_position: int, new_position: int, ): """ Rerotate k from original_position to new_position. The shape of k is (batch_size, seq_len, n_local_heads, head_dim) """ seq_len = k.shape[1] original_freqs_cos = self.freqs_cos.narrow(0, original_position, seq_len) original_freqs_sin = self.freqs_sin.narrow(0, original_position, seq_len) new_freqs_cos = self.freqs_cos.narrow(0, new_position, seq_len) new_freqs_sin = self.freqs_sin.narrow(0, new_position, seq_len) rerotation_cos = ( new_freqs_cos * original_freqs_cos + new_freqs_sin * original_freqs_sin ) rerotation_sin = ( new_freqs_sin * original_freqs_cos - new_freqs_cos * original_freqs_sin ) return self.apply_rotary_emb_to_k(k, rerotation_cos, rerotation_sin)def _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len): """ Create causal mask for attention sink. Unlike regular ring buffer mask, this mask: 1. ALWAYS allows attending to sink tokens (positions 0 to sink_size-1) 2. Uses sliding window for other tokens Args: cache_positions: Tensor of actual positions stored at each cache index window_size: Size of the sliding window sink_size: Number of sink tokens to always attend to start_pos: Starting position of the current query seq_len: Length of the current query sequence """ pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1) delta = pos_q - cache_positions # Valid if position is filled (>= 0) and causal (delta >= 0) is_valid = (cache_positions >= 0) & (delta >= 0) # Sink tokens (original positions 0 to sink_size-1) are always visible is_sink = cache_positions < sink_size # Window tokens must be within sliding window is_in_window = delta < window_size # Final mask: valid AND (is_sink OR is_in_window) attn_mask = is_valid & (is_sink | is_in_window) attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712 return attn_maskclass CachePositionsManagerWithSink(nn.Module): """ Manages cache positions for attention sink + sliding window. For sink_size=0: behaves exactly like original CachePositionsManager. For sink_size>0: sink tokens go to fixed positions, rest uses ring buffer. IMPORTANT: cache_size should be the actual cache dimension size (2x window for ring buffer). """ def __init__(self, cache_size: int): super().__init__() # cache_size is the actual size of the kv cache dimension self.max_context_length = cache_size # Use zeros like original CachePositionsManager self.register_buffer( "cache_positions", torch.zeros((self.max_context_length,), dtype=torch.long, device="cpu"), ) def calculate_positions_and_update_indices( self, input_pos: torch.Tensor, seq_len: int ) -> torch.Tensor: """ Calculate indices into k_cache, v_cache for placing k_val, v_val. This is identical to the original CachePositionsManager logic. """ start_pos = input_pos[0].item() torch._check_is_size(start_pos) orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos # Simple ring buffer: just mod by cache size indices = orig_indices % self.max_context_length # Update cache_positions exactly like original CachePositionsManager full_t = torch.full((self.max_context_length,), -1, dtype=torch.long) arange_tensor = torch.arange(self.max_context_length, dtype=torch.long) cache_positions = torch.where( arange_tensor < start_pos, self.cache_positions, full_t ) self.cache_positions.copy_(cache_positions) self.cache_positions.index_copy_(0, indices, orig_indices) return indicesclass KVCacheWithAttentionSink(KVCache): """ KV cache that supports attention sink with torch.export compatibility. Uses a ring buffer approach for the sliding window portion while keeping the first sink_size tokens fixed. This avoids dynamic shape operations. Cache layout: [sink: 0 to sink_size-1] [ring_buffer: sink_size to sink_size + window_size*2 - 1] """ def __init__( self, n_heads: int, head_dim: int, enable_dynamic_shape: bool, rope: RopeWithAttentionSink, window_size: int, sink_size: int, eviction_batch_size: int, max_batch_size: int = 1, dtype=torch.float32, ): # Total cache size is sink_size + window_size * 2 (ring buffer needs 2x) total_cache_size = sink_size + window_size * 2 super().__init__( max_batch_size=max_batch_size, max_context_length=total_cache_size, n_heads=n_heads, head_dim=head_dim, enable_dynamic_shape=enable_dynamic_shape, dtype=dtype, ) self.rope = rope self.window_size = window_size self.sink_size = sink_size self.eviction_batch_size = eviction_batch_size self.is_ring_buffer = True # Cache positions manager for determining write locations # Pass the total cache size (same as self.max_context_length after super().__init__) self.cache_positions_manager = CachePositionsManagerWithSink(total_cache_size) def create_causal_mask_for_ring_buffer( self, start_pos: torch.Tensor, seq_len: int ): """ Create causal mask for the attention with attention sink. Sink tokens are ALWAYS visible, plus recent tokens in the window. """ cache_positions = self.cache_positions_manager.cache_positions if self.sink_size > 0: # Use attention sink mask that always allows attending to sink tokens return _create_causal_mask_for_attention_sink( cache_positions, self.window_size, self.sink_size, start_pos, seq_len ) else: # Pure ring buffer mode - use original mask with window_size = actual window return _create_causal_mask_for_ring_buffer( cache_positions, self.window_size, start_pos, seq_len ) def update( self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Update KV cache with new key-value pairs. Uses ring buffer indexing for positions >= sink_size. """ seq_len = k_val.size(2) assert seq_len <= self.k_cache.size( 2 ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" # Calculate write indices indices = self.cache_positions_manager.calculate_positions_and_update_indices( input_pos, seq_len ) start_pos = input_pos[0].item() torch._check_is_size(start_pos) self.k_cache.index_copy_(2, indices, k_val) self.v_cache.index_copy_(2, indices, v_val) return self.k_cache, self.v_cache def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: """ For ring buffer implementation, no explicit eviction is needed. The ring buffer automatically overwrites old values. Returns 0 to indicate no position shift is needed. """ return 0def attention_sink_forward( self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, input_pos: Optional[torch.Tensor] = None,): """ Forward function for attention with attention sink KV cache. Uses ring buffer masking for proper attention patterns. """ assert self.use_kv_cache assert input_pos is not None bsz, seqlen, _ = x.shape # QKV q, k, v = self.wq(x), self.wk(x), self.wv(x) q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) # RoPE relative positional embeddings q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) # Transpose for attention: [B, H, S, D] q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) # Update KV cache k, v = self.kv_cache.update(input_pos, k, v) # Use ring buffer mask since we have is_ring_buffer=True start_pos = input_pos[0].item() torch._check_is_size(start_pos) attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(start_pos, seqlen) # SDPA output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) # Return tuple like original AttentionMHA.forward return self.wo(output), Nonedef _replace_rope( module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink): def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: return isinstance(child, Rope) def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: return rope_with_attention_sink _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)def _replace_attention( module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink, sink_size: int, window_size: int, eviction_batch_size: int,): for _, child_module in module._modules.items(): if len(list(child_module.children())) > 0: # pyre-ignore [16] _replace_attention( module=child_module, # pyre-ignore [6] rope_with_attention_sink=rope_with_attention_sink, sink_size=sink_size, window_size=window_size, eviction_batch_size=eviction_batch_size, ) if isinstance(child_module, AttentionMHA): kv_cache = child_module.kv_cache # Always use KVCacheWithAttentionSink, even for sink_size=0 # This ensures we don't get replaced by CustomKVCache when use_sdpa_with_kv_cache=True kv_cache_with_attention_sink = KVCacheWithAttentionSink( n_heads=kv_cache.n_heads, head_dim=kv_cache.head_dim, enable_dynamic_shape=kv_cache.enable_dynamic_shape, rope=rope_with_attention_sink, max_batch_size=kv_cache.max_batch_size, window_size=window_size, sink_size=sink_size, eviction_batch_size=eviction_batch_size, dtype=kv_cache.k_cache.dtype, ) child_module.kv_cache = kv_cache_with_attention_sink # If using SDPACustom (fused SDPA op), enable attention mask support # so it uses our ring buffer / attention sink mask instead of simple causal mask if "SDPACustom" in child_module.SDPA.__class__.__name__: child_module.SDPA.use_attention_mask = True # Don't replace forward - let the original AttentionMHA.forward handle it # since our KVCache has is_ring_buffer=True, it will use the ring buffer maskdef enable_attention_sink( module: torch.nn.Module, params: ModelArgs, sink_size: int, window_size: int, eviction_batch_size: int,) -> torch.nn.Module: """ Transform the model to be able to run inference with Attention Sink. There mainly three steps: - Replace Rope with RopeWithAttentionSink - Replace Attention's KVCache with KVCacheWithAttentionSink - Replace Attention's forward with attention_sink_forward """ rope_with_attention_sink = RopeWithAttentionSink( params=params, window_size=window_size, sink_size=sink_size, eviction_batch_size=eviction_batch_size, ) _replace_rope(module, rope_with_attention_sink) _replace_attention( module=module, rope_with_attention_sink=rope_with_attention_sink, sink_size=sink_size, window_size=window_size, eviction_batch_size=eviction_batch_size, ) return module \ No newline at end of file +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Components for supporting Attention Sink. See +# https://arxiv.org/abs/2309.17453 for more details about Attention Sink. + +# This implementation is torch.export compatible using a ring buffer approach +# for the sliding window portion while preserving the sink tokens. + +import types +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from executorch.examples.models.llama.attention import ( + _create_causal_mask_for_ring_buffer, + AttentionMHA, + CachePositionsManager, + KVCache, + RingKVCache, +) +from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.rope import ( + apply_rotary_emb_to_k, + hf_apply_rotary_emb_to_k, + Rope, +) +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter + + +class RopeWithAttentionSink(Rope): + """ + Rope that helps adjust position encoding when tokens are shifted in KVCache. + For AttentionSink, when tokens are shifted in KVCache, we need to use positions + in KVCache instead of positions in the actual text. + + For torch.export compatibility, this just passes through the position - the + actual position adjustment is handled by the cache update logic. + + Note: This class uses the model's max_context_len (params.max_context_len) for + RoPE frequency table size, which should be large enough to support generation + beyond the sliding window. The actual KV cache size is sink_size + window_size * 2. + """ + + def __init__( + self, + params: ModelArgs, + window_size: int, + sink_size: int, + eviction_batch_size: int, + ): + super().__init__(params) + if self.params.use_hf_rope: + self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k + else: + self.apply_rotary_emb_to_k = apply_rotary_emb_to_k + # The KV cache size is sink_size + window_size * 2 (ring buffer needs 2x) + self.kv_cache_size = sink_size + window_size * 2 + self.window_size = window_size + self.sink_size = sink_size + # max_context_len from params is used for RoPE frequencies (should be large) + self.max_context_length = self.params.max_context_len + self.eviction_batch_size = eviction_batch_size + + def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): + """ + Get rotary embedding frequencies. + For attention sink, we use the original position - the sliding window + is handled by the cache index management, not by position shifting. + """ + assert input_pos is not None + return super().get_freqs(input_pos, seq_len) + + def rerotate_k( + self, + k: torch.Tensor, + original_position: int, + new_position: int, + ): + """ + Rerotate k from original_position to new_position. + The shape of k is (batch_size, seq_len, n_local_heads, head_dim) + """ + seq_len = k.shape[1] + original_freqs_cos = self.freqs_cos.narrow(0, original_position, seq_len) + original_freqs_sin = self.freqs_sin.narrow(0, original_position, seq_len) + new_freqs_cos = self.freqs_cos.narrow(0, new_position, seq_len) + new_freqs_sin = self.freqs_sin.narrow(0, new_position, seq_len) + rerotation_cos = ( + new_freqs_cos * original_freqs_cos + new_freqs_sin * original_freqs_sin + ) + rerotation_sin = ( + new_freqs_sin * original_freqs_cos - new_freqs_cos * original_freqs_sin + ) + + return self.apply_rotary_emb_to_k(k, rerotation_cos, rerotation_sin) + + +def _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len +): + """ + Create causal mask for attention sink. + + Unlike regular ring buffer mask, this mask: + 1. ALWAYS allows attending to sink tokens (positions 0 to sink_size-1) + 2. Uses sliding window for other tokens + + Args: + cache_positions: Tensor of actual positions stored at each cache index + window_size: Size of the sliding window + sink_size: Number of sink tokens to always attend to + start_pos: Starting position of the current query + seq_len: Length of the current query sequence + """ + pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1) + delta = pos_q - cache_positions + + # Valid if position is filled (>= 0) and causal (delta >= 0) + is_valid = (cache_positions >= 0) & (delta >= 0) + + # Sink tokens (original positions 0 to sink_size-1) are always visible + is_sink = cache_positions < sink_size + + # Window tokens must be within sliding window + is_in_window = delta < window_size + + # Final mask: valid AND (is_sink OR is_in_window) + attn_mask = is_valid & (is_sink | is_in_window) + attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712 + return attn_mask + + +class CachePositionsManagerWithSink(nn.Module): + """ + Manages cache positions for attention sink + sliding window. + + For sink_size=0: behaves exactly like original CachePositionsManager. + For sink_size>0: sink tokens go to fixed positions, rest uses ring buffer. + + IMPORTANT: cache_size should be the actual cache dimension size (2x window for ring buffer). + """ + + def __init__(self, cache_size: int): + super().__init__() + # cache_size is the actual size of the kv cache dimension + self.max_context_length = cache_size + # Use zeros like original CachePositionsManager + self.register_buffer( + "cache_positions", + torch.zeros((self.max_context_length,), dtype=torch.long, device="cpu"), + ) + + def calculate_positions_and_update_indices( + self, input_pos: torch.Tensor, seq_len: int + ) -> torch.Tensor: + """ + Calculate indices into k_cache, v_cache for placing k_val, v_val. + + This is identical to the original CachePositionsManager logic. + """ + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + + orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos + + # Simple ring buffer: just mod by cache size + indices = orig_indices % self.max_context_length + + # Update cache_positions exactly like original CachePositionsManager + full_t = torch.full((self.max_context_length,), -1, dtype=torch.long) + arange_tensor = torch.arange(self.max_context_length, dtype=torch.long) + cache_positions = torch.where( + arange_tensor < start_pos, self.cache_positions, full_t + ) + self.cache_positions.copy_(cache_positions) + self.cache_positions.index_copy_(0, indices, orig_indices) + + return indices + + +class KVCacheWithAttentionSink(KVCache): + """ + KV cache that supports attention sink with torch.export compatibility. + + Uses a ring buffer approach for the sliding window portion while keeping + the first sink_size tokens fixed. This avoids dynamic shape operations. + + Cache layout: [sink: 0 to sink_size-1] [ring_buffer: sink_size to sink_size + window_size*2 - 1] + """ + + def __init__( + self, + n_heads: int, + head_dim: int, + enable_dynamic_shape: bool, + rope: RopeWithAttentionSink, + window_size: int, + sink_size: int, + eviction_batch_size: int, + max_batch_size: int = 1, + dtype=torch.float32, + ): + # Total cache size is sink_size + window_size * 2 (ring buffer needs 2x) + total_cache_size = sink_size + window_size * 2 + super().__init__( + max_batch_size=max_batch_size, + max_context_length=total_cache_size, + n_heads=n_heads, + head_dim=head_dim, + enable_dynamic_shape=enable_dynamic_shape, + dtype=dtype, + ) + self.rope = rope + self.window_size = window_size + self.sink_size = sink_size + self.eviction_batch_size = eviction_batch_size + self.is_ring_buffer = True + + # Cache positions manager for determining write locations + # Pass the total cache size (same as self.max_context_length after super().__init__) + self.cache_positions_manager = CachePositionsManagerWithSink(total_cache_size) + + def create_causal_mask_for_ring_buffer( + self, start_pos: torch.Tensor, seq_len: int + ): + """ + Create causal mask for the attention with attention sink. + Sink tokens are ALWAYS visible, plus recent tokens in the window. + """ + cache_positions = self.cache_positions_manager.cache_positions + if self.sink_size > 0: + # Use attention sink mask that always allows attending to sink tokens + return _create_causal_mask_for_attention_sink( + cache_positions, self.window_size, self.sink_size, start_pos, seq_len + ) + else: + # Pure ring buffer mode - use original mask with window_size = actual window + return _create_causal_mask_for_ring_buffer( + cache_positions, self.window_size, start_pos, seq_len + ) + + def update( + self, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Update KV cache with new key-value pairs. + Uses ring buffer indexing for positions >= sink_size. + """ + seq_len = k_val.size(2) + assert seq_len <= self.k_cache.size( + 2 + ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" + + # Calculate write indices + indices = self.cache_positions_manager.calculate_positions_and_update_indices( + input_pos, seq_len + ) + + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + self.k_cache.index_copy_(2, indices, k_val) + self.v_cache.index_copy_(2, indices, v_val) + + return self.k_cache, self.v_cache + + def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: + """ + For ring buffer implementation, no explicit eviction is needed. + The ring buffer automatically overwrites old values. + Returns 0 to indicate no position shift is needed. + """ + return 0 + + +def attention_sink_forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, +): + """ + Forward function for attention with attention sink KV cache. + Uses ring buffer masking for proper attention patterns. + """ + assert self.use_kv_cache + assert input_pos is not None + + bsz, seqlen, _ = x.shape + + # QKV + q, k, v = self.wq(x), self.wk(x), self.wv(x) + q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + # RoPE relative positional embeddings + q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) + + # Transpose for attention: [B, H, S, D] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Update KV cache + k, v = self.kv_cache.update(input_pos, k, v) + + # Use ring buffer mask since we have is_ring_buffer=True + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(start_pos, seqlen) + + # SDPA + output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) + + # Return tuple like original AttentionMHA.forward + return self.wo(output), None + + +def _replace_rope( + module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink +): + def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: + return isinstance(child, Rope) + + def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: + return rope_with_attention_sink + + _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) + + +def _replace_attention( + module: torch.nn.Module, + rope_with_attention_sink: RopeWithAttentionSink, + sink_size: int, + window_size: int, + eviction_batch_size: int, +): + for _, child_module in module._modules.items(): + if len(list(child_module.children())) > 0: # pyre-ignore [16] + _replace_attention( + module=child_module, # pyre-ignore [6] + rope_with_attention_sink=rope_with_attention_sink, + sink_size=sink_size, + window_size=window_size, + eviction_batch_size=eviction_batch_size, + ) + + if isinstance(child_module, AttentionMHA): + kv_cache = child_module.kv_cache + # Always use KVCacheWithAttentionSink, even for sink_size=0 + # This ensures we don't get replaced by CustomKVCache when use_sdpa_with_kv_cache=True + kv_cache_with_attention_sink = KVCacheWithAttentionSink( + n_heads=kv_cache.n_heads, + head_dim=kv_cache.head_dim, + enable_dynamic_shape=kv_cache.enable_dynamic_shape, + rope=rope_with_attention_sink, + max_batch_size=kv_cache.max_batch_size, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=eviction_batch_size, + dtype=kv_cache.k_cache.dtype, + ) + child_module.kv_cache = kv_cache_with_attention_sink + + # If using SDPACustom (fused SDPA op), enable attention mask support + # so it uses our ring buffer / attention sink mask instead of simple causal mask + if "SDPACustom" in child_module.SDPA.__class__.__name__: + child_module.SDPA.use_attention_mask = True + + # Don't replace forward - let the original AttentionMHA.forward handle it + # since our KVCache has is_ring_buffer=True, it will use the ring buffer mask + + +def enable_attention_sink( + module: torch.nn.Module, + params: ModelArgs, + sink_size: int, + window_size: int, + eviction_batch_size: int, +) -> torch.nn.Module: + """ + Transform the model to be able to run inference with Attention Sink. + There mainly three steps: + - Replace Rope with RopeWithAttentionSink + - Replace Attention's KVCache with KVCacheWithAttentionSink + - Replace Attention's forward with attention_sink_forward + """ + rope_with_attention_sink = RopeWithAttentionSink( + params=params, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=eviction_batch_size, + ) + _replace_rope(module, rope_with_attention_sink) + _replace_attention( + module=module, + rope_with_attention_sink=rope_with_attention_sink, + sink_size=sink_size, + window_size=window_size, + eviction_batch_size=eviction_batch_size, + ) + return module diff --git a/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py index 7ccad6aadbf..7109129caca 100644 --- a/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py +++ b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py @@ -1 +1,594 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree."""Tests for the ring-buffer based attention sink implementation.This tests the torch.export-compatible implementation that uses a ring bufferfor the sliding window rather than explicit token eviction.Usage: # Run with pytest python -m pytest examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py -v # Or run directly python examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py"""import unittestimport torchfrom executorch.examples.models.llama.model_args import ModelArgsfrom executorch.examples.models.llama.source_transformation.attention_sink import ( CachePositionsManagerWithSink, KVCacheWithAttentionSink, RopeWithAttentionSink, _create_causal_mask_for_attention_sink,)class CachePositionsManagerWithSinkTest(unittest.TestCase): """Test the cache positions manager for ring buffer indexing.""" def setUp(self): self.cache_size = 32 # Total cache size (e.g., sink_size + window_size * 2) self.manager = CachePositionsManagerWithSink(self.cache_size) def test_initial_positions_are_zero(self): """Cache positions should start as zeros.""" expected = torch.zeros(self.cache_size, dtype=torch.long) torch.testing.assert_close(self.manager.cache_positions, expected) def test_simple_update(self): """Test simple sequential update without wraparound.""" input_pos = torch.tensor([0], dtype=torch.long) seq_len = 5 indices = self.manager.calculate_positions_and_update_indices(input_pos, seq_len) # Should return indices 0, 1, 2, 3, 4 expected_indices = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long) torch.testing.assert_close(indices, expected_indices) # Cache positions at those indices should be the original positions for i in range(5): self.assertEqual(self.manager.cache_positions[i].item(), i) def test_wraparound(self): """Test ring buffer wraparound.""" # Fill cache to position 30 input_pos = torch.tensor([0], dtype=torch.long) self.manager.calculate_positions_and_update_indices(input_pos, 30) # Add 5 more tokens at position 30 - should wrap around input_pos = torch.tensor([30], dtype=torch.long) indices = self.manager.calculate_positions_and_update_indices(input_pos, 5) # Indices should wrap: 30 % 32 = 30, 31, 0, 1, 2 expected_indices = torch.tensor([30, 31, 0, 1, 2], dtype=torch.long) torch.testing.assert_close(indices, expected_indices) def test_cache_positions_track_original_positions(self): """Cache positions should track which original position is at each index.""" # Fill with positions 0-31 input_pos = torch.tensor([0], dtype=torch.long) self.manager.calculate_positions_and_update_indices(input_pos, 32) # Now add position 32 which wraps to index 0 input_pos = torch.tensor([32], dtype=torch.long) self.manager.calculate_positions_and_update_indices(input_pos, 1) # Index 0 should now contain original position 32 self.assertEqual(self.manager.cache_positions[0].item(), 32)class CausalMaskTest(unittest.TestCase): """Test the causal mask generation for attention sink.""" def test_mask_allows_sink_tokens(self): """Sink tokens should always be visible (mask = 0).""" cache_size = 32 sink_size = 4 window_size = 14 # cache_size = sink_size + window_size * 2 # Create cache positions where positions 0-3 are sink tokens cache_positions = torch.arange(cache_size, dtype=torch.long) start_pos = 20 # Query at position 20 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Sink tokens (indices 0-3, original positions 0-3) should have mask = 0 for i in range(sink_size): self.assertEqual(mask[0, i].item(), 0.0, f"Sink token at index {i} should be visible") def test_mask_blocks_future_tokens(self): """Future tokens should be masked (-inf).""" cache_size = 32 sink_size = 4 window_size = 14 cache_positions = torch.arange(cache_size, dtype=torch.long) start_pos = 10 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Future tokens (positions > 10) should have mask = -inf for i in range(11, cache_size): self.assertEqual(mask[0, i].item(), float('-inf'), f"Future token at position {i} should be masked") def test_mask_respects_window(self): """Tokens outside the window should be masked.""" cache_size = 32 sink_size = 4 window_size = 5 # Only allow 5 recent tokens cache_positions = torch.arange(cache_size, dtype=torch.long) start_pos = 20 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Positions 16-20 should be visible (within window of 5) for pos in range(16, 21): self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible (in window)") # Position 15 should be masked (outside window, not a sink) self.assertEqual(mask[0, 15].item(), float('-inf'), f"Position 15 should be masked (outside window)")class KVCacheWithAttentionSinkTest(unittest.TestCase): """Test the KV cache with attention sink.""" def setUp(self): torch.manual_seed(42) self.window_size = 14 self.sink_size = 4 self.n_heads = 8 self.head_dim = 64 self.max_batch_size = 1 # Create model args with enough context for RoPE self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, # Large enough for RoPE n_heads=self.n_heads, n_kv_heads=self.n_heads, dim=self.n_heads * self.head_dim, ) self.rope = RopeWithAttentionSink( params=self.params, window_size=self.window_size, sink_size=self.sink_size, eviction_batch_size=1, ) self.kv_cache = KVCacheWithAttentionSink( n_heads=self.n_heads, head_dim=self.head_dim, enable_dynamic_shape=True, rope=self.rope, max_batch_size=self.max_batch_size, window_size=self.window_size, sink_size=self.sink_size, eviction_batch_size=1, dtype=torch.float32, ) def test_cache_size(self): """Cache should be sink_size + window_size * 2.""" expected_size = self.sink_size + self.window_size * 2 # 4 + 28 = 32 self.assertEqual(self.kv_cache.k_cache.size(2), expected_size) self.assertEqual(self.kv_cache.v_cache.size(2), expected_size) def test_is_ring_buffer(self): """Cache should be marked as ring buffer.""" self.assertTrue(self.kv_cache.is_ring_buffer) def test_update_stores_kv(self): """Update should store key-value pairs.""" k = torch.randn(1, self.n_heads, 5, self.head_dim) v = torch.randn(1, self.n_heads, 5, self.head_dim) input_pos = torch.tensor([0], dtype=torch.long) k_out, v_out = self.kv_cache.update(input_pos, k, v) # First 5 positions should contain our values torch.testing.assert_close(k_out[:, :, :5, :], k) torch.testing.assert_close(v_out[:, :, :5, :], v) def test_evict_tokens_returns_zero(self): """Ring buffer implementation doesn't shift, so evict returns 0.""" input_pos = torch.tensor([100], dtype=torch.long) shift = self.kv_cache.evict_tokens(input_pos, 10) self.assertEqual(shift, 0) def test_extended_generation(self): """Test that cache works for positions beyond cache size.""" cache_size = self.kv_cache.k_cache.size(2) # Fill cache with initial tokens for pos in range(cache_size + 50): k = torch.randn(1, self.n_heads, 1, self.head_dim) v = torch.randn(1, self.n_heads, 1, self.head_dim) input_pos = torch.tensor([pos], dtype=torch.long) k_out, v_out = self.kv_cache.update(input_pos, k, v) # Should not raise any errors self.assertEqual(k_out.shape, self.kv_cache.k_cache.shape) self.assertEqual(v_out.shape, self.kv_cache.v_cache.shape)class RopeWithAttentionSinkTest(unittest.TestCase): """Test RoPE for attention sink.""" def setUp(self): torch.manual_seed(42) self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, n_heads=8, dim=512, ) self.rope = RopeWithAttentionSink( params=self.params, window_size=100, sink_size=4, eviction_batch_size=1, ) def test_get_freqs_uses_original_position(self): """RoPE frequencies should use the original position.""" input_pos = torch.tensor([50], dtype=torch.long) seq_len = 5 freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seq_len) # Should get frequencies for positions 50-54 expected_cos = self.rope.freqs_cos[50:55] expected_sin = self.rope.freqs_sin[50:55] torch.testing.assert_close(freqs_cos, expected_cos) torch.testing.assert_close(freqs_sin, expected_sin) def test_rerotate_k(self): """Test re-rotation of k from one position to another.""" batch_size = 1 seq_len = 8 n_heads = self.params.n_heads head_dim = self.params.dim // n_heads k = torch.randn(batch_size, seq_len, n_heads, head_dim) q = torch.randn(batch_size, seq_len, n_heads, head_dim) # Rotate k at position 100 original_pos = 100 freqs_cos, freqs_sin = self.rope.get_freqs( torch.tensor([original_pos], dtype=torch.long), seq_len ) _, rotated_k = self.rope.forward(q, k, freqs_cos, freqs_sin) # Re-rotate to position 50 new_pos = 50 rerotated_k = self.rope.rerotate_k(rotated_k, original_pos, new_pos) # This should be equivalent to directly rotating k at position 50 freqs_cos_new, freqs_sin_new = self.rope.get_freqs( torch.tensor([new_pos], dtype=torch.long), seq_len ) _, expected_k = self.rope.forward(q, k, freqs_cos_new, freqs_sin_new) torch.testing.assert_close(rerotated_k, expected_k, rtol=1e-4, atol=1e-4)class CausalMaskWithWraparoundTest(unittest.TestCase): """Test causal mask with ring buffer wraparound.""" def test_mask_after_wraparound(self): """Test mask after cache has wrapped around.""" cache_size = 16 sink_size = 4 window_size = 6 # cache_size = sink_size + window_size * 2 # Simulate cache after generating beyond cache_size: # The ring buffer wraps, so indices 0-15 contain positions that wrap # At position 50, with cache_size=16, the cache contains: # positions 50-15=35 to 49 at various indices cache_positions = torch.zeros(cache_size, dtype=torch.long) # Fill with positions that would exist after generating 50 tokens # idx = pos % cache_size, so: # pos 34-49 occupy indices 2-15 and 0-1 for pos in range(34, 50): idx = pos % cache_size cache_positions[idx] = pos start_pos = 49 # Query at position 49 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Positions within window (49-6+1=44 to 49) should be visible visible_count = 0 for i in range(cache_size): pos = cache_positions[i].item() if pos >= 44 and pos <= 49: # In window self.assertEqual(mask[0, i].item(), 0.0, f"Position {pos} at idx {i} should be visible (in window)") visible_count += 1 # Should have some visible tokens self.assertGreater(visible_count, 0, "Should have visible tokens in window") def test_mask_with_sink_size_zero(self): """Test pure sliding window (sink_size=0).""" cache_size = 16 sink_size = 0 window_size = 8 cache_positions = torch.arange(cache_size, dtype=torch.long) start_pos = 10 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Positions 3-10 should be visible (within window of 8) for pos in range(3, 11): self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible") # Positions 0-2 should be masked (outside window) for pos in range(0, 3): self.assertEqual(mask[0, pos].item(), float('-inf'), f"Position {pos} should be masked (outside window)")class PrefillTest(unittest.TestCase): """Test prefill scenarios.""" def setUp(self): torch.manual_seed(42) self.window_size = 14 self.sink_size = 4 self.n_heads = 8 self.head_dim = 64 self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, n_heads=self.n_heads, n_kv_heads=self.n_heads, dim=self.n_heads * self.head_dim, ) self.rope = RopeWithAttentionSink( params=self.params, window_size=self.window_size, sink_size=self.sink_size, eviction_batch_size=1, ) self.kv_cache = KVCacheWithAttentionSink( n_heads=self.n_heads, head_dim=self.head_dim, enable_dynamic_shape=True, rope=self.rope, max_batch_size=1, window_size=self.window_size, sink_size=self.sink_size, eviction_batch_size=1, dtype=torch.float32, ) def test_prefill_entire_cache(self): """Test prefill that fills entire cache.""" cache_size = self.kv_cache.k_cache.size(2) k = torch.randn(1, self.n_heads, cache_size, self.head_dim) v = torch.randn(1, self.n_heads, cache_size, self.head_dim) input_pos = torch.tensor([0], dtype=torch.long) k_out, v_out = self.kv_cache.update(input_pos, k, v) # All positions should be filled torch.testing.assert_close(k_out, k) torch.testing.assert_close(v_out, v) def test_prefill_larger_than_cache_raises_error(self): """Test that prefill larger than cache size raises an assertion error.""" cache_size = self.kv_cache.k_cache.size(2) seq_len = cache_size + 10 k = torch.randn(1, self.n_heads, seq_len, self.head_dim) v = torch.randn(1, self.n_heads, seq_len, self.head_dim) input_pos = torch.tensor([0], dtype=torch.long) # This should raise an assertion error since seq_len > cache_size with self.assertRaises(AssertionError): self.kv_cache.update(input_pos, k, v) def test_prefill_followed_by_decode(self): """Test prefill followed by decode steps.""" cache_size = self.kv_cache.k_cache.size(2) # Prefill with 20 tokens k_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) v_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) input_pos = torch.tensor([0], dtype=torch.long) self.kv_cache.update(input_pos, k_prefill, v_prefill) # Decode 5 more tokens for i in range(5): k_decode = torch.randn(1, self.n_heads, 1, self.head_dim) v_decode = torch.randn(1, self.n_heads, 1, self.head_dim) input_pos = torch.tensor([20 + i], dtype=torch.long) k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode) # Verify cache positions are updated expected_pos = 20 + i cache_idx = expected_pos % cache_size self.assertEqual( self.kv_cache.cache_positions_manager.cache_positions[cache_idx].item(), expected_pos )class EnableAttentionSinkTest(unittest.TestCase): """Test the enable_attention_sink transformation.""" def setUp(self): torch.manual_seed(42) self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, n_heads=8, n_kv_heads=8, dim=512, n_layers=2, vocab_size=100, ) def test_enable_attention_sink_transforms_model(self): """Test that enable_attention_sink properly transforms the model.""" from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.source_transformation.attention_sink import ( enable_attention_sink, ) # Create a simple transformer with torch.device("meta"): model = construct_transformer(self.params) model.to_empty(device="cpu") # Apply attention sink transformation model = enable_attention_sink( module=model, params=self.params, sink_size=4, window_size=100, eviction_batch_size=1, ) # Check that KV caches are replaced for layer in model.layers: kv_cache = layer.attention.kv_cache self.assertIsInstance(kv_cache, KVCacheWithAttentionSink) self.assertEqual(kv_cache.sink_size, 4) self.assertEqual(kv_cache.window_size, 100) self.assertTrue(kv_cache.is_ring_buffer) def test_enable_attention_sink_replaces_rope(self): """Test that RoPE is replaced with RopeWithAttentionSink.""" from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.source_transformation.attention_sink import ( enable_attention_sink, ) with torch.device("meta"): model = construct_transformer(self.params) model.to_empty(device="cpu") model = enable_attention_sink( module=model, params=self.params, sink_size=4, window_size=100, eviction_batch_size=1, ) # Check that rope is replaced for layer in model.layers: rope = layer.attention.rope self.assertIsInstance(rope, RopeWithAttentionSink)class IntegrationTest(unittest.TestCase): """Integration tests for end-to-end scenarios.""" def setUp(self): torch.manual_seed(42) def test_cache_positions_consistency(self): """Test that cache positions remain consistent during generation.""" cache_size = 32 sink_size = 4 window_size = 14 n_heads = 8 head_dim = 64 params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, n_heads=n_heads, n_kv_heads=n_heads, dim=n_heads * head_dim, ) rope = RopeWithAttentionSink( params=params, window_size=window_size, sink_size=sink_size, eviction_batch_size=1, ) kv_cache = KVCacheWithAttentionSink( n_heads=n_heads, head_dim=head_dim, enable_dynamic_shape=True, rope=rope, max_batch_size=1, window_size=window_size, sink_size=sink_size, eviction_batch_size=1, dtype=torch.float32, ) # Generate 100 tokens for pos in range(100): k = torch.randn(1, n_heads, 1, head_dim) v = torch.randn(1, n_heads, 1, head_dim) input_pos = torch.tensor([pos], dtype=torch.long) kv_cache.update(input_pos, k, v) # Create mask and verify it's valid mask = kv_cache.create_causal_mask_for_ring_buffer(pos, 1) # Mask should not be all -inf (would mean no tokens to attend to) non_inf_count = (mask != float('-inf')).sum().item() self.assertGreater(non_inf_count, 0, f"At pos {pos}, mask should have visible tokens") # For positions >= sink_size, sinks should always be visible if pos >= sink_size: for i in range(sink_size): cache_pos = kv_cache.cache_positions_manager.cache_positions[i].item() if cache_pos < sink_size: # This is actually a sink token self.assertEqual(mask[0, i].item(), 0.0, f"Sink at idx {i} should be visible at pos {pos}")if __name__ == '__main__': unittest.main() \ No newline at end of file +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for the ring-buffer based attention sink implementation. + +This tests the torch.export-compatible implementation that uses a ring buffer +for the sliding window rather than explicit token eviction. + +Usage: + # Run with pytest + python -m pytest examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py -v + + # Or run directly + python examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py +""" + +import unittest + +import torch +from executorch.examples.models.llama.model_args import ModelArgs + +from executorch.examples.models.llama.source_transformation.attention_sink import ( + CachePositionsManagerWithSink, + KVCacheWithAttentionSink, + RopeWithAttentionSink, + _create_causal_mask_for_attention_sink, +) + + +class CachePositionsManagerWithSinkTest(unittest.TestCase): + """Test the cache positions manager for ring buffer indexing.""" + + def setUp(self): + self.cache_size = 32 # Total cache size (e.g., sink_size + window_size * 2) + self.manager = CachePositionsManagerWithSink(self.cache_size) + + def test_initial_positions_are_zero(self): + """Cache positions should start as zeros.""" + expected = torch.zeros(self.cache_size, dtype=torch.long) + torch.testing.assert_close(self.manager.cache_positions, expected) + + def test_simple_update(self): + """Test simple sequential update without wraparound.""" + input_pos = torch.tensor([0], dtype=torch.long) + seq_len = 5 + indices = self.manager.calculate_positions_and_update_indices(input_pos, seq_len) + + # Should return indices 0, 1, 2, 3, 4 + expected_indices = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long) + torch.testing.assert_close(indices, expected_indices) + + # Cache positions at those indices should be the original positions + for i in range(5): + self.assertEqual(self.manager.cache_positions[i].item(), i) + + def test_wraparound(self): + """Test ring buffer wraparound.""" + # Fill cache to position 30 + input_pos = torch.tensor([0], dtype=torch.long) + self.manager.calculate_positions_and_update_indices(input_pos, 30) + + # Add 5 more tokens at position 30 - should wrap around + input_pos = torch.tensor([30], dtype=torch.long) + indices = self.manager.calculate_positions_and_update_indices(input_pos, 5) + + # Indices should wrap: 30 % 32 = 30, 31, 0, 1, 2 + expected_indices = torch.tensor([30, 31, 0, 1, 2], dtype=torch.long) + torch.testing.assert_close(indices, expected_indices) + + def test_cache_positions_track_original_positions(self): + """Cache positions should track which original position is at each index.""" + # Fill with positions 0-31 + input_pos = torch.tensor([0], dtype=torch.long) + self.manager.calculate_positions_and_update_indices(input_pos, 32) + + # Now add position 32 which wraps to index 0 + input_pos = torch.tensor([32], dtype=torch.long) + self.manager.calculate_positions_and_update_indices(input_pos, 1) + + # Index 0 should now contain original position 32 + self.assertEqual(self.manager.cache_positions[0].item(), 32) + + +class CausalMaskTest(unittest.TestCase): + """Test the causal mask generation for attention sink.""" + + def test_mask_allows_sink_tokens(self): + """Sink tokens should always be visible (mask = 0).""" + cache_size = 32 + sink_size = 4 + window_size = 14 # cache_size = sink_size + window_size * 2 + + # Create cache positions where positions 0-3 are sink tokens + cache_positions = torch.arange(cache_size, dtype=torch.long) + + start_pos = 20 # Query at position 20 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Sink tokens (indices 0-3, original positions 0-3) should have mask = 0 + for i in range(sink_size): + self.assertEqual(mask[0, i].item(), 0.0, f"Sink token at index {i} should be visible") + + def test_mask_blocks_future_tokens(self): + """Future tokens should be masked (-inf).""" + cache_size = 32 + sink_size = 4 + window_size = 14 + + cache_positions = torch.arange(cache_size, dtype=torch.long) + + start_pos = 10 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Future tokens (positions > 10) should have mask = -inf + for i in range(11, cache_size): + self.assertEqual(mask[0, i].item(), float('-inf'), f"Future token at position {i} should be masked") + + def test_mask_respects_window(self): + """Tokens outside the window should be masked.""" + cache_size = 32 + sink_size = 4 + window_size = 5 # Only allow 5 recent tokens + + cache_positions = torch.arange(cache_size, dtype=torch.long) + + start_pos = 20 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Positions 16-20 should be visible (within window of 5) + for pos in range(16, 21): + self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible (in window)") + + # Position 15 should be masked (outside window, not a sink) + self.assertEqual(mask[0, 15].item(), float('-inf'), f"Position 15 should be masked (outside window)") + + +class KVCacheWithAttentionSinkTest(unittest.TestCase): + """Test the KV cache with attention sink.""" + + def setUp(self): + torch.manual_seed(42) + self.window_size = 14 + self.sink_size = 4 + self.n_heads = 8 + self.head_dim = 64 + self.max_batch_size = 1 + + # Create model args with enough context for RoPE + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, # Large enough for RoPE + n_heads=self.n_heads, + n_kv_heads=self.n_heads, + dim=self.n_heads * self.head_dim, + ) + + self.rope = RopeWithAttentionSink( + params=self.params, + window_size=self.window_size, + sink_size=self.sink_size, + eviction_batch_size=1, + ) + + self.kv_cache = KVCacheWithAttentionSink( + n_heads=self.n_heads, + head_dim=self.head_dim, + enable_dynamic_shape=True, + rope=self.rope, + max_batch_size=self.max_batch_size, + window_size=self.window_size, + sink_size=self.sink_size, + eviction_batch_size=1, + dtype=torch.float32, + ) + + def test_cache_size(self): + """Cache should be sink_size + window_size * 2.""" + expected_size = self.sink_size + self.window_size * 2 # 4 + 28 = 32 + self.assertEqual(self.kv_cache.k_cache.size(2), expected_size) + self.assertEqual(self.kv_cache.v_cache.size(2), expected_size) + + def test_is_ring_buffer(self): + """Cache should be marked as ring buffer.""" + self.assertTrue(self.kv_cache.is_ring_buffer) + + def test_update_stores_kv(self): + """Update should store key-value pairs.""" + k = torch.randn(1, self.n_heads, 5, self.head_dim) + v = torch.randn(1, self.n_heads, 5, self.head_dim) + input_pos = torch.tensor([0], dtype=torch.long) + + k_out, v_out = self.kv_cache.update(input_pos, k, v) + + # First 5 positions should contain our values + torch.testing.assert_close(k_out[:, :, :5, :], k) + torch.testing.assert_close(v_out[:, :, :5, :], v) + + def test_evict_tokens_returns_zero(self): + """Ring buffer implementation doesn't shift, so evict returns 0.""" + input_pos = torch.tensor([100], dtype=torch.long) + shift = self.kv_cache.evict_tokens(input_pos, 10) + self.assertEqual(shift, 0) + + def test_extended_generation(self): + """Test that cache works for positions beyond cache size.""" + cache_size = self.kv_cache.k_cache.size(2) + + # Fill cache with initial tokens + for pos in range(cache_size + 50): + k = torch.randn(1, self.n_heads, 1, self.head_dim) + v = torch.randn(1, self.n_heads, 1, self.head_dim) + input_pos = torch.tensor([pos], dtype=torch.long) + + k_out, v_out = self.kv_cache.update(input_pos, k, v) + + # Should not raise any errors + self.assertEqual(k_out.shape, self.kv_cache.k_cache.shape) + self.assertEqual(v_out.shape, self.kv_cache.v_cache.shape) + + +class RopeWithAttentionSinkTest(unittest.TestCase): + """Test RoPE for attention sink.""" + + def setUp(self): + torch.manual_seed(42) + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, + n_heads=8, + dim=512, + ) + + self.rope = RopeWithAttentionSink( + params=self.params, + window_size=100, + sink_size=4, + eviction_batch_size=1, + ) + + def test_get_freqs_uses_original_position(self): + """RoPE frequencies should use the original position.""" + input_pos = torch.tensor([50], dtype=torch.long) + seq_len = 5 + + freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seq_len) + + # Should get frequencies for positions 50-54 + expected_cos = self.rope.freqs_cos[50:55] + expected_sin = self.rope.freqs_sin[50:55] + + torch.testing.assert_close(freqs_cos, expected_cos) + torch.testing.assert_close(freqs_sin, expected_sin) + + def test_rerotate_k(self): + """Test re-rotation of k from one position to another.""" + batch_size = 1 + seq_len = 8 + n_heads = self.params.n_heads + head_dim = self.params.dim // n_heads + + k = torch.randn(batch_size, seq_len, n_heads, head_dim) + q = torch.randn(batch_size, seq_len, n_heads, head_dim) + + # Rotate k at position 100 + original_pos = 100 + freqs_cos, freqs_sin = self.rope.get_freqs( + torch.tensor([original_pos], dtype=torch.long), seq_len + ) + _, rotated_k = self.rope.forward(q, k, freqs_cos, freqs_sin) + + # Re-rotate to position 50 + new_pos = 50 + rerotated_k = self.rope.rerotate_k(rotated_k, original_pos, new_pos) + + # This should be equivalent to directly rotating k at position 50 + freqs_cos_new, freqs_sin_new = self.rope.get_freqs( + torch.tensor([new_pos], dtype=torch.long), seq_len + ) + _, expected_k = self.rope.forward(q, k, freqs_cos_new, freqs_sin_new) + + torch.testing.assert_close(rerotated_k, expected_k, rtol=1e-4, atol=1e-4) + + +class CausalMaskWithWraparoundTest(unittest.TestCase): + """Test causal mask with ring buffer wraparound.""" + + def test_mask_after_wraparound(self): + """Test mask after cache has wrapped around.""" + cache_size = 16 + sink_size = 4 + window_size = 6 # cache_size = sink_size + window_size * 2 + + # Simulate cache after generating beyond cache_size: + # The ring buffer wraps, so indices 0-15 contain positions that wrap + # At position 50, with cache_size=16, the cache contains: + # positions 50-15=35 to 49 at various indices + cache_positions = torch.zeros(cache_size, dtype=torch.long) + # Fill with positions that would exist after generating 50 tokens + # idx = pos % cache_size, so: + # pos 34-49 occupy indices 2-15 and 0-1 + for pos in range(34, 50): + idx = pos % cache_size + cache_positions[idx] = pos + + start_pos = 49 # Query at position 49 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Positions within window (49-6+1=44 to 49) should be visible + visible_count = 0 + for i in range(cache_size): + pos = cache_positions[i].item() + if pos >= 44 and pos <= 49: # In window + self.assertEqual(mask[0, i].item(), 0.0, + f"Position {pos} at idx {i} should be visible (in window)") + visible_count += 1 + + # Should have some visible tokens + self.assertGreater(visible_count, 0, "Should have visible tokens in window") + + def test_mask_with_sink_size_zero(self): + """Test pure sliding window (sink_size=0).""" + cache_size = 16 + sink_size = 0 + window_size = 8 + + cache_positions = torch.arange(cache_size, dtype=torch.long) + start_pos = 10 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Positions 3-10 should be visible (within window of 8) + for pos in range(3, 11): + self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible") + + # Positions 0-2 should be masked (outside window) + for pos in range(0, 3): + self.assertEqual(mask[0, pos].item(), float('-inf'), + f"Position {pos} should be masked (outside window)") + + +class PrefillTest(unittest.TestCase): + """Test prefill scenarios.""" + + def setUp(self): + torch.manual_seed(42) + self.window_size = 14 + self.sink_size = 4 + self.n_heads = 8 + self.head_dim = 64 + + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, + n_heads=self.n_heads, + n_kv_heads=self.n_heads, + dim=self.n_heads * self.head_dim, + ) + + self.rope = RopeWithAttentionSink( + params=self.params, + window_size=self.window_size, + sink_size=self.sink_size, + eviction_batch_size=1, + ) + + self.kv_cache = KVCacheWithAttentionSink( + n_heads=self.n_heads, + head_dim=self.head_dim, + enable_dynamic_shape=True, + rope=self.rope, + max_batch_size=1, + window_size=self.window_size, + sink_size=self.sink_size, + eviction_batch_size=1, + dtype=torch.float32, + ) + + def test_prefill_entire_cache(self): + """Test prefill that fills entire cache.""" + cache_size = self.kv_cache.k_cache.size(2) + + k = torch.randn(1, self.n_heads, cache_size, self.head_dim) + v = torch.randn(1, self.n_heads, cache_size, self.head_dim) + input_pos = torch.tensor([0], dtype=torch.long) + + k_out, v_out = self.kv_cache.update(input_pos, k, v) + + # All positions should be filled + torch.testing.assert_close(k_out, k) + torch.testing.assert_close(v_out, v) + + def test_prefill_larger_than_cache_raises_error(self): + """Test that prefill larger than cache size raises an assertion error.""" + cache_size = self.kv_cache.k_cache.size(2) + seq_len = cache_size + 10 + + k = torch.randn(1, self.n_heads, seq_len, self.head_dim) + v = torch.randn(1, self.n_heads, seq_len, self.head_dim) + input_pos = torch.tensor([0], dtype=torch.long) + + # This should raise an assertion error since seq_len > cache_size + with self.assertRaises(AssertionError): + self.kv_cache.update(input_pos, k, v) + + def test_prefill_followed_by_decode(self): + """Test prefill followed by decode steps.""" + cache_size = self.kv_cache.k_cache.size(2) + + # Prefill with 20 tokens + k_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) + v_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) + input_pos = torch.tensor([0], dtype=torch.long) + self.kv_cache.update(input_pos, k_prefill, v_prefill) + + # Decode 5 more tokens + for i in range(5): + k_decode = torch.randn(1, self.n_heads, 1, self.head_dim) + v_decode = torch.randn(1, self.n_heads, 1, self.head_dim) + input_pos = torch.tensor([20 + i], dtype=torch.long) + k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode) + + # Verify cache positions are updated + expected_pos = 20 + i + cache_idx = expected_pos % cache_size + self.assertEqual( + self.kv_cache.cache_positions_manager.cache_positions[cache_idx].item(), + expected_pos + ) + + +class EnableAttentionSinkTest(unittest.TestCase): + """Test the enable_attention_sink transformation.""" + + def setUp(self): + torch.manual_seed(42) + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, + n_heads=8, + n_kv_heads=8, + dim=512, + n_layers=2, + vocab_size=100, + ) + + def test_enable_attention_sink_transforms_model(self): + """Test that enable_attention_sink properly transforms the model.""" + from executorch.examples.models.llama.llama_transformer import construct_transformer + from executorch.examples.models.llama.source_transformation.attention_sink import ( + enable_attention_sink, + ) + + # Create a simple transformer + with torch.device("meta"): + model = construct_transformer(self.params) + model.to_empty(device="cpu") + + # Apply attention sink transformation + model = enable_attention_sink( + module=model, + params=self.params, + sink_size=4, + window_size=100, + eviction_batch_size=1, + ) + + # Check that KV caches are replaced + for layer in model.layers: + kv_cache = layer.attention.kv_cache + self.assertIsInstance(kv_cache, KVCacheWithAttentionSink) + self.assertEqual(kv_cache.sink_size, 4) + self.assertEqual(kv_cache.window_size, 100) + self.assertTrue(kv_cache.is_ring_buffer) + + def test_enable_attention_sink_replaces_rope(self): + """Test that RoPE is replaced with RopeWithAttentionSink.""" + from executorch.examples.models.llama.llama_transformer import construct_transformer + from executorch.examples.models.llama.source_transformation.attention_sink import ( + enable_attention_sink, + ) + + with torch.device("meta"): + model = construct_transformer(self.params) + model.to_empty(device="cpu") + + model = enable_attention_sink( + module=model, + params=self.params, + sink_size=4, + window_size=100, + eviction_batch_size=1, + ) + + # Check that rope is replaced + for layer in model.layers: + rope = layer.attention.rope + self.assertIsInstance(rope, RopeWithAttentionSink) + + +class IntegrationTest(unittest.TestCase): + """Integration tests for end-to-end scenarios.""" + + def setUp(self): + torch.manual_seed(42) + + def test_cache_positions_consistency(self): + """Test that cache positions remain consistent during generation.""" + cache_size = 32 + sink_size = 4 + window_size = 14 + n_heads = 8 + head_dim = 64 + + params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, + n_heads=n_heads, + n_kv_heads=n_heads, + dim=n_heads * head_dim, + ) + + rope = RopeWithAttentionSink( + params=params, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=1, + ) + + kv_cache = KVCacheWithAttentionSink( + n_heads=n_heads, + head_dim=head_dim, + enable_dynamic_shape=True, + rope=rope, + max_batch_size=1, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=1, + dtype=torch.float32, + ) + + # Generate 100 tokens + for pos in range(100): + k = torch.randn(1, n_heads, 1, head_dim) + v = torch.randn(1, n_heads, 1, head_dim) + input_pos = torch.tensor([pos], dtype=torch.long) + + kv_cache.update(input_pos, k, v) + + # Create mask and verify it's valid + mask = kv_cache.create_causal_mask_for_ring_buffer(pos, 1) + + # Mask should not be all -inf (would mean no tokens to attend to) + non_inf_count = (mask != float('-inf')).sum().item() + self.assertGreater(non_inf_count, 0, f"At pos {pos}, mask should have visible tokens") + + # For positions >= sink_size, sinks should always be visible + if pos >= sink_size: + for i in range(sink_size): + cache_pos = kv_cache.cache_positions_manager.cache_positions[i].item() + if cache_pos < sink_size: # This is actually a sink token + self.assertEqual(mask[0, i].item(), 0.0, + f"Sink at idx {i} should be visible at pos {pos}") + + +if __name__ == '__main__': + unittest.main() diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index 7924d24082f..11d77f321bf 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -1 +1,310 @@ -/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated */// A simple llama2 runner that includes preprocessing and post processing logic.// The module takes in a string as input and emits a string as output.#include #include #include #include #include #include #include #include namespace executorch::extension::llm {using ::executorch::extension::Module;using ::executorch::runtime::Error;using ::executorch::runtime::Result;TextLLMRunner::TextLLMRunner( std::unordered_map metadata, std::unique_ptr<::tokenizers::Tokenizer> tokenizer, std::unique_ptr<::executorch::extension::Module> module, std::unique_ptr text_decoder_runner, std::unique_ptr text_prefiller, std::unique_ptr io_manager, std::unique_ptr text_token_generator, std::unique_ptr stats, float temperature) : tokenizer_(std::move(tokenizer)), metadata_(std::move(metadata)), module_(std::move(module)), text_decoder_runner_(std::move(text_decoder_runner)), text_prefiller_(std::move(text_prefiller)), io_manager_(std::move(io_manager)), text_token_generator_(std::move(text_token_generator)), stats_(std::move(stats)), temperature_(temperature), pos_(0) { // Note: This constructor assumes that text_prefiller and text_token_generator // already have references to the Module and TextDecoderRunner they need}bool TextLLMRunner::is_loaded() const { return text_prefiller_->is_loaded() && text_token_generator_->is_loaded();}Error TextLLMRunner::load() { if (is_loaded()) { return Error::Ok; } ET_CHECK_OK_OR_RETURN_ERROR(text_prefiller_->load()); ET_CHECK_OK_OR_RETURN_ERROR(io_manager_->load()); ET_CHECK_OK_OR_RETURN_ERROR(text_token_generator_->load()); return Error::Ok;}// Don't print with the same priority during warmup#define RUNNER_ET_LOG(warmup, format, ...) \ if (warmup) { \ ET_LOG(Debug, format, __VA_ARGS__); \ } else { \ ET_LOG(Info, format, __VA_ARGS__); \ }Error TextLLMRunner::generate( const std::string& prompt, const GenerationConfig& config, std::function token_callback, std::function stats_callback) { // Prepare the inputs. // Use ones-initialized inputs. ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); if (!is_loaded()) { stats_->model_load_start_ms = time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(load()); stats_->model_load_end_ms = time_in_ms(); } if (config.warming) { ET_LOG(Info, "Doing a warmup run..."); } RUNNER_ET_LOG( config.warming, "RSS after loading model: %f MiB (0 if unsupported)", get_rss_bytes() / 1024.0 / 1024.0); // Wrap the token_callback with print function std::function wrapped_callback = [token_callback, config](const std::string& piece) { if (!config.warming) { llm::safe_printf(piece.c_str()); fflush(stdout); } if (token_callback) { token_callback(piece); } }; // First token time only measures the time it takes to encode the prompt and // return a response token. stats_->inference_start_ms = time_in_ms(); shouldStop_ = false; ::tokenizers::Result> encode_res = tokenizer_->encode( prompt, /*bos=*/config.num_bos, /*eos=*/config.num_eos); if (!encode_res.ok()) { ET_LOG( Error, "Failed to encode prompt %s. Tokenizers error code %d", prompt.c_str(), static_cast(encode_res.error())); return Error::InvalidArgument; } // encode the (string) prompt into tokens sequence std::vector prompt_tokens = encode_res.get(); int num_prompt_tokens = prompt_tokens.size(); // Get max_seq_len for single prefill chunk limit int64_t max_seq_len = metadata_.at(kMaxSeqLen); int64_t max_context_len = metadata_.at(kMaxContextLen); ET_CHECK_OR_RETURN_ERROR( num_prompt_tokens >= 1, InvalidArgument, "Expected at least 1 prompt token"); // For models with sliding window (Ring Buffer / Attention Sink), // we allow pos_ to exceed max_context_len. The model handles this // internally via ring buffer indexing or token eviction. // We only check that a single prefill chunk doesn't exceed max_seq_len. ET_CHECK_OR_RETURN_ERROR( num_prompt_tokens <= max_seq_len, InvalidArgument, "num_prompt_tokens %d > max_seq_len %" PRId64 ", Single prefill chunk too large - please reduce prompt size or increase max_seq_len", num_prompt_tokens, max_seq_len); // Determine max_new_tokens using the GenerationConfig's resolve method. // For sliding window models, we use max_context_len directly (not reduced by pos_) // because the model handles position wrapping internally via ring buffer. int max_new_tokens = config.resolve_max_new_tokens(max_context_len, num_prompt_tokens); // TEMPORARY: For sliding window / infinite context testing, // override max_new_tokens to allow unlimited generation if (max_new_tokens <= 0) { max_new_tokens = 1000000; // Effectively unlimited } // If user specified seq_len, use that instead if (config.seq_len > 0 && config.seq_len > max_new_tokens) { max_new_tokens = config.seq_len; } ET_LOG( Info, "Max new tokens resolved: %d, given pos_ %" PRId64 ", num_prompt_tokens %zu, max_context_len %" PRId64, max_new_tokens, pos_, prompt_tokens.size(), max_context_len); ET_CHECK_OR_RETURN_ERROR( max_new_tokens > 0, InvalidArgument, "Max new tokens %d is less than or equal to 0", max_new_tokens); // Prefill first // Here feed all tokens to the model and get the next predicted token // after the prompt. After that we will enter generate loop. // print prompts if (config.echo) { wrapped_callback(prompt); } auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); uint64_t cur_token = prefill_res.get(); stats_->first_token_ms = time_in_ms(); stats_->prompt_eval_end_ms = time_in_ms(); // print the first token from prefill. No prev_token so use cur_token for it. auto decode_result = tokenizer_->decode(cur_token, cur_token); if (!decode_result.ok()) { ET_LOG( Error, "Tokenizers error code %d", static_cast(decode_result.error())); return ::executorch::runtime::Error::InvalidArgument; } wrapped_callback(std::move(*decode_result)); RUNNER_ET_LOG( config.warming, "RSS after prompt prefill: %f MiB (0 if unsupported)", get_rss_bytes() / 1024.0 / 1024.0); // start the main loop prompt_tokens.push_back(cur_token); // Set ignore_eos based on config text_token_generator_->set_ignore_eos(config.ignore_eos); // Generate max_new_tokens - 1 because prefill already generated 1 token. auto generate_result = text_token_generator_->generate( prompt_tokens, pos_, max_new_tokens - 1, temperature_ == -1.0f ? config.temperature : temperature_, wrapped_callback); if (!generate_result.ok()) { return generate_result.error(); } int64_t num_generated_tokens = generate_result.get(); pos_ += num_generated_tokens; stats_->inference_end_ms = time_in_ms(); if (!config.warming) { printf("\n"); } RUNNER_ET_LOG( config.warming, "RSS after finishing text generation: %f MiB (0 if unsupported)", get_rss_bytes() / 1024.0 / 1024.0); if (num_generated_tokens == max_new_tokens) { RUNNER_ET_LOG(config.warming, "Max new tokens %i reached!", max_new_tokens); } stats_->num_prompt_tokens = num_prompt_tokens; stats_->num_generated_tokens = num_generated_tokens; if (config.warming) { ET_LOG(Info, "Warmup run finished!"); } else { // Do not print report during warmup print_report(*stats_); } if (stats_callback) { stats_callback(*stats_); } return Error::Ok;}Error TextLLMRunner::prefill( const std::string& prompt, const GenerationConfig& config) { if (!is_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(load()); } ::tokenizers::Result> encode_res = tokenizer_->encode( prompt, /*bos=*/config.num_bos, /*eos=*/config.num_eos); ET_CHECK_TK_OK_OR_RETURN_ERROR( encode_res.error(), "Failed to encode prompt %s", prompt.c_str()); // encode the (string) prompt into tokens sequence std::vector prompt_tokens = encode_res.get(); auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); return Error::Ok;}Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) { // Create a GenerationConfig for warmup GenerationConfig config; config.echo = false; config.max_new_tokens = max_new_tokens; config.warming = true; // Call generate with the warmup config Error err = generate(prompt, config); // Reset stats after warmup, not resetting the std::unique_ptr! reset(); return err;}void TextLLMRunner::stop() { if (is_loaded()) { text_token_generator_->stop(); } else { ET_LOG(Error, "Token generator is not loaded, cannot stop"); }}void TextLLMRunner::reset() { stats_->reset(); pos_ = 0;}} // namespace executorch::extension::llm \ No newline at end of file +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated + */ + +// A simple llama2 runner that includes preprocessing and post processing logic. +// The module takes in a string as input and emits a string as output. + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace executorch::extension::llm { + +using ::executorch::extension::Module; +using ::executorch::runtime::Error; +using ::executorch::runtime::Result; + +TextLLMRunner::TextLLMRunner( + std::unordered_map metadata, + std::unique_ptr<::tokenizers::Tokenizer> tokenizer, + std::unique_ptr<::executorch::extension::Module> module, + std::unique_ptr text_decoder_runner, + std::unique_ptr text_prefiller, + std::unique_ptr io_manager, + std::unique_ptr text_token_generator, + std::unique_ptr stats, + float temperature) + : tokenizer_(std::move(tokenizer)), + metadata_(std::move(metadata)), + module_(std::move(module)), + text_decoder_runner_(std::move(text_decoder_runner)), + text_prefiller_(std::move(text_prefiller)), + io_manager_(std::move(io_manager)), + text_token_generator_(std::move(text_token_generator)), + stats_(std::move(stats)), + temperature_(temperature), + pos_(0) { + // Note: This constructor assumes that text_prefiller and text_token_generator + // already have references to the Module and TextDecoderRunner they need +} + +bool TextLLMRunner::is_loaded() const { + return text_prefiller_->is_loaded() && text_token_generator_->is_loaded(); +} + +Error TextLLMRunner::load() { + if (is_loaded()) { + return Error::Ok; + } + ET_CHECK_OK_OR_RETURN_ERROR(text_prefiller_->load()); + ET_CHECK_OK_OR_RETURN_ERROR(io_manager_->load()); + ET_CHECK_OK_OR_RETURN_ERROR(text_token_generator_->load()); + return Error::Ok; +} + +// Don't print with the same priority during warmup +#define RUNNER_ET_LOG(warmup, format, ...) \ + if (warmup) { \ + ET_LOG(Debug, format, __VA_ARGS__); \ + } else { \ + ET_LOG(Info, format, __VA_ARGS__); \ + } + +Error TextLLMRunner::generate( + const std::string& prompt, + const GenerationConfig& config, + std::function token_callback, + std::function stats_callback) { + // Prepare the inputs. + // Use ones-initialized inputs. + ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); + if (!is_loaded()) { + stats_->model_load_start_ms = time_in_ms(); + ET_CHECK_OK_OR_RETURN_ERROR(load()); + stats_->model_load_end_ms = time_in_ms(); + } + + if (config.warming) { + ET_LOG(Info, "Doing a warmup run..."); + } + + RUNNER_ET_LOG( + config.warming, + "RSS after loading model: %f MiB (0 if unsupported)", + get_rss_bytes() / 1024.0 / 1024.0); + + // Wrap the token_callback with print function + std::function wrapped_callback = + [token_callback, config](const std::string& piece) { + if (!config.warming) { + llm::safe_printf(piece.c_str()); + fflush(stdout); + } + if (token_callback) { + token_callback(piece); + } + }; + // First token time only measures the time it takes to encode the prompt and + // return a response token. + + stats_->inference_start_ms = time_in_ms(); + shouldStop_ = false; + + ::tokenizers::Result> encode_res = tokenizer_->encode( + prompt, + /*bos=*/config.num_bos, + /*eos=*/config.num_eos); + + if (!encode_res.ok()) { + ET_LOG( + Error, + "Failed to encode prompt %s. Tokenizers error code %d", + prompt.c_str(), + static_cast(encode_res.error())); + return Error::InvalidArgument; + } + + // encode the (string) prompt into tokens sequence + std::vector prompt_tokens = encode_res.get(); + int num_prompt_tokens = prompt_tokens.size(); + + // Get max_seq_len for single prefill chunk limit + int64_t max_seq_len = metadata_.at(kMaxSeqLen); + int64_t max_context_len = metadata_.at(kMaxContextLen); + + ET_CHECK_OR_RETURN_ERROR( + num_prompt_tokens >= 1, + InvalidArgument, + "Expected at least 1 prompt token"); + + // For models with sliding window (Ring Buffer / Attention Sink), + // we allow pos_ to exceed max_context_len. The model handles this + // internally via ring buffer indexing or token eviction. + // We only check that a single prefill chunk doesn't exceed max_seq_len. + ET_CHECK_OR_RETURN_ERROR( + num_prompt_tokens <= max_seq_len, + InvalidArgument, + "num_prompt_tokens %d > max_seq_len %" PRId64 + ", Single prefill chunk too large - please reduce prompt size or increase max_seq_len", + num_prompt_tokens, + max_seq_len); + + // Determine max_new_tokens using the GenerationConfig's resolve method. + // For sliding window models, we use max_context_len directly (not reduced by pos_) + // because the model handles position wrapping internally via ring buffer. + int max_new_tokens = + config.resolve_max_new_tokens(max_context_len, num_prompt_tokens); + + // TEMPORARY: For sliding window / infinite context testing, + // override max_new_tokens to allow unlimited generation + if (max_new_tokens <= 0) { + max_new_tokens = 1000000; // Effectively unlimited + } + // If user specified seq_len, use that instead + if (config.seq_len > 0 && config.seq_len > max_new_tokens) { + max_new_tokens = config.seq_len; + } + + ET_LOG( + Info, + "Max new tokens resolved: %d, given pos_ %" PRId64 + ", num_prompt_tokens %zu, max_context_len %" PRId64, + max_new_tokens, + pos_, + prompt_tokens.size(), + max_context_len); + ET_CHECK_OR_RETURN_ERROR( + max_new_tokens > 0, + InvalidArgument, + "Max new tokens %d is less than or equal to 0", + max_new_tokens); + // Prefill first + // Here feed all tokens to the model and get the next predicted token + // after the prompt. After that we will enter generate loop. + + // print prompts + if (config.echo) { + wrapped_callback(prompt); + } + auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); + ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); + uint64_t cur_token = prefill_res.get(); + stats_->first_token_ms = time_in_ms(); + stats_->prompt_eval_end_ms = time_in_ms(); + + // print the first token from prefill. No prev_token so use cur_token for it. + auto decode_result = tokenizer_->decode(cur_token, cur_token); + if (!decode_result.ok()) { + ET_LOG( + Error, + "Tokenizers error code %d", + static_cast(decode_result.error())); + return ::executorch::runtime::Error::InvalidArgument; + } + wrapped_callback(std::move(*decode_result)); + RUNNER_ET_LOG( + config.warming, + "RSS after prompt prefill: %f MiB (0 if unsupported)", + get_rss_bytes() / 1024.0 / 1024.0); + + // start the main loop + prompt_tokens.push_back(cur_token); + + // Set ignore_eos based on config + text_token_generator_->set_ignore_eos(config.ignore_eos); + + // Generate max_new_tokens - 1 because prefill already generated 1 token. + auto generate_result = text_token_generator_->generate( + prompt_tokens, + pos_, + max_new_tokens - 1, + temperature_ == -1.0f ? config.temperature : temperature_, + wrapped_callback); + if (!generate_result.ok()) { + return generate_result.error(); + } + int64_t num_generated_tokens = generate_result.get(); + + pos_ += num_generated_tokens; + + stats_->inference_end_ms = time_in_ms(); + if (!config.warming) { + printf("\n"); + } + RUNNER_ET_LOG( + config.warming, + "RSS after finishing text generation: %f MiB (0 if unsupported)", + get_rss_bytes() / 1024.0 / 1024.0); + + if (num_generated_tokens == max_new_tokens) { + RUNNER_ET_LOG(config.warming, "Max new tokens %i reached!", max_new_tokens); + } + + stats_->num_prompt_tokens = num_prompt_tokens; + stats_->num_generated_tokens = num_generated_tokens; + + if (config.warming) { + ET_LOG(Info, "Warmup run finished!"); + } else { + // Do not print report during warmup + print_report(*stats_); + } + if (stats_callback) { + stats_callback(*stats_); + } + + return Error::Ok; +} + +Error TextLLMRunner::prefill( + const std::string& prompt, + const GenerationConfig& config) { + if (!is_loaded()) { + ET_CHECK_OK_OR_RETURN_ERROR(load()); + } + + ::tokenizers::Result> encode_res = tokenizer_->encode( + prompt, + /*bos=*/config.num_bos, + /*eos=*/config.num_eos); + + ET_CHECK_TK_OK_OR_RETURN_ERROR( + encode_res.error(), "Failed to encode prompt %s", prompt.c_str()); + + // encode the (string) prompt into tokens sequence + std::vector prompt_tokens = encode_res.get(); + auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); + ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); + return Error::Ok; +} + +Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) { + // Create a GenerationConfig for warmup + GenerationConfig config; + config.echo = false; + config.max_new_tokens = max_new_tokens; + config.warming = true; + + // Call generate with the warmup config + Error err = generate(prompt, config); + + // Reset stats after warmup, not resetting the std::unique_ptr! + reset(); + return err; +} + +void TextLLMRunner::stop() { + if (is_loaded()) { + text_token_generator_->stop(); + } else { + ET_LOG(Error, "Token generator is not loaded, cannot stop"); + } +} + +void TextLLMRunner::reset() { + stats_->reset(); + pos_ = 0; +} + +} // namespace executorch::extension::llm From 04fadb5df35e0626e89fbe7176296e7d1260decb Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Feb 2026 14:41:41 -0800 Subject: [PATCH 07/17] Fix attention sink index calculation to preserve sink tokens --- .../source_transformation/attention_sink.py | 37 +- .../test_attention_sink.py | 514 ------------------ .../test_attention_sink_ring_buffer.py | 46 +- 3 files changed, 66 insertions(+), 531 deletions(-) delete mode 100644 examples/models/llama/source_transformation/test_attention_sink.py diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index c2e0f6606f5..d8a6919f7d1 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -144,31 +144,56 @@ class CachePositionsManagerWithSink(nn.Module): IMPORTANT: cache_size should be the actual cache dimension size (2x window for ring buffer). """ - def __init__(self, cache_size: int): + def __init__(self, cache_size: int, sink_size: int = 0): super().__init__() # cache_size is the actual size of the kv cache dimension self.max_context_length = cache_size + self.sink_size = sink_size # Use zeros like original CachePositionsManager self.register_buffer( "cache_positions", torch.zeros((self.max_context_length,), dtype=torch.long, device="cpu"), ) - + def calculate_positions_and_update_indices( self, input_pos: torch.Tensor, seq_len: int ) -> torch.Tensor: """ Calculate indices into k_cache, v_cache for placing k_val, v_val. - This is identical to the original CachePositionsManager logic. + Indices logic: + - If pos < sink_size: index = pos + - If pos >= sink_size: index = sink_size + (pos - sink_size) % (cache_size - sink_size) """ start_pos = input_pos[0].item() torch._check_is_size(start_pos) orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos - # Simple ring buffer: just mod by cache size - indices = orig_indices % self.max_context_length + if self.sink_size == 0: + # Simple ring buffer: just mod by cache size + indices = orig_indices % self.max_context_length + else: + # Shifted ring buffer logic + ring_size = self.max_context_length - self.sink_size + + # Calculate indices based on sink vs ring buffer logic + # Logic: + # 1. Calculate potential ring buffer index: sink_size + (pos - sink_size) % ring_size + # 2. If pos < sink_size, use pos. Else use ring buffer index. + + # Note: (pos - sink_size) % ring_size works correctly even if pos < sink_size + # in Python, but we want to be explicit. + # However, for pure torch.export compatibility without conditionals on tensors, + # we can use where or math. + + # Vectorized calculation: + shifted_pos = orig_indices - self.sink_size + ring_indices = self.sink_size + (shifted_pos % ring_size) + + # If position is within sink (0..sink_size-1), use original position + # Else use ring index + indices = torch.where(orig_indices < self.sink_size, orig_indices, ring_indices) # Update cache_positions exactly like original CachePositionsManager full_t = torch.full((self.max_context_length,), -1, dtype=torch.long) @@ -222,7 +247,7 @@ def __init__( # Cache positions manager for determining write locations # Pass the total cache size (same as self.max_context_length after super().__init__) - self.cache_positions_manager = CachePositionsManagerWithSink(total_cache_size) + self.cache_positions_manager = CachePositionsManagerWithSink(total_cache_size, sink_size) def create_causal_mask_for_ring_buffer( self, start_pos: torch.Tensor, seq_len: int diff --git a/examples/models/llama/source_transformation/test_attention_sink.py b/examples/models/llama/source_transformation/test_attention_sink.py deleted file mode 100644 index fc882ebf4ab..00000000000 --- a/examples/models/llama/source_transformation/test_attention_sink.py +++ /dev/null @@ -1,514 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -import torch -from executorch.examples.models.llama.model_args import ModelArgs - -from executorch.examples.models.llama.source_transformation.attention_sink import ( - KVCacheWithAttentionSink, - RopeWithAttentionSink, -) -from parameterized import parameterized - - -class RopeWithAttentionSinkTest(unittest.TestCase): - - def _init_rope(self, params: ModelArgs, eviction_batch_size: int): - return RopeWithAttentionSink( - params=params, - window_size=252, - sink_size=4, - eviction_batch_size=eviction_batch_size, - ) - - def setUp(self): - torch.manual_seed(42) - self.params = ModelArgs( - use_kv_cache=True, enable_dynamic_shape=True, max_context_len=256 - ) - self.rope_with_attention_sink = self._init_rope( - params=self.params, eviction_batch_size=1 - ) - - @parameterized.expand( - [ - [0, 10, 1, 0], # No shift - [250, 10, 1, 246], # Some shift - [256, 10, 1, 246], # All shift - [0, 10, 30, 0], # No shift with batch eviction - [250, 10, 30, 220], # Some shift with batch eviction - [256, 10, 30, 226], # All shift with batch eviction - ] - ) - def test_get_freqs( - self, input_pos, seq_len, eviction_batch_size, expected_result_pos - ): - self.rope_with_attention_sink = self._init_rope( - params=self.params, eviction_batch_size=eviction_batch_size - ) - - freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( - input_pos=torch.tensor([input_pos], dtype=torch.int32), - seq_len=seq_len, - ) - - torch.testing.assert_close( - freqs_cos, - self.rope_with_attention_sink.freqs_cos.narrow( - 0, expected_result_pos, seq_len - ), - ) - torch.testing.assert_close( - freqs_sin, - self.rope_with_attention_sink.freqs_sin.narrow( - 0, expected_result_pos, seq_len - ), - ) - - @parameterized.expand( - [ - [128, 127], # Rotate left - [128, 128], # No rotation - [128, 129], # Rotate right - ] - ) - def test_rotate(self, original_position, new_position): - seq_len = 32 - - size = (1, seq_len, self.params.n_heads, self.params.head_dim) - q = torch.rand(*size, dtype=torch.float32) - k = torch.rand( - *size, - dtype=torch.float32, - ) - freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( - input_pos=torch.tensor([original_position], dtype=torch.int32), - seq_len=seq_len, - ) - _, pre_rotated_k = self.rope_with_attention_sink.forward( - q=q, - k=k, - freqs_cos=freqs_cos, - freqs_sin=freqs_sin, - ) - - rerotated_k = self.rope_with_attention_sink.rerotate_k( - k=pre_rotated_k, - original_position=original_position, - new_position=new_position, - ) - - freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( - input_pos=torch.tensor([new_position], dtype=torch.int32), - seq_len=seq_len, - ) - _, expected_k = self.rope_with_attention_sink.forward( - q=q, - k=k, - freqs_cos=freqs_cos, - freqs_sin=freqs_sin, - ) - - torch.testing.assert_close(rerotated_k, expected_k) - - -class KVCacheWithAttentionSinkTest(unittest.TestCase): - - _single_evict_test_cases = [ - [4, 1], - ] - - _batch_evict_test_cases = [ - [4, 8], - ] - - _sliding_window_test_cases = [ - [0, 1], - ] - - def _init_cache(self, sink_size, eviction_batch_size): - self.params = ModelArgs( - use_kv_cache=True, - enable_dynamic_shape=True, - max_context_len=self.window_size + sink_size, - ) - self.rope_with_attention_sink = RopeWithAttentionSink( - params=self.params, - window_size=self.window_size, - sink_size=sink_size, - eviction_batch_size=eviction_batch_size, - ) - self.kv_cache = KVCacheWithAttentionSink( - n_heads=self.params.n_heads, - head_dim=self.params.head_dim, - enable_dynamic_shape=self.params.enable_dynamic_shape, - rope=self.rope_with_attention_sink, - max_batch_size=self.max_batch_size, - window_size=self.window_size, - sink_size=sink_size, - eviction_batch_size=eviction_batch_size, - dtype=self.dtype, - ) - - def _rand_kv_with_length(self, seq_len): - size = ( - self.max_batch_size, - self.params.n_heads, - seq_len, - self.params.head_dim, - ) - k = torch.rand( - *size, - dtype=self.dtype, - ) - v = torch.rand( - *size, - dtype=self.dtype, - ) - return k, v - - def _zero_kv_with_length(self, seq_len): - size = ( - self.max_batch_size, - self.params.n_heads, - seq_len, - self.params.head_dim, - ) - k = torch.zeros( - *size, - dtype=self.dtype, - ) - v = torch.zeros( - *size, - dtype=self.dtype, - ) - return k, v - - def _get_dim_to_slice(self): - return 2 - - def _get_expected_rotated_k(self, k, original_position, new_position): - return self.rope_with_attention_sink.rerotate_k( - k=k.transpose(1, 2), - original_position=original_position, - new_position=new_position, - ).transpose(1, 2) - - def setUp(self): - torch.manual_seed(42) - self.max_batch_size = 1 - self.window_size = 28 - self.dtype = torch.float32 - - @parameterized.expand( - _single_evict_test_cases + _batch_evict_test_cases + _sliding_window_test_cases - ) - def test_evict_empty_cache(self, sink_size, eviction_batch_size): - self._init_cache(sink_size, eviction_batch_size) - - # KV cache is empty, evict does nothing - input_pos = torch.tensor([0], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 1) == 0 - - expected_k, expected_v = self._zero_kv_with_length(self.window_size + sink_size) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand( - _single_evict_test_cases + _batch_evict_test_cases + _sliding_window_test_cases - ) - def test_evict_without_shift(self, sink_size, eviction_batch_size): - dimension_to_slice = 2 - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has enough spaces for new tokens, no shift - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(10) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([10], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 1) == 0 - - zero_k, zero_v = self._zero_kv_with_length(self.window_size + sink_size - 10) - - expected_k = torch.cat( - [ - k, - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v, - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_single_evict_test_cases) - def test_evict_with_some_shift(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has some spaces for new tokens but not all, shift some tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([10], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 24) == -2 - - zero_k, zero_v = self._zero_kv_with_length(24) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k(k1.narrow(dimension_to_slice, 1, 4), 6, 4), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 1, 4), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_single_evict_test_cases) - def test_evict_with_all_shift(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has no spaces for new tokens, shift all tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(27) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([32], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 6) == -6 - - zero_k, zero_v = self._zero_kv_with_length(6) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 5, 22), 10, 4 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 5, 22), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_sliding_window_test_cases) - def test_evict_with_some_shift_for_sliding_window( - self, sink_size, eviction_batch_size - ): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has some spaces for new tokens but not all, shift some tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([10], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 20) == -2 - - zero_k, zero_v = self._zero_kv_with_length(20) - expected_k = torch.cat( - [ - self._get_expected_rotated_k(k.narrow(dimension_to_slice, 2, 3), 2, 0), - self._get_expected_rotated_k(k1, 5, 3), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 2, 3), - v1, - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_sliding_window_test_cases) - def test_evict_with_all_shift_for_sliding_window( - self, sink_size, eviction_batch_size - ): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has no spaces for new tokens, shift all tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(23) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([28], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 6) == -6 - - zero_k, zero_v = self._zero_kv_with_length(6) - expected_k = torch.cat( - [ - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 1, 22), 6, 0 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v1.narrow(dimension_to_slice, 1, 22), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_batch_evict_test_cases) - def test_batch_evict_with_seq_len(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has some spaces for new tokens but not all, shift some tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(25) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([30], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 12) == -10 - - zero_k, zero_v = self._zero_kv_with_length(12) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 9, 16), 14, 4 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 9, 16), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_batch_evict_test_cases) - def test_batch_evict_with_batch_size(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has no spaces for new tokens, shift all tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(25) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([30], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 6) == -8 - - zero_k, zero_v = self._zero_kv_with_length(10) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 7, 18), 12, 4 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 7, 18), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) diff --git a/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py index 7109129caca..485a5b13bdd 100644 --- a/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py +++ b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py @@ -59,30 +59,54 @@ def test_simple_update(self): def test_wraparound(self): """Test ring buffer wraparound.""" + # sink_size=0 (default) in setUp, so it wraps to 0 + # Let's test non-zero sink size + sink_size = 4 + cache_size = 32 + manager = CachePositionsManagerWithSink(cache_size, sink_size) + # Fill cache to position 30 input_pos = torch.tensor([0], dtype=torch.long) - self.manager.calculate_positions_and_update_indices(input_pos, 30) + manager.calculate_positions_and_update_indices(input_pos, 30) - # Add 5 more tokens at position 30 - should wrap around + # Add 5 more tokens at position 30 input_pos = torch.tensor([30], dtype=torch.long) - indices = self.manager.calculate_positions_and_update_indices(input_pos, 5) - - # Indices should wrap: 30 % 32 = 30, 31, 0, 1, 2 - expected_indices = torch.tensor([30, 31, 0, 1, 2], dtype=torch.long) + indices = manager.calculate_positions_and_update_indices(input_pos, 5) + + # Indices logic: + # Ring size = 32 - 4 = 28 + # pos 30 -> idx = 4 + (30-4)%28 = 4 + 26 = 30 + # pos 31 -> idx = 4 + (31-4)%28 = 4 + 27 = 31 + # pos 32 -> idx = 4 + (32-4)%28 = 4 + 0 = 4 (WRAPS TO SINK_SIZE=4) + # pos 33 -> idx = 4 + (33-4)%28 = 4 + 1 = 5 + # pos 34 -> idx = 4 + (34-4)%28 = 4 + 2 = 6 + expected_indices = torch.tensor([30, 31, 4, 5, 6], dtype=torch.long) torch.testing.assert_close(indices, expected_indices) def test_cache_positions_track_original_positions(self): """Cache positions should track which original position is at each index.""" + sink_size = 4 + cache_size = 32 + manager = CachePositionsManagerWithSink(cache_size, sink_size) + # Fill with positions 0-31 input_pos = torch.tensor([0], dtype=torch.long) - self.manager.calculate_positions_and_update_indices(input_pos, 32) + manager.calculate_positions_and_update_indices(input_pos, 32) + + # Indices 0-3 should have pos 0-3 (Sink) + for i in range(4): + self.assertEqual(manager.cache_positions[i].item(), i) - # Now add position 32 which wraps to index 0 + # Now add position 32. + # (32-4)%28 = 0. So index = 4 + 0 = 4. input_pos = torch.tensor([32], dtype=torch.long) - self.manager.calculate_positions_and_update_indices(input_pos, 1) + manager.calculate_positions_and_update_indices(input_pos, 1) - # Index 0 should now contain original position 32 - self.assertEqual(self.manager.cache_positions[0].item(), 32) + # Index 4 should now contain original position 32 + self.assertEqual(manager.cache_positions[4].item(), 32) + + # Index 0 (sink) should STILL contain position 0 + self.assertEqual(manager.cache_positions[0].item(), 0) class CausalMaskTest(unittest.TestCase): From 97ba715656fc75eac3fb9537b89ff413479ce7c9 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Feb 2026 14:56:10 -0800 Subject: [PATCH 08/17] Update --- .../source_transformation/attention_sink.py | 64 +-- .../test_attention_sink.py | 514 ++++++++++++++++++ .../test_attention_sink_ring_buffer.py | 56 +- 3 files changed, 580 insertions(+), 54 deletions(-) create mode 100644 examples/models/llama/source_transformation/test_attention_sink.py diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index d8a6919f7d1..ba8b5298d66 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -138,71 +138,57 @@ class CachePositionsManagerWithSink(nn.Module): """ Manages cache positions for attention sink + sliding window. - For sink_size=0: behaves exactly like original CachePositionsManager. - For sink_size>0: sink tokens go to fixed positions, rest uses ring buffer. + For sink_size=0: behaves exactly like original CachePositionsManager (simple ring buffer). + For sink_size>0: sink tokens (indices 0 to sink_size-1) are NEVER overwritten. + Ring buffer only cycles through indices sink_size to cache_size-1. - IMPORTANT: cache_size should be the actual cache dimension size (2x window for ring buffer). + IMPORTANT: cache_size should be the actual cache dimension size (sink_size + 2*window_size). """ def __init__(self, cache_size: int, sink_size: int = 0): super().__init__() - # cache_size is the actual size of the kv cache dimension self.max_context_length = cache_size self.sink_size = sink_size + # Ring buffer size = cache_size - sink_size + self.ring_size = cache_size - sink_size # Use zeros like original CachePositionsManager self.register_buffer( "cache_positions", torch.zeros((self.max_context_length,), dtype=torch.long, device="cpu"), ) - + def calculate_positions_and_update_indices( self, input_pos: torch.Tensor, seq_len: int ) -> torch.Tensor: """ Calculate indices into k_cache, v_cache for placing k_val, v_val. - Indices logic: - - If pos < sink_size: index = pos - - If pos >= sink_size: index = sink_size + (pos - sink_size) % (cache_size - sink_size) + Index calculation: + - Position < sink_size: index = position (sink tokens at fixed indices) + - Position >= sink_size: index = sink_size + (position - sink_size) % ring_size + + This ensures sink tokens (indices 0 to sink_size-1) are NEVER overwritten. """ start_pos = input_pos[0].item() torch._check_is_size(start_pos) - orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos + # Original positions for the sequence + orig_positions = torch.arange(seq_len, dtype=torch.long) + start_pos if self.sink_size == 0: # Simple ring buffer: just mod by cache size - indices = orig_indices % self.max_context_length + indices = orig_positions % self.max_context_length else: - # Shifted ring buffer logic - ring_size = self.max_context_length - self.sink_size - - # Calculate indices based on sink vs ring buffer logic - # Logic: - # 1. Calculate potential ring buffer index: sink_size + (pos - sink_size) % ring_size - # 2. If pos < sink_size, use pos. Else use ring buffer index. - - # Note: (pos - sink_size) % ring_size works correctly even if pos < sink_size - # in Python, but we want to be explicit. - # However, for pure torch.export compatibility without conditionals on tensors, - # we can use where or math. - - # Vectorized calculation: - shifted_pos = orig_indices - self.sink_size - ring_indices = self.sink_size + (shifted_pos % ring_size) - - # If position is within sink (0..sink_size-1), use original position - # Else use ring index - indices = torch.where(orig_indices < self.sink_size, orig_indices, ring_indices) - - # Update cache_positions exactly like original CachePositionsManager - full_t = torch.full((self.max_context_length,), -1, dtype=torch.long) - arange_tensor = torch.arange(self.max_context_length, dtype=torch.long) - cache_positions = torch.where( - arange_tensor < start_pos, self.cache_positions, full_t - ) - self.cache_positions.copy_(cache_positions) - self.cache_positions.index_copy_(0, indices, orig_indices) + # Shifted ring buffer: sink tokens at fixed indices, rest in ring buffer + # For position >= sink_size: index = sink_size + (position - sink_size) % ring_size + shifted = orig_positions - self.sink_size + ring_indices = self.sink_size + (shifted % self.ring_size) + # For position < sink_size: use position directly + indices = torch.where(orig_positions < self.sink_size, orig_positions, ring_indices) + + # Update cache_positions to track what position is at each index + # Only update the indices we're writing to + self.cache_positions.index_copy_(0, indices, orig_positions) return indices diff --git a/examples/models/llama/source_transformation/test_attention_sink.py b/examples/models/llama/source_transformation/test_attention_sink.py new file mode 100644 index 00000000000..fc882ebf4ab --- /dev/null +++ b/examples/models/llama/source_transformation/test_attention_sink.py @@ -0,0 +1,514 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.examples.models.llama.model_args import ModelArgs + +from executorch.examples.models.llama.source_transformation.attention_sink import ( + KVCacheWithAttentionSink, + RopeWithAttentionSink, +) +from parameterized import parameterized + + +class RopeWithAttentionSinkTest(unittest.TestCase): + + def _init_rope(self, params: ModelArgs, eviction_batch_size: int): + return RopeWithAttentionSink( + params=params, + window_size=252, + sink_size=4, + eviction_batch_size=eviction_batch_size, + ) + + def setUp(self): + torch.manual_seed(42) + self.params = ModelArgs( + use_kv_cache=True, enable_dynamic_shape=True, max_context_len=256 + ) + self.rope_with_attention_sink = self._init_rope( + params=self.params, eviction_batch_size=1 + ) + + @parameterized.expand( + [ + [0, 10, 1, 0], # No shift + [250, 10, 1, 246], # Some shift + [256, 10, 1, 246], # All shift + [0, 10, 30, 0], # No shift with batch eviction + [250, 10, 30, 220], # Some shift with batch eviction + [256, 10, 30, 226], # All shift with batch eviction + ] + ) + def test_get_freqs( + self, input_pos, seq_len, eviction_batch_size, expected_result_pos + ): + self.rope_with_attention_sink = self._init_rope( + params=self.params, eviction_batch_size=eviction_batch_size + ) + + freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( + input_pos=torch.tensor([input_pos], dtype=torch.int32), + seq_len=seq_len, + ) + + torch.testing.assert_close( + freqs_cos, + self.rope_with_attention_sink.freqs_cos.narrow( + 0, expected_result_pos, seq_len + ), + ) + torch.testing.assert_close( + freqs_sin, + self.rope_with_attention_sink.freqs_sin.narrow( + 0, expected_result_pos, seq_len + ), + ) + + @parameterized.expand( + [ + [128, 127], # Rotate left + [128, 128], # No rotation + [128, 129], # Rotate right + ] + ) + def test_rotate(self, original_position, new_position): + seq_len = 32 + + size = (1, seq_len, self.params.n_heads, self.params.head_dim) + q = torch.rand(*size, dtype=torch.float32) + k = torch.rand( + *size, + dtype=torch.float32, + ) + freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( + input_pos=torch.tensor([original_position], dtype=torch.int32), + seq_len=seq_len, + ) + _, pre_rotated_k = self.rope_with_attention_sink.forward( + q=q, + k=k, + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + ) + + rerotated_k = self.rope_with_attention_sink.rerotate_k( + k=pre_rotated_k, + original_position=original_position, + new_position=new_position, + ) + + freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( + input_pos=torch.tensor([new_position], dtype=torch.int32), + seq_len=seq_len, + ) + _, expected_k = self.rope_with_attention_sink.forward( + q=q, + k=k, + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + ) + + torch.testing.assert_close(rerotated_k, expected_k) + + +class KVCacheWithAttentionSinkTest(unittest.TestCase): + + _single_evict_test_cases = [ + [4, 1], + ] + + _batch_evict_test_cases = [ + [4, 8], + ] + + _sliding_window_test_cases = [ + [0, 1], + ] + + def _init_cache(self, sink_size, eviction_batch_size): + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=self.window_size + sink_size, + ) + self.rope_with_attention_sink = RopeWithAttentionSink( + params=self.params, + window_size=self.window_size, + sink_size=sink_size, + eviction_batch_size=eviction_batch_size, + ) + self.kv_cache = KVCacheWithAttentionSink( + n_heads=self.params.n_heads, + head_dim=self.params.head_dim, + enable_dynamic_shape=self.params.enable_dynamic_shape, + rope=self.rope_with_attention_sink, + max_batch_size=self.max_batch_size, + window_size=self.window_size, + sink_size=sink_size, + eviction_batch_size=eviction_batch_size, + dtype=self.dtype, + ) + + def _rand_kv_with_length(self, seq_len): + size = ( + self.max_batch_size, + self.params.n_heads, + seq_len, + self.params.head_dim, + ) + k = torch.rand( + *size, + dtype=self.dtype, + ) + v = torch.rand( + *size, + dtype=self.dtype, + ) + return k, v + + def _zero_kv_with_length(self, seq_len): + size = ( + self.max_batch_size, + self.params.n_heads, + seq_len, + self.params.head_dim, + ) + k = torch.zeros( + *size, + dtype=self.dtype, + ) + v = torch.zeros( + *size, + dtype=self.dtype, + ) + return k, v + + def _get_dim_to_slice(self): + return 2 + + def _get_expected_rotated_k(self, k, original_position, new_position): + return self.rope_with_attention_sink.rerotate_k( + k=k.transpose(1, 2), + original_position=original_position, + new_position=new_position, + ).transpose(1, 2) + + def setUp(self): + torch.manual_seed(42) + self.max_batch_size = 1 + self.window_size = 28 + self.dtype = torch.float32 + + @parameterized.expand( + _single_evict_test_cases + _batch_evict_test_cases + _sliding_window_test_cases + ) + def test_evict_empty_cache(self, sink_size, eviction_batch_size): + self._init_cache(sink_size, eviction_batch_size) + + # KV cache is empty, evict does nothing + input_pos = torch.tensor([0], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 1) == 0 + + expected_k, expected_v = self._zero_kv_with_length(self.window_size + sink_size) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand( + _single_evict_test_cases + _batch_evict_test_cases + _sliding_window_test_cases + ) + def test_evict_without_shift(self, sink_size, eviction_batch_size): + dimension_to_slice = 2 + + self._init_cache(sink_size, eviction_batch_size) + + # KV cache has enough spaces for new tokens, no shift + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(10) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([10], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 1) == 0 + + zero_k, zero_v = self._zero_kv_with_length(self.window_size + sink_size - 10) + + expected_k = torch.cat( + [ + k, + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v, + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_single_evict_test_cases) + def test_evict_with_some_shift(self, sink_size, eviction_batch_size): + dimension_to_slice = self._get_dim_to_slice() + + self._init_cache(sink_size, eviction_batch_size) + + # KV cache has some spaces for new tokens but not all, shift some tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(5) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([10], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 24) == -2 + + zero_k, zero_v = self._zero_kv_with_length(24) + expected_k = torch.cat( + [ + k.narrow(dimension_to_slice, 0, sink_size), + self._get_expected_rotated_k(k1.narrow(dimension_to_slice, 1, 4), 6, 4), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v.narrow(dimension_to_slice, 0, sink_size), + v1.narrow(dimension_to_slice, 1, 4), + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_single_evict_test_cases) + def test_evict_with_all_shift(self, sink_size, eviction_batch_size): + dimension_to_slice = self._get_dim_to_slice() + + self._init_cache(sink_size, eviction_batch_size) + + # KV cache has no spaces for new tokens, shift all tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(27) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([32], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 6) == -6 + + zero_k, zero_v = self._zero_kv_with_length(6) + expected_k = torch.cat( + [ + k.narrow(dimension_to_slice, 0, sink_size), + self._get_expected_rotated_k( + k1.narrow(dimension_to_slice, 5, 22), 10, 4 + ), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v.narrow(dimension_to_slice, 0, sink_size), + v1.narrow(dimension_to_slice, 5, 22), + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_sliding_window_test_cases) + def test_evict_with_some_shift_for_sliding_window( + self, sink_size, eviction_batch_size + ): + dimension_to_slice = self._get_dim_to_slice() + + self._init_cache(sink_size, eviction_batch_size) + + # KV cache has some spaces for new tokens but not all, shift some tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(5) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([10], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 20) == -2 + + zero_k, zero_v = self._zero_kv_with_length(20) + expected_k = torch.cat( + [ + self._get_expected_rotated_k(k.narrow(dimension_to_slice, 2, 3), 2, 0), + self._get_expected_rotated_k(k1, 5, 3), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v.narrow(dimension_to_slice, 2, 3), + v1, + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_sliding_window_test_cases) + def test_evict_with_all_shift_for_sliding_window( + self, sink_size, eviction_batch_size + ): + dimension_to_slice = self._get_dim_to_slice() + + self._init_cache(sink_size, eviction_batch_size) + + # KV cache has no spaces for new tokens, shift all tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(23) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([28], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 6) == -6 + + zero_k, zero_v = self._zero_kv_with_length(6) + expected_k = torch.cat( + [ + self._get_expected_rotated_k( + k1.narrow(dimension_to_slice, 1, 22), 6, 0 + ), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v1.narrow(dimension_to_slice, 1, 22), + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_batch_evict_test_cases) + def test_batch_evict_with_seq_len(self, sink_size, eviction_batch_size): + dimension_to_slice = self._get_dim_to_slice() + + self._init_cache(sink_size, eviction_batch_size) + + # KV cache has some spaces for new tokens but not all, shift some tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(25) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([30], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 12) == -10 + + zero_k, zero_v = self._zero_kv_with_length(12) + expected_k = torch.cat( + [ + k.narrow(dimension_to_slice, 0, sink_size), + self._get_expected_rotated_k( + k1.narrow(dimension_to_slice, 9, 16), 14, 4 + ), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v.narrow(dimension_to_slice, 0, sink_size), + v1.narrow(dimension_to_slice, 9, 16), + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_batch_evict_test_cases) + def test_batch_evict_with_batch_size(self, sink_size, eviction_batch_size): + dimension_to_slice = self._get_dim_to_slice() + + self._init_cache(sink_size, eviction_batch_size) + + # KV cache has no spaces for new tokens, shift all tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(25) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([30], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 6) == -8 + + zero_k, zero_v = self._zero_kv_with_length(10) + expected_k = torch.cat( + [ + k.narrow(dimension_to_slice, 0, sink_size), + self._get_expected_rotated_k( + k1.narrow(dimension_to_slice, 7, 18), 12, 4 + ), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v.narrow(dimension_to_slice, 0, sink_size), + v1.narrow(dimension_to_slice, 7, 18), + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) diff --git a/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py index 485a5b13bdd..92f54b8e468 100644 --- a/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py +++ b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py @@ -36,7 +36,8 @@ class CachePositionsManagerWithSinkTest(unittest.TestCase): def setUp(self): self.cache_size = 32 # Total cache size (e.g., sink_size + window_size * 2) - self.manager = CachePositionsManagerWithSink(self.cache_size) + # Default: no sink (simple ring buffer) + self.manager = CachePositionsManagerWithSink(self.cache_size, sink_size=0) def test_initial_positions_are_zero(self): """Cache positions should start as zeros.""" @@ -57,10 +58,22 @@ def test_simple_update(self): for i in range(5): self.assertEqual(self.manager.cache_positions[i].item(), i) - def test_wraparound(self): - """Test ring buffer wraparound.""" - # sink_size=0 (default) in setUp, so it wraps to 0 - # Let's test non-zero sink size + def test_wraparound_no_sink(self): + """Test ring buffer wraparound with sink_size=0.""" + # Fill cache to position 30 + input_pos = torch.tensor([0], dtype=torch.long) + self.manager.calculate_positions_and_update_indices(input_pos, 30) + + # Add 5 more tokens at position 30 - should wrap around + input_pos = torch.tensor([30], dtype=torch.long) + indices = self.manager.calculate_positions_and_update_indices(input_pos, 5) + + # Indices should wrap: 30 % 32 = 30, 31, 0, 1, 2 + expected_indices = torch.tensor([30, 31, 0, 1, 2], dtype=torch.long) + torch.testing.assert_close(indices, expected_indices) + + def test_wraparound_with_sink(self): + """Test ring buffer wraparound with sink_size > 0.""" sink_size = 4 cache_size = 32 manager = CachePositionsManagerWithSink(cache_size, sink_size) @@ -73,18 +86,30 @@ def test_wraparound(self): input_pos = torch.tensor([30], dtype=torch.long) indices = manager.calculate_positions_and_update_indices(input_pos, 5) - # Indices logic: # Ring size = 32 - 4 = 28 # pos 30 -> idx = 4 + (30-4)%28 = 4 + 26 = 30 # pos 31 -> idx = 4 + (31-4)%28 = 4 + 27 = 31 - # pos 32 -> idx = 4 + (32-4)%28 = 4 + 0 = 4 (WRAPS TO SINK_SIZE=4) + # pos 32 -> idx = 4 + (32-4)%28 = 4 + 0 = 4 (WRAPS TO SINK_SIZE=4, not 0!) # pos 33 -> idx = 4 + (33-4)%28 = 4 + 1 = 5 # pos 34 -> idx = 4 + (34-4)%28 = 4 + 2 = 6 expected_indices = torch.tensor([30, 31, 4, 5, 6], dtype=torch.long) torch.testing.assert_close(indices, expected_indices) - def test_cache_positions_track_original_positions(self): - """Cache positions should track which original position is at each index.""" + def test_cache_positions_track_original_positions_no_sink(self): + """Cache positions should track which original position is at each index (no sink).""" + # Fill with positions 0-31 + input_pos = torch.tensor([0], dtype=torch.long) + self.manager.calculate_positions_and_update_indices(input_pos, 32) + + # Now add position 32 which wraps to index 0 + input_pos = torch.tensor([32], dtype=torch.long) + self.manager.calculate_positions_and_update_indices(input_pos, 1) + + # Index 0 should now contain original position 32 + self.assertEqual(self.manager.cache_positions[0].item(), 32) + + def test_cache_positions_track_original_positions_with_sink(self): + """Cache positions should track positions, and sink tokens are never overwritten.""" sink_size = 4 cache_size = 32 manager = CachePositionsManagerWithSink(cache_size, sink_size) @@ -93,20 +118,21 @@ def test_cache_positions_track_original_positions(self): input_pos = torch.tensor([0], dtype=torch.long) manager.calculate_positions_and_update_indices(input_pos, 32) - # Indices 0-3 should have pos 0-3 (Sink) + # Indices 0-3 should have pos 0-3 (Sink tokens) for i in range(4): self.assertEqual(manager.cache_positions[i].item(), i) - - # Now add position 32. + + # Now add position 32. # (32-4)%28 = 0. So index = 4 + 0 = 4. input_pos = torch.tensor([32], dtype=torch.long) manager.calculate_positions_and_update_indices(input_pos, 1) - + # Index 4 should now contain original position 32 self.assertEqual(manager.cache_positions[4].item(), 32) - # Index 0 (sink) should STILL contain position 0 - self.assertEqual(manager.cache_positions[0].item(), 0) + # Index 0-3 (sink) should STILL contain positions 0-3 (unchanged) + for i in range(4): + self.assertEqual(manager.cache_positions[i].item(), i) class CausalMaskTest(unittest.TestCase): From e4ccab473630315f821080fa50ae2495a65063f9 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Feb 2026 15:14:30 -0800 Subject: [PATCH 09/17] Remove obsolete test_attention_sink.py --- .../test_attention_sink.py | 514 ------------------ 1 file changed, 514 deletions(-) delete mode 100644 examples/models/llama/source_transformation/test_attention_sink.py diff --git a/examples/models/llama/source_transformation/test_attention_sink.py b/examples/models/llama/source_transformation/test_attention_sink.py deleted file mode 100644 index fc882ebf4ab..00000000000 --- a/examples/models/llama/source_transformation/test_attention_sink.py +++ /dev/null @@ -1,514 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -import torch -from executorch.examples.models.llama.model_args import ModelArgs - -from executorch.examples.models.llama.source_transformation.attention_sink import ( - KVCacheWithAttentionSink, - RopeWithAttentionSink, -) -from parameterized import parameterized - - -class RopeWithAttentionSinkTest(unittest.TestCase): - - def _init_rope(self, params: ModelArgs, eviction_batch_size: int): - return RopeWithAttentionSink( - params=params, - window_size=252, - sink_size=4, - eviction_batch_size=eviction_batch_size, - ) - - def setUp(self): - torch.manual_seed(42) - self.params = ModelArgs( - use_kv_cache=True, enable_dynamic_shape=True, max_context_len=256 - ) - self.rope_with_attention_sink = self._init_rope( - params=self.params, eviction_batch_size=1 - ) - - @parameterized.expand( - [ - [0, 10, 1, 0], # No shift - [250, 10, 1, 246], # Some shift - [256, 10, 1, 246], # All shift - [0, 10, 30, 0], # No shift with batch eviction - [250, 10, 30, 220], # Some shift with batch eviction - [256, 10, 30, 226], # All shift with batch eviction - ] - ) - def test_get_freqs( - self, input_pos, seq_len, eviction_batch_size, expected_result_pos - ): - self.rope_with_attention_sink = self._init_rope( - params=self.params, eviction_batch_size=eviction_batch_size - ) - - freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( - input_pos=torch.tensor([input_pos], dtype=torch.int32), - seq_len=seq_len, - ) - - torch.testing.assert_close( - freqs_cos, - self.rope_with_attention_sink.freqs_cos.narrow( - 0, expected_result_pos, seq_len - ), - ) - torch.testing.assert_close( - freqs_sin, - self.rope_with_attention_sink.freqs_sin.narrow( - 0, expected_result_pos, seq_len - ), - ) - - @parameterized.expand( - [ - [128, 127], # Rotate left - [128, 128], # No rotation - [128, 129], # Rotate right - ] - ) - def test_rotate(self, original_position, new_position): - seq_len = 32 - - size = (1, seq_len, self.params.n_heads, self.params.head_dim) - q = torch.rand(*size, dtype=torch.float32) - k = torch.rand( - *size, - dtype=torch.float32, - ) - freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( - input_pos=torch.tensor([original_position], dtype=torch.int32), - seq_len=seq_len, - ) - _, pre_rotated_k = self.rope_with_attention_sink.forward( - q=q, - k=k, - freqs_cos=freqs_cos, - freqs_sin=freqs_sin, - ) - - rerotated_k = self.rope_with_attention_sink.rerotate_k( - k=pre_rotated_k, - original_position=original_position, - new_position=new_position, - ) - - freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( - input_pos=torch.tensor([new_position], dtype=torch.int32), - seq_len=seq_len, - ) - _, expected_k = self.rope_with_attention_sink.forward( - q=q, - k=k, - freqs_cos=freqs_cos, - freqs_sin=freqs_sin, - ) - - torch.testing.assert_close(rerotated_k, expected_k) - - -class KVCacheWithAttentionSinkTest(unittest.TestCase): - - _single_evict_test_cases = [ - [4, 1], - ] - - _batch_evict_test_cases = [ - [4, 8], - ] - - _sliding_window_test_cases = [ - [0, 1], - ] - - def _init_cache(self, sink_size, eviction_batch_size): - self.params = ModelArgs( - use_kv_cache=True, - enable_dynamic_shape=True, - max_context_len=self.window_size + sink_size, - ) - self.rope_with_attention_sink = RopeWithAttentionSink( - params=self.params, - window_size=self.window_size, - sink_size=sink_size, - eviction_batch_size=eviction_batch_size, - ) - self.kv_cache = KVCacheWithAttentionSink( - n_heads=self.params.n_heads, - head_dim=self.params.head_dim, - enable_dynamic_shape=self.params.enable_dynamic_shape, - rope=self.rope_with_attention_sink, - max_batch_size=self.max_batch_size, - window_size=self.window_size, - sink_size=sink_size, - eviction_batch_size=eviction_batch_size, - dtype=self.dtype, - ) - - def _rand_kv_with_length(self, seq_len): - size = ( - self.max_batch_size, - self.params.n_heads, - seq_len, - self.params.head_dim, - ) - k = torch.rand( - *size, - dtype=self.dtype, - ) - v = torch.rand( - *size, - dtype=self.dtype, - ) - return k, v - - def _zero_kv_with_length(self, seq_len): - size = ( - self.max_batch_size, - self.params.n_heads, - seq_len, - self.params.head_dim, - ) - k = torch.zeros( - *size, - dtype=self.dtype, - ) - v = torch.zeros( - *size, - dtype=self.dtype, - ) - return k, v - - def _get_dim_to_slice(self): - return 2 - - def _get_expected_rotated_k(self, k, original_position, new_position): - return self.rope_with_attention_sink.rerotate_k( - k=k.transpose(1, 2), - original_position=original_position, - new_position=new_position, - ).transpose(1, 2) - - def setUp(self): - torch.manual_seed(42) - self.max_batch_size = 1 - self.window_size = 28 - self.dtype = torch.float32 - - @parameterized.expand( - _single_evict_test_cases + _batch_evict_test_cases + _sliding_window_test_cases - ) - def test_evict_empty_cache(self, sink_size, eviction_batch_size): - self._init_cache(sink_size, eviction_batch_size) - - # KV cache is empty, evict does nothing - input_pos = torch.tensor([0], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 1) == 0 - - expected_k, expected_v = self._zero_kv_with_length(self.window_size + sink_size) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand( - _single_evict_test_cases + _batch_evict_test_cases + _sliding_window_test_cases - ) - def test_evict_without_shift(self, sink_size, eviction_batch_size): - dimension_to_slice = 2 - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has enough spaces for new tokens, no shift - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(10) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([10], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 1) == 0 - - zero_k, zero_v = self._zero_kv_with_length(self.window_size + sink_size - 10) - - expected_k = torch.cat( - [ - k, - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v, - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_single_evict_test_cases) - def test_evict_with_some_shift(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has some spaces for new tokens but not all, shift some tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([10], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 24) == -2 - - zero_k, zero_v = self._zero_kv_with_length(24) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k(k1.narrow(dimension_to_slice, 1, 4), 6, 4), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 1, 4), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_single_evict_test_cases) - def test_evict_with_all_shift(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has no spaces for new tokens, shift all tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(27) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([32], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 6) == -6 - - zero_k, zero_v = self._zero_kv_with_length(6) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 5, 22), 10, 4 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 5, 22), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_sliding_window_test_cases) - def test_evict_with_some_shift_for_sliding_window( - self, sink_size, eviction_batch_size - ): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has some spaces for new tokens but not all, shift some tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([10], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 20) == -2 - - zero_k, zero_v = self._zero_kv_with_length(20) - expected_k = torch.cat( - [ - self._get_expected_rotated_k(k.narrow(dimension_to_slice, 2, 3), 2, 0), - self._get_expected_rotated_k(k1, 5, 3), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 2, 3), - v1, - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_sliding_window_test_cases) - def test_evict_with_all_shift_for_sliding_window( - self, sink_size, eviction_batch_size - ): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has no spaces for new tokens, shift all tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(23) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([28], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 6) == -6 - - zero_k, zero_v = self._zero_kv_with_length(6) - expected_k = torch.cat( - [ - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 1, 22), 6, 0 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v1.narrow(dimension_to_slice, 1, 22), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_batch_evict_test_cases) - def test_batch_evict_with_seq_len(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has some spaces for new tokens but not all, shift some tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(25) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([30], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 12) == -10 - - zero_k, zero_v = self._zero_kv_with_length(12) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 9, 16), 14, 4 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 9, 16), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_batch_evict_test_cases) - def test_batch_evict_with_batch_size(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has no spaces for new tokens, shift all tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(25) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([30], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 6) == -8 - - zero_k, zero_v = self._zero_kv_with_length(10) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 7, 18), 12, 4 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 7, 18), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) From bbca983c34c4b1d074008db133e65fc9c782cb03 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Feb 2026 15:36:56 -0800 Subject: [PATCH 10/17] Border --- .../llama/source_transformation/attention_sink.py | 11 ++++++++--- .../test_attention_sink_ring_buffer.py | 6 +++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index ba8b5298d66..3a69d85e61b 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -126,7 +126,11 @@ def _create_causal_mask_for_attention_sink( is_sink = cache_positions < sink_size # Window tokens must be within sliding window - is_in_window = delta < window_size + # Use <= to include the boundary token. For window_size=124, we want to attend + # to the last 124 tokens BEFORE the current position (delta 1 to 124), plus + # position 4 (first non-sink token) which has delta exactly = window_size. + # This ensures sink_size + window_size tokens are visible when cache is full. + is_in_window = delta <= window_size # Final mask: valid AND (is_sink OR is_in_window) attn_mask = is_valid & (is_sink | is_in_window) @@ -151,10 +155,11 @@ def __init__(self, cache_size: int, sink_size: int = 0): self.sink_size = sink_size # Ring buffer size = cache_size - sink_size self.ring_size = cache_size - sink_size - # Use zeros like original CachePositionsManager + # Initialize to -1 to mark unwritten positions + # The mask uses (cache_positions >= 0) to check if a position is valid self.register_buffer( "cache_positions", - torch.zeros((self.max_context_length,), dtype=torch.long, device="cpu"), + torch.full((self.max_context_length,), -1, dtype=torch.long, device="cpu"), ) def calculate_positions_and_update_indices( diff --git a/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py index 92f54b8e468..953033f5cd8 100644 --- a/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py +++ b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py @@ -39,9 +39,9 @@ def setUp(self): # Default: no sink (simple ring buffer) self.manager = CachePositionsManagerWithSink(self.cache_size, sink_size=0) - def test_initial_positions_are_zero(self): - """Cache positions should start as zeros.""" - expected = torch.zeros(self.cache_size, dtype=torch.long) + def test_initial_positions_are_minus_one(self): + """Cache positions should start as -1 (unwritten).""" + expected = torch.full((self.cache_size,), -1, dtype=torch.long) torch.testing.assert_close(self.manager.cache_positions, expected) def test_simple_update(self): From fec5eb9505d9517ca6ca89fe60df2c2b2b7d6002 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Feb 2026 16:06:50 -0800 Subject: [PATCH 11/17] Test commands --- test_attention_sink.md | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 test_attention_sink.md diff --git a/test_attention_sink.md b/test_attention_sink.md new file mode 100644 index 00000000000..0e70723cc20 --- /dev/null +++ b/test_attention_sink.md @@ -0,0 +1,37 @@ +# Build the runtime + +Do this after modifying runtime code (cpp) +```sh +cmake --workflow llm-debug +pushd examples/models/llama +cmake --workflow --preset llama-debug +popd +``` + +# Export model +Take a look at examples/models/llama/README.md + +Check point is in ~/executorch/ + +Make sure you are in conda executorch env + +# No quantization +# Set these paths to point to the downloaded files +LLAMA_CHECKPOINT=path/to/consolidated.00.pth +LLAMA_PARAMS=path/to/params.json + +python -m extension.llm.export.export_llm \ + --config examples/models/llama/config/llama_bf16.yaml \ + +base.model_class="llama3_2" \ + +base.checkpoint="consolidated.00.pth" \ + +base.params="params.json" +``` + +# Run + +Please also take a look at examples/models/llama/runner to make sure it can emit many tokens, exceeding context size. + +Please check whether the output makes sense or not +``` +cmake-out/examples/models/llama/llama_main --model_path= --tokenizer_path= --prompt= +``` From f3a048781f1baeeff6f62c4a64a2dec8ec3aa5fd Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Feb 2026 20:31:20 -0800 Subject: [PATCH 12/17] Runner side --- examples/models/llama/attention.py | 23 ++- .../config/llama_attention_sink_noxnn.yaml | 29 ++++ .../config/llama_attention_sink_xnnpack.yaml | 25 +++ .../config/llama_runner_attention_sink.yaml | 42 +++++ examples/models/llama/export_llama_lib.py | 19 ++- .../llama/source_transformation/sdpa.py | 7 +- extension/llm/runner/constants.h | 5 + .../io_manager/attention_sink_io_manager.cpp | 113 +++++++++++++ .../io_manager/attention_sink_io_manager.h | 152 ++++++++++++++++++ extension/llm/runner/io_manager/targets.bzl | 17 ++ extension/llm/runner/llm_runner_helper.cpp | 53 +++++- extension/llm/runner/targets.bzl | 1 + extension/llm/runner/text_llm_runner.cpp | 27 +--- 13 files changed, 481 insertions(+), 32 deletions(-) create mode 100644 examples/models/llama/config/llama_attention_sink_noxnn.yaml create mode 100644 examples/models/llama/config/llama_attention_sink_xnnpack.yaml create mode 100644 examples/models/llama/config/llama_runner_attention_sink.yaml create mode 100644 extension/llm/runner/io_manager/attention_sink_io_manager.cpp create mode 100644 extension/llm/runner/io_manager/attention_sink_io_manager.h diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 7b7691a2304..5be41830ea8 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -278,7 +278,7 @@ def __init__( [0, 1, 2, 3, 4, NA, NA, NA] After cache update we would have [8, 1, 2, 3, 4, 5, 6, 7]. We kicked out token at pos = 0. However, the current step still has access to [pos - sliding_window_size, pos] tokens. - + To make sure we dont over attend, i.e. we dont have pos = 5 to attend to pos = 1, mask calculaton has to account for the sliding window size. @@ -486,21 +486,30 @@ def forward( if self.use_kv_cache: assert input_pos is not None - if self.enable_dynamic_shape: + is_ring = getattr(self.kv_cache, "is_ring_buffer", False) + if is_ring: + # Ring buffer models: positions can exceed max_context_len. + # The ring buffer handles wrapping via modular arithmetic. + # The causal mask is computed dynamically from cache_positions, + # so we don't use the pre-computed self.mask here. + k, v = self.kv_cache.update(input_pos, k, v) + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer( + start_pos, seqlen + ) + elif self.enable_dynamic_shape: start_pos = input_pos[-1].item() torch._check_is_size(start_pos) torch._check(start_pos < self.max_context_len) seq_length = q.size(2) # pyre-ignore: Incompatible parameter type [6] attn_mask = self.mask.narrow(0, start_pos, seq_length) + k, v = self.kv_cache.update(input_pos, k, v) else: # mask is always 2D attn_mask = self.mask[input_pos] - k, v = self.kv_cache.update(input_pos, k, v) - if getattr(self.kv_cache, "is_ring_buffer", False): - attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer( - input_pos[0].item(), seqlen - ) + k, v = self.kv_cache.update(input_pos, k, v) output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) return self.wo(output), None diff --git a/examples/models/llama/config/llama_attention_sink_noxnn.yaml b/examples/models/llama/config/llama_attention_sink_noxnn.yaml new file mode 100644 index 00000000000..17ae69df4d8 --- /dev/null +++ b/examples/models/llama/config/llama_attention_sink_noxnn.yaml @@ -0,0 +1,29 @@ +base: + metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' + +model: + use_sdpa_with_kv_cache: True # Now supported! We set use_attention_mask=True on SDPACustom + use_kv_cache: True + dtype_override: fp32 + enable_dynamic_shape: True + # Attention Sink: "sink_size,window_size,eviction_batch_size" + # sink_size=4: Keep first 4 tokens (e.g., BOS + system prompt) + # window_size=124: 滑动窗口大小 + # eviction_batch_size=1: 每次驱逐 1 个 token + # KV cache size = sink_size + window_size * 2 = 4 + 124*2 = 252 + use_attention_sink: "4,124,1" + +export: + max_seq_length: 252 + max_context_length: 512 + +# Quantization enabled for this test +quantization: + qmode: 8da4w + group_size: 128 + embedding_quantize: 4,32 + +# No XNNPACK for this test +backend: + xnnpack: + enabled: False diff --git a/examples/models/llama/config/llama_attention_sink_xnnpack.yaml b/examples/models/llama/config/llama_attention_sink_xnnpack.yaml new file mode 100644 index 00000000000..6a63f3a914a --- /dev/null +++ b/examples/models/llama/config/llama_attention_sink_xnnpack.yaml @@ -0,0 +1,25 @@ +base: + metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' + +model: + use_sdpa_with_kv_cache: True + use_kv_cache: True + dtype_override: fp32 + enable_dynamic_shape: True + use_attention_sink: "4,124,1" + +export: + max_seq_length: 252 + max_context_length: 512 + +# No quantization +# quantization: +# qmode: 8da4w +# group_size: 128 +# embedding_quantize: 4,32 + +# XNNPACK enabled +backend: + xnnpack: + enabled: True + extended_ops: True diff --git a/examples/models/llama/config/llama_runner_attention_sink.yaml b/examples/models/llama/config/llama_runner_attention_sink.yaml new file mode 100644 index 00000000000..e6945335563 --- /dev/null +++ b/examples/models/llama/config/llama_runner_attention_sink.yaml @@ -0,0 +1,42 @@ +## +## Runner-side Attention Sink configuration +## +## This uses KVCacheWithAttentionSink (model-side) together with +## the runner's AttentionSinkIOManager for position bookkeeping. +## +## Key behavior: +## - Model has KVCacheWithAttentionSink which preserves sink tokens and +## uses a ring buffer for the sliding window (is_ring_buffer=True) +## - Runner's AttentionSinkIOManager tracks logical position and allows +## generation to continue past max_context_len +## - KV cache size = sink_size + window_size * 2 = 4 + 124*2 = 252 +## + +base: + metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' + +model: + use_sdpa_with_kv_cache: True + use_kv_cache: True + dtype_override: fp32 + enable_dynamic_shape: True + # Attention Sink: "sink_size,window_size,eviction_batch_size" + # sink_size=4: Keep first 4 tokens (e.g., BOS + system prompt) + # window_size=124: Sliding window size + # eviction_batch_size=1: Evict 1 token at a time + # KV cache size = sink_size + window_size * 2 = 4 + 124*2 = 252 + use_attention_sink: "4,124,1" + +export: + # max_seq_length for single prefill chunk + max_seq_length: 128 + +quantization: + qmode: 8da4w + group_size: 128 + embedding_quantize: 4,32 + +backend: + xnnpack: + enabled: True + extended_ops: True diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 219cc71ded1..aca875acdab 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -728,9 +728,16 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: calibration_limit=llm_config.quantization.calibration_limit, calibration_seq_length=llm_config.quantization.calibration_seq_length, expand_rope_table=llm_config.model.expand_rope_table, + # Attention sink models need attention mask for custom SDPA because: + # 1. The ring buffer creates a dynamic mask based on cache_positions + # 2. Without mask, custom_sdpa uses is_causal=True with start_pos, which + # fails when start_pos exceeds the cache size (positions keep growing) + # 3. With mask, custom_sdpa uses is_causal=False and the mask handles + # all masking logic including sliding window and attention sink use_custom_sdpa_with_attention_mask=getattr( llm_config.model, "use_custom_sdpa_with_attention_mask", False - ), + ) + or bool(llm_config.model.use_attention_sink), use_sdpa_with_kv_cache=llm_config.model.use_sdpa_with_kv_cache, quantize_kv_cache=llm_config.model.quantize_kv_cache, use_kv_cache=llm_config.model.use_kv_cache, @@ -1118,6 +1125,15 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 if llm_config.base.model_class.value in TORCHTUNE_DEFINED_MODELS: additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] + # For attention sink models, the cache_positions buffer must be initialized + # to -1 (sentinel for "empty slot"). Without this pass, ExecuTorch only + # serializes shape+dtype for mutated buffers, leaving them uninitialized + # at runtime, which corrupts the attention mask computation. + if llm_config.model.use_attention_sink: + additional_passes.append( + InitializedMutableBufferPass(["cache_positions"]) + ) + # export_to_edge builder_exported = _prepare_for_llama_export(llm_config).export() builder_exported.run_canonical_optimizations() @@ -1470,7 +1486,6 @@ def _get_source_transforms( # noqa transforms.append(materialze_broadcast_of_rope_freq_cis) use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask - if use_sdpa_with_kv_cache: transforms.append(replace_kv_cache_with_custom_kv_cache) # todo: do this optionally diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index 59823b533a3..2885f9c12ea 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -50,11 +50,16 @@ def forward( v = v.to(dtype=torch.float) if self.use_attention_mask: + # When using attention mask, pass 0 as start_pos since: + # 1. The mask handles all masking logic (including ring buffer / attention sink) + # 2. is_causal=False so start_pos is not used for causal masking + # 3. This avoids validation errors when logical position >= cache size + # (e.g., with ring buffer where position 252 exceeds cache_size 252) output = torch.ops.llama.custom_sdpa( q, k, v, - input_pos[0].item(), + 0, # start_pos: not used when mask is provided mask, # Attention mask 0, # dropout probability. Ignored by the code False, # is_causal diff --git a/extension/llm/runner/constants.h b/extension/llm/runner/constants.h index d7b36077757..5e06b05cb40 100644 --- a/extension/llm/runner/constants.h +++ b/extension/llm/runner/constants.h @@ -19,6 +19,11 @@ inline constexpr auto kVocabSize = "get_vocab_size"; inline constexpr auto kUseKVCache = "use_kv_cache"; inline constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache"; +// Attention sink configuration metadata keys +inline constexpr auto kUseAttentionSink = "use_attention_sink"; +inline constexpr auto kAttentionSinkSize = "attention_sink_size"; +inline constexpr auto kAttentionSinkWindowSize = "attention_sink_window_size"; + // Multimodal method name conventions inline constexpr auto kVisionEncoderMethod = "vision_encoder"; inline constexpr auto kAudioEncoderMethod = "audio_encoder"; diff --git a/extension/llm/runner/io_manager/attention_sink_io_manager.cpp b/extension/llm/runner/io_manager/attention_sink_io_manager.cpp new file mode 100644 index 00000000000..b71080e9522 --- /dev/null +++ b/extension/llm/runner/io_manager/attention_sink_io_manager.cpp @@ -0,0 +1,113 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace executorch { +namespace extension { +namespace llm { + +AttentionSinkIOManager::AttentionSinkIOManager( + ET_MODULE_NAMESPACE::Module& module, + int64_t max_cache_size, + AttentionSinkConfig config) + : IOManager(module), + max_cache_size_(max_cache_size), + config_(config), + logical_pos_(0) { + ET_CHECK_MSG( + config_.sink_size >= 0, + "sink_size must be non-negative, got %" PRId64, + config_.sink_size); + ET_CHECK_MSG( + config_.window_size > 0, + "window_size must be positive, got %" PRId64, + config_.window_size); +} + +runtime::Error AttentionSinkIOManager::load( + const std::string& prefill_method, + const std::string& decode_method) { + (void)prefill_method; + (void)decode_method; + + ET_LOG( + Info, + "AttentionSinkIOManager loaded: sink_size=%" PRId64 + ", window_size=%" PRId64 ", max_cache_size=%" PRId64, + config_.sink_size, + config_.window_size, + max_cache_size_); + + return runtime::Error::Ok; +} + +runtime::Error AttentionSinkIOManager::reset( + const std::string& prefill_method, + const std::string& decode_method) { + (void)prefill_method; + (void)decode_method; + + logical_pos_ = 0; + + ET_LOG(Debug, "AttentionSinkIOManager reset"); + return runtime::Error::Ok; +} + +runtime::Result> +AttentionSinkIOManager::prepare_prefill( + const TensorPtr& input, + const TensorPtr& start_pos, + const std::string& prefill_method) { + int64_t logical_start = start_pos->data_ptr()[0]; + int64_t seq_len = input->numel(); + + logical_pos_ = logical_start + seq_len; + + ET_LOG( + Debug, + "AttentionSinkIOManager::prepare_prefill: logical_start=%" PRId64 + ", seq_len=%" PRId64 ", logical_pos_after=%" PRId64 + ", cache_full=%s", + logical_start, + seq_len, + logical_pos_, + is_cache_full() ? "true" : "false"); + + // Pass through to model as-is. The model's KVCacheWithAttentionSink + // (or RingKVCache) handles position-to-index mapping and mask creation. + return std::vector{input, start_pos}; +} + +runtime::Result> +AttentionSinkIOManager::prepare_decode( + const TensorPtr& input, + const TensorPtr& start_pos, + const std::string& decode_method) { + int64_t logical_start = start_pos->data_ptr()[0]; + int64_t seq_len = input->numel(); + + logical_pos_ = logical_start + seq_len; + + ET_LOG( + Debug, + "AttentionSinkIOManager::prepare_decode: logical_start=%" PRId64 + ", logical_pos_after=%" PRId64 + ", cache_full=%s", + logical_start, + logical_pos_, + is_cache_full() ? "true" : "false"); + + // Pass through to model as-is. The model's KVCacheWithAttentionSink + // (or RingKVCache) handles position-to-index mapping and mask creation. + return std::vector{input, start_pos}; +} + +} // namespace llm +} // namespace extension +} // namespace executorch diff --git a/extension/llm/runner/io_manager/attention_sink_io_manager.h b/extension/llm/runner/io_manager/attention_sink_io_manager.h new file mode 100644 index 00000000000..afee1db8354 --- /dev/null +++ b/extension/llm/runner/io_manager/attention_sink_io_manager.h @@ -0,0 +1,152 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace executorch { +namespace extension { +namespace llm { + +/** + * @brief Configuration for attention sink behavior. + */ +struct AttentionSinkConfig { + /// Number of initial tokens to always keep (sink tokens). + int64_t sink_size = 4; + + /// Size of the sliding window for non-sink tokens. + int64_t window_size = 508; + + /// When the cache is full, evict this many tokens at once. + int64_t eviction_batch_size = 256; +}; + +/** + * @brief IOManager that supports attention sink models for infinite context. + * + * This IOManager is designed to work with models that have been exported with + * attention sink support (KVCacheWithAttentionSink). The model internally + * manages: + * - Cache write indices (sink tokens at fixed positions, rest in ring buffer) + * - Attention mask creation (sink tokens always visible + sliding window) + * - Position-based RoPE embeddings + * + * The IOManager's role is to: + * 1. Pass through input_pos as-is (the model handles position mapping) + * 2. Track logical position for runner bookkeeping + * 3. Allow generation to continue past max_cache_size without errors + * + * This works with models exported using: + * - KVCacheWithAttentionSink (model-side attention sink) + * - RingKVCache (standard ring buffer - provides sliding window without sink) + * + * Cache layout (managed by model, not runner): + * [sink tokens: 0 to sink_size-1] [ring buffer: sink_size to cache_size-1] + * + * Usage: + * 1. Export model with attention sink config (use_ring_kv_cache=True, etc.) + * 2. Runner detects attention sink metadata and creates this IOManager + * 3. IOManager passes positions through; model handles cache management + */ +class ET_EXPERIMENTAL AttentionSinkIOManager : public IOManager { + public: + /** + * @brief Construct an AttentionSinkIOManager. + * + * @param module The Module used for querying method metadata and execution. + * @param max_cache_size The maximum size of the KV cache in the model. + * @param config Configuration for attention sink behavior. + */ + AttentionSinkIOManager( + ET_MODULE_NAMESPACE::Module& module, + int64_t max_cache_size, + AttentionSinkConfig config = AttentionSinkConfig()); + + /** + * @brief Load the IO manager with method metadata. + */ + ET_NODISCARD runtime::Error load( + const std::string& prefill_method, + const std::string& decode_method) override; + + /** + * @brief Reset the IO manager state. + * + * Resets the logical position counter. + */ + ET_NODISCARD runtime::Error reset( + const std::string& prefill_method, + const std::string& decode_method) override; + + /** + * @brief Prepare inputs for the prefill phase. + * + * Passes through input_pos to the model. The model's internal + * KVCacheWithAttentionSink handles position-to-index mapping and masking. + */ + runtime::Result> prepare_prefill( + const TensorPtr& input, + const TensorPtr& start_pos, + const std::string& prefill_method) override; + + /** + * @brief Prepare inputs for the decode phase. + * + * Passes through input_pos to the model. The model's internal + * KVCacheWithAttentionSink handles position-to-index mapping and masking. + */ + runtime::Result> prepare_decode( + const TensorPtr& input, + const TensorPtr& start_pos, + const std::string& decode_method) override; + + /** + * @brief Get the current logical position. + * + * This is the position in the full context, which may exceed the cache size. + * The model handles wrapping internally via ring buffer. + */ + int64_t logical_position() const { + return logical_pos_; + } + + /** + * @brief Get the attention sink configuration. + */ + const AttentionSinkConfig& config() const { + return config_; + } + + /** + * @brief Check if the cache is in the "infinite context" regime. + * + * Returns true when the logical position exceeds the effective cache + * capacity, meaning the ring buffer has wrapped and old tokens are being + * overwritten. + */ + bool is_cache_full() const { + return logical_pos_ >= max_cache_size_; + } + + private: + /// Maximum size of the KV cache in the model + int64_t max_cache_size_; + + /// Attention sink configuration + AttentionSinkConfig config_; + + /// Current logical position (may exceed max_cache_size) + int64_t logical_pos_ = 0; +}; + +} // namespace llm +} // namespace extension +} // namespace executorch diff --git a/extension/llm/runner/io_manager/targets.bzl b/extension/llm/runner/io_manager/targets.bzl index fe7fe9d56ae..4ea579423fd 100644 --- a/extension/llm/runner/io_manager/targets.bzl +++ b/extension/llm/runner/io_manager/targets.bzl @@ -17,3 +17,20 @@ def define_common_targets(): ], visibility = ["PUBLIC"], ) + + # Attention Sink IOManager for runner-side infinite context + runtime.cxx_library( + name = "attention_sink_io_manager" + aten_suffix, + srcs = [ + "attention_sink_io_manager.cpp", + ], + exported_headers = [ + "attention_sink_io_manager.h", + ], + exported_deps = [ + ":io_manager" + aten_suffix, + "//executorch/extension/tensor:tensor" + aten_suffix, + "//executorch/extension/module:module" + aten_suffix, + ], + visibility = ["PUBLIC"], + ) diff --git a/extension/llm/runner/llm_runner_helper.cpp b/extension/llm/runner/llm_runner_helper.cpp index 13f8d7a9db5..7f98b9099c8 100644 --- a/extension/llm/runner/llm_runner_helper.cpp +++ b/extension/llm/runner/llm_runner_helper.cpp @@ -9,6 +9,7 @@ // Implementation of helper utilities for creating and configuring LLM runners #include +#include #include #include #include @@ -233,8 +234,56 @@ std::unique_ptr create_text_llm_runner( auto eos_ids = std::make_unique>( llm::get_eos_ids(tokenizer.get(), module.get())); - // Create IOManager - std::unique_ptr io_manager = std::make_unique(*module); + // Create IOManager - use AttentionSinkIOManager if attention sink is enabled + std::unique_ptr io_manager; + + // Get method names to check for attention sink metadata + auto method_names_result = module->method_names(); + if (method_names_result.error() != Error::Ok) { + ET_LOG(Error, "Failed reading method names for IOManager selection"); + return nullptr; + } + const auto& method_names = method_names_result.get(); + + // Check if attention sink is enabled via metadata + bool use_attention_sink = false; + int64_t sink_size = 4; // Default values + int64_t window_size = 124; + + if (method_names.count(kUseAttentionSink)) { + auto get_result = module->get(kUseAttentionSink); + use_attention_sink = get_result.get().toScalar().to(); + } + + if (use_attention_sink) { + // Get attention sink configuration from metadata + if (method_names.count(kAttentionSinkSize)) { + auto get_result = module->get(kAttentionSinkSize); + sink_size = get_result.get().toScalar().to(); + } + if (method_names.count(kAttentionSinkWindowSize)) { + auto get_result = module->get(kAttentionSinkWindowSize); + window_size = get_result.get().toScalar().to(); + } + + AttentionSinkConfig config; + config.sink_size = sink_size; + config.window_size = window_size; + + int64_t max_cache_size = metadata.at(kMaxContextLen); + ET_LOG( + Info, + "Creating AttentionSinkIOManager with sink_size=%" PRId64 + ", window_size=%" PRId64 ", max_cache_size=%" PRId64, + sink_size, + window_size, + max_cache_size); + + io_manager = std::make_unique( + *module, max_cache_size, config); + } else { + io_manager = std::make_unique(*module); + } // Create text_decoder_runner. Use a shared_ptr so that it can be shared with // TextPrefiller and TextTokenGenerator diff --git a/extension/llm/runner/targets.bzl b/extension/llm/runner/targets.bzl index 2c9000d0137..8150647a733 100644 --- a/extension/llm/runner/targets.bzl +++ b/extension/llm/runner/targets.bzl @@ -132,6 +132,7 @@ def define_common_targets(): ":text_prefiller" + aten_suffix, ":text_token_generator" + aten_suffix, "//executorch/extension/llm/runner/io_manager:io_manager" + aten_suffix, + "//executorch/extension/llm/runner/io_manager:attention_sink_io_manager" + aten_suffix, "//pytorch/tokenizers:hf_tokenizer", "//pytorch/tokenizers:llama2c_tokenizer", "//pytorch/tokenizers:sentencepiece", diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index 11d77f321bf..82e80b0a0ce 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -132,39 +132,26 @@ Error TextLLMRunner::generate( // Get max_seq_len for single prefill chunk limit int64_t max_seq_len = metadata_.at(kMaxSeqLen); int64_t max_context_len = metadata_.at(kMaxContextLen); - + ET_CHECK_OR_RETURN_ERROR( num_prompt_tokens >= 1, InvalidArgument, "Expected at least 1 prompt token"); - - // For models with sliding window (Ring Buffer / Attention Sink), - // we allow pos_ to exceed max_context_len. The model handles this - // internally via ring buffer indexing or token eviction. - // We only check that a single prefill chunk doesn't exceed max_seq_len. + ET_CHECK_OR_RETURN_ERROR( num_prompt_tokens <= max_seq_len, InvalidArgument, "num_prompt_tokens %d > max_seq_len %" PRId64 - ", Single prefill chunk too large - please reduce prompt size or increase max_seq_len", + ", Single prefill chunk too large", num_prompt_tokens, max_seq_len); - // Determine max_new_tokens using the GenerationConfig's resolve method. - // For sliding window models, we use max_context_len directly (not reduced by pos_) - // because the model handles position wrapping internally via ring buffer. + // Determine max_new_tokens from GenerationConfig. + // For ring buffer / attention sink models, the model handles position + // wrapping internally so generation can continue past max_context_len. + // If user specified seq_len explicitly, use that as the overall token limit. int max_new_tokens = config.resolve_max_new_tokens(max_context_len, num_prompt_tokens); - - // TEMPORARY: For sliding window / infinite context testing, - // override max_new_tokens to allow unlimited generation - if (max_new_tokens <= 0) { - max_new_tokens = 1000000; // Effectively unlimited - } - // If user specified seq_len, use that instead - if (config.seq_len > 0 && config.seq_len > max_new_tokens) { - max_new_tokens = config.seq_len; - } ET_LOG( Info, From 7bf2250b4cb2e60db3199a186038e800a9015447 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Fri, 6 Feb 2026 17:45:38 -0800 Subject: [PATCH 13/17] Update --- .../io_manager/attention_sink_io_manager.cpp | 79 +++++++++++++++++-- .../io_manager/attention_sink_io_manager.h | 24 +++++- extension/llm/runner/llm_runner_helper.cpp | 11 ++- 3 files changed, 100 insertions(+), 14 deletions(-) diff --git a/extension/llm/runner/io_manager/attention_sink_io_manager.cpp b/extension/llm/runner/io_manager/attention_sink_io_manager.cpp index b71080e9522..a6b781f2e9f 100644 --- a/extension/llm/runner/io_manager/attention_sink_io_manager.cpp +++ b/extension/llm/runner/io_manager/attention_sink_io_manager.cpp @@ -14,10 +14,10 @@ namespace llm { AttentionSinkIOManager::AttentionSinkIOManager( ET_MODULE_NAMESPACE::Module& module, - int64_t max_cache_size, + int64_t max_context_len, AttentionSinkConfig config) : IOManager(module), - max_cache_size_(max_cache_size), + max_context_len_(max_context_len), config_(config), logical_pos_(0) { ET_CHECK_MSG( @@ -39,10 +39,10 @@ runtime::Error AttentionSinkIOManager::load( ET_LOG( Info, "AttentionSinkIOManager loaded: sink_size=%" PRId64 - ", window_size=%" PRId64 ", max_cache_size=%" PRId64, + ", window_size=%" PRId64 ", max_context_len=%" PRId64, config_.sink_size, config_.window_size, - max_cache_size_); + max_context_len_); return runtime::Error::Ok; } @@ -79,8 +79,14 @@ AttentionSinkIOManager::prepare_prefill( logical_pos_, is_cache_full() ? "true" : "false"); - // Pass through to model as-is. The model's KVCacheWithAttentionSink - // (or RingKVCache) handles position-to-index mapping and mask creation. + // Check if we need to provide cache_indices (3rd input) + auto method_meta = module_.method_meta(prefill_method); + if (method_meta.ok() && method_meta->num_inputs() == 3) { + update_indices_tensor(logical_start, seq_len); + return std::vector{input, start_pos, *indices_tensor_}; + } + + // Pass through to model as-is. return std::vector{input, start_pos}; } @@ -103,11 +109,68 @@ AttentionSinkIOManager::prepare_decode( logical_pos_, is_cache_full() ? "true" : "false"); - // Pass through to model as-is. The model's KVCacheWithAttentionSink - // (or RingKVCache) handles position-to-index mapping and mask creation. + // Check if we need to provide cache_indices (3rd input) + auto method_meta = module_.method_meta(decode_method); + if (method_meta.ok() && method_meta->num_inputs() == 3) { + update_indices_tensor(logical_start, seq_len); + return std::vector{input, start_pos, *indices_tensor_}; + } + + // Pass through to model as-is. return std::vector{input, start_pos}; } +void AttentionSinkIOManager::update_indices_tensor( + int64_t logical_start, + int64_t seq_len) { + int64_t ring_size = config_.window_size * 2; + indices_buffer_.resize(seq_len); + for (int64_t i = 0; i < seq_len; ++i) { + int64_t pos = logical_start + i; + if (pos < config_.sink_size) { + indices_buffer_[i] = pos; + } else { + indices_buffer_[i] = + config_.sink_size + (pos - config_.sink_size) % ring_size; + } + } + + // Wrap in tensor + if (!indices_tensor_impl_ || indices_tensor_impl_->size(0) != seq_len) { + sizes_vec_ = {static_cast(seq_len)}; + dim_order_vec_ = {0}; + strides_vec_ = {1}; + + indices_tensor_impl_ = std::make_unique( + exec_aten::ScalarType::Long, + 1, + sizes_vec_.data(), + static_cast(indices_buffer_.data()), + dim_order_vec_.data(), + strides_vec_.data(), + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND); + indices_tensor_ = + std::make_unique(indices_tensor_impl_.get()); + } else { + // Update logic if buffer moved (vector resize might reallocate) + // Just re-create to be safe as data ptr is used + sizes_vec_ = {static_cast(seq_len)}; + dim_order_vec_ = {0}; + strides_vec_ = {1}; + + indices_tensor_impl_ = std::make_unique( + exec_aten::ScalarType::Long, + 1, + sizes_vec_.data(), + static_cast(indices_buffer_.data()), + dim_order_vec_.data(), + strides_vec_.data(), + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND); + indices_tensor_ = + std::make_unique(indices_tensor_impl_.get()); + } +} + } // namespace llm } // namespace extension } // namespace executorch diff --git a/extension/llm/runner/io_manager/attention_sink_io_manager.h b/extension/llm/runner/io_manager/attention_sink_io_manager.h index afee1db8354..a8bafa6feb4 100644 --- a/extension/llm/runner/io_manager/attention_sink_io_manager.h +++ b/extension/llm/runner/io_manager/attention_sink_io_manager.h @@ -14,6 +14,7 @@ namespace executorch { namespace extension { namespace llm { +namespace exec_aten = ::executorch::aten; /** * @brief Configuration for attention sink behavior. @@ -67,7 +68,7 @@ class ET_EXPERIMENTAL AttentionSinkIOManager : public IOManager { */ AttentionSinkIOManager( ET_MODULE_NAMESPACE::Module& module, - int64_t max_cache_size, + int64_t max_context_len, AttentionSinkConfig config = AttentionSinkConfig()); /** @@ -133,18 +134,35 @@ class ET_EXPERIMENTAL AttentionSinkIOManager : public IOManager { * overwritten. */ bool is_cache_full() const { - return logical_pos_ >= max_cache_size_; + return logical_pos_ >= max_context_len_; } private: /// Maximum size of the KV cache in the model - int64_t max_cache_size_; + int64_t max_context_len_; /// Attention sink configuration AttentionSinkConfig config_; /// Current logical position (may exceed max_cache_size) int64_t logical_pos_ = 0; + + /** + * @brief Update the internal indices buffer and tensor for a given position and length. + */ + void update_indices_tensor(int64_t logical_start, int64_t seq_len); + + // Buffer for cache indices + std::vector indices_buffer_; + + // Tensor wrapper for indices + std::unique_ptr indices_tensor_impl_; + std::unique_ptr indices_tensor_; + + // Metadata storage for TensorImpl + std::vector sizes_vec_; + std::vector dim_order_vec_; + std::vector strides_vec_; }; } // namespace llm diff --git a/extension/llm/runner/llm_runner_helper.cpp b/extension/llm/runner/llm_runner_helper.cpp index 7f98b9099c8..7ce2c4f7caa 100644 --- a/extension/llm/runner/llm_runner_helper.cpp +++ b/extension/llm/runner/llm_runner_helper.cpp @@ -248,7 +248,7 @@ std::unique_ptr create_text_llm_runner( // Check if attention sink is enabled via metadata bool use_attention_sink = false; int64_t sink_size = 4; // Default values - int64_t window_size = 124; + int64_t window_size = -1; if (method_names.count(kUseAttentionSink)) { auto get_result = module->get(kUseAttentionSink); @@ -265,12 +265,17 @@ std::unique_ptr create_text_llm_runner( auto get_result = module->get(kAttentionSinkWindowSize); window_size = get_result.get().toScalar().to(); } + + int64_t max_cache_size = metadata.at(kMaxContextLen); + + // If window_size is not found in metadata, calculate from max_context_len + if (window_size == -1) { + window_size = max_cache_size - sink_size; + } AttentionSinkConfig config; config.sink_size = sink_size; config.window_size = window_size; - - int64_t max_cache_size = metadata.at(kMaxContextLen); ET_LOG( Info, "Creating AttentionSinkIOManager with sink_size=%" PRId64 From f04f211aa62803ead1a7a90ef45dc41f71e680cc Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Fri, 6 Feb 2026 18:29:44 -0800 Subject: [PATCH 14/17] Update --- .../config/llama_runner_attention_sink.yaml | 7 +- .../source_transformation/attention_sink.py | 112 ++++++++++++++++-- 2 files changed, 105 insertions(+), 14 deletions(-) diff --git a/examples/models/llama/config/llama_runner_attention_sink.yaml b/examples/models/llama/config/llama_runner_attention_sink.yaml index e6945335563..4838851d79f 100644 --- a/examples/models/llama/config/llama_runner_attention_sink.yaml +++ b/examples/models/llama/config/llama_runner_attention_sink.yaml @@ -21,14 +21,13 @@ model: dtype_override: fp32 enable_dynamic_shape: True # Attention Sink: "sink_size,window_size,eviction_batch_size" - # sink_size=4: Keep first 4 tokens (e.g., BOS + system prompt) - # window_size=124: Sliding window size - # eviction_batch_size=1: Evict 1 token at a time - # KV cache size = sink_size + window_size * 2 = 4 + 124*2 = 252 + # sink_size=4, window_size=124, eviction_batch_size=1 + # Max Context (Buffer) = 4 + 1 * 124 = 128 use_attention_sink: "4,124,1" export: # max_seq_length for single prefill chunk + max_context_len: 128 max_seq_length: 128 quantization: diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index 3a69d85e61b..0f63b685f7b 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -218,10 +218,15 @@ def __init__( sink_size: int, eviction_batch_size: int, max_batch_size: int = 1, + max_context_len: Optional[int] = None, dtype=torch.float32, ): - # Total cache size is sink_size + window_size * 2 (ring buffer needs 2x) - total_cache_size = sink_size + window_size * 2 + # Total cache size is max_context_len if provided, else sink_size + window_size * 2 + if max_context_len is None: + total_cache_size = sink_size + window_size * 2 + else: + total_cache_size = max_context_len + super().__init__( max_batch_size=max_batch_size, max_context_length=total_cache_size, @@ -264,6 +269,7 @@ def update( input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor, + indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Update KV cache with new key-value pairs. @@ -274,10 +280,11 @@ def update( 2 ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" - # Calculate write indices - indices = self.cache_positions_manager.calculate_positions_and_update_indices( - input_pos, seq_len - ) + if indices is None: + # Calculate write indices + indices = self.cache_positions_manager.calculate_positions_and_update_indices( + input_pos, seq_len + ) start_pos = input_pos[0].item() torch._check_is_size(start_pos) @@ -300,14 +307,19 @@ def attention_sink_forward( x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, - input_pos: Optional[torch.Tensor] = None, + **kwargs, ): """ Forward function for attention with attention sink KV cache. Uses ring buffer masking for proper attention patterns. """ assert self.use_kv_cache + + input_pos = kwargs.get("input_pos") assert input_pos is not None + + # Extract cache_indices if provided (injected by Transformer forward) + cache_indices = kwargs.get("cache_indices") bsz, seqlen, _ = x.shape @@ -326,7 +338,7 @@ def attention_sink_forward( v = v.transpose(1, 2) # Update KV cache - k, v = self.kv_cache.update(input_pos, k, v) + k, v = self.kv_cache.update(input_pos, k, v, cache_indices) # Use ring buffer mask since we have is_ring_buffer=True start_pos = input_pos[0].item() @@ -358,6 +370,7 @@ def _replace_attention( sink_size: int, window_size: int, eviction_batch_size: int, + max_context_len: int, ): for _, child_module in module._modules.items(): if len(list(child_module.children())) > 0: # pyre-ignore [16] @@ -367,6 +380,7 @@ def _replace_attention( sink_size=sink_size, window_size=window_size, eviction_batch_size=eviction_batch_size, + max_context_len=max_context_len, ) if isinstance(child_module, AttentionMHA): @@ -382,6 +396,7 @@ def _replace_attention( window_size=window_size, sink_size=sink_size, eviction_batch_size=eviction_batch_size, + max_context_len=max_context_len, dtype=kv_cache.k_cache.dtype, ) child_module.kv_cache = kv_cache_with_attention_sink @@ -391,8 +406,10 @@ def _replace_attention( if "SDPACustom" in child_module.SDPA.__class__.__name__: child_module.SDPA.use_attention_mask = True - # Don't replace forward - let the original AttentionMHA.forward handle it - # since our KVCache has is_ring_buffer=True, it will use the ring buffer mask + # Replace forward with our custom forward that handles cache_indices + child_module.forward = types.MethodType( + attention_sink_forward, child_module + ) def enable_attention_sink( @@ -401,6 +418,7 @@ def enable_attention_sink( sink_size: int, window_size: int, eviction_batch_size: int, + max_context_len: Optional[int] = None, ) -> torch.nn.Module: """ Transform the model to be able to run inference with Attention Sink. @@ -409,6 +427,13 @@ def enable_attention_sink( - Replace Attention's KVCache with KVCacheWithAttentionSink - Replace Attention's forward with attention_sink_forward """ + if max_context_len is None: + max_context_len = sink_size + window_size * 2 + + # We update params.max_context_len to reflect the actual buffer size + # This ensures export captures the correct cache size in metadata + params.max_context_len = max_context_len + rope_with_attention_sink = RopeWithAttentionSink( params=params, window_size=window_size, @@ -422,5 +447,72 @@ def enable_attention_sink( sink_size=sink_size, window_size=window_size, eviction_batch_size=eviction_batch_size, + max_context_len=max_context_len, ) + + # Add metadata methods for IOManager + def get_sink_size(self): + return sink_size + + def get_window_size(self): + return window_size + + # Bind methods to module + # Note: For torch.export, we might need these to be part of the class or properly bound. + # Monkey patching instance methods works for some cases but let's verify. + # Ideally we subclass or mixin. But here we modify in place. + module.get_sink_size = types.MethodType(get_sink_size, module) + module.get_window_size = types.MethodType(get_window_size, module) + + # Monkey patch Transformer methods to handle cache_indices input + + # 1. New get_example_inputs that includes cache_indices + def get_example_inputs_with_sink(self): + # Create inputs manually to avoid relying on Llama2Model helper methods + # that might not be available on the Transformer module. + + # Use a small sequence length for example + seq_len = 3 + # Use simple tokens + tokens = torch.tensor([[2, 3, 4]], dtype=torch.long) + # Use corresponding input_pos matching seq_len + input_pos = torch.arange(seq_len, dtype=torch.long) + + # cache_indices matches input_pos/tokens length logic + # For export example, we can use simple indices + cache_indices = torch.arange(seq_len, dtype=torch.long) + + # Note: The original generic get_example_inputs usually returns ({tokens}, {input_pos}) + # input_pos shape depends on dynamic shape setting. + # But for export purposes, providing valid tensors is key. + # If dynamic shapes are enabled, input_pos might be expected to be 1D. + + return (tokens, {"input_pos": input_pos}, cache_indices) + + module.get_example_inputs = types.MethodType(get_example_inputs_with_sink, module) + + # 2. New forward that accepts cache_indices and passes it in attn_options + original_forward = module.forward + + def forward_with_sink( + self, + tokens: Optional[torch.LongTensor] = None, + attn_options: Optional[dict] = None, + cache_indices: Optional[torch.LongTensor] = None, + ): + # Allow cache_indices to be passed as positional or kwarg + # Note: top level export might pass inputs positionally if we aren't careful? + # LlamaTransformer.forward has (tokens, attn_options, h) + # We replace it. + + if attn_options is None: + attn_options = {} + + if cache_indices is not None: + attn_options["cache_indices"] = cache_indices + + return original_forward(tokens=tokens, attn_options=attn_options) + + module.forward = types.MethodType(forward_with_sink, module) + return module From dc2695ae602c4b22cdd81962ef7d7c22a577e069 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Fri, 6 Feb 2026 18:30:04 -0800 Subject: [PATCH 15/17] Update --- examples/models/llama/export_llama_lib.py | 108 +++++++++++++++--- .../io_manager/attention_sink_io_manager.cpp | 8 +- 2 files changed, 96 insertions(+), 20 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index aca875acdab..e287e16c553 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -40,6 +40,7 @@ get_vulkan_partitioner, get_xnnpack_partitioner, ) +from executorch.examples.models.llama.model_args import ModelArgs from executorch.extension.llm.export.quantizer_lib import ( get_coreml_quantizer, get_ov_quantizer, @@ -57,6 +58,7 @@ get_model_with_r1_r2, ) from .source_transformation.attention import replace_attention_to_attention_sha +from .source_transformation.attention_sink import enable_attention_sink from .source_transformation.custom_kv_cache import ( replace_kv_cache_with_custom_kv_cache, replace_kv_cache_with_quantized_kv_cache, @@ -757,13 +759,48 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: preq_embedding_quantize=llm_config.base.preq_embedding_quantize, local_global_attention=llm_config.model.local_global_attention, use_torchao_kernels_linear=llm_config.backend.torchao.use_torchao_kernels_linear, + use_torchao_kernels_tied_embedding=llm_config.backend.torchao.use_torchao_kernels_tied_embedding, + use_attention_sink=llm_config.model.use_attention_sink, + params_path=llm_config.base.params, ) ) + if llm_config.model.use_attention_sink: + print("Refreshing example inputs for Attention Sink...") + if hasattr(edge_manager.model, "get_example_inputs"): + # The model is now patched to return (tokens, attn_options, cache_indices) + new_inputs = edge_manager.model.get_example_inputs() + # We assume these are all positional arguments + edge_manager.example_inputs = new_inputs + # Clear kwargs since we provide everything positionally + edge_manager.example_kwarg_inputs = {} + print(f"Updated inputs: {len(new_inputs)} items") + + # Update dynamic shapes if enabled + if edge_manager.enable_dynamic_shape: + existing_shapes = edge_manager.dynamic_shapes + if existing_shapes and len(existing_shapes) == 2: + # Extract the Dim object from the first input (tokens) + # tokens shape dict is {1: Dim(...)} + token_dim = existing_shapes[0][1] + + # cache_indices is 1D tensor of size seq_len + # Spec should be {0: token_dim} + indices_spec = {0: token_dim} + + # Relieve static constraint on input_pos + # input_pos spec in existing_shapes[1] is {"input_pos": {0: 1}} + # We change it to {"input_pos": {0: token_dim}} + input_pos_spec = {"input_pos": {0: token_dim}} + + edge_manager.dynamic_shapes = (existing_shapes[0], input_pos_spec, indices_spec) + print("Updated dynamic_shapes for Attention Sink (patched input_pos)") + return edge_manager + def get_quantizer_and_quant_params(llm_config): pt2e_quant_params = get_pt2e_quantization_params( ( @@ -1298,6 +1335,28 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager": model_class_name, llm_config=llm_config, ) + + # Add attention sink metadata if enabled + metadata = _load_llama_model_metadata( + llm_config.model.use_kv_cache, + llm_config.model.use_sdpa_with_kv_cache, + llm_config.model.enable_dynamic_shape, + model.max_seq_len, + model.max_context_len, + model.n_layers, + model.vocab_size, + llm_config.base.metadata, + ) + + # Add attention sink metadata if enabled + if llm_config.model.use_attention_sink: + # Format: sink_size,window_size,eviction_batch_size + sink_params = [int(x) for x in llm_config.model.use_attention_sink.split(",")] + # IOManager expects these methods to exist returning int. + # By adding them to metadata, export_to_edge will generate constant methods. + metadata["get_sink_size"] = sink_params[0] + metadata["get_window_size"] = sink_params[1] + # Convert dtype override string to actual type. dtype_override = DType[llm_config.model.dtype_override.value] @@ -1312,31 +1371,14 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager": example_kwarg_inputs=example_kwarg_inputs, dynamic_shapes=dynamic_shapes, enable_dynamic_shape=llm_config.model.enable_dynamic_shape, + save_exported_program=llm_config.export.export_only, calibration_tasks=llm_config.quantization.calibration_tasks, calibration_limit=llm_config.quantization.calibration_limit, calibration_seq_length=llm_config.quantization.calibration_seq_length, calibration_data=llm_config.quantization.calibration_data, tokenizer_path=llm_config.base.tokenizer_path, - save_exported_program=llm_config.export.export_only, verbose=llm_config.debug.verbose, - metadata=_load_llama_model_metadata( - llm_config.model.use_kv_cache, - llm_config.model.use_sdpa_with_kv_cache, - llm_config.model.enable_dynamic_shape, - # pyre-fixme[6]: For 5th argument expected `ModelArgs` but got - # `Union[Tensor, Module]`. - model.max_seq_len, - # pyre-fixme[6]: For 6th argument expected `ModelArgs` but got - # `Union[Tensor, Module]`. - model.max_context_len, - # pyre-fixme[6]: For 7th argument expected `int` but got `Union[Tensor, - # Module]`. - model.n_layers, - # pyre-fixme[6]: For 8th argument expected `int` but got `Union[Tensor, - # Module]`. - model.vocab_size, - llm_config.base.metadata, - ), + metadata=metadata, ) @@ -1375,6 +1417,8 @@ def _get_source_transforms( # noqa use_torchao_kernels_linear: bool = False, use_torchao_kernels_tied_embedding: bool = False, quantize_with_hqq: bool = True, + use_attention_sink: Optional[str] = None, + params_path: Optional[str] = None, ) -> List[Callable[[torch.nn.Module], torch.nn.Module]]: """ Return a list of functions that transform a graph. @@ -1561,6 +1605,32 @@ def _get_source_transforms( # noqa ) ) + if use_attention_sink: + sink_params = [int(x) for x in use_attention_sink.split(",")] + + # Load ModelArgs for attention sink + if not params_path: + raise ValueError("params_path is required for attention sink") + with open(params_path, "r") as f: + params_dict = json.load(f) + + # Ensure use_kv_cache is propagated from config + params_dict["use_kv_cache"] = True # Attention Sink requires KV Cache + # ModelArgs might expect other fields usually handled by Llama2Model init + # We try to pass minimal set needed for Rope/Attention + + model_args = ModelArgs(**params_dict) + + transforms.append( + partial( + enable_attention_sink, + params=model_args, + sink_size=sink_params[0], + window_size=sink_params[1], + eviction_batch_size=sink_params[2], + ) + ) + return transforms diff --git a/extension/llm/runner/io_manager/attention_sink_io_manager.cpp b/extension/llm/runner/io_manager/attention_sink_io_manager.cpp index a6b781f2e9f..30d84887c7a 100644 --- a/extension/llm/runner/io_manager/attention_sink_io_manager.cpp +++ b/extension/llm/runner/io_manager/attention_sink_io_manager.cpp @@ -123,7 +123,13 @@ AttentionSinkIOManager::prepare_decode( void AttentionSinkIOManager::update_indices_tensor( int64_t logical_start, int64_t seq_len) { - int64_t ring_size = config_.window_size * 2; + int64_t ring_size = max_context_len_ - config_.sink_size; + ET_CHECK_MSG(ring_size > 0, "ring_size must be positive, got %" PRId64, ring_size); + ET_CHECK_MSG( + ring_size >= config_.window_size, + "ring_size (%" PRId64 ") must be >= window_size (%" PRId64 ")", + ring_size, + config_.window_size); indices_buffer_.resize(seq_len); for (int64_t i = 0; i < seq_len; ++i) { int64_t pos = logical_start + i; From 5ae4922830180c2395087568c0fd2812b5fd18c0 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Fri, 6 Feb 2026 18:35:57 -0800 Subject: [PATCH 16/17] update --- examples/models/llama/export_llama_lib.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index e287e16c553..23d26e1865f 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -763,6 +763,7 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: use_torchao_kernels_tied_embedding=llm_config.backend.torchao.use_torchao_kernels_tied_embedding, use_attention_sink=llm_config.model.use_attention_sink, params_path=llm_config.base.params, + max_context_len=llm_config.export.max_context_length, ) ) @@ -1419,6 +1420,7 @@ def _get_source_transforms( # noqa quantize_with_hqq: bool = True, use_attention_sink: Optional[str] = None, params_path: Optional[str] = None, + max_context_len: Optional[int] = None, ) -> List[Callable[[torch.nn.Module], torch.nn.Module]]: """ Return a list of functions that transform a graph. @@ -1628,6 +1630,7 @@ def _get_source_transforms( # noqa sink_size=sink_params[0], window_size=sink_params[1], eviction_batch_size=sink_params[2], + max_context_len=max_context_len, ) ) From 42e5e7a216eddb53eb08f8aefc2f66d3a509fcb9 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Fri, 6 Feb 2026 18:42:15 -0800 Subject: [PATCH 17/17] Update --- examples/models/llama/config/llama_runner_attention_sink.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/models/llama/config/llama_runner_attention_sink.yaml b/examples/models/llama/config/llama_runner_attention_sink.yaml index 4838851d79f..d23180cdd0c 100644 --- a/examples/models/llama/config/llama_runner_attention_sink.yaml +++ b/examples/models/llama/config/llama_runner_attention_sink.yaml @@ -27,7 +27,7 @@ model: export: # max_seq_length for single prefill chunk - max_context_len: 128 + max_context_length: 128 max_seq_length: 128 quantization: