From c954f27e543f57d1bdd547b28d9de0872d0effaf Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Feb 2026 01:41:34 -0800 Subject: [PATCH 01/11] attention sink --- .../llama/config/llama_attention_sink.yaml | 31 ++ .../source_transformation/attention_sink.py | 316 ++++++++++++------ .../source_transformation/custom_kv_cache.py | 11 +- 3 files changed, 247 insertions(+), 111 deletions(-) create mode 100644 examples/models/llama/config/llama_attention_sink.yaml diff --git a/examples/models/llama/config/llama_attention_sink.yaml b/examples/models/llama/config/llama_attention_sink.yaml new file mode 100644 index 00000000000..8d3fc01b8f7 --- /dev/null +++ b/examples/models/llama/config/llama_attention_sink.yaml @@ -0,0 +1,31 @@ +base: + metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' + +model: + use_sdpa_with_kv_cache: False # attention_sink 需要用标准 SDPA + use_kv_cache: True + dtype_override: fp32 + enable_dynamic_shape: True + # Attention Sink: "sink_size,window_size,eviction_batch_size" + # sink_size=4: 保留前 4 个 token (BOS + system prompt) + # window_size=126: 滑动窗口大小 + # eviction_batch_size=1: 每次驱逐 1 个 token + # 实际 cache 大小 = sink_size + window_size = 4 + 126 = 130 + # 但 ring buffer 需要 2x window, 所以实际是 4 + 126*2 = 256 + use_attention_sink: "4,126,1" + +export: + # max_context_length = sink_size + window_size = 4 + 126 = 130 + # 但 ring buffer 内部会创建 sink_size + window_size * 2 = 256 的缓存 + max_context_length: 130 + max_seq_length: 130 + +quantization: + qmode: 8da4w + group_size: 128 + embedding_quantize: 4,32 + +backend: + xnnpack: + enabled: True + extended_ops: True diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index 22bd8a3e228..126bbcf919c 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -7,12 +7,19 @@ # Components for supporting Attention Sink. See # https://arxiv.org/abs/2309.17453 for more details about Attention Sink. +# This implementation is torch.export compatible using a ring buffer approach +# for the sliding window portion while preserving the sink tokens. + import types -from typing import Optional +from typing import Optional, Tuple import torch - -from executorch.examples.models.llama.attention import AttentionMHA, KVCache +import torch.nn as nn +from executorch.examples.models.llama.attention import ( + _create_causal_mask_for_ring_buffer, + AttentionMHA, + KVCache, +) from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.rope import ( apply_rotary_emb_to_k, @@ -27,6 +34,9 @@ class RopeWithAttentionSink(Rope): Rope that helps adjust position encoding when tokens are shifted in KVCache. For AttentionSink, when tokens are shifted in KVCache, we need to use positions in KVCache instead of positions in the actual text. + + For torch.export compatibility, this just passes through the position - the + actual position adjustment is handled by the cache update logic. """ def __init__( @@ -42,27 +52,19 @@ def __init__( else: self.apply_rotary_emb_to_k = apply_rotary_emb_to_k self.max_context_length = window_size + sink_size + self.window_size = window_size + self.sink_size = sink_size assert self.max_context_length == self.params.max_context_len self.eviction_batch_size = eviction_batch_size - self.position_shift = 0 def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): + """ + Get rotary embedding frequencies. + For attention sink, we use the original position - the sliding window + is handled by the cache index management, not by position shifting. + """ assert input_pos is not None - - input_pos_item = input_pos.item() - torch._check_is_size(input_pos_item) - if input_pos_item + self.position_shift + seq_len > self.max_context_length: - # There are not enough spaces in the cache to store the new tokens. - # We need to evict some old tokens and shift some recent tokens. - num_to_evict = max( - input_pos_item - + self.position_shift - - self.max_context_length - + seq_len, - self.eviction_batch_size, - ) - self.position_shift -= num_to_evict # pyre-ignore [8] - return super().get_freqs(input_pos + self.position_shift, seq_len) + return super().get_freqs(input_pos, seq_len) def rerotate_k( self, @@ -71,15 +73,8 @@ def rerotate_k( new_position: int, ): """ - Rerotate k from original_position to new_position. This is done by rerotating - k with (new_position * theta - original_position * theta) with the following matrix: - (cos(delta), -sin(delta) - sin(delta), cos(delta)) - where delta = new_position * theta - original_position * theta - - The shape of k is (batch_size, seq_len, n_local_heads, head_dim) - - Based on https://github.com/huggingface/transformers/blame/main/src/transformers/cache_utils.py#L961 + Rerotate k from original_position to new_position. + The shape of k is (batch_size, seq_len, n_local_heads, head_dim) """ seq_len = k.shape[1] original_freqs_cos = self.freqs_cos.narrow(0, original_position, seq_len) @@ -96,15 +91,114 @@ def rerotate_k( return self.apply_rotary_emb_to_k(k, rerotation_cos, rerotation_sin) +def _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len +): + """ + Create causal mask for attention sink. + + Unlike regular ring buffer mask, this mask: + 1. ALWAYS allows attending to sink tokens (positions 0 to sink_size-1) + 2. Uses sliding window for other tokens + + Args: + cache_positions: Tensor of actual positions stored at each cache index + window_size: Size of the sliding window + sink_size: Number of sink tokens to always attend to + start_pos: Starting position of the current query + seq_len: Length of the current query sequence + """ + pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1) + delta = pos_q - cache_positions + + # Valid if position is filled (>= 0) and causal (delta >= 0) + is_valid = (cache_positions >= 0) & (delta >= 0) + + # Sink tokens (original positions 0 to sink_size-1) are always visible + is_sink = cache_positions < sink_size + + # Window tokens must be within sliding window + is_in_window = delta < window_size + + # Final mask: valid AND (is_sink OR is_in_window) + attn_mask = is_valid & (is_sink | is_in_window) + attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712 + return attn_mask + + +class CachePositionsManagerWithSink(nn.Module): + """ + Manages cache positions for attention sink + sliding window. + Similar to CachePositionsManager but handles sink tokens separately. + + Layout: [sink_tokens (fixed)] [ring_buffer_window (rotating)] + + For sink_size=4 and window_size=8: + - Positions 0-3 in the sequence go to cache indices 0-3 (fixed) + - Positions 4+ go to cache indices 4-19 using ring buffer (window_size * 2) + """ + + def __init__(self, window_size: int, sink_size: int): + super().__init__() + # Total cache size is sink + window * 2 (ring buffer needs 2x for proper masking) + self.max_context_length = sink_size + window_size * 2 + self.sink_size = sink_size + self.window_size = window_size + self.register_buffer( + "cache_positions", + torch.full((self.max_context_length,), -1, dtype=torch.long, device="cpu"), + ) + # Initialize sink positions (these are fixed) + if sink_size > 0: + self.cache_positions[:sink_size] = torch.arange(sink_size) + + def calculate_positions_and_update_indices( + self, input_pos: torch.Tensor, seq_len: int + ) -> torch.Tensor: + """ + Calculate indices into k_cache, v_cache for placing k_val, v_val. + + For positions < sink_size: index = position (fixed) + For positions >= sink_size: index = sink_size + (pos - sink_size) % (window_size * 2) + """ + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + + orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos + + # Calculate cache indices based on whether position is sink or window + sink_part = torch.minimum(orig_indices, torch.tensor(self.sink_size)) + window_part = torch.maximum( + orig_indices - self.sink_size, torch.tensor(0) + ) % (self.window_size * 2) + is_sink = orig_indices < self.sink_size + indices = torch.where(is_sink, sink_part, self.sink_size + window_part) + + # Update cache_positions: clear old positions and set new ones + full_t = torch.full((self.max_context_length,), -1, dtype=torch.long) + arange_tensor = torch.arange(self.max_context_length, dtype=torch.long) + # Keep sink positions (0 to sink_size-1) and clear window positions that will be overwritten + cache_positions = torch.where( + arange_tensor < self.sink_size, self.cache_positions, full_t + ) + # For non-sink positions, check if they should be cleared + cache_positions = torch.where( + arange_tensor < start_pos, self.cache_positions, cache_positions + ) + self.cache_positions.copy_(cache_positions) + self.cache_positions.index_copy_(0, indices, orig_indices) + + return indices + + class KVCacheWithAttentionSink(KVCache): """ - KV cache that supports attention sink. It keeps the initial few tokens as attention sink. - For other tokens, it uses a sliding window to keep the most recent tokens. + KV cache that supports attention sink with torch.export compatibility. - Parameters: - window_size: the size of the sliding window - sink_size: the number of initial tokens to keep as attention sink - eviction_batch_size: the number of tokens to evict in batch when there is not enough space in the KV cache + Uses a ring buffer approach for the sliding window portion while keeping + the first sink_size tokens fixed. This avoids dynamic shape operations. + + Cache layout: [sink: 0 to sink_size-1] [ring_buffer: sink_size to sink_size + window_size*2 - 1] """ def __init__( @@ -119,9 +213,11 @@ def __init__( max_batch_size: int = 1, dtype=torch.float32, ): + # Total cache size is sink_size + window_size * 2 (ring buffer needs 2x) + total_cache_size = sink_size + window_size * 2 super().__init__( max_batch_size=max_batch_size, - max_context_length=window_size + sink_size, + max_context_length=total_cache_size, n_heads=n_heads, head_dim=head_dim, enable_dynamic_shape=enable_dynamic_shape, @@ -131,78 +227,61 @@ def __init__( self.window_size = window_size self.sink_size = sink_size self.eviction_batch_size = eviction_batch_size - self.position_shift = 0 + self.is_ring_buffer = True - def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: + # Cache positions manager for determining write locations + self.cache_positions_manager = CachePositionsManagerWithSink( + window_size=window_size, + sink_size=sink_size, + ) + + def create_causal_mask_for_ring_buffer( + self, start_pos: torch.Tensor, seq_len: int + ): + """ + Create causal mask for the attention with attention sink. + Sink tokens are ALWAYS visible, plus recent tokens in the window. """ - Evict old tokens from the cache to make rooms for new tokens. + cache_positions = self.cache_positions_manager.cache_positions + # Use attention sink mask that always allows attending to sink tokens + return _create_causal_mask_for_attention_sink( + cache_positions, self.window_size, self.sink_size, start_pos, seq_len + ) + + def update( + self, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Update KV cache with new key-value pairs. + Uses ring buffer indexing for positions >= sink_size. + """ + seq_len = k_val.size(2) + assert seq_len <= self.k_cache.size( + 2 + ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" + + # Calculate write indices + indices = self.cache_positions_manager.calculate_positions_and_update_indices( + input_pos, seq_len + ) + + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + self.k_cache.index_copy_(2, indices, k_val) + self.v_cache.index_copy_(2, indices, v_val) - Parameters: - input_pos: the start position of the incoming token in the actual sequence - seq_len: the length of the incoming sequence - rope: the rope object to use for rerotating k + return self.k_cache, self.v_cache - Returns: - the number of tokens to evict from the cache which is also the number of - positions to shift for incoming tokens + def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: """ - input_pos_item = input_pos.item() - torch._check_is_size(input_pos_item) - if input_pos_item + self.position_shift + seq_len > self.max_context_length: - # There are not enough spaces in the cache to store the new tokens. - # We need to evict some old tokens and shift some recent tokens. - num_to_evict = max( - input_pos_item - + self.position_shift - - self.max_context_length - + seq_len, - self.eviction_batch_size, - ) - num_to_keep = ( - input_pos_item + self.position_shift - self.sink_size - num_to_evict - ) - num_empty_space = self.window_size - num_to_keep - dim_to_slice = 2 - k_to_keep = self.k_cache.narrow( - dim_to_slice, - self.sink_size + num_to_evict, # pyre-ignore [6] - num_to_keep, # pyre-ignore [6] - ) - k_to_keep = self.rope.rerotate_k( - k=k_to_keep.transpose(1, 2), - original_position=(self.sink_size + num_to_evict), # pyre-ignore [6] - new_position=self.sink_size, - ).transpose(1, 2) - self.k_cache = torch.cat( - [ - self.k_cache.narrow(dim_to_slice, 0, self.sink_size), - k_to_keep, - torch.zeros_like( - self.k_cache.narrow( - dim_to_slice, 0, num_empty_space # pyre-ignore [6] - ) - ), - ], - dim=dim_to_slice, - ) - self.v_cache = torch.cat( - [ - self.v_cache.narrow(dim_to_slice, 0, self.sink_size), - self.v_cache.narrow( - dim_to_slice, - self.sink_size + num_to_evict, # pyre-ignore [6] - num_to_keep, # pyre-ignore [6] - ), - torch.zeros_like( - self.v_cache.narrow( - dim_to_slice, 0, num_empty_space # pyre-ignore [6] - ) - ), - ], - dim=dim_to_slice, - ) - self.position_shift -= num_to_evict # pyre-ignore [8] - return self.position_shift + For ring buffer implementation, no explicit eviction is needed. + The ring buffer automatically overwrites old values. + Returns 0 to indicate no position shift is needed. + """ + return 0 def attention_sink_forward( @@ -212,6 +291,10 @@ def attention_sink_forward( freqs_sin: torch.Tensor, input_pos: Optional[torch.Tensor] = None, ): + """ + Forward function for attention with attention sink KV cache. + Uses ring buffer masking for proper attention patterns. + """ assert self.use_kv_cache assert input_pos is not None @@ -219,19 +302,31 @@ def attention_sink_forward( # QKV q, k, v = self.wq(x), self.wk(x), self.wv(x) - # We need view_copy elimination q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - # Prepare for space in KV cache and get position shift - position_shift = self.kv_cache.evict_tokens(input_pos, seqlen) - - # RoPE relative positional embeddings with shifted position in KV cache + # RoPE relative positional embeddings q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) - output = self.SDPA(input_pos + position_shift, q, k, v, bsz, seqlen, self.mask) - return self.wo(output) + # Transpose for attention: [B, H, S, D] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Update KV cache + k, v = self.kv_cache.update(input_pos, k, v) + + # Use ring buffer mask since we have is_ring_buffer=True + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(start_pos, seqlen) + + # SDPA + output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) + + # Return tuple like original AttentionMHA.forward + return self.wo(output), None def _replace_rope( @@ -293,7 +388,8 @@ def enable_attention_sink( Transform the model to be able to run inference with Attention Sink. There mainly three steps: - Replace Rope with RopeWithAttentionSink - - Replace Attention's KVCache with KVCacheWithAttentionSink, forward with attention_sink_forward + - Replace Attention's KVCache with KVCacheWithAttentionSink + - Replace Attention's forward with attention_sink_forward """ rope_with_attention_sink = RopeWithAttentionSink( params=params, diff --git a/examples/models/llama/source_transformation/custom_kv_cache.py b/examples/models/llama/source_transformation/custom_kv_cache.py index 8d4d37e0e93..373f4fcc4a0 100644 --- a/examples/models/llama/source_transformation/custom_kv_cache.py +++ b/examples/models/llama/source_transformation/custom_kv_cache.py @@ -371,8 +371,17 @@ def replace_kv_cache_with_custom_kv_cache(module): def _replace_kv_cache_with_custom_kv_cache(module): + # Import here to avoid circular imports + from executorch.examples.models.llama.source_transformation.attention_sink import ( + KVCacheWithAttentionSink, + ) + for name, child in module.named_children(): - if isinstance(child, KVCache): + # Skip KVCacheWithAttentionSink as it has special evict_tokens logic + # that is not compatible with CustomKVCache + if isinstance(child, KVCacheWithAttentionSink): + _replace_kv_cache_with_custom_kv_cache(child) + elif isinstance(child, KVCache): cache_shape = child.k_cache.shape cache_dtype = child.k_cache.dtype max_batch_size, n_heads, max_context_length, head_dim = cache_shape From 34a937f4eb5c111cfa9c89af4a7be557eb0d0835 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Feb 2026 11:05:46 -0800 Subject: [PATCH 02/11] Test this --- .../llama/config/llama_attention_sink.yaml | 18 +-- examples/models/llama/model.py | 8 +- .../source_transformation/attention_sink.py | 124 +++++++++--------- extension/llm/runner/text_llm_runner.cpp | 34 +++-- 4 files changed, 107 insertions(+), 77 deletions(-) diff --git a/examples/models/llama/config/llama_attention_sink.yaml b/examples/models/llama/config/llama_attention_sink.yaml index 8d3fc01b8f7..2e8cfb18cde 100644 --- a/examples/models/llama/config/llama_attention_sink.yaml +++ b/examples/models/llama/config/llama_attention_sink.yaml @@ -7,18 +7,18 @@ model: dtype_override: fp32 enable_dynamic_shape: True # Attention Sink: "sink_size,window_size,eviction_batch_size" - # sink_size=4: 保留前 4 个 token (BOS + system prompt) - # window_size=126: 滑动窗口大小 + # sink_size=4: Keep first 4 tokens (e.g., BOS + system prompt) + # window_size=124: 滑动窗口大小 # eviction_batch_size=1: 每次驱逐 1 个 token - # 实际 cache 大小 = sink_size + window_size = 4 + 126 = 130 - # 但 ring buffer 需要 2x window, 所以实际是 4 + 126*2 = 256 - use_attention_sink: "4,126,1" + # KV cache size = sink_size + window_size * 2 = 4 + 124*2 = 252 + use_attention_sink: "4,124,1" export: - # max_context_length = sink_size + window_size = 4 + 126 = 130 - # 但 ring buffer 内部会创建 sink_size + window_size * 2 = 256 的缓存 - max_context_length: 130 - max_seq_length: 130 + # max_context_length controls the RoPE frequency table size. + # It must be >= sink_size + window_size (128), but larger values are + # recommended to support generation beyond the sliding window. + # The model default (e.g., 8192 or 131072) is typically used if not specified. + # For testing, we use the model's default by not setting this explicitly. quantization: qmode: 8da4w diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 1ec85936f7a..a2c8d73207b 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -218,7 +218,13 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): window_size = int(attention_sink_params[1]) eviction_batch_size = int(attention_sink_params[2]) - assert self.llm_config.export.max_context_length == sink_size + window_size + # max_context_length must be >= sink_size + window_size to have enough RoPE frequencies + # A larger max_context_length is allowed (and recommended) to support generation beyond + # the sliding window size. + assert self.llm_config.export.max_context_length >= sink_size + window_size, ( + f"max_context_length ({self.llm_config.export.max_context_length}) must be >= " + f"sink_size + window_size ({sink_size + window_size})" + ) self.model_ = enable_attention_sink( module=self.model_, diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index 126bbcf919c..068ac7bdc68 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -18,7 +18,9 @@ from executorch.examples.models.llama.attention import ( _create_causal_mask_for_ring_buffer, AttentionMHA, + CachePositionsManager, KVCache, + RingKVCache, ) from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.rope import ( @@ -37,6 +39,10 @@ class RopeWithAttentionSink(Rope): For torch.export compatibility, this just passes through the position - the actual position adjustment is handled by the cache update logic. + + Note: This class uses the model's max_context_len (params.max_context_len) for + RoPE frequency table size, which should be large enough to support generation + beyond the sliding window. The actual KV cache size is sink_size + window_size * 2. """ def __init__( @@ -51,10 +57,12 @@ def __init__( self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k else: self.apply_rotary_emb_to_k = apply_rotary_emb_to_k - self.max_context_length = window_size + sink_size + # The KV cache size is sink_size + window_size * 2 (ring buffer needs 2x) + self.kv_cache_size = sink_size + window_size * 2 self.window_size = window_size self.sink_size = sink_size - assert self.max_context_length == self.params.max_context_len + # max_context_len from params is used for RoPE frequencies (should be large) + self.max_context_length = self.params.max_context_len self.eviction_batch_size = eviction_batch_size def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): @@ -129,61 +137,44 @@ def _create_causal_mask_for_attention_sink( class CachePositionsManagerWithSink(nn.Module): """ Manages cache positions for attention sink + sliding window. - Similar to CachePositionsManager but handles sink tokens separately. - - Layout: [sink_tokens (fixed)] [ring_buffer_window (rotating)] - - For sink_size=4 and window_size=8: - - Positions 0-3 in the sequence go to cache indices 0-3 (fixed) - - Positions 4+ go to cache indices 4-19 using ring buffer (window_size * 2) + + For sink_size=0: behaves exactly like original CachePositionsManager. + For sink_size>0: sink tokens go to fixed positions, rest uses ring buffer. + + IMPORTANT: cache_size should be the actual cache dimension size (2x window for ring buffer). """ - def __init__(self, window_size: int, sink_size: int): + def __init__(self, cache_size: int): super().__init__() - # Total cache size is sink + window * 2 (ring buffer needs 2x for proper masking) - self.max_context_length = sink_size + window_size * 2 - self.sink_size = sink_size - self.window_size = window_size + # cache_size is the actual size of the kv cache dimension + self.max_context_length = cache_size + # Use zeros like original CachePositionsManager self.register_buffer( "cache_positions", - torch.full((self.max_context_length,), -1, dtype=torch.long, device="cpu"), + torch.zeros((self.max_context_length,), dtype=torch.long, device="cpu"), ) - # Initialize sink positions (these are fixed) - if sink_size > 0: - self.cache_positions[:sink_size] = torch.arange(sink_size) def calculate_positions_and_update_indices( self, input_pos: torch.Tensor, seq_len: int ) -> torch.Tensor: """ Calculate indices into k_cache, v_cache for placing k_val, v_val. - - For positions < sink_size: index = position (fixed) - For positions >= sink_size: index = sink_size + (pos - sink_size) % (window_size * 2) + + This is identical to the original CachePositionsManager logic. """ start_pos = input_pos[0].item() torch._check_is_size(start_pos) orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos + + # Simple ring buffer: just mod by cache size + indices = orig_indices % self.max_context_length - # Calculate cache indices based on whether position is sink or window - sink_part = torch.minimum(orig_indices, torch.tensor(self.sink_size)) - window_part = torch.maximum( - orig_indices - self.sink_size, torch.tensor(0) - ) % (self.window_size * 2) - is_sink = orig_indices < self.sink_size - indices = torch.where(is_sink, sink_part, self.sink_size + window_part) - - # Update cache_positions: clear old positions and set new ones + # Update cache_positions exactly like original CachePositionsManager full_t = torch.full((self.max_context_length,), -1, dtype=torch.long) arange_tensor = torch.arange(self.max_context_length, dtype=torch.long) - # Keep sink positions (0 to sink_size-1) and clear window positions that will be overwritten - cache_positions = torch.where( - arange_tensor < self.sink_size, self.cache_positions, full_t - ) - # For non-sink positions, check if they should be cleared cache_positions = torch.where( - arange_tensor < start_pos, self.cache_positions, cache_positions + arange_tensor < start_pos, self.cache_positions, full_t ) self.cache_positions.copy_(cache_positions) self.cache_positions.index_copy_(0, indices, orig_indices) @@ -230,10 +221,8 @@ def __init__( self.is_ring_buffer = True # Cache positions manager for determining write locations - self.cache_positions_manager = CachePositionsManagerWithSink( - window_size=window_size, - sink_size=sink_size, - ) + # Pass the total cache size (same as self.max_context_length after super().__init__) + self.cache_positions_manager = CachePositionsManagerWithSink(total_cache_size) def create_causal_mask_for_ring_buffer( self, start_pos: torch.Tensor, seq_len: int @@ -243,10 +232,16 @@ def create_causal_mask_for_ring_buffer( Sink tokens are ALWAYS visible, plus recent tokens in the window. """ cache_positions = self.cache_positions_manager.cache_positions - # Use attention sink mask that always allows attending to sink tokens - return _create_causal_mask_for_attention_sink( - cache_positions, self.window_size, self.sink_size, start_pos, seq_len - ) + if self.sink_size > 0: + # Use attention sink mask that always allows attending to sink tokens + return _create_causal_mask_for_attention_sink( + cache_positions, self.window_size, self.sink_size, start_pos, seq_len + ) + else: + # Pure ring buffer mode - use original mask with window_size = actual window + return _create_causal_mask_for_ring_buffer( + cache_positions, self.window_size, start_pos, seq_len + ) def update( self, @@ -360,21 +355,32 @@ def _replace_attention( if isinstance(child_module, AttentionMHA): kv_cache = child_module.kv_cache - kv_cache_with_attention_sink = KVCacheWithAttentionSink( - n_heads=kv_cache.n_heads, - head_dim=kv_cache.head_dim, - enable_dynamic_shape=kv_cache.enable_dynamic_shape, - rope=rope_with_attention_sink, - max_batch_size=kv_cache.max_batch_size, - window_size=window_size, - sink_size=sink_size, - eviction_batch_size=eviction_batch_size, - dtype=kv_cache.k_cache.dtype, - ) - child_module.kv_cache = kv_cache_with_attention_sink - child_module.forward = types.MethodType( # pyre-ignore - attention_sink_forward, child_module - ) + if sink_size == 0: + # For sink_size=0, use the exact same RingKVCache that works + # This is a test to ensure parity with the working implementation + child_module.kv_cache = RingKVCache( + kv_cache.max_batch_size, + window_size, # RingKVCache expects user-provided window size + kv_cache.n_heads, + kv_cache.head_dim, + kv_cache.enable_dynamic_shape, + kv_cache.k_cache.dtype, + ) + else: + kv_cache_with_attention_sink = KVCacheWithAttentionSink( + n_heads=kv_cache.n_heads, + head_dim=kv_cache.head_dim, + enable_dynamic_shape=kv_cache.enable_dynamic_shape, + rope=rope_with_attention_sink, + max_batch_size=kv_cache.max_batch_size, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=eviction_batch_size, + dtype=kv_cache.k_cache.dtype, + ) + child_module.kv_cache = kv_cache_with_attention_sink + # Don't replace forward - let the original AttentionMHA.forward handle it + # since our KVCache has is_ring_buffer=True, it will use the ring buffer mask def enable_attention_sink( diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index 373f57d3a8e..11d77f321bf 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -129,24 +129,42 @@ Error TextLLMRunner::generate( std::vector prompt_tokens = encode_res.get(); int num_prompt_tokens = prompt_tokens.size(); - // Reduce max_context_len by pos_ - int64_t max_context_len = metadata_.at(kMaxContextLen) - pos_; + // Get max_seq_len for single prefill chunk limit + int64_t max_seq_len = metadata_.at(kMaxSeqLen); + int64_t max_context_len = metadata_.at(kMaxContextLen); + ET_CHECK_OR_RETURN_ERROR( num_prompt_tokens >= 1, InvalidArgument, "Expected at least 1 prompt token"); + + // For models with sliding window (Ring Buffer / Attention Sink), + // we allow pos_ to exceed max_context_len. The model handles this + // internally via ring buffer indexing or token eviction. + // We only check that a single prefill chunk doesn't exceed max_seq_len. ET_CHECK_OR_RETURN_ERROR( - num_prompt_tokens < max_context_len, + num_prompt_tokens <= max_seq_len, InvalidArgument, - "num_prompt_tokens %d >= max_context_len %" PRId64 - ", Max seq length exceeded - please increase max seq len value in your export script", + "num_prompt_tokens %d > max_seq_len %" PRId64 + ", Single prefill chunk too large - please reduce prompt size or increase max_seq_len", num_prompt_tokens, - max_context_len); + max_seq_len); - // Determine max_new_tokens using the GenerationConfig's resolve method, - // then subtract pos_ for max_new_tokens. + // Determine max_new_tokens using the GenerationConfig's resolve method. + // For sliding window models, we use max_context_len directly (not reduced by pos_) + // because the model handles position wrapping internally via ring buffer. int max_new_tokens = config.resolve_max_new_tokens(max_context_len, num_prompt_tokens); + + // TEMPORARY: For sliding window / infinite context testing, + // override max_new_tokens to allow unlimited generation + if (max_new_tokens <= 0) { + max_new_tokens = 1000000; // Effectively unlimited + } + // If user specified seq_len, use that instead + if (config.seq_len > 0 && config.seq_len > max_new_tokens) { + max_new_tokens = config.seq_len; + } ET_LOG( Info, From 83f437accd9259dad175087d8af6bbdd5ff798aa Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Feb 2026 12:12:03 -0800 Subject: [PATCH 03/11] kv cache sdpa --- .../llama/config/llama_attention_sink.yaml | 2 +- examples/models/llama/model.py | 9 ++++ .../source_transformation/attention_sink.py | 44 +++++++++---------- .../source_transformation/custom_kv_cache.py | 14 +++--- 4 files changed, 37 insertions(+), 32 deletions(-) diff --git a/examples/models/llama/config/llama_attention_sink.yaml b/examples/models/llama/config/llama_attention_sink.yaml index 2e8cfb18cde..1d859035d74 100644 --- a/examples/models/llama/config/llama_attention_sink.yaml +++ b/examples/models/llama/config/llama_attention_sink.yaml @@ -2,7 +2,7 @@ base: metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' model: - use_sdpa_with_kv_cache: False # attention_sink 需要用标准 SDPA + use_sdpa_with_kv_cache: True # Now supported! We set use_attention_mask=True on SDPACustom use_kv_cache: True dtype_override: fp32 enable_dynamic_shape: True diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index a2c8d73207b..3be00d78711 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -226,6 +226,15 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): f"sink_size + window_size ({sink_size + window_size})" ) + # IMPORTANT: For attention sink, we need RoPE frequencies for all possible generation + # positions, not just the cache size. Override the model's max_context_len to use + # a larger value that supports extended generation. + # We use model_args.max_context_len which was set from export.max_context_length + # but for RoPE we need the full generation length capability. + # Use 131072 (128k) as default for Llama 3.2 models or the original model max if larger. + default_rope_length = max(131072, model_args.max_context_len) + model_args.max_context_len = default_rope_length + self.model_ = enable_attention_sink( module=self.model_, params=model_args, diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index 068ac7bdc68..c2e0f6606f5 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -355,30 +355,26 @@ def _replace_attention( if isinstance(child_module, AttentionMHA): kv_cache = child_module.kv_cache - if sink_size == 0: - # For sink_size=0, use the exact same RingKVCache that works - # This is a test to ensure parity with the working implementation - child_module.kv_cache = RingKVCache( - kv_cache.max_batch_size, - window_size, # RingKVCache expects user-provided window size - kv_cache.n_heads, - kv_cache.head_dim, - kv_cache.enable_dynamic_shape, - kv_cache.k_cache.dtype, - ) - else: - kv_cache_with_attention_sink = KVCacheWithAttentionSink( - n_heads=kv_cache.n_heads, - head_dim=kv_cache.head_dim, - enable_dynamic_shape=kv_cache.enable_dynamic_shape, - rope=rope_with_attention_sink, - max_batch_size=kv_cache.max_batch_size, - window_size=window_size, - sink_size=sink_size, - eviction_batch_size=eviction_batch_size, - dtype=kv_cache.k_cache.dtype, - ) - child_module.kv_cache = kv_cache_with_attention_sink + # Always use KVCacheWithAttentionSink, even for sink_size=0 + # This ensures we don't get replaced by CustomKVCache when use_sdpa_with_kv_cache=True + kv_cache_with_attention_sink = KVCacheWithAttentionSink( + n_heads=kv_cache.n_heads, + head_dim=kv_cache.head_dim, + enable_dynamic_shape=kv_cache.enable_dynamic_shape, + rope=rope_with_attention_sink, + max_batch_size=kv_cache.max_batch_size, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=eviction_batch_size, + dtype=kv_cache.k_cache.dtype, + ) + child_module.kv_cache = kv_cache_with_attention_sink + + # If using SDPACustom (fused SDPA op), enable attention mask support + # so it uses our ring buffer / attention sink mask instead of simple causal mask + if "SDPACustom" in child_module.SDPA.__class__.__name__: + child_module.SDPA.use_attention_mask = True + # Don't replace forward - let the original AttentionMHA.forward handle it # since our KVCache has is_ring_buffer=True, it will use the ring buffer mask diff --git a/examples/models/llama/source_transformation/custom_kv_cache.py b/examples/models/llama/source_transformation/custom_kv_cache.py index 373f4fcc4a0..f8a268183b5 100644 --- a/examples/models/llama/source_transformation/custom_kv_cache.py +++ b/examples/models/llama/source_transformation/custom_kv_cache.py @@ -371,17 +371,17 @@ def replace_kv_cache_with_custom_kv_cache(module): def _replace_kv_cache_with_custom_kv_cache(module): - # Import here to avoid circular imports - from executorch.examples.models.llama.source_transformation.attention_sink import ( - KVCacheWithAttentionSink, - ) - for name, child in module.named_children(): # Skip KVCacheWithAttentionSink as it has special evict_tokens logic - # that is not compatible with CustomKVCache - if isinstance(child, KVCacheWithAttentionSink): + # that is not compatible with CustomKVCache. + # Check by class name because the class might come from different module paths + # (e.g., 'examples.models...' vs 'executorch.examples.models...') + child_class_name = type(child).__name__ + if child_class_name == "KVCacheWithAttentionSink": + logging.info(f"Skipping KVCacheWithAttentionSink at {name}") _replace_kv_cache_with_custom_kv_cache(child) elif isinstance(child, KVCache): + logging.info(f"Replacing KVCache at {name} (type={child_class_name})") cache_shape = child.k_cache.shape cache_dtype = child.k_cache.dtype max_batch_size, n_heads, max_context_length, head_dim = cache_shape From 97f9910eabf5a118d27d7c6c3adb6804739bc6d2 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Feb 2026 12:28:57 -0800 Subject: [PATCH 04/11] Test --- examples/models/llama/BUCK | 12 + examples/models/llama/eval_llama_lib.py | 10 +- .../test_attention_sink_ring_buffer.py | 594 ++++++++++++++++++ 3 files changed, 615 insertions(+), 1 deletion(-) create mode 100644 examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py diff --git a/examples/models/llama/BUCK b/examples/models/llama/BUCK index 9d9897a2819..57ed846dd86 100644 --- a/examples/models/llama/BUCK +++ b/examples/models/llama/BUCK @@ -283,6 +283,18 @@ fbcode_target(_kind = runtime.python_test, ], ) +fbcode_target(_kind = runtime.python_test, + name = "attention_sink_ring_buffer_test", + srcs = [ + "source_transformation/test_attention_sink_ring_buffer.py", + ], + supports_static_listing = False, + deps = [ + "//caffe2:torch", + ":export_library", + ], +) + fbcode_target(_kind = runtime.python_test, name = "quantized_sdpa_source_transform_test", srcs = [ diff --git a/examples/models/llama/eval_llama_lib.py b/examples/models/llama/eval_llama_lib.py index 03f8f5cd759..e13a5299e61 100644 --- a/examples/models/llama/eval_llama_lib.py +++ b/examples/models/llama/eval_llama_lib.py @@ -338,6 +338,8 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse Evaluate the model's perplexity when AttentionSink is enabled. This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py + + Updated for the ring-buffer based attention sink implementation. """ # Convert args to LlmConfig from executorch.extension.llm.export.config.llm_config import LlmConfig @@ -351,7 +353,13 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse sink_size = int(attention_sink_params[0]) window_size = int(attention_sink_params[1]) - assert llm_config.export.max_seq_length == sink_size + window_size + # For the ring buffer implementation, the cache size is sink_size + window_size * 2 + # max_context_length should be >= sink_size + window_size (for RoPE frequencies) + # but can be larger to support extended generation + assert llm_config.export.max_context_length >= sink_size + window_size, ( + f"max_context_length ({llm_config.export.max_context_length}) must be >= " + f"sink_size + window_size ({sink_size + window_size})" + ) device = "cuda" if torch.cuda.is_available() else "cpu" manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) diff --git a/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py new file mode 100644 index 00000000000..7109129caca --- /dev/null +++ b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py @@ -0,0 +1,594 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for the ring-buffer based attention sink implementation. + +This tests the torch.export-compatible implementation that uses a ring buffer +for the sliding window rather than explicit token eviction. + +Usage: + # Run with pytest + python -m pytest examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py -v + + # Or run directly + python examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py +""" + +import unittest + +import torch +from executorch.examples.models.llama.model_args import ModelArgs + +from executorch.examples.models.llama.source_transformation.attention_sink import ( + CachePositionsManagerWithSink, + KVCacheWithAttentionSink, + RopeWithAttentionSink, + _create_causal_mask_for_attention_sink, +) + + +class CachePositionsManagerWithSinkTest(unittest.TestCase): + """Test the cache positions manager for ring buffer indexing.""" + + def setUp(self): + self.cache_size = 32 # Total cache size (e.g., sink_size + window_size * 2) + self.manager = CachePositionsManagerWithSink(self.cache_size) + + def test_initial_positions_are_zero(self): + """Cache positions should start as zeros.""" + expected = torch.zeros(self.cache_size, dtype=torch.long) + torch.testing.assert_close(self.manager.cache_positions, expected) + + def test_simple_update(self): + """Test simple sequential update without wraparound.""" + input_pos = torch.tensor([0], dtype=torch.long) + seq_len = 5 + indices = self.manager.calculate_positions_and_update_indices(input_pos, seq_len) + + # Should return indices 0, 1, 2, 3, 4 + expected_indices = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long) + torch.testing.assert_close(indices, expected_indices) + + # Cache positions at those indices should be the original positions + for i in range(5): + self.assertEqual(self.manager.cache_positions[i].item(), i) + + def test_wraparound(self): + """Test ring buffer wraparound.""" + # Fill cache to position 30 + input_pos = torch.tensor([0], dtype=torch.long) + self.manager.calculate_positions_and_update_indices(input_pos, 30) + + # Add 5 more tokens at position 30 - should wrap around + input_pos = torch.tensor([30], dtype=torch.long) + indices = self.manager.calculate_positions_and_update_indices(input_pos, 5) + + # Indices should wrap: 30 % 32 = 30, 31, 0, 1, 2 + expected_indices = torch.tensor([30, 31, 0, 1, 2], dtype=torch.long) + torch.testing.assert_close(indices, expected_indices) + + def test_cache_positions_track_original_positions(self): + """Cache positions should track which original position is at each index.""" + # Fill with positions 0-31 + input_pos = torch.tensor([0], dtype=torch.long) + self.manager.calculate_positions_and_update_indices(input_pos, 32) + + # Now add position 32 which wraps to index 0 + input_pos = torch.tensor([32], dtype=torch.long) + self.manager.calculate_positions_and_update_indices(input_pos, 1) + + # Index 0 should now contain original position 32 + self.assertEqual(self.manager.cache_positions[0].item(), 32) + + +class CausalMaskTest(unittest.TestCase): + """Test the causal mask generation for attention sink.""" + + def test_mask_allows_sink_tokens(self): + """Sink tokens should always be visible (mask = 0).""" + cache_size = 32 + sink_size = 4 + window_size = 14 # cache_size = sink_size + window_size * 2 + + # Create cache positions where positions 0-3 are sink tokens + cache_positions = torch.arange(cache_size, dtype=torch.long) + + start_pos = 20 # Query at position 20 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Sink tokens (indices 0-3, original positions 0-3) should have mask = 0 + for i in range(sink_size): + self.assertEqual(mask[0, i].item(), 0.0, f"Sink token at index {i} should be visible") + + def test_mask_blocks_future_tokens(self): + """Future tokens should be masked (-inf).""" + cache_size = 32 + sink_size = 4 + window_size = 14 + + cache_positions = torch.arange(cache_size, dtype=torch.long) + + start_pos = 10 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Future tokens (positions > 10) should have mask = -inf + for i in range(11, cache_size): + self.assertEqual(mask[0, i].item(), float('-inf'), f"Future token at position {i} should be masked") + + def test_mask_respects_window(self): + """Tokens outside the window should be masked.""" + cache_size = 32 + sink_size = 4 + window_size = 5 # Only allow 5 recent tokens + + cache_positions = torch.arange(cache_size, dtype=torch.long) + + start_pos = 20 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Positions 16-20 should be visible (within window of 5) + for pos in range(16, 21): + self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible (in window)") + + # Position 15 should be masked (outside window, not a sink) + self.assertEqual(mask[0, 15].item(), float('-inf'), f"Position 15 should be masked (outside window)") + + +class KVCacheWithAttentionSinkTest(unittest.TestCase): + """Test the KV cache with attention sink.""" + + def setUp(self): + torch.manual_seed(42) + self.window_size = 14 + self.sink_size = 4 + self.n_heads = 8 + self.head_dim = 64 + self.max_batch_size = 1 + + # Create model args with enough context for RoPE + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, # Large enough for RoPE + n_heads=self.n_heads, + n_kv_heads=self.n_heads, + dim=self.n_heads * self.head_dim, + ) + + self.rope = RopeWithAttentionSink( + params=self.params, + window_size=self.window_size, + sink_size=self.sink_size, + eviction_batch_size=1, + ) + + self.kv_cache = KVCacheWithAttentionSink( + n_heads=self.n_heads, + head_dim=self.head_dim, + enable_dynamic_shape=True, + rope=self.rope, + max_batch_size=self.max_batch_size, + window_size=self.window_size, + sink_size=self.sink_size, + eviction_batch_size=1, + dtype=torch.float32, + ) + + def test_cache_size(self): + """Cache should be sink_size + window_size * 2.""" + expected_size = self.sink_size + self.window_size * 2 # 4 + 28 = 32 + self.assertEqual(self.kv_cache.k_cache.size(2), expected_size) + self.assertEqual(self.kv_cache.v_cache.size(2), expected_size) + + def test_is_ring_buffer(self): + """Cache should be marked as ring buffer.""" + self.assertTrue(self.kv_cache.is_ring_buffer) + + def test_update_stores_kv(self): + """Update should store key-value pairs.""" + k = torch.randn(1, self.n_heads, 5, self.head_dim) + v = torch.randn(1, self.n_heads, 5, self.head_dim) + input_pos = torch.tensor([0], dtype=torch.long) + + k_out, v_out = self.kv_cache.update(input_pos, k, v) + + # First 5 positions should contain our values + torch.testing.assert_close(k_out[:, :, :5, :], k) + torch.testing.assert_close(v_out[:, :, :5, :], v) + + def test_evict_tokens_returns_zero(self): + """Ring buffer implementation doesn't shift, so evict returns 0.""" + input_pos = torch.tensor([100], dtype=torch.long) + shift = self.kv_cache.evict_tokens(input_pos, 10) + self.assertEqual(shift, 0) + + def test_extended_generation(self): + """Test that cache works for positions beyond cache size.""" + cache_size = self.kv_cache.k_cache.size(2) + + # Fill cache with initial tokens + for pos in range(cache_size + 50): + k = torch.randn(1, self.n_heads, 1, self.head_dim) + v = torch.randn(1, self.n_heads, 1, self.head_dim) + input_pos = torch.tensor([pos], dtype=torch.long) + + k_out, v_out = self.kv_cache.update(input_pos, k, v) + + # Should not raise any errors + self.assertEqual(k_out.shape, self.kv_cache.k_cache.shape) + self.assertEqual(v_out.shape, self.kv_cache.v_cache.shape) + + +class RopeWithAttentionSinkTest(unittest.TestCase): + """Test RoPE for attention sink.""" + + def setUp(self): + torch.manual_seed(42) + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, + n_heads=8, + dim=512, + ) + + self.rope = RopeWithAttentionSink( + params=self.params, + window_size=100, + sink_size=4, + eviction_batch_size=1, + ) + + def test_get_freqs_uses_original_position(self): + """RoPE frequencies should use the original position.""" + input_pos = torch.tensor([50], dtype=torch.long) + seq_len = 5 + + freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seq_len) + + # Should get frequencies for positions 50-54 + expected_cos = self.rope.freqs_cos[50:55] + expected_sin = self.rope.freqs_sin[50:55] + + torch.testing.assert_close(freqs_cos, expected_cos) + torch.testing.assert_close(freqs_sin, expected_sin) + + def test_rerotate_k(self): + """Test re-rotation of k from one position to another.""" + batch_size = 1 + seq_len = 8 + n_heads = self.params.n_heads + head_dim = self.params.dim // n_heads + + k = torch.randn(batch_size, seq_len, n_heads, head_dim) + q = torch.randn(batch_size, seq_len, n_heads, head_dim) + + # Rotate k at position 100 + original_pos = 100 + freqs_cos, freqs_sin = self.rope.get_freqs( + torch.tensor([original_pos], dtype=torch.long), seq_len + ) + _, rotated_k = self.rope.forward(q, k, freqs_cos, freqs_sin) + + # Re-rotate to position 50 + new_pos = 50 + rerotated_k = self.rope.rerotate_k(rotated_k, original_pos, new_pos) + + # This should be equivalent to directly rotating k at position 50 + freqs_cos_new, freqs_sin_new = self.rope.get_freqs( + torch.tensor([new_pos], dtype=torch.long), seq_len + ) + _, expected_k = self.rope.forward(q, k, freqs_cos_new, freqs_sin_new) + + torch.testing.assert_close(rerotated_k, expected_k, rtol=1e-4, atol=1e-4) + + +class CausalMaskWithWraparoundTest(unittest.TestCase): + """Test causal mask with ring buffer wraparound.""" + + def test_mask_after_wraparound(self): + """Test mask after cache has wrapped around.""" + cache_size = 16 + sink_size = 4 + window_size = 6 # cache_size = sink_size + window_size * 2 + + # Simulate cache after generating beyond cache_size: + # The ring buffer wraps, so indices 0-15 contain positions that wrap + # At position 50, with cache_size=16, the cache contains: + # positions 50-15=35 to 49 at various indices + cache_positions = torch.zeros(cache_size, dtype=torch.long) + # Fill with positions that would exist after generating 50 tokens + # idx = pos % cache_size, so: + # pos 34-49 occupy indices 2-15 and 0-1 + for pos in range(34, 50): + idx = pos % cache_size + cache_positions[idx] = pos + + start_pos = 49 # Query at position 49 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Positions within window (49-6+1=44 to 49) should be visible + visible_count = 0 + for i in range(cache_size): + pos = cache_positions[i].item() + if pos >= 44 and pos <= 49: # In window + self.assertEqual(mask[0, i].item(), 0.0, + f"Position {pos} at idx {i} should be visible (in window)") + visible_count += 1 + + # Should have some visible tokens + self.assertGreater(visible_count, 0, "Should have visible tokens in window") + + def test_mask_with_sink_size_zero(self): + """Test pure sliding window (sink_size=0).""" + cache_size = 16 + sink_size = 0 + window_size = 8 + + cache_positions = torch.arange(cache_size, dtype=torch.long) + start_pos = 10 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Positions 3-10 should be visible (within window of 8) + for pos in range(3, 11): + self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible") + + # Positions 0-2 should be masked (outside window) + for pos in range(0, 3): + self.assertEqual(mask[0, pos].item(), float('-inf'), + f"Position {pos} should be masked (outside window)") + + +class PrefillTest(unittest.TestCase): + """Test prefill scenarios.""" + + def setUp(self): + torch.manual_seed(42) + self.window_size = 14 + self.sink_size = 4 + self.n_heads = 8 + self.head_dim = 64 + + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, + n_heads=self.n_heads, + n_kv_heads=self.n_heads, + dim=self.n_heads * self.head_dim, + ) + + self.rope = RopeWithAttentionSink( + params=self.params, + window_size=self.window_size, + sink_size=self.sink_size, + eviction_batch_size=1, + ) + + self.kv_cache = KVCacheWithAttentionSink( + n_heads=self.n_heads, + head_dim=self.head_dim, + enable_dynamic_shape=True, + rope=self.rope, + max_batch_size=1, + window_size=self.window_size, + sink_size=self.sink_size, + eviction_batch_size=1, + dtype=torch.float32, + ) + + def test_prefill_entire_cache(self): + """Test prefill that fills entire cache.""" + cache_size = self.kv_cache.k_cache.size(2) + + k = torch.randn(1, self.n_heads, cache_size, self.head_dim) + v = torch.randn(1, self.n_heads, cache_size, self.head_dim) + input_pos = torch.tensor([0], dtype=torch.long) + + k_out, v_out = self.kv_cache.update(input_pos, k, v) + + # All positions should be filled + torch.testing.assert_close(k_out, k) + torch.testing.assert_close(v_out, v) + + def test_prefill_larger_than_cache_raises_error(self): + """Test that prefill larger than cache size raises an assertion error.""" + cache_size = self.kv_cache.k_cache.size(2) + seq_len = cache_size + 10 + + k = torch.randn(1, self.n_heads, seq_len, self.head_dim) + v = torch.randn(1, self.n_heads, seq_len, self.head_dim) + input_pos = torch.tensor([0], dtype=torch.long) + + # This should raise an assertion error since seq_len > cache_size + with self.assertRaises(AssertionError): + self.kv_cache.update(input_pos, k, v) + + def test_prefill_followed_by_decode(self): + """Test prefill followed by decode steps.""" + cache_size = self.kv_cache.k_cache.size(2) + + # Prefill with 20 tokens + k_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) + v_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) + input_pos = torch.tensor([0], dtype=torch.long) + self.kv_cache.update(input_pos, k_prefill, v_prefill) + + # Decode 5 more tokens + for i in range(5): + k_decode = torch.randn(1, self.n_heads, 1, self.head_dim) + v_decode = torch.randn(1, self.n_heads, 1, self.head_dim) + input_pos = torch.tensor([20 + i], dtype=torch.long) + k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode) + + # Verify cache positions are updated + expected_pos = 20 + i + cache_idx = expected_pos % cache_size + self.assertEqual( + self.kv_cache.cache_positions_manager.cache_positions[cache_idx].item(), + expected_pos + ) + + +class EnableAttentionSinkTest(unittest.TestCase): + """Test the enable_attention_sink transformation.""" + + def setUp(self): + torch.manual_seed(42) + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, + n_heads=8, + n_kv_heads=8, + dim=512, + n_layers=2, + vocab_size=100, + ) + + def test_enable_attention_sink_transforms_model(self): + """Test that enable_attention_sink properly transforms the model.""" + from executorch.examples.models.llama.llama_transformer import construct_transformer + from executorch.examples.models.llama.source_transformation.attention_sink import ( + enable_attention_sink, + ) + + # Create a simple transformer + with torch.device("meta"): + model = construct_transformer(self.params) + model.to_empty(device="cpu") + + # Apply attention sink transformation + model = enable_attention_sink( + module=model, + params=self.params, + sink_size=4, + window_size=100, + eviction_batch_size=1, + ) + + # Check that KV caches are replaced + for layer in model.layers: + kv_cache = layer.attention.kv_cache + self.assertIsInstance(kv_cache, KVCacheWithAttentionSink) + self.assertEqual(kv_cache.sink_size, 4) + self.assertEqual(kv_cache.window_size, 100) + self.assertTrue(kv_cache.is_ring_buffer) + + def test_enable_attention_sink_replaces_rope(self): + """Test that RoPE is replaced with RopeWithAttentionSink.""" + from executorch.examples.models.llama.llama_transformer import construct_transformer + from executorch.examples.models.llama.source_transformation.attention_sink import ( + enable_attention_sink, + ) + + with torch.device("meta"): + model = construct_transformer(self.params) + model.to_empty(device="cpu") + + model = enable_attention_sink( + module=model, + params=self.params, + sink_size=4, + window_size=100, + eviction_batch_size=1, + ) + + # Check that rope is replaced + for layer in model.layers: + rope = layer.attention.rope + self.assertIsInstance(rope, RopeWithAttentionSink) + + +class IntegrationTest(unittest.TestCase): + """Integration tests for end-to-end scenarios.""" + + def setUp(self): + torch.manual_seed(42) + + def test_cache_positions_consistency(self): + """Test that cache positions remain consistent during generation.""" + cache_size = 32 + sink_size = 4 + window_size = 14 + n_heads = 8 + head_dim = 64 + + params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, + n_heads=n_heads, + n_kv_heads=n_heads, + dim=n_heads * head_dim, + ) + + rope = RopeWithAttentionSink( + params=params, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=1, + ) + + kv_cache = KVCacheWithAttentionSink( + n_heads=n_heads, + head_dim=head_dim, + enable_dynamic_shape=True, + rope=rope, + max_batch_size=1, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=1, + dtype=torch.float32, + ) + + # Generate 100 tokens + for pos in range(100): + k = torch.randn(1, n_heads, 1, head_dim) + v = torch.randn(1, n_heads, 1, head_dim) + input_pos = torch.tensor([pos], dtype=torch.long) + + kv_cache.update(input_pos, k, v) + + # Create mask and verify it's valid + mask = kv_cache.create_causal_mask_for_ring_buffer(pos, 1) + + # Mask should not be all -inf (would mean no tokens to attend to) + non_inf_count = (mask != float('-inf')).sum().item() + self.assertGreater(non_inf_count, 0, f"At pos {pos}, mask should have visible tokens") + + # For positions >= sink_size, sinks should always be visible + if pos >= sink_size: + for i in range(sink_size): + cache_pos = kv_cache.cache_positions_manager.cache_positions[i].item() + if cache_pos < sink_size: # This is actually a sink token + self.assertEqual(mask[0, i].item(), 0.0, + f"Sink at idx {i} should be visible at pos {pos}") + + +if __name__ == '__main__': + unittest.main() From 075d521b2960945aedd4f4a908b3f69107c5f18d Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Feb 2026 12:31:38 -0800 Subject: [PATCH 05/11] Remove trailing whitespace Co-authored-by: Claude --- examples/models/llama/eval_llama_lib.py | 404 +----------- .../source_transformation/attention_sink.py | 411 +----------- .../test_attention_sink_ring_buffer.py | 595 +----------------- extension/llm/runner/text_llm_runner.cpp | 311 +-------- 4 files changed, 4 insertions(+), 1717 deletions(-) diff --git a/examples/models/llama/eval_llama_lib.py b/examples/models/llama/eval_llama_lib.py index e13a5299e61..d8e107f74ef 100644 --- a/examples/models/llama/eval_llama_lib.py +++ b/examples/models/llama/eval_llama_lib.py @@ -1,403 +1 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -import argparse - -from typing import Optional, Union - -import torch - -from datasets import load_dataset -from executorch.examples.models.llama.export_llama_lib import ( - get_quantizer_and_quant_params, -) - -from executorch.extension.llm.export.builder import LLMEdgeManager -from lm_eval.evaluator import simple_evaluate -from pytorch_tokenizers import get_tokenizer -from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer -from pytorch_tokenizers.tiktoken import TiktokenTokenizer as Tiktoken -from torch.nn import CrossEntropyLoss -from tqdm import tqdm - -from .evaluate.eager_eval import EagerEvalWrapper - -from .export_llama_lib import ( - _prepare_for_llama_export, - build_args_parser as _build_args_parser, -) - - -class GraphModuleEvalWrapper(EagerEvalWrapper): - """ - A wrapper class for ExecuTorch py-binded integration with the - lm-evaluation-harness library. - """ - - def __init__( - self, - model: torch.fx.GraphModule, - tokenizer: Union[SentencePieceTokenizer, Tiktoken], - max_seq_length: Optional[int] = None, - use_kv_cache: bool = False, - generate_full_logits: bool = False, - enable_dynamic_shape: bool = True, - ): - super().__init__( - model=model, tokenizer=tokenizer, max_seq_length=max_seq_length - ) - self._model = model.to(self.device) - self._use_kv_cache = use_kv_cache - self._generate_full_logits = generate_full_logits - self._enable_dynamic_shape = enable_dynamic_shape - - def _model_call(self, inps): - if self._use_kv_cache: - if not self._enable_dynamic_shape: - # graph module exported without dynamic shape won't work with a different shape. - # And we have to do single token prefill here. - result_logits = [] - for pos in range(inps.shape[-1]): - pos_tensor = torch.tensor([pos], dtype=torch.int64) - logits = self._model( - inps[:, pos : pos + 1], {"input_pos": pos_tensor} - ) - result_logits.append(logits) - if self._generate_full_logits: - return torch.cat(result_logits, dim=1) - else: - return torch.stack(result_logits, dim=1) - else: - pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) - # Batch process the whole sequence. - logits = self._model( - inps[:, : self._max_seq_length], {"input_pos": pos_tensor} - ) - return logits - - else: - return self._model(inps) - - def _model_generate(self, context, max_length, eos_token_id): - raise Exception("unimplemented") - - -class ETPybindEvalWrapper(EagerEvalWrapper): - """ - A wrapper class for ExecuTorch py-binded integration with the - lm-evaluation-harness library. - """ - - def __init__( - self, - model: str, - tokenizer: Union[SentencePieceTokenizer, Tiktoken], - max_seq_length: Optional[int] = None, - ): - super().__init__(None, tokenizer, max_seq_length) # pyre-ignore - self._model = model # Expects model to be path to a .pte file - - from executorch.extension.pybindings.portable_lib import _load_for_executorch - - # Load custom ops and quantized ops. - from executorch.extension.pybindings import portable_lib # noqa # usort: skip - - # Note: import this after portable_lib - from executorch.extension.llm.custom_ops import ( # noqa - custom_ops, # usort: skip - ) - from executorch.kernels import quantized # noqa - - self._et_model = _load_for_executorch(self._model) - self._use_kv_cache = self._et_model.run_method("use_kv_cache")[0] # pyre-ignore - - def _model_call(self, inps): - # Given inps (tokens), return the logits from a single forward call - # inps: Tensor of shape (1, max_seq_len - 1) - # logits: Tensor of shape (1, max_seq_len - 1, vocab_size) - result = [] - if self._use_kv_cache: - pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) - result = self._et_model.forward( - (inps[:, : self._max_seq_length], pos_tensor) - ) - else: - result = self._et_model.forward((inps,)) - if result[0].dim() != 3: - raise ValueError( - f"Dim of logits must be 3 for evaluation. Got {result[0].dim()} here. Add --generate_full_logits in export_llama to generate a pte file with full logits." - ) - return result[0] - - -class ETRunnerEvalWrapper(EagerEvalWrapper): - """ - A wrapper class for ExecuTorch Runtime integration with the - lm-evaluation-harness library. - """ - - def __init__( - self, - model: str, - tokenizer: Union[SentencePieceTokenizer, Tiktoken], - tokenizer_bin: str, - max_seq_length: Optional[int] = None, - ): - super().__init__(None, tokenizer, max_seq_length) # pyre-ignore - self._model = model - self._tokenizer_bin = tokenizer_bin - - def _model_call(self, inps): - # Given inps (tokens), return the logits from a single - # forward call - - # Example: - # inps: Tensor of shape (1, N) - # logits: Tensor of shape (1, N, vocab_size) - pass - - -def gen_eval_wrapper( - model_name: str, - args: argparse.ArgumentParser, - llm_config=None, -): - """ - Generates a wrapper interface around the provided model and tokenizer for - the lm-evaluation-harness library. - - Returns: - eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library. - """ - # If llm_config is not provided, convert args to llm_config - if llm_config is None: - from executorch.extension.llm.export.config.llm_config import LlmConfig - - llm_config = LlmConfig.from_args(args) - - tokenizer = get_tokenizer(llm_config.base.tokenizer_path) - - # ExecuTorch Binary Evaluation - if (model := args.pte) is not None: # pyre-ignore - if (tokenizer_bin := args.tokenizer_bin) is not None: # pyre-ignore - # ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime - return ETRunnerEvalWrapper( - model=model, - tokenizer=tokenizer, - tokenizer_bin=tokenizer_bin, - max_seq_length=llm_config.export.max_seq_length, - ) - - # ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings - return ETPybindEvalWrapper( - model=model, - tokenizer=tokenizer, - # Exported model takes at most (max_seq_length - 1) tokens. - # Note that the eager model takes at most max_seq_length tokens. - max_seq_length=llm_config.export.max_seq_length - 1, - ) - - pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( - llm_config - ) - # GPTFastEvalWrapper: Create a wrapper around a pre-exported model - manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) - - if len(quantizers) != 0: - manager = manager.export().pt2e_quantize(quantizers) - model = ( - manager.pre_autograd_graph_module.to(device="cuda") # pyre-ignore - if torch.cuda.is_available() - else manager.pre_autograd_graph_module.to(device="cpu") - ) - return GraphModuleEvalWrapper( - model=model, - tokenizer=tokenizer, - max_seq_length=llm_config.export.max_seq_length, - use_kv_cache=llm_config.model.use_kv_cache, - enable_dynamic_shape=llm_config.model.enable_dynamic_shape, - ) - else: - # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch - # for quantizers. Currently export only works with --kv_cache, but - # fails without the kv_cache mode - model = ( - manager.model.eval().to(device="cuda") - if torch.cuda.is_available() - else manager.model.eval().to(device="cpu") - ) - - # Save the checkpoint after the eager model preparation is done. - # The reason for this option is that the checkpoint can be used - # to do evaluations in other evaluation platforms, or with data - # that is not available in this eval_llama. We save the checkpoint - # here for consistency with eval_llama. The accuracy results we - # get from eval_llama can be used as a reference to other evaluations. - if args.output_eager_checkpoint_file is not None: # pyre-ignore - torch.save(model, args.output_eager_checkpoint_file) - - return EagerEvalWrapper( - model=model, - tokenizer=tokenizer, - max_seq_length=llm_config.export.max_seq_length, - use_kv_cache=llm_config.model.use_kv_cache, - ) - - -def build_args_parser() -> argparse.ArgumentParser: - # Start with arg parser from export_llama_lib - parser = _build_args_parser() - - # Add additional args specific to eval - parser.add_argument( - "--tasks", - nargs="+", - type=str, - default=["wikitext"], - help="list of lm-eluther tasks to evaluate usage: --tasks task1 task2", - ) - parser.add_argument( - "--limit", - type=int, - default=None, - help="number of samples to evalulate. If not set, evaluate all samples", - ) - parser.add_argument( - "-f", - "--num_fewshot", - type=int, - default=None, - metavar="N", - help="Number of examples in few-shot context", - ) - # Add additional args specific to eval via an ET Runner - # Note: For initial integration, the tokenizer.model is also required - parser.add_argument( - "--pte", - type=str, - default=None, - help="[For ExecuTorch] Path to the ExecuTorch model being evaluated. If provided, don't go through the export flow", - ) - parser.add_argument( - "--tokenizer_bin", - type=str, - default=None, - help="[For ExecuTorch] Path to the Tokenizer binary for evaluating ExecuTorch models via runtime", - ) - parser.add_argument( - "--output_eager_checkpoint_file", - type=str, - default=None, - help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.", - ) - - # Set of parameters secpific to AttentionSink. - parser.add_argument("--attention_sink_eval_tokens", type=int, default=0) - - return parser - - -def eval_llama( - model_name: str, - args: argparse.ArgumentParser, -) -> None: - # Convert args to LlmConfig - from executorch.extension.llm.export.config.llm_config import LlmConfig - - llm_config = LlmConfig.from_args(args) - - # Generate the eval wrapper - eval_wrapper = gen_eval_wrapper(model_name, args, llm_config) - - # Needed for loading mmlu dataset. - # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files - if args.tasks and "mmlu" in args.tasks: - import datasets - - datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True - - # Evaluate the model - with torch.no_grad(): - eval_results = simple_evaluate( - model=eval_wrapper, - tasks=args.tasks, - num_fewshot=args.num_fewshot, - limit=args.limit, - ) - - for task, res in eval_results["results"].items(): - print(f"{task}: {res}") - - -def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParser): - """ - Evaluate the model's perplexity when AttentionSink is enabled. - - This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py - - Updated for the ring-buffer based attention sink implementation. - """ - # Convert args to LlmConfig - from executorch.extension.llm.export.config.llm_config import LlmConfig - - llm_config = LlmConfig.from_args(args) - - assert llm_config.model.use_attention_sink is not None - assert args.attention_sink_eval_tokens > 0 - attention_sink_params = llm_config.model.use_attention_sink.split(",") - assert len(attention_sink_params) == 3 - sink_size = int(attention_sink_params[0]) - window_size = int(attention_sink_params[1]) - - # For the ring buffer implementation, the cache size is sink_size + window_size * 2 - # max_context_length should be >= sink_size + window_size (for RoPE frequencies) - # but can be larger to support extended generation - assert llm_config.export.max_context_length >= sink_size + window_size, ( - f"max_context_length ({llm_config.export.max_context_length}) must be >= " - f"sink_size + window_size ({sink_size + window_size})" - ) - - device = "cuda" if torch.cuda.is_available() else "cpu" - manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) - model = manager.model.eval().to(device=device) - tokenizer = get_tokenizer(llm_config.base.tokenizer_path) - - eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") - - nlls = [] - loss_fn = CrossEntropyLoss(reduction="none") - progress_bar = tqdm(total=args.attention_sink_eval_tokens) - input_pos = 0 - while input_pos < args.attention_sink_eval_tokens: - for text in eval_data["text"]: - tokens = tokenizer.encode(text, bos=False, eos=False) - if len(tokens) <= 0: - continue - with torch.no_grad(): - num_tokens = min( - len(tokens) - 1, args.attention_sink_eval_tokens - input_pos - ) - logits = model( - torch.tensor( - [tokens[:num_tokens]], dtype=torch.int64, device=device - ), - torch.tensor([input_pos], dtype=torch.int64, device=device), - ).squeeze(dim=0) - neg_log_likelihood = loss_fn( - logits, - torch.tensor( - [tokens[1 : num_tokens + 1]], dtype=torch.int64, device=device - ).view(-1), - ) - nlls.append(neg_log_likelihood) - input_pos += num_tokens - progress_bar.update(num_tokens) - if input_pos >= args.attention_sink_eval_tokens: - break - ppl = torch.exp(torch.cat(nlls).mean()) - print(f"Perplexity: {ppl.item()}") - return ppl.item() +# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.import argparsefrom typing import Optional, Unionimport torchfrom datasets import load_datasetfrom executorch.examples.models.llama.export_llama_lib import ( get_quantizer_and_quant_params,)from executorch.extension.llm.export.builder import LLMEdgeManagerfrom lm_eval.evaluator import simple_evaluatefrom pytorch_tokenizers import get_tokenizerfrom pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizerfrom pytorch_tokenizers.tiktoken import TiktokenTokenizer as Tiktokenfrom torch.nn import CrossEntropyLossfrom tqdm import tqdmfrom .evaluate.eager_eval import EagerEvalWrapperfrom .export_llama_lib import ( _prepare_for_llama_export, build_args_parser as _build_args_parser,)class GraphModuleEvalWrapper(EagerEvalWrapper): """ A wrapper class for ExecuTorch py-binded integration with the lm-evaluation-harness library. """ def __init__( self, model: torch.fx.GraphModule, tokenizer: Union[SentencePieceTokenizer, Tiktoken], max_seq_length: Optional[int] = None, use_kv_cache: bool = False, generate_full_logits: bool = False, enable_dynamic_shape: bool = True, ): super().__init__( model=model, tokenizer=tokenizer, max_seq_length=max_seq_length ) self._model = model.to(self.device) self._use_kv_cache = use_kv_cache self._generate_full_logits = generate_full_logits self._enable_dynamic_shape = enable_dynamic_shape def _model_call(self, inps): if self._use_kv_cache: if not self._enable_dynamic_shape: # graph module exported without dynamic shape won't work with a different shape. # And we have to do single token prefill here. result_logits = [] for pos in range(inps.shape[-1]): pos_tensor = torch.tensor([pos], dtype=torch.int64) logits = self._model( inps[:, pos : pos + 1], {"input_pos": pos_tensor} ) result_logits.append(logits) if self._generate_full_logits: return torch.cat(result_logits, dim=1) else: return torch.stack(result_logits, dim=1) else: pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) # Batch process the whole sequence. logits = self._model( inps[:, : self._max_seq_length], {"input_pos": pos_tensor} ) return logits else: return self._model(inps) def _model_generate(self, context, max_length, eos_token_id): raise Exception("unimplemented")class ETPybindEvalWrapper(EagerEvalWrapper): """ A wrapper class for ExecuTorch py-binded integration with the lm-evaluation-harness library. """ def __init__( self, model: str, tokenizer: Union[SentencePieceTokenizer, Tiktoken], max_seq_length: Optional[int] = None, ): super().__init__(None, tokenizer, max_seq_length) # pyre-ignore self._model = model # Expects model to be path to a .pte file from executorch.extension.pybindings.portable_lib import _load_for_executorch # Load custom ops and quantized ops. from executorch.extension.pybindings import portable_lib # noqa # usort: skip # Note: import this after portable_lib from executorch.extension.llm.custom_ops import ( # noqa custom_ops, # usort: skip ) from executorch.kernels import quantized # noqa self._et_model = _load_for_executorch(self._model) self._use_kv_cache = self._et_model.run_method("use_kv_cache")[0] # pyre-ignore def _model_call(self, inps): # Given inps (tokens), return the logits from a single forward call # inps: Tensor of shape (1, max_seq_len - 1) # logits: Tensor of shape (1, max_seq_len - 1, vocab_size) result = [] if self._use_kv_cache: pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) result = self._et_model.forward( (inps[:, : self._max_seq_length], pos_tensor) ) else: result = self._et_model.forward((inps,)) if result[0].dim() != 3: raise ValueError( f"Dim of logits must be 3 for evaluation. Got {result[0].dim()} here. Add --generate_full_logits in export_llama to generate a pte file with full logits." ) return result[0]class ETRunnerEvalWrapper(EagerEvalWrapper): """ A wrapper class for ExecuTorch Runtime integration with the lm-evaluation-harness library. """ def __init__( self, model: str, tokenizer: Union[SentencePieceTokenizer, Tiktoken], tokenizer_bin: str, max_seq_length: Optional[int] = None, ): super().__init__(None, tokenizer, max_seq_length) # pyre-ignore self._model = model self._tokenizer_bin = tokenizer_bin def _model_call(self, inps): # Given inps (tokens), return the logits from a single # forward call # Example: # inps: Tensor of shape (1, N) # logits: Tensor of shape (1, N, vocab_size) passdef gen_eval_wrapper( model_name: str, args: argparse.ArgumentParser, llm_config=None,): """ Generates a wrapper interface around the provided model and tokenizer for the lm-evaluation-harness library. Returns: eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library. """ # If llm_config is not provided, convert args to llm_config if llm_config is None: from executorch.extension.llm.export.config.llm_config import LlmConfig llm_config = LlmConfig.from_args(args) tokenizer = get_tokenizer(llm_config.base.tokenizer_path) # ExecuTorch Binary Evaluation if (model := args.pte) is not None: # pyre-ignore if (tokenizer_bin := args.tokenizer_bin) is not None: # pyre-ignore # ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime return ETRunnerEvalWrapper( model=model, tokenizer=tokenizer, tokenizer_bin=tokenizer_bin, max_seq_length=llm_config.export.max_seq_length, ) # ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings return ETPybindEvalWrapper( model=model, tokenizer=tokenizer, # Exported model takes at most (max_seq_length - 1) tokens. # Note that the eager model takes at most max_seq_length tokens. max_seq_length=llm_config.export.max_seq_length - 1, ) pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( llm_config ) # GPTFastEvalWrapper: Create a wrapper around a pre-exported model manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) if len(quantizers) != 0: manager = manager.export().pt2e_quantize(quantizers) model = ( manager.pre_autograd_graph_module.to(device="cuda") # pyre-ignore if torch.cuda.is_available() else manager.pre_autograd_graph_module.to(device="cpu") ) return GraphModuleEvalWrapper( model=model, tokenizer=tokenizer, max_seq_length=llm_config.export.max_seq_length, use_kv_cache=llm_config.model.use_kv_cache, enable_dynamic_shape=llm_config.model.enable_dynamic_shape, ) else: # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch # for quantizers. Currently export only works with --kv_cache, but # fails without the kv_cache mode model = ( manager.model.eval().to(device="cuda") if torch.cuda.is_available() else manager.model.eval().to(device="cpu") ) # Save the checkpoint after the eager model preparation is done. # The reason for this option is that the checkpoint can be used # to do evaluations in other evaluation platforms, or with data # that is not available in this eval_llama. We save the checkpoint # here for consistency with eval_llama. The accuracy results we # get from eval_llama can be used as a reference to other evaluations. if args.output_eager_checkpoint_file is not None: # pyre-ignore torch.save(model, args.output_eager_checkpoint_file) return EagerEvalWrapper( model=model, tokenizer=tokenizer, max_seq_length=llm_config.export.max_seq_length, use_kv_cache=llm_config.model.use_kv_cache, )def build_args_parser() -> argparse.ArgumentParser: # Start with arg parser from export_llama_lib parser = _build_args_parser() # Add additional args specific to eval parser.add_argument( "--tasks", nargs="+", type=str, default=["wikitext"], help="list of lm-eluther tasks to evaluate usage: --tasks task1 task2", ) parser.add_argument( "--limit", type=int, default=None, help="number of samples to evalulate. If not set, evaluate all samples", ) parser.add_argument( "-f", "--num_fewshot", type=int, default=None, metavar="N", help="Number of examples in few-shot context", ) # Add additional args specific to eval via an ET Runner # Note: For initial integration, the tokenizer.model is also required parser.add_argument( "--pte", type=str, default=None, help="[For ExecuTorch] Path to the ExecuTorch model being evaluated. If provided, don't go through the export flow", ) parser.add_argument( "--tokenizer_bin", type=str, default=None, help="[For ExecuTorch] Path to the Tokenizer binary for evaluating ExecuTorch models via runtime", ) parser.add_argument( "--output_eager_checkpoint_file", type=str, default=None, help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.", ) # Set of parameters secpific to AttentionSink. parser.add_argument("--attention_sink_eval_tokens", type=int, default=0) return parserdef eval_llama( model_name: str, args: argparse.ArgumentParser,) -> None: # Convert args to LlmConfig from executorch.extension.llm.export.config.llm_config import LlmConfig llm_config = LlmConfig.from_args(args) # Generate the eval wrapper eval_wrapper = gen_eval_wrapper(model_name, args, llm_config) # Needed for loading mmlu dataset. # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files if args.tasks and "mmlu" in args.tasks: import datasets datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True # Evaluate the model with torch.no_grad(): eval_results = simple_evaluate( model=eval_wrapper, tasks=args.tasks, num_fewshot=args.num_fewshot, limit=args.limit, ) for task, res in eval_results["results"].items(): print(f"{task}: {res}")def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParser): """ Evaluate the model's perplexity when AttentionSink is enabled. This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py Updated for the ring-buffer based attention sink implementation. """ # Convert args to LlmConfig from executorch.extension.llm.export.config.llm_config import LlmConfig llm_config = LlmConfig.from_args(args) assert llm_config.model.use_attention_sink is not None assert args.attention_sink_eval_tokens > 0 attention_sink_params = llm_config.model.use_attention_sink.split(",") assert len(attention_sink_params) == 3 sink_size = int(attention_sink_params[0]) window_size = int(attention_sink_params[1]) # For the ring buffer implementation, the cache size is sink_size + window_size * 2 # max_context_length should be >= sink_size + window_size (for RoPE frequencies) # but can be larger to support extended generation assert llm_config.export.max_context_length >= sink_size + window_size, ( f"max_context_length ({llm_config.export.max_context_length}) must be >= " f"sink_size + window_size ({sink_size + window_size})" ) device = "cuda" if torch.cuda.is_available() else "cpu" manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) model = manager.model.eval().to(device=device) tokenizer = get_tokenizer(llm_config.base.tokenizer_path) eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") nlls = [] loss_fn = CrossEntropyLoss(reduction="none") progress_bar = tqdm(total=args.attention_sink_eval_tokens) input_pos = 0 while input_pos < args.attention_sink_eval_tokens: for text in eval_data["text"]: tokens = tokenizer.encode(text, bos=False, eos=False) if len(tokens) <= 0: continue with torch.no_grad(): num_tokens = min( len(tokens) - 1, args.attention_sink_eval_tokens - input_pos ) logits = model( torch.tensor( [tokens[:num_tokens]], dtype=torch.int64, device=device ), torch.tensor([input_pos], dtype=torch.int64, device=device), ).squeeze(dim=0) neg_log_likelihood = loss_fn( logits, torch.tensor( [tokens[1 : num_tokens + 1]], dtype=torch.int64, device=device ).view(-1), ) nlls.append(neg_log_likelihood) input_pos += num_tokens progress_bar.update(num_tokens) if input_pos >= args.attention_sink_eval_tokens: break ppl = torch.exp(torch.cat(nlls).mean()) print(f"Perplexity: {ppl.item()}") return ppl.item() \ No newline at end of file diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index c2e0f6606f5..99b0231f2c3 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -1,410 +1 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# Components for supporting Attention Sink. See -# https://arxiv.org/abs/2309.17453 for more details about Attention Sink. - -# This implementation is torch.export compatible using a ring buffer approach -# for the sliding window portion while preserving the sink tokens. - -import types -from typing import Optional, Tuple - -import torch -import torch.nn as nn -from executorch.examples.models.llama.attention import ( - _create_causal_mask_for_ring_buffer, - AttentionMHA, - CachePositionsManager, - KVCache, - RingKVCache, -) -from executorch.examples.models.llama.model_args import ModelArgs -from executorch.examples.models.llama.rope import ( - apply_rotary_emb_to_k, - hf_apply_rotary_emb_to_k, - Rope, -) -from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter - - -class RopeWithAttentionSink(Rope): - """ - Rope that helps adjust position encoding when tokens are shifted in KVCache. - For AttentionSink, when tokens are shifted in KVCache, we need to use positions - in KVCache instead of positions in the actual text. - - For torch.export compatibility, this just passes through the position - the - actual position adjustment is handled by the cache update logic. - - Note: This class uses the model's max_context_len (params.max_context_len) for - RoPE frequency table size, which should be large enough to support generation - beyond the sliding window. The actual KV cache size is sink_size + window_size * 2. - """ - - def __init__( - self, - params: ModelArgs, - window_size: int, - sink_size: int, - eviction_batch_size: int, - ): - super().__init__(params) - if self.params.use_hf_rope: - self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k - else: - self.apply_rotary_emb_to_k = apply_rotary_emb_to_k - # The KV cache size is sink_size + window_size * 2 (ring buffer needs 2x) - self.kv_cache_size = sink_size + window_size * 2 - self.window_size = window_size - self.sink_size = sink_size - # max_context_len from params is used for RoPE frequencies (should be large) - self.max_context_length = self.params.max_context_len - self.eviction_batch_size = eviction_batch_size - - def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): - """ - Get rotary embedding frequencies. - For attention sink, we use the original position - the sliding window - is handled by the cache index management, not by position shifting. - """ - assert input_pos is not None - return super().get_freqs(input_pos, seq_len) - - def rerotate_k( - self, - k: torch.Tensor, - original_position: int, - new_position: int, - ): - """ - Rerotate k from original_position to new_position. - The shape of k is (batch_size, seq_len, n_local_heads, head_dim) - """ - seq_len = k.shape[1] - original_freqs_cos = self.freqs_cos.narrow(0, original_position, seq_len) - original_freqs_sin = self.freqs_sin.narrow(0, original_position, seq_len) - new_freqs_cos = self.freqs_cos.narrow(0, new_position, seq_len) - new_freqs_sin = self.freqs_sin.narrow(0, new_position, seq_len) - rerotation_cos = ( - new_freqs_cos * original_freqs_cos + new_freqs_sin * original_freqs_sin - ) - rerotation_sin = ( - new_freqs_sin * original_freqs_cos - new_freqs_cos * original_freqs_sin - ) - - return self.apply_rotary_emb_to_k(k, rerotation_cos, rerotation_sin) - - -def _create_causal_mask_for_attention_sink( - cache_positions, window_size, sink_size, start_pos, seq_len -): - """ - Create causal mask for attention sink. - - Unlike regular ring buffer mask, this mask: - 1. ALWAYS allows attending to sink tokens (positions 0 to sink_size-1) - 2. Uses sliding window for other tokens - - Args: - cache_positions: Tensor of actual positions stored at each cache index - window_size: Size of the sliding window - sink_size: Number of sink tokens to always attend to - start_pos: Starting position of the current query - seq_len: Length of the current query sequence - """ - pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1) - delta = pos_q - cache_positions - - # Valid if position is filled (>= 0) and causal (delta >= 0) - is_valid = (cache_positions >= 0) & (delta >= 0) - - # Sink tokens (original positions 0 to sink_size-1) are always visible - is_sink = cache_positions < sink_size - - # Window tokens must be within sliding window - is_in_window = delta < window_size - - # Final mask: valid AND (is_sink OR is_in_window) - attn_mask = is_valid & (is_sink | is_in_window) - attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712 - return attn_mask - - -class CachePositionsManagerWithSink(nn.Module): - """ - Manages cache positions for attention sink + sliding window. - - For sink_size=0: behaves exactly like original CachePositionsManager. - For sink_size>0: sink tokens go to fixed positions, rest uses ring buffer. - - IMPORTANT: cache_size should be the actual cache dimension size (2x window for ring buffer). - """ - - def __init__(self, cache_size: int): - super().__init__() - # cache_size is the actual size of the kv cache dimension - self.max_context_length = cache_size - # Use zeros like original CachePositionsManager - self.register_buffer( - "cache_positions", - torch.zeros((self.max_context_length,), dtype=torch.long, device="cpu"), - ) - - def calculate_positions_and_update_indices( - self, input_pos: torch.Tensor, seq_len: int - ) -> torch.Tensor: - """ - Calculate indices into k_cache, v_cache for placing k_val, v_val. - - This is identical to the original CachePositionsManager logic. - """ - start_pos = input_pos[0].item() - torch._check_is_size(start_pos) - - orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos - - # Simple ring buffer: just mod by cache size - indices = orig_indices % self.max_context_length - - # Update cache_positions exactly like original CachePositionsManager - full_t = torch.full((self.max_context_length,), -1, dtype=torch.long) - arange_tensor = torch.arange(self.max_context_length, dtype=torch.long) - cache_positions = torch.where( - arange_tensor < start_pos, self.cache_positions, full_t - ) - self.cache_positions.copy_(cache_positions) - self.cache_positions.index_copy_(0, indices, orig_indices) - - return indices - - -class KVCacheWithAttentionSink(KVCache): - """ - KV cache that supports attention sink with torch.export compatibility. - - Uses a ring buffer approach for the sliding window portion while keeping - the first sink_size tokens fixed. This avoids dynamic shape operations. - - Cache layout: [sink: 0 to sink_size-1] [ring_buffer: sink_size to sink_size + window_size*2 - 1] - """ - - def __init__( - self, - n_heads: int, - head_dim: int, - enable_dynamic_shape: bool, - rope: RopeWithAttentionSink, - window_size: int, - sink_size: int, - eviction_batch_size: int, - max_batch_size: int = 1, - dtype=torch.float32, - ): - # Total cache size is sink_size + window_size * 2 (ring buffer needs 2x) - total_cache_size = sink_size + window_size * 2 - super().__init__( - max_batch_size=max_batch_size, - max_context_length=total_cache_size, - n_heads=n_heads, - head_dim=head_dim, - enable_dynamic_shape=enable_dynamic_shape, - dtype=dtype, - ) - self.rope = rope - self.window_size = window_size - self.sink_size = sink_size - self.eviction_batch_size = eviction_batch_size - self.is_ring_buffer = True - - # Cache positions manager for determining write locations - # Pass the total cache size (same as self.max_context_length after super().__init__) - self.cache_positions_manager = CachePositionsManagerWithSink(total_cache_size) - - def create_causal_mask_for_ring_buffer( - self, start_pos: torch.Tensor, seq_len: int - ): - """ - Create causal mask for the attention with attention sink. - Sink tokens are ALWAYS visible, plus recent tokens in the window. - """ - cache_positions = self.cache_positions_manager.cache_positions - if self.sink_size > 0: - # Use attention sink mask that always allows attending to sink tokens - return _create_causal_mask_for_attention_sink( - cache_positions, self.window_size, self.sink_size, start_pos, seq_len - ) - else: - # Pure ring buffer mode - use original mask with window_size = actual window - return _create_causal_mask_for_ring_buffer( - cache_positions, self.window_size, start_pos, seq_len - ) - - def update( - self, - input_pos: torch.Tensor, - k_val: torch.Tensor, - v_val: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Update KV cache with new key-value pairs. - Uses ring buffer indexing for positions >= sink_size. - """ - seq_len = k_val.size(2) - assert seq_len <= self.k_cache.size( - 2 - ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" - - # Calculate write indices - indices = self.cache_positions_manager.calculate_positions_and_update_indices( - input_pos, seq_len - ) - - start_pos = input_pos[0].item() - torch._check_is_size(start_pos) - self.k_cache.index_copy_(2, indices, k_val) - self.v_cache.index_copy_(2, indices, v_val) - - return self.k_cache, self.v_cache - - def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: - """ - For ring buffer implementation, no explicit eviction is needed. - The ring buffer automatically overwrites old values. - Returns 0 to indicate no position shift is needed. - """ - return 0 - - -def attention_sink_forward( - self, - x: torch.Tensor, - freqs_cos: torch.Tensor, - freqs_sin: torch.Tensor, - input_pos: Optional[torch.Tensor] = None, -): - """ - Forward function for attention with attention sink KV cache. - Uses ring buffer masking for proper attention patterns. - """ - assert self.use_kv_cache - assert input_pos is not None - - bsz, seqlen, _ = x.shape - - # QKV - q, k, v = self.wq(x), self.wk(x), self.wv(x) - q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) - k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - - # RoPE relative positional embeddings - q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) - - # Transpose for attention: [B, H, S, D] - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - # Update KV cache - k, v = self.kv_cache.update(input_pos, k, v) - - # Use ring buffer mask since we have is_ring_buffer=True - start_pos = input_pos[0].item() - torch._check_is_size(start_pos) - attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(start_pos, seqlen) - - # SDPA - output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) - - # Return tuple like original AttentionMHA.forward - return self.wo(output), None - - -def _replace_rope( - module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink -): - def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: - return isinstance(child, Rope) - - def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: - return rope_with_attention_sink - - _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) - - -def _replace_attention( - module: torch.nn.Module, - rope_with_attention_sink: RopeWithAttentionSink, - sink_size: int, - window_size: int, - eviction_batch_size: int, -): - for _, child_module in module._modules.items(): - if len(list(child_module.children())) > 0: # pyre-ignore [16] - _replace_attention( - module=child_module, # pyre-ignore [6] - rope_with_attention_sink=rope_with_attention_sink, - sink_size=sink_size, - window_size=window_size, - eviction_batch_size=eviction_batch_size, - ) - - if isinstance(child_module, AttentionMHA): - kv_cache = child_module.kv_cache - # Always use KVCacheWithAttentionSink, even for sink_size=0 - # This ensures we don't get replaced by CustomKVCache when use_sdpa_with_kv_cache=True - kv_cache_with_attention_sink = KVCacheWithAttentionSink( - n_heads=kv_cache.n_heads, - head_dim=kv_cache.head_dim, - enable_dynamic_shape=kv_cache.enable_dynamic_shape, - rope=rope_with_attention_sink, - max_batch_size=kv_cache.max_batch_size, - window_size=window_size, - sink_size=sink_size, - eviction_batch_size=eviction_batch_size, - dtype=kv_cache.k_cache.dtype, - ) - child_module.kv_cache = kv_cache_with_attention_sink - - # If using SDPACustom (fused SDPA op), enable attention mask support - # so it uses our ring buffer / attention sink mask instead of simple causal mask - if "SDPACustom" in child_module.SDPA.__class__.__name__: - child_module.SDPA.use_attention_mask = True - - # Don't replace forward - let the original AttentionMHA.forward handle it - # since our KVCache has is_ring_buffer=True, it will use the ring buffer mask - - -def enable_attention_sink( - module: torch.nn.Module, - params: ModelArgs, - sink_size: int, - window_size: int, - eviction_batch_size: int, -) -> torch.nn.Module: - """ - Transform the model to be able to run inference with Attention Sink. - There mainly three steps: - - Replace Rope with RopeWithAttentionSink - - Replace Attention's KVCache with KVCacheWithAttentionSink - - Replace Attention's forward with attention_sink_forward - """ - rope_with_attention_sink = RopeWithAttentionSink( - params=params, - window_size=window_size, - sink_size=sink_size, - eviction_batch_size=eviction_batch_size, - ) - _replace_rope(module, rope_with_attention_sink) - _replace_attention( - module=module, - rope_with_attention_sink=rope_with_attention_sink, - sink_size=sink_size, - window_size=window_size, - eviction_batch_size=eviction_batch_size, - ) - return module +# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.# Components for supporting Attention Sink. See# https://arxiv.org/abs/2309.17453 for more details about Attention Sink.# This implementation is torch.export compatible using a ring buffer approach# for the sliding window portion while preserving the sink tokens.import typesfrom typing import Optional, Tupleimport torchimport torch.nn as nnfrom executorch.examples.models.llama.attention import ( _create_causal_mask_for_ring_buffer, AttentionMHA, CachePositionsManager, KVCache, RingKVCache,)from executorch.examples.models.llama.model_args import ModelArgsfrom executorch.examples.models.llama.rope import ( apply_rotary_emb_to_k, hf_apply_rotary_emb_to_k, Rope,)from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filterclass RopeWithAttentionSink(Rope): """ Rope that helps adjust position encoding when tokens are shifted in KVCache. For AttentionSink, when tokens are shifted in KVCache, we need to use positions in KVCache instead of positions in the actual text. For torch.export compatibility, this just passes through the position - the actual position adjustment is handled by the cache update logic. Note: This class uses the model's max_context_len (params.max_context_len) for RoPE frequency table size, which should be large enough to support generation beyond the sliding window. The actual KV cache size is sink_size + window_size * 2. """ def __init__( self, params: ModelArgs, window_size: int, sink_size: int, eviction_batch_size: int, ): super().__init__(params) if self.params.use_hf_rope: self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k else: self.apply_rotary_emb_to_k = apply_rotary_emb_to_k # The KV cache size is sink_size + window_size * 2 (ring buffer needs 2x) self.kv_cache_size = sink_size + window_size * 2 self.window_size = window_size self.sink_size = sink_size # max_context_len from params is used for RoPE frequencies (should be large) self.max_context_length = self.params.max_context_len self.eviction_batch_size = eviction_batch_size def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): """ Get rotary embedding frequencies. For attention sink, we use the original position - the sliding window is handled by the cache index management, not by position shifting. """ assert input_pos is not None return super().get_freqs(input_pos, seq_len) def rerotate_k( self, k: torch.Tensor, original_position: int, new_position: int, ): """ Rerotate k from original_position to new_position. The shape of k is (batch_size, seq_len, n_local_heads, head_dim) """ seq_len = k.shape[1] original_freqs_cos = self.freqs_cos.narrow(0, original_position, seq_len) original_freqs_sin = self.freqs_sin.narrow(0, original_position, seq_len) new_freqs_cos = self.freqs_cos.narrow(0, new_position, seq_len) new_freqs_sin = self.freqs_sin.narrow(0, new_position, seq_len) rerotation_cos = ( new_freqs_cos * original_freqs_cos + new_freqs_sin * original_freqs_sin ) rerotation_sin = ( new_freqs_sin * original_freqs_cos - new_freqs_cos * original_freqs_sin ) return self.apply_rotary_emb_to_k(k, rerotation_cos, rerotation_sin)def _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len): """ Create causal mask for attention sink. Unlike regular ring buffer mask, this mask: 1. ALWAYS allows attending to sink tokens (positions 0 to sink_size-1) 2. Uses sliding window for other tokens Args: cache_positions: Tensor of actual positions stored at each cache index window_size: Size of the sliding window sink_size: Number of sink tokens to always attend to start_pos: Starting position of the current query seq_len: Length of the current query sequence """ pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1) delta = pos_q - cache_positions # Valid if position is filled (>= 0) and causal (delta >= 0) is_valid = (cache_positions >= 0) & (delta >= 0) # Sink tokens (original positions 0 to sink_size-1) are always visible is_sink = cache_positions < sink_size # Window tokens must be within sliding window is_in_window = delta < window_size # Final mask: valid AND (is_sink OR is_in_window) attn_mask = is_valid & (is_sink | is_in_window) attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712 return attn_maskclass CachePositionsManagerWithSink(nn.Module): """ Manages cache positions for attention sink + sliding window. For sink_size=0: behaves exactly like original CachePositionsManager. For sink_size>0: sink tokens go to fixed positions, rest uses ring buffer. IMPORTANT: cache_size should be the actual cache dimension size (2x window for ring buffer). """ def __init__(self, cache_size: int): super().__init__() # cache_size is the actual size of the kv cache dimension self.max_context_length = cache_size # Use zeros like original CachePositionsManager self.register_buffer( "cache_positions", torch.zeros((self.max_context_length,), dtype=torch.long, device="cpu"), ) def calculate_positions_and_update_indices( self, input_pos: torch.Tensor, seq_len: int ) -> torch.Tensor: """ Calculate indices into k_cache, v_cache for placing k_val, v_val. This is identical to the original CachePositionsManager logic. """ start_pos = input_pos[0].item() torch._check_is_size(start_pos) orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos # Simple ring buffer: just mod by cache size indices = orig_indices % self.max_context_length # Update cache_positions exactly like original CachePositionsManager full_t = torch.full((self.max_context_length,), -1, dtype=torch.long) arange_tensor = torch.arange(self.max_context_length, dtype=torch.long) cache_positions = torch.where( arange_tensor < start_pos, self.cache_positions, full_t ) self.cache_positions.copy_(cache_positions) self.cache_positions.index_copy_(0, indices, orig_indices) return indicesclass KVCacheWithAttentionSink(KVCache): """ KV cache that supports attention sink with torch.export compatibility. Uses a ring buffer approach for the sliding window portion while keeping the first sink_size tokens fixed. This avoids dynamic shape operations. Cache layout: [sink: 0 to sink_size-1] [ring_buffer: sink_size to sink_size + window_size*2 - 1] """ def __init__( self, n_heads: int, head_dim: int, enable_dynamic_shape: bool, rope: RopeWithAttentionSink, window_size: int, sink_size: int, eviction_batch_size: int, max_batch_size: int = 1, dtype=torch.float32, ): # Total cache size is sink_size + window_size * 2 (ring buffer needs 2x) total_cache_size = sink_size + window_size * 2 super().__init__( max_batch_size=max_batch_size, max_context_length=total_cache_size, n_heads=n_heads, head_dim=head_dim, enable_dynamic_shape=enable_dynamic_shape, dtype=dtype, ) self.rope = rope self.window_size = window_size self.sink_size = sink_size self.eviction_batch_size = eviction_batch_size self.is_ring_buffer = True # Cache positions manager for determining write locations # Pass the total cache size (same as self.max_context_length after super().__init__) self.cache_positions_manager = CachePositionsManagerWithSink(total_cache_size) def create_causal_mask_for_ring_buffer( self, start_pos: torch.Tensor, seq_len: int ): """ Create causal mask for the attention with attention sink. Sink tokens are ALWAYS visible, plus recent tokens in the window. """ cache_positions = self.cache_positions_manager.cache_positions if self.sink_size > 0: # Use attention sink mask that always allows attending to sink tokens return _create_causal_mask_for_attention_sink( cache_positions, self.window_size, self.sink_size, start_pos, seq_len ) else: # Pure ring buffer mode - use original mask with window_size = actual window return _create_causal_mask_for_ring_buffer( cache_positions, self.window_size, start_pos, seq_len ) def update( self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Update KV cache with new key-value pairs. Uses ring buffer indexing for positions >= sink_size. """ seq_len = k_val.size(2) assert seq_len <= self.k_cache.size( 2 ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" # Calculate write indices indices = self.cache_positions_manager.calculate_positions_and_update_indices( input_pos, seq_len ) start_pos = input_pos[0].item() torch._check_is_size(start_pos) self.k_cache.index_copy_(2, indices, k_val) self.v_cache.index_copy_(2, indices, v_val) return self.k_cache, self.v_cache def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: """ For ring buffer implementation, no explicit eviction is needed. The ring buffer automatically overwrites old values. Returns 0 to indicate no position shift is needed. """ return 0def attention_sink_forward( self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, input_pos: Optional[torch.Tensor] = None,): """ Forward function for attention with attention sink KV cache. Uses ring buffer masking for proper attention patterns. """ assert self.use_kv_cache assert input_pos is not None bsz, seqlen, _ = x.shape # QKV q, k, v = self.wq(x), self.wk(x), self.wv(x) q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) # RoPE relative positional embeddings q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) # Transpose for attention: [B, H, S, D] q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) # Update KV cache k, v = self.kv_cache.update(input_pos, k, v) # Use ring buffer mask since we have is_ring_buffer=True start_pos = input_pos[0].item() torch._check_is_size(start_pos) attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(start_pos, seqlen) # SDPA output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) # Return tuple like original AttentionMHA.forward return self.wo(output), Nonedef _replace_rope( module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink): def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: return isinstance(child, Rope) def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: return rope_with_attention_sink _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)def _replace_attention( module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink, sink_size: int, window_size: int, eviction_batch_size: int,): for _, child_module in module._modules.items(): if len(list(child_module.children())) > 0: # pyre-ignore [16] _replace_attention( module=child_module, # pyre-ignore [6] rope_with_attention_sink=rope_with_attention_sink, sink_size=sink_size, window_size=window_size, eviction_batch_size=eviction_batch_size, ) if isinstance(child_module, AttentionMHA): kv_cache = child_module.kv_cache # Always use KVCacheWithAttentionSink, even for sink_size=0 # This ensures we don't get replaced by CustomKVCache when use_sdpa_with_kv_cache=True kv_cache_with_attention_sink = KVCacheWithAttentionSink( n_heads=kv_cache.n_heads, head_dim=kv_cache.head_dim, enable_dynamic_shape=kv_cache.enable_dynamic_shape, rope=rope_with_attention_sink, max_batch_size=kv_cache.max_batch_size, window_size=window_size, sink_size=sink_size, eviction_batch_size=eviction_batch_size, dtype=kv_cache.k_cache.dtype, ) child_module.kv_cache = kv_cache_with_attention_sink # If using SDPACustom (fused SDPA op), enable attention mask support # so it uses our ring buffer / attention sink mask instead of simple causal mask if "SDPACustom" in child_module.SDPA.__class__.__name__: child_module.SDPA.use_attention_mask = True # Don't replace forward - let the original AttentionMHA.forward handle it # since our KVCache has is_ring_buffer=True, it will use the ring buffer maskdef enable_attention_sink( module: torch.nn.Module, params: ModelArgs, sink_size: int, window_size: int, eviction_batch_size: int,) -> torch.nn.Module: """ Transform the model to be able to run inference with Attention Sink. There mainly three steps: - Replace Rope with RopeWithAttentionSink - Replace Attention's KVCache with KVCacheWithAttentionSink - Replace Attention's forward with attention_sink_forward """ rope_with_attention_sink = RopeWithAttentionSink( params=params, window_size=window_size, sink_size=sink_size, eviction_batch_size=eviction_batch_size, ) _replace_rope(module, rope_with_attention_sink) _replace_attention( module=module, rope_with_attention_sink=rope_with_attention_sink, sink_size=sink_size, window_size=window_size, eviction_batch_size=eviction_batch_size, ) return module \ No newline at end of file diff --git a/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py index 7109129caca..7ccad6aadbf 100644 --- a/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py +++ b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py @@ -1,594 +1 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Tests for the ring-buffer based attention sink implementation. - -This tests the torch.export-compatible implementation that uses a ring buffer -for the sliding window rather than explicit token eviction. - -Usage: - # Run with pytest - python -m pytest examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py -v - - # Or run directly - python examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py -""" - -import unittest - -import torch -from executorch.examples.models.llama.model_args import ModelArgs - -from executorch.examples.models.llama.source_transformation.attention_sink import ( - CachePositionsManagerWithSink, - KVCacheWithAttentionSink, - RopeWithAttentionSink, - _create_causal_mask_for_attention_sink, -) - - -class CachePositionsManagerWithSinkTest(unittest.TestCase): - """Test the cache positions manager for ring buffer indexing.""" - - def setUp(self): - self.cache_size = 32 # Total cache size (e.g., sink_size + window_size * 2) - self.manager = CachePositionsManagerWithSink(self.cache_size) - - def test_initial_positions_are_zero(self): - """Cache positions should start as zeros.""" - expected = torch.zeros(self.cache_size, dtype=torch.long) - torch.testing.assert_close(self.manager.cache_positions, expected) - - def test_simple_update(self): - """Test simple sequential update without wraparound.""" - input_pos = torch.tensor([0], dtype=torch.long) - seq_len = 5 - indices = self.manager.calculate_positions_and_update_indices(input_pos, seq_len) - - # Should return indices 0, 1, 2, 3, 4 - expected_indices = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long) - torch.testing.assert_close(indices, expected_indices) - - # Cache positions at those indices should be the original positions - for i in range(5): - self.assertEqual(self.manager.cache_positions[i].item(), i) - - def test_wraparound(self): - """Test ring buffer wraparound.""" - # Fill cache to position 30 - input_pos = torch.tensor([0], dtype=torch.long) - self.manager.calculate_positions_and_update_indices(input_pos, 30) - - # Add 5 more tokens at position 30 - should wrap around - input_pos = torch.tensor([30], dtype=torch.long) - indices = self.manager.calculate_positions_and_update_indices(input_pos, 5) - - # Indices should wrap: 30 % 32 = 30, 31, 0, 1, 2 - expected_indices = torch.tensor([30, 31, 0, 1, 2], dtype=torch.long) - torch.testing.assert_close(indices, expected_indices) - - def test_cache_positions_track_original_positions(self): - """Cache positions should track which original position is at each index.""" - # Fill with positions 0-31 - input_pos = torch.tensor([0], dtype=torch.long) - self.manager.calculate_positions_and_update_indices(input_pos, 32) - - # Now add position 32 which wraps to index 0 - input_pos = torch.tensor([32], dtype=torch.long) - self.manager.calculate_positions_and_update_indices(input_pos, 1) - - # Index 0 should now contain original position 32 - self.assertEqual(self.manager.cache_positions[0].item(), 32) - - -class CausalMaskTest(unittest.TestCase): - """Test the causal mask generation for attention sink.""" - - def test_mask_allows_sink_tokens(self): - """Sink tokens should always be visible (mask = 0).""" - cache_size = 32 - sink_size = 4 - window_size = 14 # cache_size = sink_size + window_size * 2 - - # Create cache positions where positions 0-3 are sink tokens - cache_positions = torch.arange(cache_size, dtype=torch.long) - - start_pos = 20 # Query at position 20 - seq_len = 1 - - mask = _create_causal_mask_for_attention_sink( - cache_positions, window_size, sink_size, start_pos, seq_len - ) - - # Sink tokens (indices 0-3, original positions 0-3) should have mask = 0 - for i in range(sink_size): - self.assertEqual(mask[0, i].item(), 0.0, f"Sink token at index {i} should be visible") - - def test_mask_blocks_future_tokens(self): - """Future tokens should be masked (-inf).""" - cache_size = 32 - sink_size = 4 - window_size = 14 - - cache_positions = torch.arange(cache_size, dtype=torch.long) - - start_pos = 10 - seq_len = 1 - - mask = _create_causal_mask_for_attention_sink( - cache_positions, window_size, sink_size, start_pos, seq_len - ) - - # Future tokens (positions > 10) should have mask = -inf - for i in range(11, cache_size): - self.assertEqual(mask[0, i].item(), float('-inf'), f"Future token at position {i} should be masked") - - def test_mask_respects_window(self): - """Tokens outside the window should be masked.""" - cache_size = 32 - sink_size = 4 - window_size = 5 # Only allow 5 recent tokens - - cache_positions = torch.arange(cache_size, dtype=torch.long) - - start_pos = 20 - seq_len = 1 - - mask = _create_causal_mask_for_attention_sink( - cache_positions, window_size, sink_size, start_pos, seq_len - ) - - # Positions 16-20 should be visible (within window of 5) - for pos in range(16, 21): - self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible (in window)") - - # Position 15 should be masked (outside window, not a sink) - self.assertEqual(mask[0, 15].item(), float('-inf'), f"Position 15 should be masked (outside window)") - - -class KVCacheWithAttentionSinkTest(unittest.TestCase): - """Test the KV cache with attention sink.""" - - def setUp(self): - torch.manual_seed(42) - self.window_size = 14 - self.sink_size = 4 - self.n_heads = 8 - self.head_dim = 64 - self.max_batch_size = 1 - - # Create model args with enough context for RoPE - self.params = ModelArgs( - use_kv_cache=True, - enable_dynamic_shape=True, - max_context_len=1024, # Large enough for RoPE - n_heads=self.n_heads, - n_kv_heads=self.n_heads, - dim=self.n_heads * self.head_dim, - ) - - self.rope = RopeWithAttentionSink( - params=self.params, - window_size=self.window_size, - sink_size=self.sink_size, - eviction_batch_size=1, - ) - - self.kv_cache = KVCacheWithAttentionSink( - n_heads=self.n_heads, - head_dim=self.head_dim, - enable_dynamic_shape=True, - rope=self.rope, - max_batch_size=self.max_batch_size, - window_size=self.window_size, - sink_size=self.sink_size, - eviction_batch_size=1, - dtype=torch.float32, - ) - - def test_cache_size(self): - """Cache should be sink_size + window_size * 2.""" - expected_size = self.sink_size + self.window_size * 2 # 4 + 28 = 32 - self.assertEqual(self.kv_cache.k_cache.size(2), expected_size) - self.assertEqual(self.kv_cache.v_cache.size(2), expected_size) - - def test_is_ring_buffer(self): - """Cache should be marked as ring buffer.""" - self.assertTrue(self.kv_cache.is_ring_buffer) - - def test_update_stores_kv(self): - """Update should store key-value pairs.""" - k = torch.randn(1, self.n_heads, 5, self.head_dim) - v = torch.randn(1, self.n_heads, 5, self.head_dim) - input_pos = torch.tensor([0], dtype=torch.long) - - k_out, v_out = self.kv_cache.update(input_pos, k, v) - - # First 5 positions should contain our values - torch.testing.assert_close(k_out[:, :, :5, :], k) - torch.testing.assert_close(v_out[:, :, :5, :], v) - - def test_evict_tokens_returns_zero(self): - """Ring buffer implementation doesn't shift, so evict returns 0.""" - input_pos = torch.tensor([100], dtype=torch.long) - shift = self.kv_cache.evict_tokens(input_pos, 10) - self.assertEqual(shift, 0) - - def test_extended_generation(self): - """Test that cache works for positions beyond cache size.""" - cache_size = self.kv_cache.k_cache.size(2) - - # Fill cache with initial tokens - for pos in range(cache_size + 50): - k = torch.randn(1, self.n_heads, 1, self.head_dim) - v = torch.randn(1, self.n_heads, 1, self.head_dim) - input_pos = torch.tensor([pos], dtype=torch.long) - - k_out, v_out = self.kv_cache.update(input_pos, k, v) - - # Should not raise any errors - self.assertEqual(k_out.shape, self.kv_cache.k_cache.shape) - self.assertEqual(v_out.shape, self.kv_cache.v_cache.shape) - - -class RopeWithAttentionSinkTest(unittest.TestCase): - """Test RoPE for attention sink.""" - - def setUp(self): - torch.manual_seed(42) - self.params = ModelArgs( - use_kv_cache=True, - enable_dynamic_shape=True, - max_context_len=1024, - n_heads=8, - dim=512, - ) - - self.rope = RopeWithAttentionSink( - params=self.params, - window_size=100, - sink_size=4, - eviction_batch_size=1, - ) - - def test_get_freqs_uses_original_position(self): - """RoPE frequencies should use the original position.""" - input_pos = torch.tensor([50], dtype=torch.long) - seq_len = 5 - - freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seq_len) - - # Should get frequencies for positions 50-54 - expected_cos = self.rope.freqs_cos[50:55] - expected_sin = self.rope.freqs_sin[50:55] - - torch.testing.assert_close(freqs_cos, expected_cos) - torch.testing.assert_close(freqs_sin, expected_sin) - - def test_rerotate_k(self): - """Test re-rotation of k from one position to another.""" - batch_size = 1 - seq_len = 8 - n_heads = self.params.n_heads - head_dim = self.params.dim // n_heads - - k = torch.randn(batch_size, seq_len, n_heads, head_dim) - q = torch.randn(batch_size, seq_len, n_heads, head_dim) - - # Rotate k at position 100 - original_pos = 100 - freqs_cos, freqs_sin = self.rope.get_freqs( - torch.tensor([original_pos], dtype=torch.long), seq_len - ) - _, rotated_k = self.rope.forward(q, k, freqs_cos, freqs_sin) - - # Re-rotate to position 50 - new_pos = 50 - rerotated_k = self.rope.rerotate_k(rotated_k, original_pos, new_pos) - - # This should be equivalent to directly rotating k at position 50 - freqs_cos_new, freqs_sin_new = self.rope.get_freqs( - torch.tensor([new_pos], dtype=torch.long), seq_len - ) - _, expected_k = self.rope.forward(q, k, freqs_cos_new, freqs_sin_new) - - torch.testing.assert_close(rerotated_k, expected_k, rtol=1e-4, atol=1e-4) - - -class CausalMaskWithWraparoundTest(unittest.TestCase): - """Test causal mask with ring buffer wraparound.""" - - def test_mask_after_wraparound(self): - """Test mask after cache has wrapped around.""" - cache_size = 16 - sink_size = 4 - window_size = 6 # cache_size = sink_size + window_size * 2 - - # Simulate cache after generating beyond cache_size: - # The ring buffer wraps, so indices 0-15 contain positions that wrap - # At position 50, with cache_size=16, the cache contains: - # positions 50-15=35 to 49 at various indices - cache_positions = torch.zeros(cache_size, dtype=torch.long) - # Fill with positions that would exist after generating 50 tokens - # idx = pos % cache_size, so: - # pos 34-49 occupy indices 2-15 and 0-1 - for pos in range(34, 50): - idx = pos % cache_size - cache_positions[idx] = pos - - start_pos = 49 # Query at position 49 - seq_len = 1 - - mask = _create_causal_mask_for_attention_sink( - cache_positions, window_size, sink_size, start_pos, seq_len - ) - - # Positions within window (49-6+1=44 to 49) should be visible - visible_count = 0 - for i in range(cache_size): - pos = cache_positions[i].item() - if pos >= 44 and pos <= 49: # In window - self.assertEqual(mask[0, i].item(), 0.0, - f"Position {pos} at idx {i} should be visible (in window)") - visible_count += 1 - - # Should have some visible tokens - self.assertGreater(visible_count, 0, "Should have visible tokens in window") - - def test_mask_with_sink_size_zero(self): - """Test pure sliding window (sink_size=0).""" - cache_size = 16 - sink_size = 0 - window_size = 8 - - cache_positions = torch.arange(cache_size, dtype=torch.long) - start_pos = 10 - seq_len = 1 - - mask = _create_causal_mask_for_attention_sink( - cache_positions, window_size, sink_size, start_pos, seq_len - ) - - # Positions 3-10 should be visible (within window of 8) - for pos in range(3, 11): - self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible") - - # Positions 0-2 should be masked (outside window) - for pos in range(0, 3): - self.assertEqual(mask[0, pos].item(), float('-inf'), - f"Position {pos} should be masked (outside window)") - - -class PrefillTest(unittest.TestCase): - """Test prefill scenarios.""" - - def setUp(self): - torch.manual_seed(42) - self.window_size = 14 - self.sink_size = 4 - self.n_heads = 8 - self.head_dim = 64 - - self.params = ModelArgs( - use_kv_cache=True, - enable_dynamic_shape=True, - max_context_len=1024, - n_heads=self.n_heads, - n_kv_heads=self.n_heads, - dim=self.n_heads * self.head_dim, - ) - - self.rope = RopeWithAttentionSink( - params=self.params, - window_size=self.window_size, - sink_size=self.sink_size, - eviction_batch_size=1, - ) - - self.kv_cache = KVCacheWithAttentionSink( - n_heads=self.n_heads, - head_dim=self.head_dim, - enable_dynamic_shape=True, - rope=self.rope, - max_batch_size=1, - window_size=self.window_size, - sink_size=self.sink_size, - eviction_batch_size=1, - dtype=torch.float32, - ) - - def test_prefill_entire_cache(self): - """Test prefill that fills entire cache.""" - cache_size = self.kv_cache.k_cache.size(2) - - k = torch.randn(1, self.n_heads, cache_size, self.head_dim) - v = torch.randn(1, self.n_heads, cache_size, self.head_dim) - input_pos = torch.tensor([0], dtype=torch.long) - - k_out, v_out = self.kv_cache.update(input_pos, k, v) - - # All positions should be filled - torch.testing.assert_close(k_out, k) - torch.testing.assert_close(v_out, v) - - def test_prefill_larger_than_cache_raises_error(self): - """Test that prefill larger than cache size raises an assertion error.""" - cache_size = self.kv_cache.k_cache.size(2) - seq_len = cache_size + 10 - - k = torch.randn(1, self.n_heads, seq_len, self.head_dim) - v = torch.randn(1, self.n_heads, seq_len, self.head_dim) - input_pos = torch.tensor([0], dtype=torch.long) - - # This should raise an assertion error since seq_len > cache_size - with self.assertRaises(AssertionError): - self.kv_cache.update(input_pos, k, v) - - def test_prefill_followed_by_decode(self): - """Test prefill followed by decode steps.""" - cache_size = self.kv_cache.k_cache.size(2) - - # Prefill with 20 tokens - k_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) - v_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) - input_pos = torch.tensor([0], dtype=torch.long) - self.kv_cache.update(input_pos, k_prefill, v_prefill) - - # Decode 5 more tokens - for i in range(5): - k_decode = torch.randn(1, self.n_heads, 1, self.head_dim) - v_decode = torch.randn(1, self.n_heads, 1, self.head_dim) - input_pos = torch.tensor([20 + i], dtype=torch.long) - k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode) - - # Verify cache positions are updated - expected_pos = 20 + i - cache_idx = expected_pos % cache_size - self.assertEqual( - self.kv_cache.cache_positions_manager.cache_positions[cache_idx].item(), - expected_pos - ) - - -class EnableAttentionSinkTest(unittest.TestCase): - """Test the enable_attention_sink transformation.""" - - def setUp(self): - torch.manual_seed(42) - self.params = ModelArgs( - use_kv_cache=True, - enable_dynamic_shape=True, - max_context_len=1024, - n_heads=8, - n_kv_heads=8, - dim=512, - n_layers=2, - vocab_size=100, - ) - - def test_enable_attention_sink_transforms_model(self): - """Test that enable_attention_sink properly transforms the model.""" - from executorch.examples.models.llama.llama_transformer import construct_transformer - from executorch.examples.models.llama.source_transformation.attention_sink import ( - enable_attention_sink, - ) - - # Create a simple transformer - with torch.device("meta"): - model = construct_transformer(self.params) - model.to_empty(device="cpu") - - # Apply attention sink transformation - model = enable_attention_sink( - module=model, - params=self.params, - sink_size=4, - window_size=100, - eviction_batch_size=1, - ) - - # Check that KV caches are replaced - for layer in model.layers: - kv_cache = layer.attention.kv_cache - self.assertIsInstance(kv_cache, KVCacheWithAttentionSink) - self.assertEqual(kv_cache.sink_size, 4) - self.assertEqual(kv_cache.window_size, 100) - self.assertTrue(kv_cache.is_ring_buffer) - - def test_enable_attention_sink_replaces_rope(self): - """Test that RoPE is replaced with RopeWithAttentionSink.""" - from executorch.examples.models.llama.llama_transformer import construct_transformer - from executorch.examples.models.llama.source_transformation.attention_sink import ( - enable_attention_sink, - ) - - with torch.device("meta"): - model = construct_transformer(self.params) - model.to_empty(device="cpu") - - model = enable_attention_sink( - module=model, - params=self.params, - sink_size=4, - window_size=100, - eviction_batch_size=1, - ) - - # Check that rope is replaced - for layer in model.layers: - rope = layer.attention.rope - self.assertIsInstance(rope, RopeWithAttentionSink) - - -class IntegrationTest(unittest.TestCase): - """Integration tests for end-to-end scenarios.""" - - def setUp(self): - torch.manual_seed(42) - - def test_cache_positions_consistency(self): - """Test that cache positions remain consistent during generation.""" - cache_size = 32 - sink_size = 4 - window_size = 14 - n_heads = 8 - head_dim = 64 - - params = ModelArgs( - use_kv_cache=True, - enable_dynamic_shape=True, - max_context_len=1024, - n_heads=n_heads, - n_kv_heads=n_heads, - dim=n_heads * head_dim, - ) - - rope = RopeWithAttentionSink( - params=params, - window_size=window_size, - sink_size=sink_size, - eviction_batch_size=1, - ) - - kv_cache = KVCacheWithAttentionSink( - n_heads=n_heads, - head_dim=head_dim, - enable_dynamic_shape=True, - rope=rope, - max_batch_size=1, - window_size=window_size, - sink_size=sink_size, - eviction_batch_size=1, - dtype=torch.float32, - ) - - # Generate 100 tokens - for pos in range(100): - k = torch.randn(1, n_heads, 1, head_dim) - v = torch.randn(1, n_heads, 1, head_dim) - input_pos = torch.tensor([pos], dtype=torch.long) - - kv_cache.update(input_pos, k, v) - - # Create mask and verify it's valid - mask = kv_cache.create_causal_mask_for_ring_buffer(pos, 1) - - # Mask should not be all -inf (would mean no tokens to attend to) - non_inf_count = (mask != float('-inf')).sum().item() - self.assertGreater(non_inf_count, 0, f"At pos {pos}, mask should have visible tokens") - - # For positions >= sink_size, sinks should always be visible - if pos >= sink_size: - for i in range(sink_size): - cache_pos = kv_cache.cache_positions_manager.cache_positions[i].item() - if cache_pos < sink_size: # This is actually a sink token - self.assertEqual(mask[0, i].item(), 0.0, - f"Sink at idx {i} should be visible at pos {pos}") - - -if __name__ == '__main__': - unittest.main() +# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree."""Tests for the ring-buffer based attention sink implementation.This tests the torch.export-compatible implementation that uses a ring bufferfor the sliding window rather than explicit token eviction.Usage: # Run with pytest python -m pytest examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py -v # Or run directly python examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py"""import unittestimport torchfrom executorch.examples.models.llama.model_args import ModelArgsfrom executorch.examples.models.llama.source_transformation.attention_sink import ( CachePositionsManagerWithSink, KVCacheWithAttentionSink, RopeWithAttentionSink, _create_causal_mask_for_attention_sink,)class CachePositionsManagerWithSinkTest(unittest.TestCase): """Test the cache positions manager for ring buffer indexing.""" def setUp(self): self.cache_size = 32 # Total cache size (e.g., sink_size + window_size * 2) self.manager = CachePositionsManagerWithSink(self.cache_size) def test_initial_positions_are_zero(self): """Cache positions should start as zeros.""" expected = torch.zeros(self.cache_size, dtype=torch.long) torch.testing.assert_close(self.manager.cache_positions, expected) def test_simple_update(self): """Test simple sequential update without wraparound.""" input_pos = torch.tensor([0], dtype=torch.long) seq_len = 5 indices = self.manager.calculate_positions_and_update_indices(input_pos, seq_len) # Should return indices 0, 1, 2, 3, 4 expected_indices = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long) torch.testing.assert_close(indices, expected_indices) # Cache positions at those indices should be the original positions for i in range(5): self.assertEqual(self.manager.cache_positions[i].item(), i) def test_wraparound(self): """Test ring buffer wraparound.""" # Fill cache to position 30 input_pos = torch.tensor([0], dtype=torch.long) self.manager.calculate_positions_and_update_indices(input_pos, 30) # Add 5 more tokens at position 30 - should wrap around input_pos = torch.tensor([30], dtype=torch.long) indices = self.manager.calculate_positions_and_update_indices(input_pos, 5) # Indices should wrap: 30 % 32 = 30, 31, 0, 1, 2 expected_indices = torch.tensor([30, 31, 0, 1, 2], dtype=torch.long) torch.testing.assert_close(indices, expected_indices) def test_cache_positions_track_original_positions(self): """Cache positions should track which original position is at each index.""" # Fill with positions 0-31 input_pos = torch.tensor([0], dtype=torch.long) self.manager.calculate_positions_and_update_indices(input_pos, 32) # Now add position 32 which wraps to index 0 input_pos = torch.tensor([32], dtype=torch.long) self.manager.calculate_positions_and_update_indices(input_pos, 1) # Index 0 should now contain original position 32 self.assertEqual(self.manager.cache_positions[0].item(), 32)class CausalMaskTest(unittest.TestCase): """Test the causal mask generation for attention sink.""" def test_mask_allows_sink_tokens(self): """Sink tokens should always be visible (mask = 0).""" cache_size = 32 sink_size = 4 window_size = 14 # cache_size = sink_size + window_size * 2 # Create cache positions where positions 0-3 are sink tokens cache_positions = torch.arange(cache_size, dtype=torch.long) start_pos = 20 # Query at position 20 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Sink tokens (indices 0-3, original positions 0-3) should have mask = 0 for i in range(sink_size): self.assertEqual(mask[0, i].item(), 0.0, f"Sink token at index {i} should be visible") def test_mask_blocks_future_tokens(self): """Future tokens should be masked (-inf).""" cache_size = 32 sink_size = 4 window_size = 14 cache_positions = torch.arange(cache_size, dtype=torch.long) start_pos = 10 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Future tokens (positions > 10) should have mask = -inf for i in range(11, cache_size): self.assertEqual(mask[0, i].item(), float('-inf'), f"Future token at position {i} should be masked") def test_mask_respects_window(self): """Tokens outside the window should be masked.""" cache_size = 32 sink_size = 4 window_size = 5 # Only allow 5 recent tokens cache_positions = torch.arange(cache_size, dtype=torch.long) start_pos = 20 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Positions 16-20 should be visible (within window of 5) for pos in range(16, 21): self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible (in window)") # Position 15 should be masked (outside window, not a sink) self.assertEqual(mask[0, 15].item(), float('-inf'), f"Position 15 should be masked (outside window)")class KVCacheWithAttentionSinkTest(unittest.TestCase): """Test the KV cache with attention sink.""" def setUp(self): torch.manual_seed(42) self.window_size = 14 self.sink_size = 4 self.n_heads = 8 self.head_dim = 64 self.max_batch_size = 1 # Create model args with enough context for RoPE self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, # Large enough for RoPE n_heads=self.n_heads, n_kv_heads=self.n_heads, dim=self.n_heads * self.head_dim, ) self.rope = RopeWithAttentionSink( params=self.params, window_size=self.window_size, sink_size=self.sink_size, eviction_batch_size=1, ) self.kv_cache = KVCacheWithAttentionSink( n_heads=self.n_heads, head_dim=self.head_dim, enable_dynamic_shape=True, rope=self.rope, max_batch_size=self.max_batch_size, window_size=self.window_size, sink_size=self.sink_size, eviction_batch_size=1, dtype=torch.float32, ) def test_cache_size(self): """Cache should be sink_size + window_size * 2.""" expected_size = self.sink_size + self.window_size * 2 # 4 + 28 = 32 self.assertEqual(self.kv_cache.k_cache.size(2), expected_size) self.assertEqual(self.kv_cache.v_cache.size(2), expected_size) def test_is_ring_buffer(self): """Cache should be marked as ring buffer.""" self.assertTrue(self.kv_cache.is_ring_buffer) def test_update_stores_kv(self): """Update should store key-value pairs.""" k = torch.randn(1, self.n_heads, 5, self.head_dim) v = torch.randn(1, self.n_heads, 5, self.head_dim) input_pos = torch.tensor([0], dtype=torch.long) k_out, v_out = self.kv_cache.update(input_pos, k, v) # First 5 positions should contain our values torch.testing.assert_close(k_out[:, :, :5, :], k) torch.testing.assert_close(v_out[:, :, :5, :], v) def test_evict_tokens_returns_zero(self): """Ring buffer implementation doesn't shift, so evict returns 0.""" input_pos = torch.tensor([100], dtype=torch.long) shift = self.kv_cache.evict_tokens(input_pos, 10) self.assertEqual(shift, 0) def test_extended_generation(self): """Test that cache works for positions beyond cache size.""" cache_size = self.kv_cache.k_cache.size(2) # Fill cache with initial tokens for pos in range(cache_size + 50): k = torch.randn(1, self.n_heads, 1, self.head_dim) v = torch.randn(1, self.n_heads, 1, self.head_dim) input_pos = torch.tensor([pos], dtype=torch.long) k_out, v_out = self.kv_cache.update(input_pos, k, v) # Should not raise any errors self.assertEqual(k_out.shape, self.kv_cache.k_cache.shape) self.assertEqual(v_out.shape, self.kv_cache.v_cache.shape)class RopeWithAttentionSinkTest(unittest.TestCase): """Test RoPE for attention sink.""" def setUp(self): torch.manual_seed(42) self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, n_heads=8, dim=512, ) self.rope = RopeWithAttentionSink( params=self.params, window_size=100, sink_size=4, eviction_batch_size=1, ) def test_get_freqs_uses_original_position(self): """RoPE frequencies should use the original position.""" input_pos = torch.tensor([50], dtype=torch.long) seq_len = 5 freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seq_len) # Should get frequencies for positions 50-54 expected_cos = self.rope.freqs_cos[50:55] expected_sin = self.rope.freqs_sin[50:55] torch.testing.assert_close(freqs_cos, expected_cos) torch.testing.assert_close(freqs_sin, expected_sin) def test_rerotate_k(self): """Test re-rotation of k from one position to another.""" batch_size = 1 seq_len = 8 n_heads = self.params.n_heads head_dim = self.params.dim // n_heads k = torch.randn(batch_size, seq_len, n_heads, head_dim) q = torch.randn(batch_size, seq_len, n_heads, head_dim) # Rotate k at position 100 original_pos = 100 freqs_cos, freqs_sin = self.rope.get_freqs( torch.tensor([original_pos], dtype=torch.long), seq_len ) _, rotated_k = self.rope.forward(q, k, freqs_cos, freqs_sin) # Re-rotate to position 50 new_pos = 50 rerotated_k = self.rope.rerotate_k(rotated_k, original_pos, new_pos) # This should be equivalent to directly rotating k at position 50 freqs_cos_new, freqs_sin_new = self.rope.get_freqs( torch.tensor([new_pos], dtype=torch.long), seq_len ) _, expected_k = self.rope.forward(q, k, freqs_cos_new, freqs_sin_new) torch.testing.assert_close(rerotated_k, expected_k, rtol=1e-4, atol=1e-4)class CausalMaskWithWraparoundTest(unittest.TestCase): """Test causal mask with ring buffer wraparound.""" def test_mask_after_wraparound(self): """Test mask after cache has wrapped around.""" cache_size = 16 sink_size = 4 window_size = 6 # cache_size = sink_size + window_size * 2 # Simulate cache after generating beyond cache_size: # The ring buffer wraps, so indices 0-15 contain positions that wrap # At position 50, with cache_size=16, the cache contains: # positions 50-15=35 to 49 at various indices cache_positions = torch.zeros(cache_size, dtype=torch.long) # Fill with positions that would exist after generating 50 tokens # idx = pos % cache_size, so: # pos 34-49 occupy indices 2-15 and 0-1 for pos in range(34, 50): idx = pos % cache_size cache_positions[idx] = pos start_pos = 49 # Query at position 49 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Positions within window (49-6+1=44 to 49) should be visible visible_count = 0 for i in range(cache_size): pos = cache_positions[i].item() if pos >= 44 and pos <= 49: # In window self.assertEqual(mask[0, i].item(), 0.0, f"Position {pos} at idx {i} should be visible (in window)") visible_count += 1 # Should have some visible tokens self.assertGreater(visible_count, 0, "Should have visible tokens in window") def test_mask_with_sink_size_zero(self): """Test pure sliding window (sink_size=0).""" cache_size = 16 sink_size = 0 window_size = 8 cache_positions = torch.arange(cache_size, dtype=torch.long) start_pos = 10 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Positions 3-10 should be visible (within window of 8) for pos in range(3, 11): self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible") # Positions 0-2 should be masked (outside window) for pos in range(0, 3): self.assertEqual(mask[0, pos].item(), float('-inf'), f"Position {pos} should be masked (outside window)")class PrefillTest(unittest.TestCase): """Test prefill scenarios.""" def setUp(self): torch.manual_seed(42) self.window_size = 14 self.sink_size = 4 self.n_heads = 8 self.head_dim = 64 self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, n_heads=self.n_heads, n_kv_heads=self.n_heads, dim=self.n_heads * self.head_dim, ) self.rope = RopeWithAttentionSink( params=self.params, window_size=self.window_size, sink_size=self.sink_size, eviction_batch_size=1, ) self.kv_cache = KVCacheWithAttentionSink( n_heads=self.n_heads, head_dim=self.head_dim, enable_dynamic_shape=True, rope=self.rope, max_batch_size=1, window_size=self.window_size, sink_size=self.sink_size, eviction_batch_size=1, dtype=torch.float32, ) def test_prefill_entire_cache(self): """Test prefill that fills entire cache.""" cache_size = self.kv_cache.k_cache.size(2) k = torch.randn(1, self.n_heads, cache_size, self.head_dim) v = torch.randn(1, self.n_heads, cache_size, self.head_dim) input_pos = torch.tensor([0], dtype=torch.long) k_out, v_out = self.kv_cache.update(input_pos, k, v) # All positions should be filled torch.testing.assert_close(k_out, k) torch.testing.assert_close(v_out, v) def test_prefill_larger_than_cache_raises_error(self): """Test that prefill larger than cache size raises an assertion error.""" cache_size = self.kv_cache.k_cache.size(2) seq_len = cache_size + 10 k = torch.randn(1, self.n_heads, seq_len, self.head_dim) v = torch.randn(1, self.n_heads, seq_len, self.head_dim) input_pos = torch.tensor([0], dtype=torch.long) # This should raise an assertion error since seq_len > cache_size with self.assertRaises(AssertionError): self.kv_cache.update(input_pos, k, v) def test_prefill_followed_by_decode(self): """Test prefill followed by decode steps.""" cache_size = self.kv_cache.k_cache.size(2) # Prefill with 20 tokens k_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) v_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) input_pos = torch.tensor([0], dtype=torch.long) self.kv_cache.update(input_pos, k_prefill, v_prefill) # Decode 5 more tokens for i in range(5): k_decode = torch.randn(1, self.n_heads, 1, self.head_dim) v_decode = torch.randn(1, self.n_heads, 1, self.head_dim) input_pos = torch.tensor([20 + i], dtype=torch.long) k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode) # Verify cache positions are updated expected_pos = 20 + i cache_idx = expected_pos % cache_size self.assertEqual( self.kv_cache.cache_positions_manager.cache_positions[cache_idx].item(), expected_pos )class EnableAttentionSinkTest(unittest.TestCase): """Test the enable_attention_sink transformation.""" def setUp(self): torch.manual_seed(42) self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, n_heads=8, n_kv_heads=8, dim=512, n_layers=2, vocab_size=100, ) def test_enable_attention_sink_transforms_model(self): """Test that enable_attention_sink properly transforms the model.""" from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.source_transformation.attention_sink import ( enable_attention_sink, ) # Create a simple transformer with torch.device("meta"): model = construct_transformer(self.params) model.to_empty(device="cpu") # Apply attention sink transformation model = enable_attention_sink( module=model, params=self.params, sink_size=4, window_size=100, eviction_batch_size=1, ) # Check that KV caches are replaced for layer in model.layers: kv_cache = layer.attention.kv_cache self.assertIsInstance(kv_cache, KVCacheWithAttentionSink) self.assertEqual(kv_cache.sink_size, 4) self.assertEqual(kv_cache.window_size, 100) self.assertTrue(kv_cache.is_ring_buffer) def test_enable_attention_sink_replaces_rope(self): """Test that RoPE is replaced with RopeWithAttentionSink.""" from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.source_transformation.attention_sink import ( enable_attention_sink, ) with torch.device("meta"): model = construct_transformer(self.params) model.to_empty(device="cpu") model = enable_attention_sink( module=model, params=self.params, sink_size=4, window_size=100, eviction_batch_size=1, ) # Check that rope is replaced for layer in model.layers: rope = layer.attention.rope self.assertIsInstance(rope, RopeWithAttentionSink)class IntegrationTest(unittest.TestCase): """Integration tests for end-to-end scenarios.""" def setUp(self): torch.manual_seed(42) def test_cache_positions_consistency(self): """Test that cache positions remain consistent during generation.""" cache_size = 32 sink_size = 4 window_size = 14 n_heads = 8 head_dim = 64 params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, n_heads=n_heads, n_kv_heads=n_heads, dim=n_heads * head_dim, ) rope = RopeWithAttentionSink( params=params, window_size=window_size, sink_size=sink_size, eviction_batch_size=1, ) kv_cache = KVCacheWithAttentionSink( n_heads=n_heads, head_dim=head_dim, enable_dynamic_shape=True, rope=rope, max_batch_size=1, window_size=window_size, sink_size=sink_size, eviction_batch_size=1, dtype=torch.float32, ) # Generate 100 tokens for pos in range(100): k = torch.randn(1, n_heads, 1, head_dim) v = torch.randn(1, n_heads, 1, head_dim) input_pos = torch.tensor([pos], dtype=torch.long) kv_cache.update(input_pos, k, v) # Create mask and verify it's valid mask = kv_cache.create_causal_mask_for_ring_buffer(pos, 1) # Mask should not be all -inf (would mean no tokens to attend to) non_inf_count = (mask != float('-inf')).sum().item() self.assertGreater(non_inf_count, 0, f"At pos {pos}, mask should have visible tokens") # For positions >= sink_size, sinks should always be visible if pos >= sink_size: for i in range(sink_size): cache_pos = kv_cache.cache_positions_manager.cache_positions[i].item() if cache_pos < sink_size: # This is actually a sink token self.assertEqual(mask[0, i].item(), 0.0, f"Sink at idx {i} should be visible at pos {pos}")if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index 11d77f321bf..7924d24082f 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -1,310 +1 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated - */ - -// A simple llama2 runner that includes preprocessing and post processing logic. -// The module takes in a string as input and emits a string as output. - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace executorch::extension::llm { - -using ::executorch::extension::Module; -using ::executorch::runtime::Error; -using ::executorch::runtime::Result; - -TextLLMRunner::TextLLMRunner( - std::unordered_map metadata, - std::unique_ptr<::tokenizers::Tokenizer> tokenizer, - std::unique_ptr<::executorch::extension::Module> module, - std::unique_ptr text_decoder_runner, - std::unique_ptr text_prefiller, - std::unique_ptr io_manager, - std::unique_ptr text_token_generator, - std::unique_ptr stats, - float temperature) - : tokenizer_(std::move(tokenizer)), - metadata_(std::move(metadata)), - module_(std::move(module)), - text_decoder_runner_(std::move(text_decoder_runner)), - text_prefiller_(std::move(text_prefiller)), - io_manager_(std::move(io_manager)), - text_token_generator_(std::move(text_token_generator)), - stats_(std::move(stats)), - temperature_(temperature), - pos_(0) { - // Note: This constructor assumes that text_prefiller and text_token_generator - // already have references to the Module and TextDecoderRunner they need -} - -bool TextLLMRunner::is_loaded() const { - return text_prefiller_->is_loaded() && text_token_generator_->is_loaded(); -} - -Error TextLLMRunner::load() { - if (is_loaded()) { - return Error::Ok; - } - ET_CHECK_OK_OR_RETURN_ERROR(text_prefiller_->load()); - ET_CHECK_OK_OR_RETURN_ERROR(io_manager_->load()); - ET_CHECK_OK_OR_RETURN_ERROR(text_token_generator_->load()); - return Error::Ok; -} - -// Don't print with the same priority during warmup -#define RUNNER_ET_LOG(warmup, format, ...) \ - if (warmup) { \ - ET_LOG(Debug, format, __VA_ARGS__); \ - } else { \ - ET_LOG(Info, format, __VA_ARGS__); \ - } - -Error TextLLMRunner::generate( - const std::string& prompt, - const GenerationConfig& config, - std::function token_callback, - std::function stats_callback) { - // Prepare the inputs. - // Use ones-initialized inputs. - ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); - if (!is_loaded()) { - stats_->model_load_start_ms = time_in_ms(); - ET_CHECK_OK_OR_RETURN_ERROR(load()); - stats_->model_load_end_ms = time_in_ms(); - } - - if (config.warming) { - ET_LOG(Info, "Doing a warmup run..."); - } - - RUNNER_ET_LOG( - config.warming, - "RSS after loading model: %f MiB (0 if unsupported)", - get_rss_bytes() / 1024.0 / 1024.0); - - // Wrap the token_callback with print function - std::function wrapped_callback = - [token_callback, config](const std::string& piece) { - if (!config.warming) { - llm::safe_printf(piece.c_str()); - fflush(stdout); - } - if (token_callback) { - token_callback(piece); - } - }; - // First token time only measures the time it takes to encode the prompt and - // return a response token. - - stats_->inference_start_ms = time_in_ms(); - shouldStop_ = false; - - ::tokenizers::Result> encode_res = tokenizer_->encode( - prompt, - /*bos=*/config.num_bos, - /*eos=*/config.num_eos); - - if (!encode_res.ok()) { - ET_LOG( - Error, - "Failed to encode prompt %s. Tokenizers error code %d", - prompt.c_str(), - static_cast(encode_res.error())); - return Error::InvalidArgument; - } - - // encode the (string) prompt into tokens sequence - std::vector prompt_tokens = encode_res.get(); - int num_prompt_tokens = prompt_tokens.size(); - - // Get max_seq_len for single prefill chunk limit - int64_t max_seq_len = metadata_.at(kMaxSeqLen); - int64_t max_context_len = metadata_.at(kMaxContextLen); - - ET_CHECK_OR_RETURN_ERROR( - num_prompt_tokens >= 1, - InvalidArgument, - "Expected at least 1 prompt token"); - - // For models with sliding window (Ring Buffer / Attention Sink), - // we allow pos_ to exceed max_context_len. The model handles this - // internally via ring buffer indexing or token eviction. - // We only check that a single prefill chunk doesn't exceed max_seq_len. - ET_CHECK_OR_RETURN_ERROR( - num_prompt_tokens <= max_seq_len, - InvalidArgument, - "num_prompt_tokens %d > max_seq_len %" PRId64 - ", Single prefill chunk too large - please reduce prompt size or increase max_seq_len", - num_prompt_tokens, - max_seq_len); - - // Determine max_new_tokens using the GenerationConfig's resolve method. - // For sliding window models, we use max_context_len directly (not reduced by pos_) - // because the model handles position wrapping internally via ring buffer. - int max_new_tokens = - config.resolve_max_new_tokens(max_context_len, num_prompt_tokens); - - // TEMPORARY: For sliding window / infinite context testing, - // override max_new_tokens to allow unlimited generation - if (max_new_tokens <= 0) { - max_new_tokens = 1000000; // Effectively unlimited - } - // If user specified seq_len, use that instead - if (config.seq_len > 0 && config.seq_len > max_new_tokens) { - max_new_tokens = config.seq_len; - } - - ET_LOG( - Info, - "Max new tokens resolved: %d, given pos_ %" PRId64 - ", num_prompt_tokens %zu, max_context_len %" PRId64, - max_new_tokens, - pos_, - prompt_tokens.size(), - max_context_len); - ET_CHECK_OR_RETURN_ERROR( - max_new_tokens > 0, - InvalidArgument, - "Max new tokens %d is less than or equal to 0", - max_new_tokens); - // Prefill first - // Here feed all tokens to the model and get the next predicted token - // after the prompt. After that we will enter generate loop. - - // print prompts - if (config.echo) { - wrapped_callback(prompt); - } - auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); - ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); - uint64_t cur_token = prefill_res.get(); - stats_->first_token_ms = time_in_ms(); - stats_->prompt_eval_end_ms = time_in_ms(); - - // print the first token from prefill. No prev_token so use cur_token for it. - auto decode_result = tokenizer_->decode(cur_token, cur_token); - if (!decode_result.ok()) { - ET_LOG( - Error, - "Tokenizers error code %d", - static_cast(decode_result.error())); - return ::executorch::runtime::Error::InvalidArgument; - } - wrapped_callback(std::move(*decode_result)); - RUNNER_ET_LOG( - config.warming, - "RSS after prompt prefill: %f MiB (0 if unsupported)", - get_rss_bytes() / 1024.0 / 1024.0); - - // start the main loop - prompt_tokens.push_back(cur_token); - - // Set ignore_eos based on config - text_token_generator_->set_ignore_eos(config.ignore_eos); - - // Generate max_new_tokens - 1 because prefill already generated 1 token. - auto generate_result = text_token_generator_->generate( - prompt_tokens, - pos_, - max_new_tokens - 1, - temperature_ == -1.0f ? config.temperature : temperature_, - wrapped_callback); - if (!generate_result.ok()) { - return generate_result.error(); - } - int64_t num_generated_tokens = generate_result.get(); - - pos_ += num_generated_tokens; - - stats_->inference_end_ms = time_in_ms(); - if (!config.warming) { - printf("\n"); - } - RUNNER_ET_LOG( - config.warming, - "RSS after finishing text generation: %f MiB (0 if unsupported)", - get_rss_bytes() / 1024.0 / 1024.0); - - if (num_generated_tokens == max_new_tokens) { - RUNNER_ET_LOG(config.warming, "Max new tokens %i reached!", max_new_tokens); - } - - stats_->num_prompt_tokens = num_prompt_tokens; - stats_->num_generated_tokens = num_generated_tokens; - - if (config.warming) { - ET_LOG(Info, "Warmup run finished!"); - } else { - // Do not print report during warmup - print_report(*stats_); - } - if (stats_callback) { - stats_callback(*stats_); - } - - return Error::Ok; -} - -Error TextLLMRunner::prefill( - const std::string& prompt, - const GenerationConfig& config) { - if (!is_loaded()) { - ET_CHECK_OK_OR_RETURN_ERROR(load()); - } - - ::tokenizers::Result> encode_res = tokenizer_->encode( - prompt, - /*bos=*/config.num_bos, - /*eos=*/config.num_eos); - - ET_CHECK_TK_OK_OR_RETURN_ERROR( - encode_res.error(), "Failed to encode prompt %s", prompt.c_str()); - - // encode the (string) prompt into tokens sequence - std::vector prompt_tokens = encode_res.get(); - auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); - ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); - return Error::Ok; -} - -Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) { - // Create a GenerationConfig for warmup - GenerationConfig config; - config.echo = false; - config.max_new_tokens = max_new_tokens; - config.warming = true; - - // Call generate with the warmup config - Error err = generate(prompt, config); - - // Reset stats after warmup, not resetting the std::unique_ptr! - reset(); - return err; -} - -void TextLLMRunner::stop() { - if (is_loaded()) { - text_token_generator_->stop(); - } else { - ET_LOG(Error, "Token generator is not loaded, cannot stop"); - } -} - -void TextLLMRunner::reset() { - stats_->reset(); - pos_ = 0; -} - -} // namespace executorch::extension::llm +/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated */// A simple llama2 runner that includes preprocessing and post processing logic.// The module takes in a string as input and emits a string as output.#include #include #include #include #include #include #include #include namespace executorch::extension::llm {using ::executorch::extension::Module;using ::executorch::runtime::Error;using ::executorch::runtime::Result;TextLLMRunner::TextLLMRunner( std::unordered_map metadata, std::unique_ptr<::tokenizers::Tokenizer> tokenizer, std::unique_ptr<::executorch::extension::Module> module, std::unique_ptr text_decoder_runner, std::unique_ptr text_prefiller, std::unique_ptr io_manager, std::unique_ptr text_token_generator, std::unique_ptr stats, float temperature) : tokenizer_(std::move(tokenizer)), metadata_(std::move(metadata)), module_(std::move(module)), text_decoder_runner_(std::move(text_decoder_runner)), text_prefiller_(std::move(text_prefiller)), io_manager_(std::move(io_manager)), text_token_generator_(std::move(text_token_generator)), stats_(std::move(stats)), temperature_(temperature), pos_(0) { // Note: This constructor assumes that text_prefiller and text_token_generator // already have references to the Module and TextDecoderRunner they need}bool TextLLMRunner::is_loaded() const { return text_prefiller_->is_loaded() && text_token_generator_->is_loaded();}Error TextLLMRunner::load() { if (is_loaded()) { return Error::Ok; } ET_CHECK_OK_OR_RETURN_ERROR(text_prefiller_->load()); ET_CHECK_OK_OR_RETURN_ERROR(io_manager_->load()); ET_CHECK_OK_OR_RETURN_ERROR(text_token_generator_->load()); return Error::Ok;}// Don't print with the same priority during warmup#define RUNNER_ET_LOG(warmup, format, ...) \ if (warmup) { \ ET_LOG(Debug, format, __VA_ARGS__); \ } else { \ ET_LOG(Info, format, __VA_ARGS__); \ }Error TextLLMRunner::generate( const std::string& prompt, const GenerationConfig& config, std::function token_callback, std::function stats_callback) { // Prepare the inputs. // Use ones-initialized inputs. ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); if (!is_loaded()) { stats_->model_load_start_ms = time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(load()); stats_->model_load_end_ms = time_in_ms(); } if (config.warming) { ET_LOG(Info, "Doing a warmup run..."); } RUNNER_ET_LOG( config.warming, "RSS after loading model: %f MiB (0 if unsupported)", get_rss_bytes() / 1024.0 / 1024.0); // Wrap the token_callback with print function std::function wrapped_callback = [token_callback, config](const std::string& piece) { if (!config.warming) { llm::safe_printf(piece.c_str()); fflush(stdout); } if (token_callback) { token_callback(piece); } }; // First token time only measures the time it takes to encode the prompt and // return a response token. stats_->inference_start_ms = time_in_ms(); shouldStop_ = false; ::tokenizers::Result> encode_res = tokenizer_->encode( prompt, /*bos=*/config.num_bos, /*eos=*/config.num_eos); if (!encode_res.ok()) { ET_LOG( Error, "Failed to encode prompt %s. Tokenizers error code %d", prompt.c_str(), static_cast(encode_res.error())); return Error::InvalidArgument; } // encode the (string) prompt into tokens sequence std::vector prompt_tokens = encode_res.get(); int num_prompt_tokens = prompt_tokens.size(); // Get max_seq_len for single prefill chunk limit int64_t max_seq_len = metadata_.at(kMaxSeqLen); int64_t max_context_len = metadata_.at(kMaxContextLen); ET_CHECK_OR_RETURN_ERROR( num_prompt_tokens >= 1, InvalidArgument, "Expected at least 1 prompt token"); // For models with sliding window (Ring Buffer / Attention Sink), // we allow pos_ to exceed max_context_len. The model handles this // internally via ring buffer indexing or token eviction. // We only check that a single prefill chunk doesn't exceed max_seq_len. ET_CHECK_OR_RETURN_ERROR( num_prompt_tokens <= max_seq_len, InvalidArgument, "num_prompt_tokens %d > max_seq_len %" PRId64 ", Single prefill chunk too large - please reduce prompt size or increase max_seq_len", num_prompt_tokens, max_seq_len); // Determine max_new_tokens using the GenerationConfig's resolve method. // For sliding window models, we use max_context_len directly (not reduced by pos_) // because the model handles position wrapping internally via ring buffer. int max_new_tokens = config.resolve_max_new_tokens(max_context_len, num_prompt_tokens); // TEMPORARY: For sliding window / infinite context testing, // override max_new_tokens to allow unlimited generation if (max_new_tokens <= 0) { max_new_tokens = 1000000; // Effectively unlimited } // If user specified seq_len, use that instead if (config.seq_len > 0 && config.seq_len > max_new_tokens) { max_new_tokens = config.seq_len; } ET_LOG( Info, "Max new tokens resolved: %d, given pos_ %" PRId64 ", num_prompt_tokens %zu, max_context_len %" PRId64, max_new_tokens, pos_, prompt_tokens.size(), max_context_len); ET_CHECK_OR_RETURN_ERROR( max_new_tokens > 0, InvalidArgument, "Max new tokens %d is less than or equal to 0", max_new_tokens); // Prefill first // Here feed all tokens to the model and get the next predicted token // after the prompt. After that we will enter generate loop. // print prompts if (config.echo) { wrapped_callback(prompt); } auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); uint64_t cur_token = prefill_res.get(); stats_->first_token_ms = time_in_ms(); stats_->prompt_eval_end_ms = time_in_ms(); // print the first token from prefill. No prev_token so use cur_token for it. auto decode_result = tokenizer_->decode(cur_token, cur_token); if (!decode_result.ok()) { ET_LOG( Error, "Tokenizers error code %d", static_cast(decode_result.error())); return ::executorch::runtime::Error::InvalidArgument; } wrapped_callback(std::move(*decode_result)); RUNNER_ET_LOG( config.warming, "RSS after prompt prefill: %f MiB (0 if unsupported)", get_rss_bytes() / 1024.0 / 1024.0); // start the main loop prompt_tokens.push_back(cur_token); // Set ignore_eos based on config text_token_generator_->set_ignore_eos(config.ignore_eos); // Generate max_new_tokens - 1 because prefill already generated 1 token. auto generate_result = text_token_generator_->generate( prompt_tokens, pos_, max_new_tokens - 1, temperature_ == -1.0f ? config.temperature : temperature_, wrapped_callback); if (!generate_result.ok()) { return generate_result.error(); } int64_t num_generated_tokens = generate_result.get(); pos_ += num_generated_tokens; stats_->inference_end_ms = time_in_ms(); if (!config.warming) { printf("\n"); } RUNNER_ET_LOG( config.warming, "RSS after finishing text generation: %f MiB (0 if unsupported)", get_rss_bytes() / 1024.0 / 1024.0); if (num_generated_tokens == max_new_tokens) { RUNNER_ET_LOG(config.warming, "Max new tokens %i reached!", max_new_tokens); } stats_->num_prompt_tokens = num_prompt_tokens; stats_->num_generated_tokens = num_generated_tokens; if (config.warming) { ET_LOG(Info, "Warmup run finished!"); } else { // Do not print report during warmup print_report(*stats_); } if (stats_callback) { stats_callback(*stats_); } return Error::Ok;}Error TextLLMRunner::prefill( const std::string& prompt, const GenerationConfig& config) { if (!is_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(load()); } ::tokenizers::Result> encode_res = tokenizer_->encode( prompt, /*bos=*/config.num_bos, /*eos=*/config.num_eos); ET_CHECK_TK_OK_OR_RETURN_ERROR( encode_res.error(), "Failed to encode prompt %s", prompt.c_str()); // encode the (string) prompt into tokens sequence std::vector prompt_tokens = encode_res.get(); auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); return Error::Ok;}Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) { // Create a GenerationConfig for warmup GenerationConfig config; config.echo = false; config.max_new_tokens = max_new_tokens; config.warming = true; // Call generate with the warmup config Error err = generate(prompt, config); // Reset stats after warmup, not resetting the std::unique_ptr! reset(); return err;}void TextLLMRunner::stop() { if (is_loaded()) { text_token_generator_->stop(); } else { ET_LOG(Error, "Token generator is not loaded, cannot stop"); }}void TextLLMRunner::reset() { stats_->reset(); pos_ = 0;}} // namespace executorch::extension::llm \ No newline at end of file From fe66e7443b0778c4b5bff3a531601262a796b6b7 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Feb 2026 14:11:02 -0800 Subject: [PATCH 06/11] Revert "Remove trailing whitespace" This reverts commit 075d521b2960945aedd4f4a908b3f69107c5f18d. --- examples/models/llama/eval_llama_lib.py | 404 +++++++++++- .../source_transformation/attention_sink.py | 411 +++++++++++- .../test_attention_sink_ring_buffer.py | 595 +++++++++++++++++- extension/llm/runner/text_llm_runner.cpp | 311 ++++++++- 4 files changed, 1717 insertions(+), 4 deletions(-) diff --git a/examples/models/llama/eval_llama_lib.py b/examples/models/llama/eval_llama_lib.py index d8e107f74ef..e13a5299e61 100644 --- a/examples/models/llama/eval_llama_lib.py +++ b/examples/models/llama/eval_llama_lib.py @@ -1 +1,403 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.import argparsefrom typing import Optional, Unionimport torchfrom datasets import load_datasetfrom executorch.examples.models.llama.export_llama_lib import ( get_quantizer_and_quant_params,)from executorch.extension.llm.export.builder import LLMEdgeManagerfrom lm_eval.evaluator import simple_evaluatefrom pytorch_tokenizers import get_tokenizerfrom pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizerfrom pytorch_tokenizers.tiktoken import TiktokenTokenizer as Tiktokenfrom torch.nn import CrossEntropyLossfrom tqdm import tqdmfrom .evaluate.eager_eval import EagerEvalWrapperfrom .export_llama_lib import ( _prepare_for_llama_export, build_args_parser as _build_args_parser,)class GraphModuleEvalWrapper(EagerEvalWrapper): """ A wrapper class for ExecuTorch py-binded integration with the lm-evaluation-harness library. """ def __init__( self, model: torch.fx.GraphModule, tokenizer: Union[SentencePieceTokenizer, Tiktoken], max_seq_length: Optional[int] = None, use_kv_cache: bool = False, generate_full_logits: bool = False, enable_dynamic_shape: bool = True, ): super().__init__( model=model, tokenizer=tokenizer, max_seq_length=max_seq_length ) self._model = model.to(self.device) self._use_kv_cache = use_kv_cache self._generate_full_logits = generate_full_logits self._enable_dynamic_shape = enable_dynamic_shape def _model_call(self, inps): if self._use_kv_cache: if not self._enable_dynamic_shape: # graph module exported without dynamic shape won't work with a different shape. # And we have to do single token prefill here. result_logits = [] for pos in range(inps.shape[-1]): pos_tensor = torch.tensor([pos], dtype=torch.int64) logits = self._model( inps[:, pos : pos + 1], {"input_pos": pos_tensor} ) result_logits.append(logits) if self._generate_full_logits: return torch.cat(result_logits, dim=1) else: return torch.stack(result_logits, dim=1) else: pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) # Batch process the whole sequence. logits = self._model( inps[:, : self._max_seq_length], {"input_pos": pos_tensor} ) return logits else: return self._model(inps) def _model_generate(self, context, max_length, eos_token_id): raise Exception("unimplemented")class ETPybindEvalWrapper(EagerEvalWrapper): """ A wrapper class for ExecuTorch py-binded integration with the lm-evaluation-harness library. """ def __init__( self, model: str, tokenizer: Union[SentencePieceTokenizer, Tiktoken], max_seq_length: Optional[int] = None, ): super().__init__(None, tokenizer, max_seq_length) # pyre-ignore self._model = model # Expects model to be path to a .pte file from executorch.extension.pybindings.portable_lib import _load_for_executorch # Load custom ops and quantized ops. from executorch.extension.pybindings import portable_lib # noqa # usort: skip # Note: import this after portable_lib from executorch.extension.llm.custom_ops import ( # noqa custom_ops, # usort: skip ) from executorch.kernels import quantized # noqa self._et_model = _load_for_executorch(self._model) self._use_kv_cache = self._et_model.run_method("use_kv_cache")[0] # pyre-ignore def _model_call(self, inps): # Given inps (tokens), return the logits from a single forward call # inps: Tensor of shape (1, max_seq_len - 1) # logits: Tensor of shape (1, max_seq_len - 1, vocab_size) result = [] if self._use_kv_cache: pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) result = self._et_model.forward( (inps[:, : self._max_seq_length], pos_tensor) ) else: result = self._et_model.forward((inps,)) if result[0].dim() != 3: raise ValueError( f"Dim of logits must be 3 for evaluation. Got {result[0].dim()} here. Add --generate_full_logits in export_llama to generate a pte file with full logits." ) return result[0]class ETRunnerEvalWrapper(EagerEvalWrapper): """ A wrapper class for ExecuTorch Runtime integration with the lm-evaluation-harness library. """ def __init__( self, model: str, tokenizer: Union[SentencePieceTokenizer, Tiktoken], tokenizer_bin: str, max_seq_length: Optional[int] = None, ): super().__init__(None, tokenizer, max_seq_length) # pyre-ignore self._model = model self._tokenizer_bin = tokenizer_bin def _model_call(self, inps): # Given inps (tokens), return the logits from a single # forward call # Example: # inps: Tensor of shape (1, N) # logits: Tensor of shape (1, N, vocab_size) passdef gen_eval_wrapper( model_name: str, args: argparse.ArgumentParser, llm_config=None,): """ Generates a wrapper interface around the provided model and tokenizer for the lm-evaluation-harness library. Returns: eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library. """ # If llm_config is not provided, convert args to llm_config if llm_config is None: from executorch.extension.llm.export.config.llm_config import LlmConfig llm_config = LlmConfig.from_args(args) tokenizer = get_tokenizer(llm_config.base.tokenizer_path) # ExecuTorch Binary Evaluation if (model := args.pte) is not None: # pyre-ignore if (tokenizer_bin := args.tokenizer_bin) is not None: # pyre-ignore # ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime return ETRunnerEvalWrapper( model=model, tokenizer=tokenizer, tokenizer_bin=tokenizer_bin, max_seq_length=llm_config.export.max_seq_length, ) # ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings return ETPybindEvalWrapper( model=model, tokenizer=tokenizer, # Exported model takes at most (max_seq_length - 1) tokens. # Note that the eager model takes at most max_seq_length tokens. max_seq_length=llm_config.export.max_seq_length - 1, ) pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( llm_config ) # GPTFastEvalWrapper: Create a wrapper around a pre-exported model manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) if len(quantizers) != 0: manager = manager.export().pt2e_quantize(quantizers) model = ( manager.pre_autograd_graph_module.to(device="cuda") # pyre-ignore if torch.cuda.is_available() else manager.pre_autograd_graph_module.to(device="cpu") ) return GraphModuleEvalWrapper( model=model, tokenizer=tokenizer, max_seq_length=llm_config.export.max_seq_length, use_kv_cache=llm_config.model.use_kv_cache, enable_dynamic_shape=llm_config.model.enable_dynamic_shape, ) else: # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch # for quantizers. Currently export only works with --kv_cache, but # fails without the kv_cache mode model = ( manager.model.eval().to(device="cuda") if torch.cuda.is_available() else manager.model.eval().to(device="cpu") ) # Save the checkpoint after the eager model preparation is done. # The reason for this option is that the checkpoint can be used # to do evaluations in other evaluation platforms, or with data # that is not available in this eval_llama. We save the checkpoint # here for consistency with eval_llama. The accuracy results we # get from eval_llama can be used as a reference to other evaluations. if args.output_eager_checkpoint_file is not None: # pyre-ignore torch.save(model, args.output_eager_checkpoint_file) return EagerEvalWrapper( model=model, tokenizer=tokenizer, max_seq_length=llm_config.export.max_seq_length, use_kv_cache=llm_config.model.use_kv_cache, )def build_args_parser() -> argparse.ArgumentParser: # Start with arg parser from export_llama_lib parser = _build_args_parser() # Add additional args specific to eval parser.add_argument( "--tasks", nargs="+", type=str, default=["wikitext"], help="list of lm-eluther tasks to evaluate usage: --tasks task1 task2", ) parser.add_argument( "--limit", type=int, default=None, help="number of samples to evalulate. If not set, evaluate all samples", ) parser.add_argument( "-f", "--num_fewshot", type=int, default=None, metavar="N", help="Number of examples in few-shot context", ) # Add additional args specific to eval via an ET Runner # Note: For initial integration, the tokenizer.model is also required parser.add_argument( "--pte", type=str, default=None, help="[For ExecuTorch] Path to the ExecuTorch model being evaluated. If provided, don't go through the export flow", ) parser.add_argument( "--tokenizer_bin", type=str, default=None, help="[For ExecuTorch] Path to the Tokenizer binary for evaluating ExecuTorch models via runtime", ) parser.add_argument( "--output_eager_checkpoint_file", type=str, default=None, help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.", ) # Set of parameters secpific to AttentionSink. parser.add_argument("--attention_sink_eval_tokens", type=int, default=0) return parserdef eval_llama( model_name: str, args: argparse.ArgumentParser,) -> None: # Convert args to LlmConfig from executorch.extension.llm.export.config.llm_config import LlmConfig llm_config = LlmConfig.from_args(args) # Generate the eval wrapper eval_wrapper = gen_eval_wrapper(model_name, args, llm_config) # Needed for loading mmlu dataset. # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files if args.tasks and "mmlu" in args.tasks: import datasets datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True # Evaluate the model with torch.no_grad(): eval_results = simple_evaluate( model=eval_wrapper, tasks=args.tasks, num_fewshot=args.num_fewshot, limit=args.limit, ) for task, res in eval_results["results"].items(): print(f"{task}: {res}")def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParser): """ Evaluate the model's perplexity when AttentionSink is enabled. This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py Updated for the ring-buffer based attention sink implementation. """ # Convert args to LlmConfig from executorch.extension.llm.export.config.llm_config import LlmConfig llm_config = LlmConfig.from_args(args) assert llm_config.model.use_attention_sink is not None assert args.attention_sink_eval_tokens > 0 attention_sink_params = llm_config.model.use_attention_sink.split(",") assert len(attention_sink_params) == 3 sink_size = int(attention_sink_params[0]) window_size = int(attention_sink_params[1]) # For the ring buffer implementation, the cache size is sink_size + window_size * 2 # max_context_length should be >= sink_size + window_size (for RoPE frequencies) # but can be larger to support extended generation assert llm_config.export.max_context_length >= sink_size + window_size, ( f"max_context_length ({llm_config.export.max_context_length}) must be >= " f"sink_size + window_size ({sink_size + window_size})" ) device = "cuda" if torch.cuda.is_available() else "cpu" manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) model = manager.model.eval().to(device=device) tokenizer = get_tokenizer(llm_config.base.tokenizer_path) eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") nlls = [] loss_fn = CrossEntropyLoss(reduction="none") progress_bar = tqdm(total=args.attention_sink_eval_tokens) input_pos = 0 while input_pos < args.attention_sink_eval_tokens: for text in eval_data["text"]: tokens = tokenizer.encode(text, bos=False, eos=False) if len(tokens) <= 0: continue with torch.no_grad(): num_tokens = min( len(tokens) - 1, args.attention_sink_eval_tokens - input_pos ) logits = model( torch.tensor( [tokens[:num_tokens]], dtype=torch.int64, device=device ), torch.tensor([input_pos], dtype=torch.int64, device=device), ).squeeze(dim=0) neg_log_likelihood = loss_fn( logits, torch.tensor( [tokens[1 : num_tokens + 1]], dtype=torch.int64, device=device ).view(-1), ) nlls.append(neg_log_likelihood) input_pos += num_tokens progress_bar.update(num_tokens) if input_pos >= args.attention_sink_eval_tokens: break ppl = torch.exp(torch.cat(nlls).mean()) print(f"Perplexity: {ppl.item()}") return ppl.item() \ No newline at end of file +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse + +from typing import Optional, Union + +import torch + +from datasets import load_dataset +from executorch.examples.models.llama.export_llama_lib import ( + get_quantizer_and_quant_params, +) + +from executorch.extension.llm.export.builder import LLMEdgeManager +from lm_eval.evaluator import simple_evaluate +from pytorch_tokenizers import get_tokenizer +from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer +from pytorch_tokenizers.tiktoken import TiktokenTokenizer as Tiktoken +from torch.nn import CrossEntropyLoss +from tqdm import tqdm + +from .evaluate.eager_eval import EagerEvalWrapper + +from .export_llama_lib import ( + _prepare_for_llama_export, + build_args_parser as _build_args_parser, +) + + +class GraphModuleEvalWrapper(EagerEvalWrapper): + """ + A wrapper class for ExecuTorch py-binded integration with the + lm-evaluation-harness library. + """ + + def __init__( + self, + model: torch.fx.GraphModule, + tokenizer: Union[SentencePieceTokenizer, Tiktoken], + max_seq_length: Optional[int] = None, + use_kv_cache: bool = False, + generate_full_logits: bool = False, + enable_dynamic_shape: bool = True, + ): + super().__init__( + model=model, tokenizer=tokenizer, max_seq_length=max_seq_length + ) + self._model = model.to(self.device) + self._use_kv_cache = use_kv_cache + self._generate_full_logits = generate_full_logits + self._enable_dynamic_shape = enable_dynamic_shape + + def _model_call(self, inps): + if self._use_kv_cache: + if not self._enable_dynamic_shape: + # graph module exported without dynamic shape won't work with a different shape. + # And we have to do single token prefill here. + result_logits = [] + for pos in range(inps.shape[-1]): + pos_tensor = torch.tensor([pos], dtype=torch.int64) + logits = self._model( + inps[:, pos : pos + 1], {"input_pos": pos_tensor} + ) + result_logits.append(logits) + if self._generate_full_logits: + return torch.cat(result_logits, dim=1) + else: + return torch.stack(result_logits, dim=1) + else: + pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) + # Batch process the whole sequence. + logits = self._model( + inps[:, : self._max_seq_length], {"input_pos": pos_tensor} + ) + return logits + + else: + return self._model(inps) + + def _model_generate(self, context, max_length, eos_token_id): + raise Exception("unimplemented") + + +class ETPybindEvalWrapper(EagerEvalWrapper): + """ + A wrapper class for ExecuTorch py-binded integration with the + lm-evaluation-harness library. + """ + + def __init__( + self, + model: str, + tokenizer: Union[SentencePieceTokenizer, Tiktoken], + max_seq_length: Optional[int] = None, + ): + super().__init__(None, tokenizer, max_seq_length) # pyre-ignore + self._model = model # Expects model to be path to a .pte file + + from executorch.extension.pybindings.portable_lib import _load_for_executorch + + # Load custom ops and quantized ops. + from executorch.extension.pybindings import portable_lib # noqa # usort: skip + + # Note: import this after portable_lib + from executorch.extension.llm.custom_ops import ( # noqa + custom_ops, # usort: skip + ) + from executorch.kernels import quantized # noqa + + self._et_model = _load_for_executorch(self._model) + self._use_kv_cache = self._et_model.run_method("use_kv_cache")[0] # pyre-ignore + + def _model_call(self, inps): + # Given inps (tokens), return the logits from a single forward call + # inps: Tensor of shape (1, max_seq_len - 1) + # logits: Tensor of shape (1, max_seq_len - 1, vocab_size) + result = [] + if self._use_kv_cache: + pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) + result = self._et_model.forward( + (inps[:, : self._max_seq_length], pos_tensor) + ) + else: + result = self._et_model.forward((inps,)) + if result[0].dim() != 3: + raise ValueError( + f"Dim of logits must be 3 for evaluation. Got {result[0].dim()} here. Add --generate_full_logits in export_llama to generate a pte file with full logits." + ) + return result[0] + + +class ETRunnerEvalWrapper(EagerEvalWrapper): + """ + A wrapper class for ExecuTorch Runtime integration with the + lm-evaluation-harness library. + """ + + def __init__( + self, + model: str, + tokenizer: Union[SentencePieceTokenizer, Tiktoken], + tokenizer_bin: str, + max_seq_length: Optional[int] = None, + ): + super().__init__(None, tokenizer, max_seq_length) # pyre-ignore + self._model = model + self._tokenizer_bin = tokenizer_bin + + def _model_call(self, inps): + # Given inps (tokens), return the logits from a single + # forward call + + # Example: + # inps: Tensor of shape (1, N) + # logits: Tensor of shape (1, N, vocab_size) + pass + + +def gen_eval_wrapper( + model_name: str, + args: argparse.ArgumentParser, + llm_config=None, +): + """ + Generates a wrapper interface around the provided model and tokenizer for + the lm-evaluation-harness library. + + Returns: + eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library. + """ + # If llm_config is not provided, convert args to llm_config + if llm_config is None: + from executorch.extension.llm.export.config.llm_config import LlmConfig + + llm_config = LlmConfig.from_args(args) + + tokenizer = get_tokenizer(llm_config.base.tokenizer_path) + + # ExecuTorch Binary Evaluation + if (model := args.pte) is not None: # pyre-ignore + if (tokenizer_bin := args.tokenizer_bin) is not None: # pyre-ignore + # ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime + return ETRunnerEvalWrapper( + model=model, + tokenizer=tokenizer, + tokenizer_bin=tokenizer_bin, + max_seq_length=llm_config.export.max_seq_length, + ) + + # ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings + return ETPybindEvalWrapper( + model=model, + tokenizer=tokenizer, + # Exported model takes at most (max_seq_length - 1) tokens. + # Note that the eager model takes at most max_seq_length tokens. + max_seq_length=llm_config.export.max_seq_length - 1, + ) + + pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( + llm_config + ) + # GPTFastEvalWrapper: Create a wrapper around a pre-exported model + manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) + + if len(quantizers) != 0: + manager = manager.export().pt2e_quantize(quantizers) + model = ( + manager.pre_autograd_graph_module.to(device="cuda") # pyre-ignore + if torch.cuda.is_available() + else manager.pre_autograd_graph_module.to(device="cpu") + ) + return GraphModuleEvalWrapper( + model=model, + tokenizer=tokenizer, + max_seq_length=llm_config.export.max_seq_length, + use_kv_cache=llm_config.model.use_kv_cache, + enable_dynamic_shape=llm_config.model.enable_dynamic_shape, + ) + else: + # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch + # for quantizers. Currently export only works with --kv_cache, but + # fails without the kv_cache mode + model = ( + manager.model.eval().to(device="cuda") + if torch.cuda.is_available() + else manager.model.eval().to(device="cpu") + ) + + # Save the checkpoint after the eager model preparation is done. + # The reason for this option is that the checkpoint can be used + # to do evaluations in other evaluation platforms, or with data + # that is not available in this eval_llama. We save the checkpoint + # here for consistency with eval_llama. The accuracy results we + # get from eval_llama can be used as a reference to other evaluations. + if args.output_eager_checkpoint_file is not None: # pyre-ignore + torch.save(model, args.output_eager_checkpoint_file) + + return EagerEvalWrapper( + model=model, + tokenizer=tokenizer, + max_seq_length=llm_config.export.max_seq_length, + use_kv_cache=llm_config.model.use_kv_cache, + ) + + +def build_args_parser() -> argparse.ArgumentParser: + # Start with arg parser from export_llama_lib + parser = _build_args_parser() + + # Add additional args specific to eval + parser.add_argument( + "--tasks", + nargs="+", + type=str, + default=["wikitext"], + help="list of lm-eluther tasks to evaluate usage: --tasks task1 task2", + ) + parser.add_argument( + "--limit", + type=int, + default=None, + help="number of samples to evalulate. If not set, evaluate all samples", + ) + parser.add_argument( + "-f", + "--num_fewshot", + type=int, + default=None, + metavar="N", + help="Number of examples in few-shot context", + ) + # Add additional args specific to eval via an ET Runner + # Note: For initial integration, the tokenizer.model is also required + parser.add_argument( + "--pte", + type=str, + default=None, + help="[For ExecuTorch] Path to the ExecuTorch model being evaluated. If provided, don't go through the export flow", + ) + parser.add_argument( + "--tokenizer_bin", + type=str, + default=None, + help="[For ExecuTorch] Path to the Tokenizer binary for evaluating ExecuTorch models via runtime", + ) + parser.add_argument( + "--output_eager_checkpoint_file", + type=str, + default=None, + help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.", + ) + + # Set of parameters secpific to AttentionSink. + parser.add_argument("--attention_sink_eval_tokens", type=int, default=0) + + return parser + + +def eval_llama( + model_name: str, + args: argparse.ArgumentParser, +) -> None: + # Convert args to LlmConfig + from executorch.extension.llm.export.config.llm_config import LlmConfig + + llm_config = LlmConfig.from_args(args) + + # Generate the eval wrapper + eval_wrapper = gen_eval_wrapper(model_name, args, llm_config) + + # Needed for loading mmlu dataset. + # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files + if args.tasks and "mmlu" in args.tasks: + import datasets + + datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True + + # Evaluate the model + with torch.no_grad(): + eval_results = simple_evaluate( + model=eval_wrapper, + tasks=args.tasks, + num_fewshot=args.num_fewshot, + limit=args.limit, + ) + + for task, res in eval_results["results"].items(): + print(f"{task}: {res}") + + +def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParser): + """ + Evaluate the model's perplexity when AttentionSink is enabled. + + This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py + + Updated for the ring-buffer based attention sink implementation. + """ + # Convert args to LlmConfig + from executorch.extension.llm.export.config.llm_config import LlmConfig + + llm_config = LlmConfig.from_args(args) + + assert llm_config.model.use_attention_sink is not None + assert args.attention_sink_eval_tokens > 0 + attention_sink_params = llm_config.model.use_attention_sink.split(",") + assert len(attention_sink_params) == 3 + sink_size = int(attention_sink_params[0]) + window_size = int(attention_sink_params[1]) + + # For the ring buffer implementation, the cache size is sink_size + window_size * 2 + # max_context_length should be >= sink_size + window_size (for RoPE frequencies) + # but can be larger to support extended generation + assert llm_config.export.max_context_length >= sink_size + window_size, ( + f"max_context_length ({llm_config.export.max_context_length}) must be >= " + f"sink_size + window_size ({sink_size + window_size})" + ) + + device = "cuda" if torch.cuda.is_available() else "cpu" + manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) + model = manager.model.eval().to(device=device) + tokenizer = get_tokenizer(llm_config.base.tokenizer_path) + + eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + + nlls = [] + loss_fn = CrossEntropyLoss(reduction="none") + progress_bar = tqdm(total=args.attention_sink_eval_tokens) + input_pos = 0 + while input_pos < args.attention_sink_eval_tokens: + for text in eval_data["text"]: + tokens = tokenizer.encode(text, bos=False, eos=False) + if len(tokens) <= 0: + continue + with torch.no_grad(): + num_tokens = min( + len(tokens) - 1, args.attention_sink_eval_tokens - input_pos + ) + logits = model( + torch.tensor( + [tokens[:num_tokens]], dtype=torch.int64, device=device + ), + torch.tensor([input_pos], dtype=torch.int64, device=device), + ).squeeze(dim=0) + neg_log_likelihood = loss_fn( + logits, + torch.tensor( + [tokens[1 : num_tokens + 1]], dtype=torch.int64, device=device + ).view(-1), + ) + nlls.append(neg_log_likelihood) + input_pos += num_tokens + progress_bar.update(num_tokens) + if input_pos >= args.attention_sink_eval_tokens: + break + ppl = torch.exp(torch.cat(nlls).mean()) + print(f"Perplexity: {ppl.item()}") + return ppl.item() diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index 99b0231f2c3..c2e0f6606f5 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -1 +1,410 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.# Components for supporting Attention Sink. See# https://arxiv.org/abs/2309.17453 for more details about Attention Sink.# This implementation is torch.export compatible using a ring buffer approach# for the sliding window portion while preserving the sink tokens.import typesfrom typing import Optional, Tupleimport torchimport torch.nn as nnfrom executorch.examples.models.llama.attention import ( _create_causal_mask_for_ring_buffer, AttentionMHA, CachePositionsManager, KVCache, RingKVCache,)from executorch.examples.models.llama.model_args import ModelArgsfrom executorch.examples.models.llama.rope import ( apply_rotary_emb_to_k, hf_apply_rotary_emb_to_k, Rope,)from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filterclass RopeWithAttentionSink(Rope): """ Rope that helps adjust position encoding when tokens are shifted in KVCache. For AttentionSink, when tokens are shifted in KVCache, we need to use positions in KVCache instead of positions in the actual text. For torch.export compatibility, this just passes through the position - the actual position adjustment is handled by the cache update logic. Note: This class uses the model's max_context_len (params.max_context_len) for RoPE frequency table size, which should be large enough to support generation beyond the sliding window. The actual KV cache size is sink_size + window_size * 2. """ def __init__( self, params: ModelArgs, window_size: int, sink_size: int, eviction_batch_size: int, ): super().__init__(params) if self.params.use_hf_rope: self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k else: self.apply_rotary_emb_to_k = apply_rotary_emb_to_k # The KV cache size is sink_size + window_size * 2 (ring buffer needs 2x) self.kv_cache_size = sink_size + window_size * 2 self.window_size = window_size self.sink_size = sink_size # max_context_len from params is used for RoPE frequencies (should be large) self.max_context_length = self.params.max_context_len self.eviction_batch_size = eviction_batch_size def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): """ Get rotary embedding frequencies. For attention sink, we use the original position - the sliding window is handled by the cache index management, not by position shifting. """ assert input_pos is not None return super().get_freqs(input_pos, seq_len) def rerotate_k( self, k: torch.Tensor, original_position: int, new_position: int, ): """ Rerotate k from original_position to new_position. The shape of k is (batch_size, seq_len, n_local_heads, head_dim) """ seq_len = k.shape[1] original_freqs_cos = self.freqs_cos.narrow(0, original_position, seq_len) original_freqs_sin = self.freqs_sin.narrow(0, original_position, seq_len) new_freqs_cos = self.freqs_cos.narrow(0, new_position, seq_len) new_freqs_sin = self.freqs_sin.narrow(0, new_position, seq_len) rerotation_cos = ( new_freqs_cos * original_freqs_cos + new_freqs_sin * original_freqs_sin ) rerotation_sin = ( new_freqs_sin * original_freqs_cos - new_freqs_cos * original_freqs_sin ) return self.apply_rotary_emb_to_k(k, rerotation_cos, rerotation_sin)def _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len): """ Create causal mask for attention sink. Unlike regular ring buffer mask, this mask: 1. ALWAYS allows attending to sink tokens (positions 0 to sink_size-1) 2. Uses sliding window for other tokens Args: cache_positions: Tensor of actual positions stored at each cache index window_size: Size of the sliding window sink_size: Number of sink tokens to always attend to start_pos: Starting position of the current query seq_len: Length of the current query sequence """ pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1) delta = pos_q - cache_positions # Valid if position is filled (>= 0) and causal (delta >= 0) is_valid = (cache_positions >= 0) & (delta >= 0) # Sink tokens (original positions 0 to sink_size-1) are always visible is_sink = cache_positions < sink_size # Window tokens must be within sliding window is_in_window = delta < window_size # Final mask: valid AND (is_sink OR is_in_window) attn_mask = is_valid & (is_sink | is_in_window) attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712 return attn_maskclass CachePositionsManagerWithSink(nn.Module): """ Manages cache positions for attention sink + sliding window. For sink_size=0: behaves exactly like original CachePositionsManager. For sink_size>0: sink tokens go to fixed positions, rest uses ring buffer. IMPORTANT: cache_size should be the actual cache dimension size (2x window for ring buffer). """ def __init__(self, cache_size: int): super().__init__() # cache_size is the actual size of the kv cache dimension self.max_context_length = cache_size # Use zeros like original CachePositionsManager self.register_buffer( "cache_positions", torch.zeros((self.max_context_length,), dtype=torch.long, device="cpu"), ) def calculate_positions_and_update_indices( self, input_pos: torch.Tensor, seq_len: int ) -> torch.Tensor: """ Calculate indices into k_cache, v_cache for placing k_val, v_val. This is identical to the original CachePositionsManager logic. """ start_pos = input_pos[0].item() torch._check_is_size(start_pos) orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos # Simple ring buffer: just mod by cache size indices = orig_indices % self.max_context_length # Update cache_positions exactly like original CachePositionsManager full_t = torch.full((self.max_context_length,), -1, dtype=torch.long) arange_tensor = torch.arange(self.max_context_length, dtype=torch.long) cache_positions = torch.where( arange_tensor < start_pos, self.cache_positions, full_t ) self.cache_positions.copy_(cache_positions) self.cache_positions.index_copy_(0, indices, orig_indices) return indicesclass KVCacheWithAttentionSink(KVCache): """ KV cache that supports attention sink with torch.export compatibility. Uses a ring buffer approach for the sliding window portion while keeping the first sink_size tokens fixed. This avoids dynamic shape operations. Cache layout: [sink: 0 to sink_size-1] [ring_buffer: sink_size to sink_size + window_size*2 - 1] """ def __init__( self, n_heads: int, head_dim: int, enable_dynamic_shape: bool, rope: RopeWithAttentionSink, window_size: int, sink_size: int, eviction_batch_size: int, max_batch_size: int = 1, dtype=torch.float32, ): # Total cache size is sink_size + window_size * 2 (ring buffer needs 2x) total_cache_size = sink_size + window_size * 2 super().__init__( max_batch_size=max_batch_size, max_context_length=total_cache_size, n_heads=n_heads, head_dim=head_dim, enable_dynamic_shape=enable_dynamic_shape, dtype=dtype, ) self.rope = rope self.window_size = window_size self.sink_size = sink_size self.eviction_batch_size = eviction_batch_size self.is_ring_buffer = True # Cache positions manager for determining write locations # Pass the total cache size (same as self.max_context_length after super().__init__) self.cache_positions_manager = CachePositionsManagerWithSink(total_cache_size) def create_causal_mask_for_ring_buffer( self, start_pos: torch.Tensor, seq_len: int ): """ Create causal mask for the attention with attention sink. Sink tokens are ALWAYS visible, plus recent tokens in the window. """ cache_positions = self.cache_positions_manager.cache_positions if self.sink_size > 0: # Use attention sink mask that always allows attending to sink tokens return _create_causal_mask_for_attention_sink( cache_positions, self.window_size, self.sink_size, start_pos, seq_len ) else: # Pure ring buffer mode - use original mask with window_size = actual window return _create_causal_mask_for_ring_buffer( cache_positions, self.window_size, start_pos, seq_len ) def update( self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Update KV cache with new key-value pairs. Uses ring buffer indexing for positions >= sink_size. """ seq_len = k_val.size(2) assert seq_len <= self.k_cache.size( 2 ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" # Calculate write indices indices = self.cache_positions_manager.calculate_positions_and_update_indices( input_pos, seq_len ) start_pos = input_pos[0].item() torch._check_is_size(start_pos) self.k_cache.index_copy_(2, indices, k_val) self.v_cache.index_copy_(2, indices, v_val) return self.k_cache, self.v_cache def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: """ For ring buffer implementation, no explicit eviction is needed. The ring buffer automatically overwrites old values. Returns 0 to indicate no position shift is needed. """ return 0def attention_sink_forward( self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, input_pos: Optional[torch.Tensor] = None,): """ Forward function for attention with attention sink KV cache. Uses ring buffer masking for proper attention patterns. """ assert self.use_kv_cache assert input_pos is not None bsz, seqlen, _ = x.shape # QKV q, k, v = self.wq(x), self.wk(x), self.wv(x) q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) # RoPE relative positional embeddings q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) # Transpose for attention: [B, H, S, D] q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) # Update KV cache k, v = self.kv_cache.update(input_pos, k, v) # Use ring buffer mask since we have is_ring_buffer=True start_pos = input_pos[0].item() torch._check_is_size(start_pos) attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(start_pos, seqlen) # SDPA output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) # Return tuple like original AttentionMHA.forward return self.wo(output), Nonedef _replace_rope( module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink): def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: return isinstance(child, Rope) def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: return rope_with_attention_sink _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)def _replace_attention( module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink, sink_size: int, window_size: int, eviction_batch_size: int,): for _, child_module in module._modules.items(): if len(list(child_module.children())) > 0: # pyre-ignore [16] _replace_attention( module=child_module, # pyre-ignore [6] rope_with_attention_sink=rope_with_attention_sink, sink_size=sink_size, window_size=window_size, eviction_batch_size=eviction_batch_size, ) if isinstance(child_module, AttentionMHA): kv_cache = child_module.kv_cache # Always use KVCacheWithAttentionSink, even for sink_size=0 # This ensures we don't get replaced by CustomKVCache when use_sdpa_with_kv_cache=True kv_cache_with_attention_sink = KVCacheWithAttentionSink( n_heads=kv_cache.n_heads, head_dim=kv_cache.head_dim, enable_dynamic_shape=kv_cache.enable_dynamic_shape, rope=rope_with_attention_sink, max_batch_size=kv_cache.max_batch_size, window_size=window_size, sink_size=sink_size, eviction_batch_size=eviction_batch_size, dtype=kv_cache.k_cache.dtype, ) child_module.kv_cache = kv_cache_with_attention_sink # If using SDPACustom (fused SDPA op), enable attention mask support # so it uses our ring buffer / attention sink mask instead of simple causal mask if "SDPACustom" in child_module.SDPA.__class__.__name__: child_module.SDPA.use_attention_mask = True # Don't replace forward - let the original AttentionMHA.forward handle it # since our KVCache has is_ring_buffer=True, it will use the ring buffer maskdef enable_attention_sink( module: torch.nn.Module, params: ModelArgs, sink_size: int, window_size: int, eviction_batch_size: int,) -> torch.nn.Module: """ Transform the model to be able to run inference with Attention Sink. There mainly three steps: - Replace Rope with RopeWithAttentionSink - Replace Attention's KVCache with KVCacheWithAttentionSink - Replace Attention's forward with attention_sink_forward """ rope_with_attention_sink = RopeWithAttentionSink( params=params, window_size=window_size, sink_size=sink_size, eviction_batch_size=eviction_batch_size, ) _replace_rope(module, rope_with_attention_sink) _replace_attention( module=module, rope_with_attention_sink=rope_with_attention_sink, sink_size=sink_size, window_size=window_size, eviction_batch_size=eviction_batch_size, ) return module \ No newline at end of file +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Components for supporting Attention Sink. See +# https://arxiv.org/abs/2309.17453 for more details about Attention Sink. + +# This implementation is torch.export compatible using a ring buffer approach +# for the sliding window portion while preserving the sink tokens. + +import types +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from executorch.examples.models.llama.attention import ( + _create_causal_mask_for_ring_buffer, + AttentionMHA, + CachePositionsManager, + KVCache, + RingKVCache, +) +from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.rope import ( + apply_rotary_emb_to_k, + hf_apply_rotary_emb_to_k, + Rope, +) +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter + + +class RopeWithAttentionSink(Rope): + """ + Rope that helps adjust position encoding when tokens are shifted in KVCache. + For AttentionSink, when tokens are shifted in KVCache, we need to use positions + in KVCache instead of positions in the actual text. + + For torch.export compatibility, this just passes through the position - the + actual position adjustment is handled by the cache update logic. + + Note: This class uses the model's max_context_len (params.max_context_len) for + RoPE frequency table size, which should be large enough to support generation + beyond the sliding window. The actual KV cache size is sink_size + window_size * 2. + """ + + def __init__( + self, + params: ModelArgs, + window_size: int, + sink_size: int, + eviction_batch_size: int, + ): + super().__init__(params) + if self.params.use_hf_rope: + self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k + else: + self.apply_rotary_emb_to_k = apply_rotary_emb_to_k + # The KV cache size is sink_size + window_size * 2 (ring buffer needs 2x) + self.kv_cache_size = sink_size + window_size * 2 + self.window_size = window_size + self.sink_size = sink_size + # max_context_len from params is used for RoPE frequencies (should be large) + self.max_context_length = self.params.max_context_len + self.eviction_batch_size = eviction_batch_size + + def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): + """ + Get rotary embedding frequencies. + For attention sink, we use the original position - the sliding window + is handled by the cache index management, not by position shifting. + """ + assert input_pos is not None + return super().get_freqs(input_pos, seq_len) + + def rerotate_k( + self, + k: torch.Tensor, + original_position: int, + new_position: int, + ): + """ + Rerotate k from original_position to new_position. + The shape of k is (batch_size, seq_len, n_local_heads, head_dim) + """ + seq_len = k.shape[1] + original_freqs_cos = self.freqs_cos.narrow(0, original_position, seq_len) + original_freqs_sin = self.freqs_sin.narrow(0, original_position, seq_len) + new_freqs_cos = self.freqs_cos.narrow(0, new_position, seq_len) + new_freqs_sin = self.freqs_sin.narrow(0, new_position, seq_len) + rerotation_cos = ( + new_freqs_cos * original_freqs_cos + new_freqs_sin * original_freqs_sin + ) + rerotation_sin = ( + new_freqs_sin * original_freqs_cos - new_freqs_cos * original_freqs_sin + ) + + return self.apply_rotary_emb_to_k(k, rerotation_cos, rerotation_sin) + + +def _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len +): + """ + Create causal mask for attention sink. + + Unlike regular ring buffer mask, this mask: + 1. ALWAYS allows attending to sink tokens (positions 0 to sink_size-1) + 2. Uses sliding window for other tokens + + Args: + cache_positions: Tensor of actual positions stored at each cache index + window_size: Size of the sliding window + sink_size: Number of sink tokens to always attend to + start_pos: Starting position of the current query + seq_len: Length of the current query sequence + """ + pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1) + delta = pos_q - cache_positions + + # Valid if position is filled (>= 0) and causal (delta >= 0) + is_valid = (cache_positions >= 0) & (delta >= 0) + + # Sink tokens (original positions 0 to sink_size-1) are always visible + is_sink = cache_positions < sink_size + + # Window tokens must be within sliding window + is_in_window = delta < window_size + + # Final mask: valid AND (is_sink OR is_in_window) + attn_mask = is_valid & (is_sink | is_in_window) + attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712 + return attn_mask + + +class CachePositionsManagerWithSink(nn.Module): + """ + Manages cache positions for attention sink + sliding window. + + For sink_size=0: behaves exactly like original CachePositionsManager. + For sink_size>0: sink tokens go to fixed positions, rest uses ring buffer. + + IMPORTANT: cache_size should be the actual cache dimension size (2x window for ring buffer). + """ + + def __init__(self, cache_size: int): + super().__init__() + # cache_size is the actual size of the kv cache dimension + self.max_context_length = cache_size + # Use zeros like original CachePositionsManager + self.register_buffer( + "cache_positions", + torch.zeros((self.max_context_length,), dtype=torch.long, device="cpu"), + ) + + def calculate_positions_and_update_indices( + self, input_pos: torch.Tensor, seq_len: int + ) -> torch.Tensor: + """ + Calculate indices into k_cache, v_cache for placing k_val, v_val. + + This is identical to the original CachePositionsManager logic. + """ + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + + orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos + + # Simple ring buffer: just mod by cache size + indices = orig_indices % self.max_context_length + + # Update cache_positions exactly like original CachePositionsManager + full_t = torch.full((self.max_context_length,), -1, dtype=torch.long) + arange_tensor = torch.arange(self.max_context_length, dtype=torch.long) + cache_positions = torch.where( + arange_tensor < start_pos, self.cache_positions, full_t + ) + self.cache_positions.copy_(cache_positions) + self.cache_positions.index_copy_(0, indices, orig_indices) + + return indices + + +class KVCacheWithAttentionSink(KVCache): + """ + KV cache that supports attention sink with torch.export compatibility. + + Uses a ring buffer approach for the sliding window portion while keeping + the first sink_size tokens fixed. This avoids dynamic shape operations. + + Cache layout: [sink: 0 to sink_size-1] [ring_buffer: sink_size to sink_size + window_size*2 - 1] + """ + + def __init__( + self, + n_heads: int, + head_dim: int, + enable_dynamic_shape: bool, + rope: RopeWithAttentionSink, + window_size: int, + sink_size: int, + eviction_batch_size: int, + max_batch_size: int = 1, + dtype=torch.float32, + ): + # Total cache size is sink_size + window_size * 2 (ring buffer needs 2x) + total_cache_size = sink_size + window_size * 2 + super().__init__( + max_batch_size=max_batch_size, + max_context_length=total_cache_size, + n_heads=n_heads, + head_dim=head_dim, + enable_dynamic_shape=enable_dynamic_shape, + dtype=dtype, + ) + self.rope = rope + self.window_size = window_size + self.sink_size = sink_size + self.eviction_batch_size = eviction_batch_size + self.is_ring_buffer = True + + # Cache positions manager for determining write locations + # Pass the total cache size (same as self.max_context_length after super().__init__) + self.cache_positions_manager = CachePositionsManagerWithSink(total_cache_size) + + def create_causal_mask_for_ring_buffer( + self, start_pos: torch.Tensor, seq_len: int + ): + """ + Create causal mask for the attention with attention sink. + Sink tokens are ALWAYS visible, plus recent tokens in the window. + """ + cache_positions = self.cache_positions_manager.cache_positions + if self.sink_size > 0: + # Use attention sink mask that always allows attending to sink tokens + return _create_causal_mask_for_attention_sink( + cache_positions, self.window_size, self.sink_size, start_pos, seq_len + ) + else: + # Pure ring buffer mode - use original mask with window_size = actual window + return _create_causal_mask_for_ring_buffer( + cache_positions, self.window_size, start_pos, seq_len + ) + + def update( + self, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Update KV cache with new key-value pairs. + Uses ring buffer indexing for positions >= sink_size. + """ + seq_len = k_val.size(2) + assert seq_len <= self.k_cache.size( + 2 + ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" + + # Calculate write indices + indices = self.cache_positions_manager.calculate_positions_and_update_indices( + input_pos, seq_len + ) + + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + self.k_cache.index_copy_(2, indices, k_val) + self.v_cache.index_copy_(2, indices, v_val) + + return self.k_cache, self.v_cache + + def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: + """ + For ring buffer implementation, no explicit eviction is needed. + The ring buffer automatically overwrites old values. + Returns 0 to indicate no position shift is needed. + """ + return 0 + + +def attention_sink_forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, +): + """ + Forward function for attention with attention sink KV cache. + Uses ring buffer masking for proper attention patterns. + """ + assert self.use_kv_cache + assert input_pos is not None + + bsz, seqlen, _ = x.shape + + # QKV + q, k, v = self.wq(x), self.wk(x), self.wv(x) + q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + # RoPE relative positional embeddings + q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) + + # Transpose for attention: [B, H, S, D] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Update KV cache + k, v = self.kv_cache.update(input_pos, k, v) + + # Use ring buffer mask since we have is_ring_buffer=True + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(start_pos, seqlen) + + # SDPA + output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) + + # Return tuple like original AttentionMHA.forward + return self.wo(output), None + + +def _replace_rope( + module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink +): + def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: + return isinstance(child, Rope) + + def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: + return rope_with_attention_sink + + _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) + + +def _replace_attention( + module: torch.nn.Module, + rope_with_attention_sink: RopeWithAttentionSink, + sink_size: int, + window_size: int, + eviction_batch_size: int, +): + for _, child_module in module._modules.items(): + if len(list(child_module.children())) > 0: # pyre-ignore [16] + _replace_attention( + module=child_module, # pyre-ignore [6] + rope_with_attention_sink=rope_with_attention_sink, + sink_size=sink_size, + window_size=window_size, + eviction_batch_size=eviction_batch_size, + ) + + if isinstance(child_module, AttentionMHA): + kv_cache = child_module.kv_cache + # Always use KVCacheWithAttentionSink, even for sink_size=0 + # This ensures we don't get replaced by CustomKVCache when use_sdpa_with_kv_cache=True + kv_cache_with_attention_sink = KVCacheWithAttentionSink( + n_heads=kv_cache.n_heads, + head_dim=kv_cache.head_dim, + enable_dynamic_shape=kv_cache.enable_dynamic_shape, + rope=rope_with_attention_sink, + max_batch_size=kv_cache.max_batch_size, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=eviction_batch_size, + dtype=kv_cache.k_cache.dtype, + ) + child_module.kv_cache = kv_cache_with_attention_sink + + # If using SDPACustom (fused SDPA op), enable attention mask support + # so it uses our ring buffer / attention sink mask instead of simple causal mask + if "SDPACustom" in child_module.SDPA.__class__.__name__: + child_module.SDPA.use_attention_mask = True + + # Don't replace forward - let the original AttentionMHA.forward handle it + # since our KVCache has is_ring_buffer=True, it will use the ring buffer mask + + +def enable_attention_sink( + module: torch.nn.Module, + params: ModelArgs, + sink_size: int, + window_size: int, + eviction_batch_size: int, +) -> torch.nn.Module: + """ + Transform the model to be able to run inference with Attention Sink. + There mainly three steps: + - Replace Rope with RopeWithAttentionSink + - Replace Attention's KVCache with KVCacheWithAttentionSink + - Replace Attention's forward with attention_sink_forward + """ + rope_with_attention_sink = RopeWithAttentionSink( + params=params, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=eviction_batch_size, + ) + _replace_rope(module, rope_with_attention_sink) + _replace_attention( + module=module, + rope_with_attention_sink=rope_with_attention_sink, + sink_size=sink_size, + window_size=window_size, + eviction_batch_size=eviction_batch_size, + ) + return module diff --git a/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py index 7ccad6aadbf..7109129caca 100644 --- a/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py +++ b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py @@ -1 +1,594 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree."""Tests for the ring-buffer based attention sink implementation.This tests the torch.export-compatible implementation that uses a ring bufferfor the sliding window rather than explicit token eviction.Usage: # Run with pytest python -m pytest examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py -v # Or run directly python examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py"""import unittestimport torchfrom executorch.examples.models.llama.model_args import ModelArgsfrom executorch.examples.models.llama.source_transformation.attention_sink import ( CachePositionsManagerWithSink, KVCacheWithAttentionSink, RopeWithAttentionSink, _create_causal_mask_for_attention_sink,)class CachePositionsManagerWithSinkTest(unittest.TestCase): """Test the cache positions manager for ring buffer indexing.""" def setUp(self): self.cache_size = 32 # Total cache size (e.g., sink_size + window_size * 2) self.manager = CachePositionsManagerWithSink(self.cache_size) def test_initial_positions_are_zero(self): """Cache positions should start as zeros.""" expected = torch.zeros(self.cache_size, dtype=torch.long) torch.testing.assert_close(self.manager.cache_positions, expected) def test_simple_update(self): """Test simple sequential update without wraparound.""" input_pos = torch.tensor([0], dtype=torch.long) seq_len = 5 indices = self.manager.calculate_positions_and_update_indices(input_pos, seq_len) # Should return indices 0, 1, 2, 3, 4 expected_indices = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long) torch.testing.assert_close(indices, expected_indices) # Cache positions at those indices should be the original positions for i in range(5): self.assertEqual(self.manager.cache_positions[i].item(), i) def test_wraparound(self): """Test ring buffer wraparound.""" # Fill cache to position 30 input_pos = torch.tensor([0], dtype=torch.long) self.manager.calculate_positions_and_update_indices(input_pos, 30) # Add 5 more tokens at position 30 - should wrap around input_pos = torch.tensor([30], dtype=torch.long) indices = self.manager.calculate_positions_and_update_indices(input_pos, 5) # Indices should wrap: 30 % 32 = 30, 31, 0, 1, 2 expected_indices = torch.tensor([30, 31, 0, 1, 2], dtype=torch.long) torch.testing.assert_close(indices, expected_indices) def test_cache_positions_track_original_positions(self): """Cache positions should track which original position is at each index.""" # Fill with positions 0-31 input_pos = torch.tensor([0], dtype=torch.long) self.manager.calculate_positions_and_update_indices(input_pos, 32) # Now add position 32 which wraps to index 0 input_pos = torch.tensor([32], dtype=torch.long) self.manager.calculate_positions_and_update_indices(input_pos, 1) # Index 0 should now contain original position 32 self.assertEqual(self.manager.cache_positions[0].item(), 32)class CausalMaskTest(unittest.TestCase): """Test the causal mask generation for attention sink.""" def test_mask_allows_sink_tokens(self): """Sink tokens should always be visible (mask = 0).""" cache_size = 32 sink_size = 4 window_size = 14 # cache_size = sink_size + window_size * 2 # Create cache positions where positions 0-3 are sink tokens cache_positions = torch.arange(cache_size, dtype=torch.long) start_pos = 20 # Query at position 20 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Sink tokens (indices 0-3, original positions 0-3) should have mask = 0 for i in range(sink_size): self.assertEqual(mask[0, i].item(), 0.0, f"Sink token at index {i} should be visible") def test_mask_blocks_future_tokens(self): """Future tokens should be masked (-inf).""" cache_size = 32 sink_size = 4 window_size = 14 cache_positions = torch.arange(cache_size, dtype=torch.long) start_pos = 10 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Future tokens (positions > 10) should have mask = -inf for i in range(11, cache_size): self.assertEqual(mask[0, i].item(), float('-inf'), f"Future token at position {i} should be masked") def test_mask_respects_window(self): """Tokens outside the window should be masked.""" cache_size = 32 sink_size = 4 window_size = 5 # Only allow 5 recent tokens cache_positions = torch.arange(cache_size, dtype=torch.long) start_pos = 20 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Positions 16-20 should be visible (within window of 5) for pos in range(16, 21): self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible (in window)") # Position 15 should be masked (outside window, not a sink) self.assertEqual(mask[0, 15].item(), float('-inf'), f"Position 15 should be masked (outside window)")class KVCacheWithAttentionSinkTest(unittest.TestCase): """Test the KV cache with attention sink.""" def setUp(self): torch.manual_seed(42) self.window_size = 14 self.sink_size = 4 self.n_heads = 8 self.head_dim = 64 self.max_batch_size = 1 # Create model args with enough context for RoPE self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, # Large enough for RoPE n_heads=self.n_heads, n_kv_heads=self.n_heads, dim=self.n_heads * self.head_dim, ) self.rope = RopeWithAttentionSink( params=self.params, window_size=self.window_size, sink_size=self.sink_size, eviction_batch_size=1, ) self.kv_cache = KVCacheWithAttentionSink( n_heads=self.n_heads, head_dim=self.head_dim, enable_dynamic_shape=True, rope=self.rope, max_batch_size=self.max_batch_size, window_size=self.window_size, sink_size=self.sink_size, eviction_batch_size=1, dtype=torch.float32, ) def test_cache_size(self): """Cache should be sink_size + window_size * 2.""" expected_size = self.sink_size + self.window_size * 2 # 4 + 28 = 32 self.assertEqual(self.kv_cache.k_cache.size(2), expected_size) self.assertEqual(self.kv_cache.v_cache.size(2), expected_size) def test_is_ring_buffer(self): """Cache should be marked as ring buffer.""" self.assertTrue(self.kv_cache.is_ring_buffer) def test_update_stores_kv(self): """Update should store key-value pairs.""" k = torch.randn(1, self.n_heads, 5, self.head_dim) v = torch.randn(1, self.n_heads, 5, self.head_dim) input_pos = torch.tensor([0], dtype=torch.long) k_out, v_out = self.kv_cache.update(input_pos, k, v) # First 5 positions should contain our values torch.testing.assert_close(k_out[:, :, :5, :], k) torch.testing.assert_close(v_out[:, :, :5, :], v) def test_evict_tokens_returns_zero(self): """Ring buffer implementation doesn't shift, so evict returns 0.""" input_pos = torch.tensor([100], dtype=torch.long) shift = self.kv_cache.evict_tokens(input_pos, 10) self.assertEqual(shift, 0) def test_extended_generation(self): """Test that cache works for positions beyond cache size.""" cache_size = self.kv_cache.k_cache.size(2) # Fill cache with initial tokens for pos in range(cache_size + 50): k = torch.randn(1, self.n_heads, 1, self.head_dim) v = torch.randn(1, self.n_heads, 1, self.head_dim) input_pos = torch.tensor([pos], dtype=torch.long) k_out, v_out = self.kv_cache.update(input_pos, k, v) # Should not raise any errors self.assertEqual(k_out.shape, self.kv_cache.k_cache.shape) self.assertEqual(v_out.shape, self.kv_cache.v_cache.shape)class RopeWithAttentionSinkTest(unittest.TestCase): """Test RoPE for attention sink.""" def setUp(self): torch.manual_seed(42) self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, n_heads=8, dim=512, ) self.rope = RopeWithAttentionSink( params=self.params, window_size=100, sink_size=4, eviction_batch_size=1, ) def test_get_freqs_uses_original_position(self): """RoPE frequencies should use the original position.""" input_pos = torch.tensor([50], dtype=torch.long) seq_len = 5 freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seq_len) # Should get frequencies for positions 50-54 expected_cos = self.rope.freqs_cos[50:55] expected_sin = self.rope.freqs_sin[50:55] torch.testing.assert_close(freqs_cos, expected_cos) torch.testing.assert_close(freqs_sin, expected_sin) def test_rerotate_k(self): """Test re-rotation of k from one position to another.""" batch_size = 1 seq_len = 8 n_heads = self.params.n_heads head_dim = self.params.dim // n_heads k = torch.randn(batch_size, seq_len, n_heads, head_dim) q = torch.randn(batch_size, seq_len, n_heads, head_dim) # Rotate k at position 100 original_pos = 100 freqs_cos, freqs_sin = self.rope.get_freqs( torch.tensor([original_pos], dtype=torch.long), seq_len ) _, rotated_k = self.rope.forward(q, k, freqs_cos, freqs_sin) # Re-rotate to position 50 new_pos = 50 rerotated_k = self.rope.rerotate_k(rotated_k, original_pos, new_pos) # This should be equivalent to directly rotating k at position 50 freqs_cos_new, freqs_sin_new = self.rope.get_freqs( torch.tensor([new_pos], dtype=torch.long), seq_len ) _, expected_k = self.rope.forward(q, k, freqs_cos_new, freqs_sin_new) torch.testing.assert_close(rerotated_k, expected_k, rtol=1e-4, atol=1e-4)class CausalMaskWithWraparoundTest(unittest.TestCase): """Test causal mask with ring buffer wraparound.""" def test_mask_after_wraparound(self): """Test mask after cache has wrapped around.""" cache_size = 16 sink_size = 4 window_size = 6 # cache_size = sink_size + window_size * 2 # Simulate cache after generating beyond cache_size: # The ring buffer wraps, so indices 0-15 contain positions that wrap # At position 50, with cache_size=16, the cache contains: # positions 50-15=35 to 49 at various indices cache_positions = torch.zeros(cache_size, dtype=torch.long) # Fill with positions that would exist after generating 50 tokens # idx = pos % cache_size, so: # pos 34-49 occupy indices 2-15 and 0-1 for pos in range(34, 50): idx = pos % cache_size cache_positions[idx] = pos start_pos = 49 # Query at position 49 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Positions within window (49-6+1=44 to 49) should be visible visible_count = 0 for i in range(cache_size): pos = cache_positions[i].item() if pos >= 44 and pos <= 49: # In window self.assertEqual(mask[0, i].item(), 0.0, f"Position {pos} at idx {i} should be visible (in window)") visible_count += 1 # Should have some visible tokens self.assertGreater(visible_count, 0, "Should have visible tokens in window") def test_mask_with_sink_size_zero(self): """Test pure sliding window (sink_size=0).""" cache_size = 16 sink_size = 0 window_size = 8 cache_positions = torch.arange(cache_size, dtype=torch.long) start_pos = 10 seq_len = 1 mask = _create_causal_mask_for_attention_sink( cache_positions, window_size, sink_size, start_pos, seq_len ) # Positions 3-10 should be visible (within window of 8) for pos in range(3, 11): self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible") # Positions 0-2 should be masked (outside window) for pos in range(0, 3): self.assertEqual(mask[0, pos].item(), float('-inf'), f"Position {pos} should be masked (outside window)")class PrefillTest(unittest.TestCase): """Test prefill scenarios.""" def setUp(self): torch.manual_seed(42) self.window_size = 14 self.sink_size = 4 self.n_heads = 8 self.head_dim = 64 self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, n_heads=self.n_heads, n_kv_heads=self.n_heads, dim=self.n_heads * self.head_dim, ) self.rope = RopeWithAttentionSink( params=self.params, window_size=self.window_size, sink_size=self.sink_size, eviction_batch_size=1, ) self.kv_cache = KVCacheWithAttentionSink( n_heads=self.n_heads, head_dim=self.head_dim, enable_dynamic_shape=True, rope=self.rope, max_batch_size=1, window_size=self.window_size, sink_size=self.sink_size, eviction_batch_size=1, dtype=torch.float32, ) def test_prefill_entire_cache(self): """Test prefill that fills entire cache.""" cache_size = self.kv_cache.k_cache.size(2) k = torch.randn(1, self.n_heads, cache_size, self.head_dim) v = torch.randn(1, self.n_heads, cache_size, self.head_dim) input_pos = torch.tensor([0], dtype=torch.long) k_out, v_out = self.kv_cache.update(input_pos, k, v) # All positions should be filled torch.testing.assert_close(k_out, k) torch.testing.assert_close(v_out, v) def test_prefill_larger_than_cache_raises_error(self): """Test that prefill larger than cache size raises an assertion error.""" cache_size = self.kv_cache.k_cache.size(2) seq_len = cache_size + 10 k = torch.randn(1, self.n_heads, seq_len, self.head_dim) v = torch.randn(1, self.n_heads, seq_len, self.head_dim) input_pos = torch.tensor([0], dtype=torch.long) # This should raise an assertion error since seq_len > cache_size with self.assertRaises(AssertionError): self.kv_cache.update(input_pos, k, v) def test_prefill_followed_by_decode(self): """Test prefill followed by decode steps.""" cache_size = self.kv_cache.k_cache.size(2) # Prefill with 20 tokens k_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) v_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) input_pos = torch.tensor([0], dtype=torch.long) self.kv_cache.update(input_pos, k_prefill, v_prefill) # Decode 5 more tokens for i in range(5): k_decode = torch.randn(1, self.n_heads, 1, self.head_dim) v_decode = torch.randn(1, self.n_heads, 1, self.head_dim) input_pos = torch.tensor([20 + i], dtype=torch.long) k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode) # Verify cache positions are updated expected_pos = 20 + i cache_idx = expected_pos % cache_size self.assertEqual( self.kv_cache.cache_positions_manager.cache_positions[cache_idx].item(), expected_pos )class EnableAttentionSinkTest(unittest.TestCase): """Test the enable_attention_sink transformation.""" def setUp(self): torch.manual_seed(42) self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, n_heads=8, n_kv_heads=8, dim=512, n_layers=2, vocab_size=100, ) def test_enable_attention_sink_transforms_model(self): """Test that enable_attention_sink properly transforms the model.""" from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.source_transformation.attention_sink import ( enable_attention_sink, ) # Create a simple transformer with torch.device("meta"): model = construct_transformer(self.params) model.to_empty(device="cpu") # Apply attention sink transformation model = enable_attention_sink( module=model, params=self.params, sink_size=4, window_size=100, eviction_batch_size=1, ) # Check that KV caches are replaced for layer in model.layers: kv_cache = layer.attention.kv_cache self.assertIsInstance(kv_cache, KVCacheWithAttentionSink) self.assertEqual(kv_cache.sink_size, 4) self.assertEqual(kv_cache.window_size, 100) self.assertTrue(kv_cache.is_ring_buffer) def test_enable_attention_sink_replaces_rope(self): """Test that RoPE is replaced with RopeWithAttentionSink.""" from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.source_transformation.attention_sink import ( enable_attention_sink, ) with torch.device("meta"): model = construct_transformer(self.params) model.to_empty(device="cpu") model = enable_attention_sink( module=model, params=self.params, sink_size=4, window_size=100, eviction_batch_size=1, ) # Check that rope is replaced for layer in model.layers: rope = layer.attention.rope self.assertIsInstance(rope, RopeWithAttentionSink)class IntegrationTest(unittest.TestCase): """Integration tests for end-to-end scenarios.""" def setUp(self): torch.manual_seed(42) def test_cache_positions_consistency(self): """Test that cache positions remain consistent during generation.""" cache_size = 32 sink_size = 4 window_size = 14 n_heads = 8 head_dim = 64 params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, max_context_len=1024, n_heads=n_heads, n_kv_heads=n_heads, dim=n_heads * head_dim, ) rope = RopeWithAttentionSink( params=params, window_size=window_size, sink_size=sink_size, eviction_batch_size=1, ) kv_cache = KVCacheWithAttentionSink( n_heads=n_heads, head_dim=head_dim, enable_dynamic_shape=True, rope=rope, max_batch_size=1, window_size=window_size, sink_size=sink_size, eviction_batch_size=1, dtype=torch.float32, ) # Generate 100 tokens for pos in range(100): k = torch.randn(1, n_heads, 1, head_dim) v = torch.randn(1, n_heads, 1, head_dim) input_pos = torch.tensor([pos], dtype=torch.long) kv_cache.update(input_pos, k, v) # Create mask and verify it's valid mask = kv_cache.create_causal_mask_for_ring_buffer(pos, 1) # Mask should not be all -inf (would mean no tokens to attend to) non_inf_count = (mask != float('-inf')).sum().item() self.assertGreater(non_inf_count, 0, f"At pos {pos}, mask should have visible tokens") # For positions >= sink_size, sinks should always be visible if pos >= sink_size: for i in range(sink_size): cache_pos = kv_cache.cache_positions_manager.cache_positions[i].item() if cache_pos < sink_size: # This is actually a sink token self.assertEqual(mask[0, i].item(), 0.0, f"Sink at idx {i} should be visible at pos {pos}")if __name__ == '__main__': unittest.main() \ No newline at end of file +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for the ring-buffer based attention sink implementation. + +This tests the torch.export-compatible implementation that uses a ring buffer +for the sliding window rather than explicit token eviction. + +Usage: + # Run with pytest + python -m pytest examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py -v + + # Or run directly + python examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py +""" + +import unittest + +import torch +from executorch.examples.models.llama.model_args import ModelArgs + +from executorch.examples.models.llama.source_transformation.attention_sink import ( + CachePositionsManagerWithSink, + KVCacheWithAttentionSink, + RopeWithAttentionSink, + _create_causal_mask_for_attention_sink, +) + + +class CachePositionsManagerWithSinkTest(unittest.TestCase): + """Test the cache positions manager for ring buffer indexing.""" + + def setUp(self): + self.cache_size = 32 # Total cache size (e.g., sink_size + window_size * 2) + self.manager = CachePositionsManagerWithSink(self.cache_size) + + def test_initial_positions_are_zero(self): + """Cache positions should start as zeros.""" + expected = torch.zeros(self.cache_size, dtype=torch.long) + torch.testing.assert_close(self.manager.cache_positions, expected) + + def test_simple_update(self): + """Test simple sequential update without wraparound.""" + input_pos = torch.tensor([0], dtype=torch.long) + seq_len = 5 + indices = self.manager.calculate_positions_and_update_indices(input_pos, seq_len) + + # Should return indices 0, 1, 2, 3, 4 + expected_indices = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long) + torch.testing.assert_close(indices, expected_indices) + + # Cache positions at those indices should be the original positions + for i in range(5): + self.assertEqual(self.manager.cache_positions[i].item(), i) + + def test_wraparound(self): + """Test ring buffer wraparound.""" + # Fill cache to position 30 + input_pos = torch.tensor([0], dtype=torch.long) + self.manager.calculate_positions_and_update_indices(input_pos, 30) + + # Add 5 more tokens at position 30 - should wrap around + input_pos = torch.tensor([30], dtype=torch.long) + indices = self.manager.calculate_positions_and_update_indices(input_pos, 5) + + # Indices should wrap: 30 % 32 = 30, 31, 0, 1, 2 + expected_indices = torch.tensor([30, 31, 0, 1, 2], dtype=torch.long) + torch.testing.assert_close(indices, expected_indices) + + def test_cache_positions_track_original_positions(self): + """Cache positions should track which original position is at each index.""" + # Fill with positions 0-31 + input_pos = torch.tensor([0], dtype=torch.long) + self.manager.calculate_positions_and_update_indices(input_pos, 32) + + # Now add position 32 which wraps to index 0 + input_pos = torch.tensor([32], dtype=torch.long) + self.manager.calculate_positions_and_update_indices(input_pos, 1) + + # Index 0 should now contain original position 32 + self.assertEqual(self.manager.cache_positions[0].item(), 32) + + +class CausalMaskTest(unittest.TestCase): + """Test the causal mask generation for attention sink.""" + + def test_mask_allows_sink_tokens(self): + """Sink tokens should always be visible (mask = 0).""" + cache_size = 32 + sink_size = 4 + window_size = 14 # cache_size = sink_size + window_size * 2 + + # Create cache positions where positions 0-3 are sink tokens + cache_positions = torch.arange(cache_size, dtype=torch.long) + + start_pos = 20 # Query at position 20 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Sink tokens (indices 0-3, original positions 0-3) should have mask = 0 + for i in range(sink_size): + self.assertEqual(mask[0, i].item(), 0.0, f"Sink token at index {i} should be visible") + + def test_mask_blocks_future_tokens(self): + """Future tokens should be masked (-inf).""" + cache_size = 32 + sink_size = 4 + window_size = 14 + + cache_positions = torch.arange(cache_size, dtype=torch.long) + + start_pos = 10 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Future tokens (positions > 10) should have mask = -inf + for i in range(11, cache_size): + self.assertEqual(mask[0, i].item(), float('-inf'), f"Future token at position {i} should be masked") + + def test_mask_respects_window(self): + """Tokens outside the window should be masked.""" + cache_size = 32 + sink_size = 4 + window_size = 5 # Only allow 5 recent tokens + + cache_positions = torch.arange(cache_size, dtype=torch.long) + + start_pos = 20 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Positions 16-20 should be visible (within window of 5) + for pos in range(16, 21): + self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible (in window)") + + # Position 15 should be masked (outside window, not a sink) + self.assertEqual(mask[0, 15].item(), float('-inf'), f"Position 15 should be masked (outside window)") + + +class KVCacheWithAttentionSinkTest(unittest.TestCase): + """Test the KV cache with attention sink.""" + + def setUp(self): + torch.manual_seed(42) + self.window_size = 14 + self.sink_size = 4 + self.n_heads = 8 + self.head_dim = 64 + self.max_batch_size = 1 + + # Create model args with enough context for RoPE + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, # Large enough for RoPE + n_heads=self.n_heads, + n_kv_heads=self.n_heads, + dim=self.n_heads * self.head_dim, + ) + + self.rope = RopeWithAttentionSink( + params=self.params, + window_size=self.window_size, + sink_size=self.sink_size, + eviction_batch_size=1, + ) + + self.kv_cache = KVCacheWithAttentionSink( + n_heads=self.n_heads, + head_dim=self.head_dim, + enable_dynamic_shape=True, + rope=self.rope, + max_batch_size=self.max_batch_size, + window_size=self.window_size, + sink_size=self.sink_size, + eviction_batch_size=1, + dtype=torch.float32, + ) + + def test_cache_size(self): + """Cache should be sink_size + window_size * 2.""" + expected_size = self.sink_size + self.window_size * 2 # 4 + 28 = 32 + self.assertEqual(self.kv_cache.k_cache.size(2), expected_size) + self.assertEqual(self.kv_cache.v_cache.size(2), expected_size) + + def test_is_ring_buffer(self): + """Cache should be marked as ring buffer.""" + self.assertTrue(self.kv_cache.is_ring_buffer) + + def test_update_stores_kv(self): + """Update should store key-value pairs.""" + k = torch.randn(1, self.n_heads, 5, self.head_dim) + v = torch.randn(1, self.n_heads, 5, self.head_dim) + input_pos = torch.tensor([0], dtype=torch.long) + + k_out, v_out = self.kv_cache.update(input_pos, k, v) + + # First 5 positions should contain our values + torch.testing.assert_close(k_out[:, :, :5, :], k) + torch.testing.assert_close(v_out[:, :, :5, :], v) + + def test_evict_tokens_returns_zero(self): + """Ring buffer implementation doesn't shift, so evict returns 0.""" + input_pos = torch.tensor([100], dtype=torch.long) + shift = self.kv_cache.evict_tokens(input_pos, 10) + self.assertEqual(shift, 0) + + def test_extended_generation(self): + """Test that cache works for positions beyond cache size.""" + cache_size = self.kv_cache.k_cache.size(2) + + # Fill cache with initial tokens + for pos in range(cache_size + 50): + k = torch.randn(1, self.n_heads, 1, self.head_dim) + v = torch.randn(1, self.n_heads, 1, self.head_dim) + input_pos = torch.tensor([pos], dtype=torch.long) + + k_out, v_out = self.kv_cache.update(input_pos, k, v) + + # Should not raise any errors + self.assertEqual(k_out.shape, self.kv_cache.k_cache.shape) + self.assertEqual(v_out.shape, self.kv_cache.v_cache.shape) + + +class RopeWithAttentionSinkTest(unittest.TestCase): + """Test RoPE for attention sink.""" + + def setUp(self): + torch.manual_seed(42) + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, + n_heads=8, + dim=512, + ) + + self.rope = RopeWithAttentionSink( + params=self.params, + window_size=100, + sink_size=4, + eviction_batch_size=1, + ) + + def test_get_freqs_uses_original_position(self): + """RoPE frequencies should use the original position.""" + input_pos = torch.tensor([50], dtype=torch.long) + seq_len = 5 + + freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seq_len) + + # Should get frequencies for positions 50-54 + expected_cos = self.rope.freqs_cos[50:55] + expected_sin = self.rope.freqs_sin[50:55] + + torch.testing.assert_close(freqs_cos, expected_cos) + torch.testing.assert_close(freqs_sin, expected_sin) + + def test_rerotate_k(self): + """Test re-rotation of k from one position to another.""" + batch_size = 1 + seq_len = 8 + n_heads = self.params.n_heads + head_dim = self.params.dim // n_heads + + k = torch.randn(batch_size, seq_len, n_heads, head_dim) + q = torch.randn(batch_size, seq_len, n_heads, head_dim) + + # Rotate k at position 100 + original_pos = 100 + freqs_cos, freqs_sin = self.rope.get_freqs( + torch.tensor([original_pos], dtype=torch.long), seq_len + ) + _, rotated_k = self.rope.forward(q, k, freqs_cos, freqs_sin) + + # Re-rotate to position 50 + new_pos = 50 + rerotated_k = self.rope.rerotate_k(rotated_k, original_pos, new_pos) + + # This should be equivalent to directly rotating k at position 50 + freqs_cos_new, freqs_sin_new = self.rope.get_freqs( + torch.tensor([new_pos], dtype=torch.long), seq_len + ) + _, expected_k = self.rope.forward(q, k, freqs_cos_new, freqs_sin_new) + + torch.testing.assert_close(rerotated_k, expected_k, rtol=1e-4, atol=1e-4) + + +class CausalMaskWithWraparoundTest(unittest.TestCase): + """Test causal mask with ring buffer wraparound.""" + + def test_mask_after_wraparound(self): + """Test mask after cache has wrapped around.""" + cache_size = 16 + sink_size = 4 + window_size = 6 # cache_size = sink_size + window_size * 2 + + # Simulate cache after generating beyond cache_size: + # The ring buffer wraps, so indices 0-15 contain positions that wrap + # At position 50, with cache_size=16, the cache contains: + # positions 50-15=35 to 49 at various indices + cache_positions = torch.zeros(cache_size, dtype=torch.long) + # Fill with positions that would exist after generating 50 tokens + # idx = pos % cache_size, so: + # pos 34-49 occupy indices 2-15 and 0-1 + for pos in range(34, 50): + idx = pos % cache_size + cache_positions[idx] = pos + + start_pos = 49 # Query at position 49 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Positions within window (49-6+1=44 to 49) should be visible + visible_count = 0 + for i in range(cache_size): + pos = cache_positions[i].item() + if pos >= 44 and pos <= 49: # In window + self.assertEqual(mask[0, i].item(), 0.0, + f"Position {pos} at idx {i} should be visible (in window)") + visible_count += 1 + + # Should have some visible tokens + self.assertGreater(visible_count, 0, "Should have visible tokens in window") + + def test_mask_with_sink_size_zero(self): + """Test pure sliding window (sink_size=0).""" + cache_size = 16 + sink_size = 0 + window_size = 8 + + cache_positions = torch.arange(cache_size, dtype=torch.long) + start_pos = 10 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Positions 3-10 should be visible (within window of 8) + for pos in range(3, 11): + self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible") + + # Positions 0-2 should be masked (outside window) + for pos in range(0, 3): + self.assertEqual(mask[0, pos].item(), float('-inf'), + f"Position {pos} should be masked (outside window)") + + +class PrefillTest(unittest.TestCase): + """Test prefill scenarios.""" + + def setUp(self): + torch.manual_seed(42) + self.window_size = 14 + self.sink_size = 4 + self.n_heads = 8 + self.head_dim = 64 + + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, + n_heads=self.n_heads, + n_kv_heads=self.n_heads, + dim=self.n_heads * self.head_dim, + ) + + self.rope = RopeWithAttentionSink( + params=self.params, + window_size=self.window_size, + sink_size=self.sink_size, + eviction_batch_size=1, + ) + + self.kv_cache = KVCacheWithAttentionSink( + n_heads=self.n_heads, + head_dim=self.head_dim, + enable_dynamic_shape=True, + rope=self.rope, + max_batch_size=1, + window_size=self.window_size, + sink_size=self.sink_size, + eviction_batch_size=1, + dtype=torch.float32, + ) + + def test_prefill_entire_cache(self): + """Test prefill that fills entire cache.""" + cache_size = self.kv_cache.k_cache.size(2) + + k = torch.randn(1, self.n_heads, cache_size, self.head_dim) + v = torch.randn(1, self.n_heads, cache_size, self.head_dim) + input_pos = torch.tensor([0], dtype=torch.long) + + k_out, v_out = self.kv_cache.update(input_pos, k, v) + + # All positions should be filled + torch.testing.assert_close(k_out, k) + torch.testing.assert_close(v_out, v) + + def test_prefill_larger_than_cache_raises_error(self): + """Test that prefill larger than cache size raises an assertion error.""" + cache_size = self.kv_cache.k_cache.size(2) + seq_len = cache_size + 10 + + k = torch.randn(1, self.n_heads, seq_len, self.head_dim) + v = torch.randn(1, self.n_heads, seq_len, self.head_dim) + input_pos = torch.tensor([0], dtype=torch.long) + + # This should raise an assertion error since seq_len > cache_size + with self.assertRaises(AssertionError): + self.kv_cache.update(input_pos, k, v) + + def test_prefill_followed_by_decode(self): + """Test prefill followed by decode steps.""" + cache_size = self.kv_cache.k_cache.size(2) + + # Prefill with 20 tokens + k_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) + v_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) + input_pos = torch.tensor([0], dtype=torch.long) + self.kv_cache.update(input_pos, k_prefill, v_prefill) + + # Decode 5 more tokens + for i in range(5): + k_decode = torch.randn(1, self.n_heads, 1, self.head_dim) + v_decode = torch.randn(1, self.n_heads, 1, self.head_dim) + input_pos = torch.tensor([20 + i], dtype=torch.long) + k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode) + + # Verify cache positions are updated + expected_pos = 20 + i + cache_idx = expected_pos % cache_size + self.assertEqual( + self.kv_cache.cache_positions_manager.cache_positions[cache_idx].item(), + expected_pos + ) + + +class EnableAttentionSinkTest(unittest.TestCase): + """Test the enable_attention_sink transformation.""" + + def setUp(self): + torch.manual_seed(42) + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, + n_heads=8, + n_kv_heads=8, + dim=512, + n_layers=2, + vocab_size=100, + ) + + def test_enable_attention_sink_transforms_model(self): + """Test that enable_attention_sink properly transforms the model.""" + from executorch.examples.models.llama.llama_transformer import construct_transformer + from executorch.examples.models.llama.source_transformation.attention_sink import ( + enable_attention_sink, + ) + + # Create a simple transformer + with torch.device("meta"): + model = construct_transformer(self.params) + model.to_empty(device="cpu") + + # Apply attention sink transformation + model = enable_attention_sink( + module=model, + params=self.params, + sink_size=4, + window_size=100, + eviction_batch_size=1, + ) + + # Check that KV caches are replaced + for layer in model.layers: + kv_cache = layer.attention.kv_cache + self.assertIsInstance(kv_cache, KVCacheWithAttentionSink) + self.assertEqual(kv_cache.sink_size, 4) + self.assertEqual(kv_cache.window_size, 100) + self.assertTrue(kv_cache.is_ring_buffer) + + def test_enable_attention_sink_replaces_rope(self): + """Test that RoPE is replaced with RopeWithAttentionSink.""" + from executorch.examples.models.llama.llama_transformer import construct_transformer + from executorch.examples.models.llama.source_transformation.attention_sink import ( + enable_attention_sink, + ) + + with torch.device("meta"): + model = construct_transformer(self.params) + model.to_empty(device="cpu") + + model = enable_attention_sink( + module=model, + params=self.params, + sink_size=4, + window_size=100, + eviction_batch_size=1, + ) + + # Check that rope is replaced + for layer in model.layers: + rope = layer.attention.rope + self.assertIsInstance(rope, RopeWithAttentionSink) + + +class IntegrationTest(unittest.TestCase): + """Integration tests for end-to-end scenarios.""" + + def setUp(self): + torch.manual_seed(42) + + def test_cache_positions_consistency(self): + """Test that cache positions remain consistent during generation.""" + cache_size = 32 + sink_size = 4 + window_size = 14 + n_heads = 8 + head_dim = 64 + + params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, + n_heads=n_heads, + n_kv_heads=n_heads, + dim=n_heads * head_dim, + ) + + rope = RopeWithAttentionSink( + params=params, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=1, + ) + + kv_cache = KVCacheWithAttentionSink( + n_heads=n_heads, + head_dim=head_dim, + enable_dynamic_shape=True, + rope=rope, + max_batch_size=1, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=1, + dtype=torch.float32, + ) + + # Generate 100 tokens + for pos in range(100): + k = torch.randn(1, n_heads, 1, head_dim) + v = torch.randn(1, n_heads, 1, head_dim) + input_pos = torch.tensor([pos], dtype=torch.long) + + kv_cache.update(input_pos, k, v) + + # Create mask and verify it's valid + mask = kv_cache.create_causal_mask_for_ring_buffer(pos, 1) + + # Mask should not be all -inf (would mean no tokens to attend to) + non_inf_count = (mask != float('-inf')).sum().item() + self.assertGreater(non_inf_count, 0, f"At pos {pos}, mask should have visible tokens") + + # For positions >= sink_size, sinks should always be visible + if pos >= sink_size: + for i in range(sink_size): + cache_pos = kv_cache.cache_positions_manager.cache_positions[i].item() + if cache_pos < sink_size: # This is actually a sink token + self.assertEqual(mask[0, i].item(), 0.0, + f"Sink at idx {i} should be visible at pos {pos}") + + +if __name__ == '__main__': + unittest.main() diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index 7924d24082f..11d77f321bf 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -1 +1,310 @@ -/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated */// A simple llama2 runner that includes preprocessing and post processing logic.// The module takes in a string as input and emits a string as output.#include #include #include #include #include #include #include #include namespace executorch::extension::llm {using ::executorch::extension::Module;using ::executorch::runtime::Error;using ::executorch::runtime::Result;TextLLMRunner::TextLLMRunner( std::unordered_map metadata, std::unique_ptr<::tokenizers::Tokenizer> tokenizer, std::unique_ptr<::executorch::extension::Module> module, std::unique_ptr text_decoder_runner, std::unique_ptr text_prefiller, std::unique_ptr io_manager, std::unique_ptr text_token_generator, std::unique_ptr stats, float temperature) : tokenizer_(std::move(tokenizer)), metadata_(std::move(metadata)), module_(std::move(module)), text_decoder_runner_(std::move(text_decoder_runner)), text_prefiller_(std::move(text_prefiller)), io_manager_(std::move(io_manager)), text_token_generator_(std::move(text_token_generator)), stats_(std::move(stats)), temperature_(temperature), pos_(0) { // Note: This constructor assumes that text_prefiller and text_token_generator // already have references to the Module and TextDecoderRunner they need}bool TextLLMRunner::is_loaded() const { return text_prefiller_->is_loaded() && text_token_generator_->is_loaded();}Error TextLLMRunner::load() { if (is_loaded()) { return Error::Ok; } ET_CHECK_OK_OR_RETURN_ERROR(text_prefiller_->load()); ET_CHECK_OK_OR_RETURN_ERROR(io_manager_->load()); ET_CHECK_OK_OR_RETURN_ERROR(text_token_generator_->load()); return Error::Ok;}// Don't print with the same priority during warmup#define RUNNER_ET_LOG(warmup, format, ...) \ if (warmup) { \ ET_LOG(Debug, format, __VA_ARGS__); \ } else { \ ET_LOG(Info, format, __VA_ARGS__); \ }Error TextLLMRunner::generate( const std::string& prompt, const GenerationConfig& config, std::function token_callback, std::function stats_callback) { // Prepare the inputs. // Use ones-initialized inputs. ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); if (!is_loaded()) { stats_->model_load_start_ms = time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(load()); stats_->model_load_end_ms = time_in_ms(); } if (config.warming) { ET_LOG(Info, "Doing a warmup run..."); } RUNNER_ET_LOG( config.warming, "RSS after loading model: %f MiB (0 if unsupported)", get_rss_bytes() / 1024.0 / 1024.0); // Wrap the token_callback with print function std::function wrapped_callback = [token_callback, config](const std::string& piece) { if (!config.warming) { llm::safe_printf(piece.c_str()); fflush(stdout); } if (token_callback) { token_callback(piece); } }; // First token time only measures the time it takes to encode the prompt and // return a response token. stats_->inference_start_ms = time_in_ms(); shouldStop_ = false; ::tokenizers::Result> encode_res = tokenizer_->encode( prompt, /*bos=*/config.num_bos, /*eos=*/config.num_eos); if (!encode_res.ok()) { ET_LOG( Error, "Failed to encode prompt %s. Tokenizers error code %d", prompt.c_str(), static_cast(encode_res.error())); return Error::InvalidArgument; } // encode the (string) prompt into tokens sequence std::vector prompt_tokens = encode_res.get(); int num_prompt_tokens = prompt_tokens.size(); // Get max_seq_len for single prefill chunk limit int64_t max_seq_len = metadata_.at(kMaxSeqLen); int64_t max_context_len = metadata_.at(kMaxContextLen); ET_CHECK_OR_RETURN_ERROR( num_prompt_tokens >= 1, InvalidArgument, "Expected at least 1 prompt token"); // For models with sliding window (Ring Buffer / Attention Sink), // we allow pos_ to exceed max_context_len. The model handles this // internally via ring buffer indexing or token eviction. // We only check that a single prefill chunk doesn't exceed max_seq_len. ET_CHECK_OR_RETURN_ERROR( num_prompt_tokens <= max_seq_len, InvalidArgument, "num_prompt_tokens %d > max_seq_len %" PRId64 ", Single prefill chunk too large - please reduce prompt size or increase max_seq_len", num_prompt_tokens, max_seq_len); // Determine max_new_tokens using the GenerationConfig's resolve method. // For sliding window models, we use max_context_len directly (not reduced by pos_) // because the model handles position wrapping internally via ring buffer. int max_new_tokens = config.resolve_max_new_tokens(max_context_len, num_prompt_tokens); // TEMPORARY: For sliding window / infinite context testing, // override max_new_tokens to allow unlimited generation if (max_new_tokens <= 0) { max_new_tokens = 1000000; // Effectively unlimited } // If user specified seq_len, use that instead if (config.seq_len > 0 && config.seq_len > max_new_tokens) { max_new_tokens = config.seq_len; } ET_LOG( Info, "Max new tokens resolved: %d, given pos_ %" PRId64 ", num_prompt_tokens %zu, max_context_len %" PRId64, max_new_tokens, pos_, prompt_tokens.size(), max_context_len); ET_CHECK_OR_RETURN_ERROR( max_new_tokens > 0, InvalidArgument, "Max new tokens %d is less than or equal to 0", max_new_tokens); // Prefill first // Here feed all tokens to the model and get the next predicted token // after the prompt. After that we will enter generate loop. // print prompts if (config.echo) { wrapped_callback(prompt); } auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); uint64_t cur_token = prefill_res.get(); stats_->first_token_ms = time_in_ms(); stats_->prompt_eval_end_ms = time_in_ms(); // print the first token from prefill. No prev_token so use cur_token for it. auto decode_result = tokenizer_->decode(cur_token, cur_token); if (!decode_result.ok()) { ET_LOG( Error, "Tokenizers error code %d", static_cast(decode_result.error())); return ::executorch::runtime::Error::InvalidArgument; } wrapped_callback(std::move(*decode_result)); RUNNER_ET_LOG( config.warming, "RSS after prompt prefill: %f MiB (0 if unsupported)", get_rss_bytes() / 1024.0 / 1024.0); // start the main loop prompt_tokens.push_back(cur_token); // Set ignore_eos based on config text_token_generator_->set_ignore_eos(config.ignore_eos); // Generate max_new_tokens - 1 because prefill already generated 1 token. auto generate_result = text_token_generator_->generate( prompt_tokens, pos_, max_new_tokens - 1, temperature_ == -1.0f ? config.temperature : temperature_, wrapped_callback); if (!generate_result.ok()) { return generate_result.error(); } int64_t num_generated_tokens = generate_result.get(); pos_ += num_generated_tokens; stats_->inference_end_ms = time_in_ms(); if (!config.warming) { printf("\n"); } RUNNER_ET_LOG( config.warming, "RSS after finishing text generation: %f MiB (0 if unsupported)", get_rss_bytes() / 1024.0 / 1024.0); if (num_generated_tokens == max_new_tokens) { RUNNER_ET_LOG(config.warming, "Max new tokens %i reached!", max_new_tokens); } stats_->num_prompt_tokens = num_prompt_tokens; stats_->num_generated_tokens = num_generated_tokens; if (config.warming) { ET_LOG(Info, "Warmup run finished!"); } else { // Do not print report during warmup print_report(*stats_); } if (stats_callback) { stats_callback(*stats_); } return Error::Ok;}Error TextLLMRunner::prefill( const std::string& prompt, const GenerationConfig& config) { if (!is_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(load()); } ::tokenizers::Result> encode_res = tokenizer_->encode( prompt, /*bos=*/config.num_bos, /*eos=*/config.num_eos); ET_CHECK_TK_OK_OR_RETURN_ERROR( encode_res.error(), "Failed to encode prompt %s", prompt.c_str()); // encode the (string) prompt into tokens sequence std::vector prompt_tokens = encode_res.get(); auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); return Error::Ok;}Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) { // Create a GenerationConfig for warmup GenerationConfig config; config.echo = false; config.max_new_tokens = max_new_tokens; config.warming = true; // Call generate with the warmup config Error err = generate(prompt, config); // Reset stats after warmup, not resetting the std::unique_ptr! reset(); return err;}void TextLLMRunner::stop() { if (is_loaded()) { text_token_generator_->stop(); } else { ET_LOG(Error, "Token generator is not loaded, cannot stop"); }}void TextLLMRunner::reset() { stats_->reset(); pos_ = 0;}} // namespace executorch::extension::llm \ No newline at end of file +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated + */ + +// A simple llama2 runner that includes preprocessing and post processing logic. +// The module takes in a string as input and emits a string as output. + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace executorch::extension::llm { + +using ::executorch::extension::Module; +using ::executorch::runtime::Error; +using ::executorch::runtime::Result; + +TextLLMRunner::TextLLMRunner( + std::unordered_map metadata, + std::unique_ptr<::tokenizers::Tokenizer> tokenizer, + std::unique_ptr<::executorch::extension::Module> module, + std::unique_ptr text_decoder_runner, + std::unique_ptr text_prefiller, + std::unique_ptr io_manager, + std::unique_ptr text_token_generator, + std::unique_ptr stats, + float temperature) + : tokenizer_(std::move(tokenizer)), + metadata_(std::move(metadata)), + module_(std::move(module)), + text_decoder_runner_(std::move(text_decoder_runner)), + text_prefiller_(std::move(text_prefiller)), + io_manager_(std::move(io_manager)), + text_token_generator_(std::move(text_token_generator)), + stats_(std::move(stats)), + temperature_(temperature), + pos_(0) { + // Note: This constructor assumes that text_prefiller and text_token_generator + // already have references to the Module and TextDecoderRunner they need +} + +bool TextLLMRunner::is_loaded() const { + return text_prefiller_->is_loaded() && text_token_generator_->is_loaded(); +} + +Error TextLLMRunner::load() { + if (is_loaded()) { + return Error::Ok; + } + ET_CHECK_OK_OR_RETURN_ERROR(text_prefiller_->load()); + ET_CHECK_OK_OR_RETURN_ERROR(io_manager_->load()); + ET_CHECK_OK_OR_RETURN_ERROR(text_token_generator_->load()); + return Error::Ok; +} + +// Don't print with the same priority during warmup +#define RUNNER_ET_LOG(warmup, format, ...) \ + if (warmup) { \ + ET_LOG(Debug, format, __VA_ARGS__); \ + } else { \ + ET_LOG(Info, format, __VA_ARGS__); \ + } + +Error TextLLMRunner::generate( + const std::string& prompt, + const GenerationConfig& config, + std::function token_callback, + std::function stats_callback) { + // Prepare the inputs. + // Use ones-initialized inputs. + ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); + if (!is_loaded()) { + stats_->model_load_start_ms = time_in_ms(); + ET_CHECK_OK_OR_RETURN_ERROR(load()); + stats_->model_load_end_ms = time_in_ms(); + } + + if (config.warming) { + ET_LOG(Info, "Doing a warmup run..."); + } + + RUNNER_ET_LOG( + config.warming, + "RSS after loading model: %f MiB (0 if unsupported)", + get_rss_bytes() / 1024.0 / 1024.0); + + // Wrap the token_callback with print function + std::function wrapped_callback = + [token_callback, config](const std::string& piece) { + if (!config.warming) { + llm::safe_printf(piece.c_str()); + fflush(stdout); + } + if (token_callback) { + token_callback(piece); + } + }; + // First token time only measures the time it takes to encode the prompt and + // return a response token. + + stats_->inference_start_ms = time_in_ms(); + shouldStop_ = false; + + ::tokenizers::Result> encode_res = tokenizer_->encode( + prompt, + /*bos=*/config.num_bos, + /*eos=*/config.num_eos); + + if (!encode_res.ok()) { + ET_LOG( + Error, + "Failed to encode prompt %s. Tokenizers error code %d", + prompt.c_str(), + static_cast(encode_res.error())); + return Error::InvalidArgument; + } + + // encode the (string) prompt into tokens sequence + std::vector prompt_tokens = encode_res.get(); + int num_prompt_tokens = prompt_tokens.size(); + + // Get max_seq_len for single prefill chunk limit + int64_t max_seq_len = metadata_.at(kMaxSeqLen); + int64_t max_context_len = metadata_.at(kMaxContextLen); + + ET_CHECK_OR_RETURN_ERROR( + num_prompt_tokens >= 1, + InvalidArgument, + "Expected at least 1 prompt token"); + + // For models with sliding window (Ring Buffer / Attention Sink), + // we allow pos_ to exceed max_context_len. The model handles this + // internally via ring buffer indexing or token eviction. + // We only check that a single prefill chunk doesn't exceed max_seq_len. + ET_CHECK_OR_RETURN_ERROR( + num_prompt_tokens <= max_seq_len, + InvalidArgument, + "num_prompt_tokens %d > max_seq_len %" PRId64 + ", Single prefill chunk too large - please reduce prompt size or increase max_seq_len", + num_prompt_tokens, + max_seq_len); + + // Determine max_new_tokens using the GenerationConfig's resolve method. + // For sliding window models, we use max_context_len directly (not reduced by pos_) + // because the model handles position wrapping internally via ring buffer. + int max_new_tokens = + config.resolve_max_new_tokens(max_context_len, num_prompt_tokens); + + // TEMPORARY: For sliding window / infinite context testing, + // override max_new_tokens to allow unlimited generation + if (max_new_tokens <= 0) { + max_new_tokens = 1000000; // Effectively unlimited + } + // If user specified seq_len, use that instead + if (config.seq_len > 0 && config.seq_len > max_new_tokens) { + max_new_tokens = config.seq_len; + } + + ET_LOG( + Info, + "Max new tokens resolved: %d, given pos_ %" PRId64 + ", num_prompt_tokens %zu, max_context_len %" PRId64, + max_new_tokens, + pos_, + prompt_tokens.size(), + max_context_len); + ET_CHECK_OR_RETURN_ERROR( + max_new_tokens > 0, + InvalidArgument, + "Max new tokens %d is less than or equal to 0", + max_new_tokens); + // Prefill first + // Here feed all tokens to the model and get the next predicted token + // after the prompt. After that we will enter generate loop. + + // print prompts + if (config.echo) { + wrapped_callback(prompt); + } + auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); + ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); + uint64_t cur_token = prefill_res.get(); + stats_->first_token_ms = time_in_ms(); + stats_->prompt_eval_end_ms = time_in_ms(); + + // print the first token from prefill. No prev_token so use cur_token for it. + auto decode_result = tokenizer_->decode(cur_token, cur_token); + if (!decode_result.ok()) { + ET_LOG( + Error, + "Tokenizers error code %d", + static_cast(decode_result.error())); + return ::executorch::runtime::Error::InvalidArgument; + } + wrapped_callback(std::move(*decode_result)); + RUNNER_ET_LOG( + config.warming, + "RSS after prompt prefill: %f MiB (0 if unsupported)", + get_rss_bytes() / 1024.0 / 1024.0); + + // start the main loop + prompt_tokens.push_back(cur_token); + + // Set ignore_eos based on config + text_token_generator_->set_ignore_eos(config.ignore_eos); + + // Generate max_new_tokens - 1 because prefill already generated 1 token. + auto generate_result = text_token_generator_->generate( + prompt_tokens, + pos_, + max_new_tokens - 1, + temperature_ == -1.0f ? config.temperature : temperature_, + wrapped_callback); + if (!generate_result.ok()) { + return generate_result.error(); + } + int64_t num_generated_tokens = generate_result.get(); + + pos_ += num_generated_tokens; + + stats_->inference_end_ms = time_in_ms(); + if (!config.warming) { + printf("\n"); + } + RUNNER_ET_LOG( + config.warming, + "RSS after finishing text generation: %f MiB (0 if unsupported)", + get_rss_bytes() / 1024.0 / 1024.0); + + if (num_generated_tokens == max_new_tokens) { + RUNNER_ET_LOG(config.warming, "Max new tokens %i reached!", max_new_tokens); + } + + stats_->num_prompt_tokens = num_prompt_tokens; + stats_->num_generated_tokens = num_generated_tokens; + + if (config.warming) { + ET_LOG(Info, "Warmup run finished!"); + } else { + // Do not print report during warmup + print_report(*stats_); + } + if (stats_callback) { + stats_callback(*stats_); + } + + return Error::Ok; +} + +Error TextLLMRunner::prefill( + const std::string& prompt, + const GenerationConfig& config) { + if (!is_loaded()) { + ET_CHECK_OK_OR_RETURN_ERROR(load()); + } + + ::tokenizers::Result> encode_res = tokenizer_->encode( + prompt, + /*bos=*/config.num_bos, + /*eos=*/config.num_eos); + + ET_CHECK_TK_OK_OR_RETURN_ERROR( + encode_res.error(), "Failed to encode prompt %s", prompt.c_str()); + + // encode the (string) prompt into tokens sequence + std::vector prompt_tokens = encode_res.get(); + auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); + ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); + return Error::Ok; +} + +Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) { + // Create a GenerationConfig for warmup + GenerationConfig config; + config.echo = false; + config.max_new_tokens = max_new_tokens; + config.warming = true; + + // Call generate with the warmup config + Error err = generate(prompt, config); + + // Reset stats after warmup, not resetting the std::unique_ptr! + reset(); + return err; +} + +void TextLLMRunner::stop() { + if (is_loaded()) { + text_token_generator_->stop(); + } else { + ET_LOG(Error, "Token generator is not loaded, cannot stop"); + } +} + +void TextLLMRunner::reset() { + stats_->reset(); + pos_ = 0; +} + +} // namespace executorch::extension::llm From 04fadb5df35e0626e89fbe7176296e7d1260decb Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Feb 2026 14:41:41 -0800 Subject: [PATCH 07/11] Fix attention sink index calculation to preserve sink tokens --- .../source_transformation/attention_sink.py | 37 +- .../test_attention_sink.py | 514 ------------------ .../test_attention_sink_ring_buffer.py | 46 +- 3 files changed, 66 insertions(+), 531 deletions(-) delete mode 100644 examples/models/llama/source_transformation/test_attention_sink.py diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index c2e0f6606f5..d8a6919f7d1 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -144,31 +144,56 @@ class CachePositionsManagerWithSink(nn.Module): IMPORTANT: cache_size should be the actual cache dimension size (2x window for ring buffer). """ - def __init__(self, cache_size: int): + def __init__(self, cache_size: int, sink_size: int = 0): super().__init__() # cache_size is the actual size of the kv cache dimension self.max_context_length = cache_size + self.sink_size = sink_size # Use zeros like original CachePositionsManager self.register_buffer( "cache_positions", torch.zeros((self.max_context_length,), dtype=torch.long, device="cpu"), ) - + def calculate_positions_and_update_indices( self, input_pos: torch.Tensor, seq_len: int ) -> torch.Tensor: """ Calculate indices into k_cache, v_cache for placing k_val, v_val. - This is identical to the original CachePositionsManager logic. + Indices logic: + - If pos < sink_size: index = pos + - If pos >= sink_size: index = sink_size + (pos - sink_size) % (cache_size - sink_size) """ start_pos = input_pos[0].item() torch._check_is_size(start_pos) orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos - # Simple ring buffer: just mod by cache size - indices = orig_indices % self.max_context_length + if self.sink_size == 0: + # Simple ring buffer: just mod by cache size + indices = orig_indices % self.max_context_length + else: + # Shifted ring buffer logic + ring_size = self.max_context_length - self.sink_size + + # Calculate indices based on sink vs ring buffer logic + # Logic: + # 1. Calculate potential ring buffer index: sink_size + (pos - sink_size) % ring_size + # 2. If pos < sink_size, use pos. Else use ring buffer index. + + # Note: (pos - sink_size) % ring_size works correctly even if pos < sink_size + # in Python, but we want to be explicit. + # However, for pure torch.export compatibility without conditionals on tensors, + # we can use where or math. + + # Vectorized calculation: + shifted_pos = orig_indices - self.sink_size + ring_indices = self.sink_size + (shifted_pos % ring_size) + + # If position is within sink (0..sink_size-1), use original position + # Else use ring index + indices = torch.where(orig_indices < self.sink_size, orig_indices, ring_indices) # Update cache_positions exactly like original CachePositionsManager full_t = torch.full((self.max_context_length,), -1, dtype=torch.long) @@ -222,7 +247,7 @@ def __init__( # Cache positions manager for determining write locations # Pass the total cache size (same as self.max_context_length after super().__init__) - self.cache_positions_manager = CachePositionsManagerWithSink(total_cache_size) + self.cache_positions_manager = CachePositionsManagerWithSink(total_cache_size, sink_size) def create_causal_mask_for_ring_buffer( self, start_pos: torch.Tensor, seq_len: int diff --git a/examples/models/llama/source_transformation/test_attention_sink.py b/examples/models/llama/source_transformation/test_attention_sink.py deleted file mode 100644 index fc882ebf4ab..00000000000 --- a/examples/models/llama/source_transformation/test_attention_sink.py +++ /dev/null @@ -1,514 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -import torch -from executorch.examples.models.llama.model_args import ModelArgs - -from executorch.examples.models.llama.source_transformation.attention_sink import ( - KVCacheWithAttentionSink, - RopeWithAttentionSink, -) -from parameterized import parameterized - - -class RopeWithAttentionSinkTest(unittest.TestCase): - - def _init_rope(self, params: ModelArgs, eviction_batch_size: int): - return RopeWithAttentionSink( - params=params, - window_size=252, - sink_size=4, - eviction_batch_size=eviction_batch_size, - ) - - def setUp(self): - torch.manual_seed(42) - self.params = ModelArgs( - use_kv_cache=True, enable_dynamic_shape=True, max_context_len=256 - ) - self.rope_with_attention_sink = self._init_rope( - params=self.params, eviction_batch_size=1 - ) - - @parameterized.expand( - [ - [0, 10, 1, 0], # No shift - [250, 10, 1, 246], # Some shift - [256, 10, 1, 246], # All shift - [0, 10, 30, 0], # No shift with batch eviction - [250, 10, 30, 220], # Some shift with batch eviction - [256, 10, 30, 226], # All shift with batch eviction - ] - ) - def test_get_freqs( - self, input_pos, seq_len, eviction_batch_size, expected_result_pos - ): - self.rope_with_attention_sink = self._init_rope( - params=self.params, eviction_batch_size=eviction_batch_size - ) - - freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( - input_pos=torch.tensor([input_pos], dtype=torch.int32), - seq_len=seq_len, - ) - - torch.testing.assert_close( - freqs_cos, - self.rope_with_attention_sink.freqs_cos.narrow( - 0, expected_result_pos, seq_len - ), - ) - torch.testing.assert_close( - freqs_sin, - self.rope_with_attention_sink.freqs_sin.narrow( - 0, expected_result_pos, seq_len - ), - ) - - @parameterized.expand( - [ - [128, 127], # Rotate left - [128, 128], # No rotation - [128, 129], # Rotate right - ] - ) - def test_rotate(self, original_position, new_position): - seq_len = 32 - - size = (1, seq_len, self.params.n_heads, self.params.head_dim) - q = torch.rand(*size, dtype=torch.float32) - k = torch.rand( - *size, - dtype=torch.float32, - ) - freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( - input_pos=torch.tensor([original_position], dtype=torch.int32), - seq_len=seq_len, - ) - _, pre_rotated_k = self.rope_with_attention_sink.forward( - q=q, - k=k, - freqs_cos=freqs_cos, - freqs_sin=freqs_sin, - ) - - rerotated_k = self.rope_with_attention_sink.rerotate_k( - k=pre_rotated_k, - original_position=original_position, - new_position=new_position, - ) - - freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( - input_pos=torch.tensor([new_position], dtype=torch.int32), - seq_len=seq_len, - ) - _, expected_k = self.rope_with_attention_sink.forward( - q=q, - k=k, - freqs_cos=freqs_cos, - freqs_sin=freqs_sin, - ) - - torch.testing.assert_close(rerotated_k, expected_k) - - -class KVCacheWithAttentionSinkTest(unittest.TestCase): - - _single_evict_test_cases = [ - [4, 1], - ] - - _batch_evict_test_cases = [ - [4, 8], - ] - - _sliding_window_test_cases = [ - [0, 1], - ] - - def _init_cache(self, sink_size, eviction_batch_size): - self.params = ModelArgs( - use_kv_cache=True, - enable_dynamic_shape=True, - max_context_len=self.window_size + sink_size, - ) - self.rope_with_attention_sink = RopeWithAttentionSink( - params=self.params, - window_size=self.window_size, - sink_size=sink_size, - eviction_batch_size=eviction_batch_size, - ) - self.kv_cache = KVCacheWithAttentionSink( - n_heads=self.params.n_heads, - head_dim=self.params.head_dim, - enable_dynamic_shape=self.params.enable_dynamic_shape, - rope=self.rope_with_attention_sink, - max_batch_size=self.max_batch_size, - window_size=self.window_size, - sink_size=sink_size, - eviction_batch_size=eviction_batch_size, - dtype=self.dtype, - ) - - def _rand_kv_with_length(self, seq_len): - size = ( - self.max_batch_size, - self.params.n_heads, - seq_len, - self.params.head_dim, - ) - k = torch.rand( - *size, - dtype=self.dtype, - ) - v = torch.rand( - *size, - dtype=self.dtype, - ) - return k, v - - def _zero_kv_with_length(self, seq_len): - size = ( - self.max_batch_size, - self.params.n_heads, - seq_len, - self.params.head_dim, - ) - k = torch.zeros( - *size, - dtype=self.dtype, - ) - v = torch.zeros( - *size, - dtype=self.dtype, - ) - return k, v - - def _get_dim_to_slice(self): - return 2 - - def _get_expected_rotated_k(self, k, original_position, new_position): - return self.rope_with_attention_sink.rerotate_k( - k=k.transpose(1, 2), - original_position=original_position, - new_position=new_position, - ).transpose(1, 2) - - def setUp(self): - torch.manual_seed(42) - self.max_batch_size = 1 - self.window_size = 28 - self.dtype = torch.float32 - - @parameterized.expand( - _single_evict_test_cases + _batch_evict_test_cases + _sliding_window_test_cases - ) - def test_evict_empty_cache(self, sink_size, eviction_batch_size): - self._init_cache(sink_size, eviction_batch_size) - - # KV cache is empty, evict does nothing - input_pos = torch.tensor([0], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 1) == 0 - - expected_k, expected_v = self._zero_kv_with_length(self.window_size + sink_size) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand( - _single_evict_test_cases + _batch_evict_test_cases + _sliding_window_test_cases - ) - def test_evict_without_shift(self, sink_size, eviction_batch_size): - dimension_to_slice = 2 - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has enough spaces for new tokens, no shift - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(10) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([10], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 1) == 0 - - zero_k, zero_v = self._zero_kv_with_length(self.window_size + sink_size - 10) - - expected_k = torch.cat( - [ - k, - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v, - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_single_evict_test_cases) - def test_evict_with_some_shift(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has some spaces for new tokens but not all, shift some tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([10], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 24) == -2 - - zero_k, zero_v = self._zero_kv_with_length(24) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k(k1.narrow(dimension_to_slice, 1, 4), 6, 4), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 1, 4), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_single_evict_test_cases) - def test_evict_with_all_shift(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has no spaces for new tokens, shift all tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(27) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([32], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 6) == -6 - - zero_k, zero_v = self._zero_kv_with_length(6) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 5, 22), 10, 4 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 5, 22), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_sliding_window_test_cases) - def test_evict_with_some_shift_for_sliding_window( - self, sink_size, eviction_batch_size - ): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has some spaces for new tokens but not all, shift some tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([10], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 20) == -2 - - zero_k, zero_v = self._zero_kv_with_length(20) - expected_k = torch.cat( - [ - self._get_expected_rotated_k(k.narrow(dimension_to_slice, 2, 3), 2, 0), - self._get_expected_rotated_k(k1, 5, 3), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 2, 3), - v1, - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_sliding_window_test_cases) - def test_evict_with_all_shift_for_sliding_window( - self, sink_size, eviction_batch_size - ): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has no spaces for new tokens, shift all tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(23) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([28], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 6) == -6 - - zero_k, zero_v = self._zero_kv_with_length(6) - expected_k = torch.cat( - [ - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 1, 22), 6, 0 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v1.narrow(dimension_to_slice, 1, 22), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_batch_evict_test_cases) - def test_batch_evict_with_seq_len(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has some spaces for new tokens but not all, shift some tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(25) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([30], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 12) == -10 - - zero_k, zero_v = self._zero_kv_with_length(12) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 9, 16), 14, 4 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 9, 16), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_batch_evict_test_cases) - def test_batch_evict_with_batch_size(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has no spaces for new tokens, shift all tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(25) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([30], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 6) == -8 - - zero_k, zero_v = self._zero_kv_with_length(10) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 7, 18), 12, 4 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 7, 18), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) diff --git a/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py index 7109129caca..485a5b13bdd 100644 --- a/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py +++ b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py @@ -59,30 +59,54 @@ def test_simple_update(self): def test_wraparound(self): """Test ring buffer wraparound.""" + # sink_size=0 (default) in setUp, so it wraps to 0 + # Let's test non-zero sink size + sink_size = 4 + cache_size = 32 + manager = CachePositionsManagerWithSink(cache_size, sink_size) + # Fill cache to position 30 input_pos = torch.tensor([0], dtype=torch.long) - self.manager.calculate_positions_and_update_indices(input_pos, 30) + manager.calculate_positions_and_update_indices(input_pos, 30) - # Add 5 more tokens at position 30 - should wrap around + # Add 5 more tokens at position 30 input_pos = torch.tensor([30], dtype=torch.long) - indices = self.manager.calculate_positions_and_update_indices(input_pos, 5) - - # Indices should wrap: 30 % 32 = 30, 31, 0, 1, 2 - expected_indices = torch.tensor([30, 31, 0, 1, 2], dtype=torch.long) + indices = manager.calculate_positions_and_update_indices(input_pos, 5) + + # Indices logic: + # Ring size = 32 - 4 = 28 + # pos 30 -> idx = 4 + (30-4)%28 = 4 + 26 = 30 + # pos 31 -> idx = 4 + (31-4)%28 = 4 + 27 = 31 + # pos 32 -> idx = 4 + (32-4)%28 = 4 + 0 = 4 (WRAPS TO SINK_SIZE=4) + # pos 33 -> idx = 4 + (33-4)%28 = 4 + 1 = 5 + # pos 34 -> idx = 4 + (34-4)%28 = 4 + 2 = 6 + expected_indices = torch.tensor([30, 31, 4, 5, 6], dtype=torch.long) torch.testing.assert_close(indices, expected_indices) def test_cache_positions_track_original_positions(self): """Cache positions should track which original position is at each index.""" + sink_size = 4 + cache_size = 32 + manager = CachePositionsManagerWithSink(cache_size, sink_size) + # Fill with positions 0-31 input_pos = torch.tensor([0], dtype=torch.long) - self.manager.calculate_positions_and_update_indices(input_pos, 32) + manager.calculate_positions_and_update_indices(input_pos, 32) + + # Indices 0-3 should have pos 0-3 (Sink) + for i in range(4): + self.assertEqual(manager.cache_positions[i].item(), i) - # Now add position 32 which wraps to index 0 + # Now add position 32. + # (32-4)%28 = 0. So index = 4 + 0 = 4. input_pos = torch.tensor([32], dtype=torch.long) - self.manager.calculate_positions_and_update_indices(input_pos, 1) + manager.calculate_positions_and_update_indices(input_pos, 1) - # Index 0 should now contain original position 32 - self.assertEqual(self.manager.cache_positions[0].item(), 32) + # Index 4 should now contain original position 32 + self.assertEqual(manager.cache_positions[4].item(), 32) + + # Index 0 (sink) should STILL contain position 0 + self.assertEqual(manager.cache_positions[0].item(), 0) class CausalMaskTest(unittest.TestCase): From 97ba715656fc75eac3fb9537b89ff413479ce7c9 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Feb 2026 14:56:10 -0800 Subject: [PATCH 08/11] Update --- .../source_transformation/attention_sink.py | 64 +-- .../test_attention_sink.py | 514 ++++++++++++++++++ .../test_attention_sink_ring_buffer.py | 56 +- 3 files changed, 580 insertions(+), 54 deletions(-) create mode 100644 examples/models/llama/source_transformation/test_attention_sink.py diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index d8a6919f7d1..ba8b5298d66 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -138,71 +138,57 @@ class CachePositionsManagerWithSink(nn.Module): """ Manages cache positions for attention sink + sliding window. - For sink_size=0: behaves exactly like original CachePositionsManager. - For sink_size>0: sink tokens go to fixed positions, rest uses ring buffer. + For sink_size=0: behaves exactly like original CachePositionsManager (simple ring buffer). + For sink_size>0: sink tokens (indices 0 to sink_size-1) are NEVER overwritten. + Ring buffer only cycles through indices sink_size to cache_size-1. - IMPORTANT: cache_size should be the actual cache dimension size (2x window for ring buffer). + IMPORTANT: cache_size should be the actual cache dimension size (sink_size + 2*window_size). """ def __init__(self, cache_size: int, sink_size: int = 0): super().__init__() - # cache_size is the actual size of the kv cache dimension self.max_context_length = cache_size self.sink_size = sink_size + # Ring buffer size = cache_size - sink_size + self.ring_size = cache_size - sink_size # Use zeros like original CachePositionsManager self.register_buffer( "cache_positions", torch.zeros((self.max_context_length,), dtype=torch.long, device="cpu"), ) - + def calculate_positions_and_update_indices( self, input_pos: torch.Tensor, seq_len: int ) -> torch.Tensor: """ Calculate indices into k_cache, v_cache for placing k_val, v_val. - Indices logic: - - If pos < sink_size: index = pos - - If pos >= sink_size: index = sink_size + (pos - sink_size) % (cache_size - sink_size) + Index calculation: + - Position < sink_size: index = position (sink tokens at fixed indices) + - Position >= sink_size: index = sink_size + (position - sink_size) % ring_size + + This ensures sink tokens (indices 0 to sink_size-1) are NEVER overwritten. """ start_pos = input_pos[0].item() torch._check_is_size(start_pos) - orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos + # Original positions for the sequence + orig_positions = torch.arange(seq_len, dtype=torch.long) + start_pos if self.sink_size == 0: # Simple ring buffer: just mod by cache size - indices = orig_indices % self.max_context_length + indices = orig_positions % self.max_context_length else: - # Shifted ring buffer logic - ring_size = self.max_context_length - self.sink_size - - # Calculate indices based on sink vs ring buffer logic - # Logic: - # 1. Calculate potential ring buffer index: sink_size + (pos - sink_size) % ring_size - # 2. If pos < sink_size, use pos. Else use ring buffer index. - - # Note: (pos - sink_size) % ring_size works correctly even if pos < sink_size - # in Python, but we want to be explicit. - # However, for pure torch.export compatibility without conditionals on tensors, - # we can use where or math. - - # Vectorized calculation: - shifted_pos = orig_indices - self.sink_size - ring_indices = self.sink_size + (shifted_pos % ring_size) - - # If position is within sink (0..sink_size-1), use original position - # Else use ring index - indices = torch.where(orig_indices < self.sink_size, orig_indices, ring_indices) - - # Update cache_positions exactly like original CachePositionsManager - full_t = torch.full((self.max_context_length,), -1, dtype=torch.long) - arange_tensor = torch.arange(self.max_context_length, dtype=torch.long) - cache_positions = torch.where( - arange_tensor < start_pos, self.cache_positions, full_t - ) - self.cache_positions.copy_(cache_positions) - self.cache_positions.index_copy_(0, indices, orig_indices) + # Shifted ring buffer: sink tokens at fixed indices, rest in ring buffer + # For position >= sink_size: index = sink_size + (position - sink_size) % ring_size + shifted = orig_positions - self.sink_size + ring_indices = self.sink_size + (shifted % self.ring_size) + # For position < sink_size: use position directly + indices = torch.where(orig_positions < self.sink_size, orig_positions, ring_indices) + + # Update cache_positions to track what position is at each index + # Only update the indices we're writing to + self.cache_positions.index_copy_(0, indices, orig_positions) return indices diff --git a/examples/models/llama/source_transformation/test_attention_sink.py b/examples/models/llama/source_transformation/test_attention_sink.py new file mode 100644 index 00000000000..fc882ebf4ab --- /dev/null +++ b/examples/models/llama/source_transformation/test_attention_sink.py @@ -0,0 +1,514 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.examples.models.llama.model_args import ModelArgs + +from executorch.examples.models.llama.source_transformation.attention_sink import ( + KVCacheWithAttentionSink, + RopeWithAttentionSink, +) +from parameterized import parameterized + + +class RopeWithAttentionSinkTest(unittest.TestCase): + + def _init_rope(self, params: ModelArgs, eviction_batch_size: int): + return RopeWithAttentionSink( + params=params, + window_size=252, + sink_size=4, + eviction_batch_size=eviction_batch_size, + ) + + def setUp(self): + torch.manual_seed(42) + self.params = ModelArgs( + use_kv_cache=True, enable_dynamic_shape=True, max_context_len=256 + ) + self.rope_with_attention_sink = self._init_rope( + params=self.params, eviction_batch_size=1 + ) + + @parameterized.expand( + [ + [0, 10, 1, 0], # No shift + [250, 10, 1, 246], # Some shift + [256, 10, 1, 246], # All shift + [0, 10, 30, 0], # No shift with batch eviction + [250, 10, 30, 220], # Some shift with batch eviction + [256, 10, 30, 226], # All shift with batch eviction + ] + ) + def test_get_freqs( + self, input_pos, seq_len, eviction_batch_size, expected_result_pos + ): + self.rope_with_attention_sink = self._init_rope( + params=self.params, eviction_batch_size=eviction_batch_size + ) + + freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( + input_pos=torch.tensor([input_pos], dtype=torch.int32), + seq_len=seq_len, + ) + + torch.testing.assert_close( + freqs_cos, + self.rope_with_attention_sink.freqs_cos.narrow( + 0, expected_result_pos, seq_len + ), + ) + torch.testing.assert_close( + freqs_sin, + self.rope_with_attention_sink.freqs_sin.narrow( + 0, expected_result_pos, seq_len + ), + ) + + @parameterized.expand( + [ + [128, 127], # Rotate left + [128, 128], # No rotation + [128, 129], # Rotate right + ] + ) + def test_rotate(self, original_position, new_position): + seq_len = 32 + + size = (1, seq_len, self.params.n_heads, self.params.head_dim) + q = torch.rand(*size, dtype=torch.float32) + k = torch.rand( + *size, + dtype=torch.float32, + ) + freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( + input_pos=torch.tensor([original_position], dtype=torch.int32), + seq_len=seq_len, + ) + _, pre_rotated_k = self.rope_with_attention_sink.forward( + q=q, + k=k, + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + ) + + rerotated_k = self.rope_with_attention_sink.rerotate_k( + k=pre_rotated_k, + original_position=original_position, + new_position=new_position, + ) + + freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( + input_pos=torch.tensor([new_position], dtype=torch.int32), + seq_len=seq_len, + ) + _, expected_k = self.rope_with_attention_sink.forward( + q=q, + k=k, + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + ) + + torch.testing.assert_close(rerotated_k, expected_k) + + +class KVCacheWithAttentionSinkTest(unittest.TestCase): + + _single_evict_test_cases = [ + [4, 1], + ] + + _batch_evict_test_cases = [ + [4, 8], + ] + + _sliding_window_test_cases = [ + [0, 1], + ] + + def _init_cache(self, sink_size, eviction_batch_size): + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=self.window_size + sink_size, + ) + self.rope_with_attention_sink = RopeWithAttentionSink( + params=self.params, + window_size=self.window_size, + sink_size=sink_size, + eviction_batch_size=eviction_batch_size, + ) + self.kv_cache = KVCacheWithAttentionSink( + n_heads=self.params.n_heads, + head_dim=self.params.head_dim, + enable_dynamic_shape=self.params.enable_dynamic_shape, + rope=self.rope_with_attention_sink, + max_batch_size=self.max_batch_size, + window_size=self.window_size, + sink_size=sink_size, + eviction_batch_size=eviction_batch_size, + dtype=self.dtype, + ) + + def _rand_kv_with_length(self, seq_len): + size = ( + self.max_batch_size, + self.params.n_heads, + seq_len, + self.params.head_dim, + ) + k = torch.rand( + *size, + dtype=self.dtype, + ) + v = torch.rand( + *size, + dtype=self.dtype, + ) + return k, v + + def _zero_kv_with_length(self, seq_len): + size = ( + self.max_batch_size, + self.params.n_heads, + seq_len, + self.params.head_dim, + ) + k = torch.zeros( + *size, + dtype=self.dtype, + ) + v = torch.zeros( + *size, + dtype=self.dtype, + ) + return k, v + + def _get_dim_to_slice(self): + return 2 + + def _get_expected_rotated_k(self, k, original_position, new_position): + return self.rope_with_attention_sink.rerotate_k( + k=k.transpose(1, 2), + original_position=original_position, + new_position=new_position, + ).transpose(1, 2) + + def setUp(self): + torch.manual_seed(42) + self.max_batch_size = 1 + self.window_size = 28 + self.dtype = torch.float32 + + @parameterized.expand( + _single_evict_test_cases + _batch_evict_test_cases + _sliding_window_test_cases + ) + def test_evict_empty_cache(self, sink_size, eviction_batch_size): + self._init_cache(sink_size, eviction_batch_size) + + # KV cache is empty, evict does nothing + input_pos = torch.tensor([0], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 1) == 0 + + expected_k, expected_v = self._zero_kv_with_length(self.window_size + sink_size) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand( + _single_evict_test_cases + _batch_evict_test_cases + _sliding_window_test_cases + ) + def test_evict_without_shift(self, sink_size, eviction_batch_size): + dimension_to_slice = 2 + + self._init_cache(sink_size, eviction_batch_size) + + # KV cache has enough spaces for new tokens, no shift + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(10) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([10], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 1) == 0 + + zero_k, zero_v = self._zero_kv_with_length(self.window_size + sink_size - 10) + + expected_k = torch.cat( + [ + k, + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v, + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_single_evict_test_cases) + def test_evict_with_some_shift(self, sink_size, eviction_batch_size): + dimension_to_slice = self._get_dim_to_slice() + + self._init_cache(sink_size, eviction_batch_size) + + # KV cache has some spaces for new tokens but not all, shift some tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(5) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([10], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 24) == -2 + + zero_k, zero_v = self._zero_kv_with_length(24) + expected_k = torch.cat( + [ + k.narrow(dimension_to_slice, 0, sink_size), + self._get_expected_rotated_k(k1.narrow(dimension_to_slice, 1, 4), 6, 4), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v.narrow(dimension_to_slice, 0, sink_size), + v1.narrow(dimension_to_slice, 1, 4), + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_single_evict_test_cases) + def test_evict_with_all_shift(self, sink_size, eviction_batch_size): + dimension_to_slice = self._get_dim_to_slice() + + self._init_cache(sink_size, eviction_batch_size) + + # KV cache has no spaces for new tokens, shift all tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(27) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([32], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 6) == -6 + + zero_k, zero_v = self._zero_kv_with_length(6) + expected_k = torch.cat( + [ + k.narrow(dimension_to_slice, 0, sink_size), + self._get_expected_rotated_k( + k1.narrow(dimension_to_slice, 5, 22), 10, 4 + ), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v.narrow(dimension_to_slice, 0, sink_size), + v1.narrow(dimension_to_slice, 5, 22), + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_sliding_window_test_cases) + def test_evict_with_some_shift_for_sliding_window( + self, sink_size, eviction_batch_size + ): + dimension_to_slice = self._get_dim_to_slice() + + self._init_cache(sink_size, eviction_batch_size) + + # KV cache has some spaces for new tokens but not all, shift some tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(5) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([10], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 20) == -2 + + zero_k, zero_v = self._zero_kv_with_length(20) + expected_k = torch.cat( + [ + self._get_expected_rotated_k(k.narrow(dimension_to_slice, 2, 3), 2, 0), + self._get_expected_rotated_k(k1, 5, 3), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v.narrow(dimension_to_slice, 2, 3), + v1, + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_sliding_window_test_cases) + def test_evict_with_all_shift_for_sliding_window( + self, sink_size, eviction_batch_size + ): + dimension_to_slice = self._get_dim_to_slice() + + self._init_cache(sink_size, eviction_batch_size) + + # KV cache has no spaces for new tokens, shift all tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(23) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([28], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 6) == -6 + + zero_k, zero_v = self._zero_kv_with_length(6) + expected_k = torch.cat( + [ + self._get_expected_rotated_k( + k1.narrow(dimension_to_slice, 1, 22), 6, 0 + ), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v1.narrow(dimension_to_slice, 1, 22), + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_batch_evict_test_cases) + def test_batch_evict_with_seq_len(self, sink_size, eviction_batch_size): + dimension_to_slice = self._get_dim_to_slice() + + self._init_cache(sink_size, eviction_batch_size) + + # KV cache has some spaces for new tokens but not all, shift some tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(25) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([30], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 12) == -10 + + zero_k, zero_v = self._zero_kv_with_length(12) + expected_k = torch.cat( + [ + k.narrow(dimension_to_slice, 0, sink_size), + self._get_expected_rotated_k( + k1.narrow(dimension_to_slice, 9, 16), 14, 4 + ), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v.narrow(dimension_to_slice, 0, sink_size), + v1.narrow(dimension_to_slice, 9, 16), + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) + + @parameterized.expand(_batch_evict_test_cases) + def test_batch_evict_with_batch_size(self, sink_size, eviction_batch_size): + dimension_to_slice = self._get_dim_to_slice() + + self._init_cache(sink_size, eviction_batch_size) + + # KV cache has no spaces for new tokens, shift all tokens + input_pos = torch.tensor([0], dtype=torch.int32) + k, v = self._rand_kv_with_length(5) + + self.kv_cache.update(input_pos, k, v) + + input_pos = torch.tensor([5], dtype=torch.int32) + k1, v1 = self._rand_kv_with_length(25) + + self.kv_cache.update(input_pos, k1, v1) + + input_pos = torch.tensor([30], dtype=torch.int32) + assert self.kv_cache.evict_tokens(input_pos, 6) == -8 + + zero_k, zero_v = self._zero_kv_with_length(10) + expected_k = torch.cat( + [ + k.narrow(dimension_to_slice, 0, sink_size), + self._get_expected_rotated_k( + k1.narrow(dimension_to_slice, 7, 18), 12, 4 + ), + zero_k, + ], + dim=dimension_to_slice, + ) + expected_v = torch.cat( + [ + v.narrow(dimension_to_slice, 0, sink_size), + v1.narrow(dimension_to_slice, 7, 18), + zero_v, + ], + dim=dimension_to_slice, + ) + + torch.testing.assert_close(self.kv_cache.k_cache, expected_k) + torch.testing.assert_close(self.kv_cache.v_cache, expected_v) diff --git a/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py index 485a5b13bdd..92f54b8e468 100644 --- a/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py +++ b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py @@ -36,7 +36,8 @@ class CachePositionsManagerWithSinkTest(unittest.TestCase): def setUp(self): self.cache_size = 32 # Total cache size (e.g., sink_size + window_size * 2) - self.manager = CachePositionsManagerWithSink(self.cache_size) + # Default: no sink (simple ring buffer) + self.manager = CachePositionsManagerWithSink(self.cache_size, sink_size=0) def test_initial_positions_are_zero(self): """Cache positions should start as zeros.""" @@ -57,10 +58,22 @@ def test_simple_update(self): for i in range(5): self.assertEqual(self.manager.cache_positions[i].item(), i) - def test_wraparound(self): - """Test ring buffer wraparound.""" - # sink_size=0 (default) in setUp, so it wraps to 0 - # Let's test non-zero sink size + def test_wraparound_no_sink(self): + """Test ring buffer wraparound with sink_size=0.""" + # Fill cache to position 30 + input_pos = torch.tensor([0], dtype=torch.long) + self.manager.calculate_positions_and_update_indices(input_pos, 30) + + # Add 5 more tokens at position 30 - should wrap around + input_pos = torch.tensor([30], dtype=torch.long) + indices = self.manager.calculate_positions_and_update_indices(input_pos, 5) + + # Indices should wrap: 30 % 32 = 30, 31, 0, 1, 2 + expected_indices = torch.tensor([30, 31, 0, 1, 2], dtype=torch.long) + torch.testing.assert_close(indices, expected_indices) + + def test_wraparound_with_sink(self): + """Test ring buffer wraparound with sink_size > 0.""" sink_size = 4 cache_size = 32 manager = CachePositionsManagerWithSink(cache_size, sink_size) @@ -73,18 +86,30 @@ def test_wraparound(self): input_pos = torch.tensor([30], dtype=torch.long) indices = manager.calculate_positions_and_update_indices(input_pos, 5) - # Indices logic: # Ring size = 32 - 4 = 28 # pos 30 -> idx = 4 + (30-4)%28 = 4 + 26 = 30 # pos 31 -> idx = 4 + (31-4)%28 = 4 + 27 = 31 - # pos 32 -> idx = 4 + (32-4)%28 = 4 + 0 = 4 (WRAPS TO SINK_SIZE=4) + # pos 32 -> idx = 4 + (32-4)%28 = 4 + 0 = 4 (WRAPS TO SINK_SIZE=4, not 0!) # pos 33 -> idx = 4 + (33-4)%28 = 4 + 1 = 5 # pos 34 -> idx = 4 + (34-4)%28 = 4 + 2 = 6 expected_indices = torch.tensor([30, 31, 4, 5, 6], dtype=torch.long) torch.testing.assert_close(indices, expected_indices) - def test_cache_positions_track_original_positions(self): - """Cache positions should track which original position is at each index.""" + def test_cache_positions_track_original_positions_no_sink(self): + """Cache positions should track which original position is at each index (no sink).""" + # Fill with positions 0-31 + input_pos = torch.tensor([0], dtype=torch.long) + self.manager.calculate_positions_and_update_indices(input_pos, 32) + + # Now add position 32 which wraps to index 0 + input_pos = torch.tensor([32], dtype=torch.long) + self.manager.calculate_positions_and_update_indices(input_pos, 1) + + # Index 0 should now contain original position 32 + self.assertEqual(self.manager.cache_positions[0].item(), 32) + + def test_cache_positions_track_original_positions_with_sink(self): + """Cache positions should track positions, and sink tokens are never overwritten.""" sink_size = 4 cache_size = 32 manager = CachePositionsManagerWithSink(cache_size, sink_size) @@ -93,20 +118,21 @@ def test_cache_positions_track_original_positions(self): input_pos = torch.tensor([0], dtype=torch.long) manager.calculate_positions_and_update_indices(input_pos, 32) - # Indices 0-3 should have pos 0-3 (Sink) + # Indices 0-3 should have pos 0-3 (Sink tokens) for i in range(4): self.assertEqual(manager.cache_positions[i].item(), i) - - # Now add position 32. + + # Now add position 32. # (32-4)%28 = 0. So index = 4 + 0 = 4. input_pos = torch.tensor([32], dtype=torch.long) manager.calculate_positions_and_update_indices(input_pos, 1) - + # Index 4 should now contain original position 32 self.assertEqual(manager.cache_positions[4].item(), 32) - # Index 0 (sink) should STILL contain position 0 - self.assertEqual(manager.cache_positions[0].item(), 0) + # Index 0-3 (sink) should STILL contain positions 0-3 (unchanged) + for i in range(4): + self.assertEqual(manager.cache_positions[i].item(), i) class CausalMaskTest(unittest.TestCase): From e4ccab473630315f821080fa50ae2495a65063f9 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Feb 2026 15:14:30 -0800 Subject: [PATCH 09/11] Remove obsolete test_attention_sink.py --- .../test_attention_sink.py | 514 ------------------ 1 file changed, 514 deletions(-) delete mode 100644 examples/models/llama/source_transformation/test_attention_sink.py diff --git a/examples/models/llama/source_transformation/test_attention_sink.py b/examples/models/llama/source_transformation/test_attention_sink.py deleted file mode 100644 index fc882ebf4ab..00000000000 --- a/examples/models/llama/source_transformation/test_attention_sink.py +++ /dev/null @@ -1,514 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -import torch -from executorch.examples.models.llama.model_args import ModelArgs - -from executorch.examples.models.llama.source_transformation.attention_sink import ( - KVCacheWithAttentionSink, - RopeWithAttentionSink, -) -from parameterized import parameterized - - -class RopeWithAttentionSinkTest(unittest.TestCase): - - def _init_rope(self, params: ModelArgs, eviction_batch_size: int): - return RopeWithAttentionSink( - params=params, - window_size=252, - sink_size=4, - eviction_batch_size=eviction_batch_size, - ) - - def setUp(self): - torch.manual_seed(42) - self.params = ModelArgs( - use_kv_cache=True, enable_dynamic_shape=True, max_context_len=256 - ) - self.rope_with_attention_sink = self._init_rope( - params=self.params, eviction_batch_size=1 - ) - - @parameterized.expand( - [ - [0, 10, 1, 0], # No shift - [250, 10, 1, 246], # Some shift - [256, 10, 1, 246], # All shift - [0, 10, 30, 0], # No shift with batch eviction - [250, 10, 30, 220], # Some shift with batch eviction - [256, 10, 30, 226], # All shift with batch eviction - ] - ) - def test_get_freqs( - self, input_pos, seq_len, eviction_batch_size, expected_result_pos - ): - self.rope_with_attention_sink = self._init_rope( - params=self.params, eviction_batch_size=eviction_batch_size - ) - - freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( - input_pos=torch.tensor([input_pos], dtype=torch.int32), - seq_len=seq_len, - ) - - torch.testing.assert_close( - freqs_cos, - self.rope_with_attention_sink.freqs_cos.narrow( - 0, expected_result_pos, seq_len - ), - ) - torch.testing.assert_close( - freqs_sin, - self.rope_with_attention_sink.freqs_sin.narrow( - 0, expected_result_pos, seq_len - ), - ) - - @parameterized.expand( - [ - [128, 127], # Rotate left - [128, 128], # No rotation - [128, 129], # Rotate right - ] - ) - def test_rotate(self, original_position, new_position): - seq_len = 32 - - size = (1, seq_len, self.params.n_heads, self.params.head_dim) - q = torch.rand(*size, dtype=torch.float32) - k = torch.rand( - *size, - dtype=torch.float32, - ) - freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( - input_pos=torch.tensor([original_position], dtype=torch.int32), - seq_len=seq_len, - ) - _, pre_rotated_k = self.rope_with_attention_sink.forward( - q=q, - k=k, - freqs_cos=freqs_cos, - freqs_sin=freqs_sin, - ) - - rerotated_k = self.rope_with_attention_sink.rerotate_k( - k=pre_rotated_k, - original_position=original_position, - new_position=new_position, - ) - - freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( - input_pos=torch.tensor([new_position], dtype=torch.int32), - seq_len=seq_len, - ) - _, expected_k = self.rope_with_attention_sink.forward( - q=q, - k=k, - freqs_cos=freqs_cos, - freqs_sin=freqs_sin, - ) - - torch.testing.assert_close(rerotated_k, expected_k) - - -class KVCacheWithAttentionSinkTest(unittest.TestCase): - - _single_evict_test_cases = [ - [4, 1], - ] - - _batch_evict_test_cases = [ - [4, 8], - ] - - _sliding_window_test_cases = [ - [0, 1], - ] - - def _init_cache(self, sink_size, eviction_batch_size): - self.params = ModelArgs( - use_kv_cache=True, - enable_dynamic_shape=True, - max_context_len=self.window_size + sink_size, - ) - self.rope_with_attention_sink = RopeWithAttentionSink( - params=self.params, - window_size=self.window_size, - sink_size=sink_size, - eviction_batch_size=eviction_batch_size, - ) - self.kv_cache = KVCacheWithAttentionSink( - n_heads=self.params.n_heads, - head_dim=self.params.head_dim, - enable_dynamic_shape=self.params.enable_dynamic_shape, - rope=self.rope_with_attention_sink, - max_batch_size=self.max_batch_size, - window_size=self.window_size, - sink_size=sink_size, - eviction_batch_size=eviction_batch_size, - dtype=self.dtype, - ) - - def _rand_kv_with_length(self, seq_len): - size = ( - self.max_batch_size, - self.params.n_heads, - seq_len, - self.params.head_dim, - ) - k = torch.rand( - *size, - dtype=self.dtype, - ) - v = torch.rand( - *size, - dtype=self.dtype, - ) - return k, v - - def _zero_kv_with_length(self, seq_len): - size = ( - self.max_batch_size, - self.params.n_heads, - seq_len, - self.params.head_dim, - ) - k = torch.zeros( - *size, - dtype=self.dtype, - ) - v = torch.zeros( - *size, - dtype=self.dtype, - ) - return k, v - - def _get_dim_to_slice(self): - return 2 - - def _get_expected_rotated_k(self, k, original_position, new_position): - return self.rope_with_attention_sink.rerotate_k( - k=k.transpose(1, 2), - original_position=original_position, - new_position=new_position, - ).transpose(1, 2) - - def setUp(self): - torch.manual_seed(42) - self.max_batch_size = 1 - self.window_size = 28 - self.dtype = torch.float32 - - @parameterized.expand( - _single_evict_test_cases + _batch_evict_test_cases + _sliding_window_test_cases - ) - def test_evict_empty_cache(self, sink_size, eviction_batch_size): - self._init_cache(sink_size, eviction_batch_size) - - # KV cache is empty, evict does nothing - input_pos = torch.tensor([0], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 1) == 0 - - expected_k, expected_v = self._zero_kv_with_length(self.window_size + sink_size) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand( - _single_evict_test_cases + _batch_evict_test_cases + _sliding_window_test_cases - ) - def test_evict_without_shift(self, sink_size, eviction_batch_size): - dimension_to_slice = 2 - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has enough spaces for new tokens, no shift - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(10) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([10], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 1) == 0 - - zero_k, zero_v = self._zero_kv_with_length(self.window_size + sink_size - 10) - - expected_k = torch.cat( - [ - k, - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v, - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_single_evict_test_cases) - def test_evict_with_some_shift(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has some spaces for new tokens but not all, shift some tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([10], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 24) == -2 - - zero_k, zero_v = self._zero_kv_with_length(24) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k(k1.narrow(dimension_to_slice, 1, 4), 6, 4), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 1, 4), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_single_evict_test_cases) - def test_evict_with_all_shift(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has no spaces for new tokens, shift all tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(27) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([32], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 6) == -6 - - zero_k, zero_v = self._zero_kv_with_length(6) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 5, 22), 10, 4 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 5, 22), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_sliding_window_test_cases) - def test_evict_with_some_shift_for_sliding_window( - self, sink_size, eviction_batch_size - ): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has some spaces for new tokens but not all, shift some tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([10], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 20) == -2 - - zero_k, zero_v = self._zero_kv_with_length(20) - expected_k = torch.cat( - [ - self._get_expected_rotated_k(k.narrow(dimension_to_slice, 2, 3), 2, 0), - self._get_expected_rotated_k(k1, 5, 3), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 2, 3), - v1, - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_sliding_window_test_cases) - def test_evict_with_all_shift_for_sliding_window( - self, sink_size, eviction_batch_size - ): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has no spaces for new tokens, shift all tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(23) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([28], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 6) == -6 - - zero_k, zero_v = self._zero_kv_with_length(6) - expected_k = torch.cat( - [ - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 1, 22), 6, 0 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v1.narrow(dimension_to_slice, 1, 22), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_batch_evict_test_cases) - def test_batch_evict_with_seq_len(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has some spaces for new tokens but not all, shift some tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(25) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([30], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 12) == -10 - - zero_k, zero_v = self._zero_kv_with_length(12) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 9, 16), 14, 4 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 9, 16), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_batch_evict_test_cases) - def test_batch_evict_with_batch_size(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has no spaces for new tokens, shift all tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(25) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([30], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 6) == -8 - - zero_k, zero_v = self._zero_kv_with_length(10) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 7, 18), 12, 4 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 7, 18), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) From bbca983c34c4b1d074008db133e65fc9c782cb03 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Feb 2026 15:36:56 -0800 Subject: [PATCH 10/11] Border --- .../llama/source_transformation/attention_sink.py | 11 ++++++++--- .../test_attention_sink_ring_buffer.py | 6 +++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index ba8b5298d66..3a69d85e61b 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -126,7 +126,11 @@ def _create_causal_mask_for_attention_sink( is_sink = cache_positions < sink_size # Window tokens must be within sliding window - is_in_window = delta < window_size + # Use <= to include the boundary token. For window_size=124, we want to attend + # to the last 124 tokens BEFORE the current position (delta 1 to 124), plus + # position 4 (first non-sink token) which has delta exactly = window_size. + # This ensures sink_size + window_size tokens are visible when cache is full. + is_in_window = delta <= window_size # Final mask: valid AND (is_sink OR is_in_window) attn_mask = is_valid & (is_sink | is_in_window) @@ -151,10 +155,11 @@ def __init__(self, cache_size: int, sink_size: int = 0): self.sink_size = sink_size # Ring buffer size = cache_size - sink_size self.ring_size = cache_size - sink_size - # Use zeros like original CachePositionsManager + # Initialize to -1 to mark unwritten positions + # The mask uses (cache_positions >= 0) to check if a position is valid self.register_buffer( "cache_positions", - torch.zeros((self.max_context_length,), dtype=torch.long, device="cpu"), + torch.full((self.max_context_length,), -1, dtype=torch.long, device="cpu"), ) def calculate_positions_and_update_indices( diff --git a/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py index 92f54b8e468..953033f5cd8 100644 --- a/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py +++ b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py @@ -39,9 +39,9 @@ def setUp(self): # Default: no sink (simple ring buffer) self.manager = CachePositionsManagerWithSink(self.cache_size, sink_size=0) - def test_initial_positions_are_zero(self): - """Cache positions should start as zeros.""" - expected = torch.zeros(self.cache_size, dtype=torch.long) + def test_initial_positions_are_minus_one(self): + """Cache positions should start as -1 (unwritten).""" + expected = torch.full((self.cache_size,), -1, dtype=torch.long) torch.testing.assert_close(self.manager.cache_positions, expected) def test_simple_update(self): From fec5eb9505d9517ca6ca89fe60df2c2b2b7d6002 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Feb 2026 16:06:50 -0800 Subject: [PATCH 11/11] Test commands --- test_attention_sink.md | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 test_attention_sink.md diff --git a/test_attention_sink.md b/test_attention_sink.md new file mode 100644 index 00000000000..0e70723cc20 --- /dev/null +++ b/test_attention_sink.md @@ -0,0 +1,37 @@ +# Build the runtime + +Do this after modifying runtime code (cpp) +```sh +cmake --workflow llm-debug +pushd examples/models/llama +cmake --workflow --preset llama-debug +popd +``` + +# Export model +Take a look at examples/models/llama/README.md + +Check point is in ~/executorch/ + +Make sure you are in conda executorch env + +# No quantization +# Set these paths to point to the downloaded files +LLAMA_CHECKPOINT=path/to/consolidated.00.pth +LLAMA_PARAMS=path/to/params.json + +python -m extension.llm.export.export_llm \ + --config examples/models/llama/config/llama_bf16.yaml \ + +base.model_class="llama3_2" \ + +base.checkpoint="consolidated.00.pth" \ + +base.params="params.json" +``` + +# Run + +Please also take a look at examples/models/llama/runner to make sure it can emit many tokens, exceeding context size. + +Please check whether the output makes sense or not +``` +cmake-out/examples/models/llama/llama_main --model_path= --tokenizer_path= --prompt= +```