diff --git a/examples/models/llama/BUCK b/examples/models/llama/BUCK index 9d9897a2819..57ed846dd86 100644 --- a/examples/models/llama/BUCK +++ b/examples/models/llama/BUCK @@ -283,6 +283,18 @@ fbcode_target(_kind = runtime.python_test, ], ) +fbcode_target(_kind = runtime.python_test, + name = "attention_sink_ring_buffer_test", + srcs = [ + "source_transformation/test_attention_sink_ring_buffer.py", + ], + supports_static_listing = False, + deps = [ + "//caffe2:torch", + ":export_library", + ], +) + fbcode_target(_kind = runtime.python_test, name = "quantized_sdpa_source_transform_test", srcs = [ diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 7b7691a2304..5be41830ea8 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -278,7 +278,7 @@ def __init__( [0, 1, 2, 3, 4, NA, NA, NA] After cache update we would have [8, 1, 2, 3, 4, 5, 6, 7]. We kicked out token at pos = 0. However, the current step still has access to [pos - sliding_window_size, pos] tokens. - + To make sure we dont over attend, i.e. we dont have pos = 5 to attend to pos = 1, mask calculaton has to account for the sliding window size. @@ -486,21 +486,30 @@ def forward( if self.use_kv_cache: assert input_pos is not None - if self.enable_dynamic_shape: + is_ring = getattr(self.kv_cache, "is_ring_buffer", False) + if is_ring: + # Ring buffer models: positions can exceed max_context_len. + # The ring buffer handles wrapping via modular arithmetic. + # The causal mask is computed dynamically from cache_positions, + # so we don't use the pre-computed self.mask here. + k, v = self.kv_cache.update(input_pos, k, v) + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer( + start_pos, seqlen + ) + elif self.enable_dynamic_shape: start_pos = input_pos[-1].item() torch._check_is_size(start_pos) torch._check(start_pos < self.max_context_len) seq_length = q.size(2) # pyre-ignore: Incompatible parameter type [6] attn_mask = self.mask.narrow(0, start_pos, seq_length) + k, v = self.kv_cache.update(input_pos, k, v) else: # mask is always 2D attn_mask = self.mask[input_pos] - k, v = self.kv_cache.update(input_pos, k, v) - if getattr(self.kv_cache, "is_ring_buffer", False): - attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer( - input_pos[0].item(), seqlen - ) + k, v = self.kv_cache.update(input_pos, k, v) output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) return self.wo(output), None diff --git a/examples/models/llama/config/llama_attention_sink.yaml b/examples/models/llama/config/llama_attention_sink.yaml new file mode 100644 index 00000000000..1d859035d74 --- /dev/null +++ b/examples/models/llama/config/llama_attention_sink.yaml @@ -0,0 +1,31 @@ +base: + metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' + +model: + use_sdpa_with_kv_cache: True # Now supported! We set use_attention_mask=True on SDPACustom + use_kv_cache: True + dtype_override: fp32 + enable_dynamic_shape: True + # Attention Sink: "sink_size,window_size,eviction_batch_size" + # sink_size=4: Keep first 4 tokens (e.g., BOS + system prompt) + # window_size=124: 滑动窗口大小 + # eviction_batch_size=1: 每次驱逐 1 个 token + # KV cache size = sink_size + window_size * 2 = 4 + 124*2 = 252 + use_attention_sink: "4,124,1" + +export: + # max_context_length controls the RoPE frequency table size. + # It must be >= sink_size + window_size (128), but larger values are + # recommended to support generation beyond the sliding window. + # The model default (e.g., 8192 or 131072) is typically used if not specified. + # For testing, we use the model's default by not setting this explicitly. + +quantization: + qmode: 8da4w + group_size: 128 + embedding_quantize: 4,32 + +backend: + xnnpack: + enabled: True + extended_ops: True diff --git a/examples/models/llama/config/llama_attention_sink_noxnn.yaml b/examples/models/llama/config/llama_attention_sink_noxnn.yaml new file mode 100644 index 00000000000..17ae69df4d8 --- /dev/null +++ b/examples/models/llama/config/llama_attention_sink_noxnn.yaml @@ -0,0 +1,29 @@ +base: + metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' + +model: + use_sdpa_with_kv_cache: True # Now supported! We set use_attention_mask=True on SDPACustom + use_kv_cache: True + dtype_override: fp32 + enable_dynamic_shape: True + # Attention Sink: "sink_size,window_size,eviction_batch_size" + # sink_size=4: Keep first 4 tokens (e.g., BOS + system prompt) + # window_size=124: 滑动窗口大小 + # eviction_batch_size=1: 每次驱逐 1 个 token + # KV cache size = sink_size + window_size * 2 = 4 + 124*2 = 252 + use_attention_sink: "4,124,1" + +export: + max_seq_length: 252 + max_context_length: 512 + +# Quantization enabled for this test +quantization: + qmode: 8da4w + group_size: 128 + embedding_quantize: 4,32 + +# No XNNPACK for this test +backend: + xnnpack: + enabled: False diff --git a/examples/models/llama/config/llama_attention_sink_xnnpack.yaml b/examples/models/llama/config/llama_attention_sink_xnnpack.yaml new file mode 100644 index 00000000000..6a63f3a914a --- /dev/null +++ b/examples/models/llama/config/llama_attention_sink_xnnpack.yaml @@ -0,0 +1,25 @@ +base: + metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' + +model: + use_sdpa_with_kv_cache: True + use_kv_cache: True + dtype_override: fp32 + enable_dynamic_shape: True + use_attention_sink: "4,124,1" + +export: + max_seq_length: 252 + max_context_length: 512 + +# No quantization +# quantization: +# qmode: 8da4w +# group_size: 128 +# embedding_quantize: 4,32 + +# XNNPACK enabled +backend: + xnnpack: + enabled: True + extended_ops: True diff --git a/examples/models/llama/config/llama_runner_attention_sink.yaml b/examples/models/llama/config/llama_runner_attention_sink.yaml new file mode 100644 index 00000000000..d23180cdd0c --- /dev/null +++ b/examples/models/llama/config/llama_runner_attention_sink.yaml @@ -0,0 +1,41 @@ +## +## Runner-side Attention Sink configuration +## +## This uses KVCacheWithAttentionSink (model-side) together with +## the runner's AttentionSinkIOManager for position bookkeeping. +## +## Key behavior: +## - Model has KVCacheWithAttentionSink which preserves sink tokens and +## uses a ring buffer for the sliding window (is_ring_buffer=True) +## - Runner's AttentionSinkIOManager tracks logical position and allows +## generation to continue past max_context_len +## - KV cache size = sink_size + window_size * 2 = 4 + 124*2 = 252 +## + +base: + metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' + +model: + use_sdpa_with_kv_cache: True + use_kv_cache: True + dtype_override: fp32 + enable_dynamic_shape: True + # Attention Sink: "sink_size,window_size,eviction_batch_size" + # sink_size=4, window_size=124, eviction_batch_size=1 + # Max Context (Buffer) = 4 + 1 * 124 = 128 + use_attention_sink: "4,124,1" + +export: + # max_seq_length for single prefill chunk + max_context_length: 128 + max_seq_length: 128 + +quantization: + qmode: 8da4w + group_size: 128 + embedding_quantize: 4,32 + +backend: + xnnpack: + enabled: True + extended_ops: True diff --git a/examples/models/llama/eval_llama_lib.py b/examples/models/llama/eval_llama_lib.py index 03f8f5cd759..e13a5299e61 100644 --- a/examples/models/llama/eval_llama_lib.py +++ b/examples/models/llama/eval_llama_lib.py @@ -338,6 +338,8 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse Evaluate the model's perplexity when AttentionSink is enabled. This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py + + Updated for the ring-buffer based attention sink implementation. """ # Convert args to LlmConfig from executorch.extension.llm.export.config.llm_config import LlmConfig @@ -351,7 +353,13 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse sink_size = int(attention_sink_params[0]) window_size = int(attention_sink_params[1]) - assert llm_config.export.max_seq_length == sink_size + window_size + # For the ring buffer implementation, the cache size is sink_size + window_size * 2 + # max_context_length should be >= sink_size + window_size (for RoPE frequencies) + # but can be larger to support extended generation + assert llm_config.export.max_context_length >= sink_size + window_size, ( + f"max_context_length ({llm_config.export.max_context_length}) must be >= " + f"sink_size + window_size ({sink_size + window_size})" + ) device = "cuda" if torch.cuda.is_available() else "cpu" manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 219cc71ded1..23d26e1865f 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -40,6 +40,7 @@ get_vulkan_partitioner, get_xnnpack_partitioner, ) +from executorch.examples.models.llama.model_args import ModelArgs from executorch.extension.llm.export.quantizer_lib import ( get_coreml_quantizer, get_ov_quantizer, @@ -57,6 +58,7 @@ get_model_with_r1_r2, ) from .source_transformation.attention import replace_attention_to_attention_sha +from .source_transformation.attention_sink import enable_attention_sink from .source_transformation.custom_kv_cache import ( replace_kv_cache_with_custom_kv_cache, replace_kv_cache_with_quantized_kv_cache, @@ -728,9 +730,16 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: calibration_limit=llm_config.quantization.calibration_limit, calibration_seq_length=llm_config.quantization.calibration_seq_length, expand_rope_table=llm_config.model.expand_rope_table, + # Attention sink models need attention mask for custom SDPA because: + # 1. The ring buffer creates a dynamic mask based on cache_positions + # 2. Without mask, custom_sdpa uses is_causal=True with start_pos, which + # fails when start_pos exceeds the cache size (positions keep growing) + # 3. With mask, custom_sdpa uses is_causal=False and the mask handles + # all masking logic including sliding window and attention sink use_custom_sdpa_with_attention_mask=getattr( llm_config.model, "use_custom_sdpa_with_attention_mask", False - ), + ) + or bool(llm_config.model.use_attention_sink), use_sdpa_with_kv_cache=llm_config.model.use_sdpa_with_kv_cache, quantize_kv_cache=llm_config.model.quantize_kv_cache, use_kv_cache=llm_config.model.use_kv_cache, @@ -750,13 +759,49 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: preq_embedding_quantize=llm_config.base.preq_embedding_quantize, local_global_attention=llm_config.model.local_global_attention, use_torchao_kernels_linear=llm_config.backend.torchao.use_torchao_kernels_linear, + use_torchao_kernels_tied_embedding=llm_config.backend.torchao.use_torchao_kernels_tied_embedding, + use_attention_sink=llm_config.model.use_attention_sink, + params_path=llm_config.base.params, + max_context_len=llm_config.export.max_context_length, ) ) + if llm_config.model.use_attention_sink: + print("Refreshing example inputs for Attention Sink...") + if hasattr(edge_manager.model, "get_example_inputs"): + # The model is now patched to return (tokens, attn_options, cache_indices) + new_inputs = edge_manager.model.get_example_inputs() + # We assume these are all positional arguments + edge_manager.example_inputs = new_inputs + # Clear kwargs since we provide everything positionally + edge_manager.example_kwarg_inputs = {} + print(f"Updated inputs: {len(new_inputs)} items") + + # Update dynamic shapes if enabled + if edge_manager.enable_dynamic_shape: + existing_shapes = edge_manager.dynamic_shapes + if existing_shapes and len(existing_shapes) == 2: + # Extract the Dim object from the first input (tokens) + # tokens shape dict is {1: Dim(...)} + token_dim = existing_shapes[0][1] + + # cache_indices is 1D tensor of size seq_len + # Spec should be {0: token_dim} + indices_spec = {0: token_dim} + + # Relieve static constraint on input_pos + # input_pos spec in existing_shapes[1] is {"input_pos": {0: 1}} + # We change it to {"input_pos": {0: token_dim}} + input_pos_spec = {"input_pos": {0: token_dim}} + + edge_manager.dynamic_shapes = (existing_shapes[0], input_pos_spec, indices_spec) + print("Updated dynamic_shapes for Attention Sink (patched input_pos)") + return edge_manager + def get_quantizer_and_quant_params(llm_config): pt2e_quant_params = get_pt2e_quantization_params( ( @@ -1118,6 +1163,15 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 if llm_config.base.model_class.value in TORCHTUNE_DEFINED_MODELS: additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] + # For attention sink models, the cache_positions buffer must be initialized + # to -1 (sentinel for "empty slot"). Without this pass, ExecuTorch only + # serializes shape+dtype for mutated buffers, leaving them uninitialized + # at runtime, which corrupts the attention mask computation. + if llm_config.model.use_attention_sink: + additional_passes.append( + InitializedMutableBufferPass(["cache_positions"]) + ) + # export_to_edge builder_exported = _prepare_for_llama_export(llm_config).export() builder_exported.run_canonical_optimizations() @@ -1282,6 +1336,28 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager": model_class_name, llm_config=llm_config, ) + + # Add attention sink metadata if enabled + metadata = _load_llama_model_metadata( + llm_config.model.use_kv_cache, + llm_config.model.use_sdpa_with_kv_cache, + llm_config.model.enable_dynamic_shape, + model.max_seq_len, + model.max_context_len, + model.n_layers, + model.vocab_size, + llm_config.base.metadata, + ) + + # Add attention sink metadata if enabled + if llm_config.model.use_attention_sink: + # Format: sink_size,window_size,eviction_batch_size + sink_params = [int(x) for x in llm_config.model.use_attention_sink.split(",")] + # IOManager expects these methods to exist returning int. + # By adding them to metadata, export_to_edge will generate constant methods. + metadata["get_sink_size"] = sink_params[0] + metadata["get_window_size"] = sink_params[1] + # Convert dtype override string to actual type. dtype_override = DType[llm_config.model.dtype_override.value] @@ -1296,31 +1372,14 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager": example_kwarg_inputs=example_kwarg_inputs, dynamic_shapes=dynamic_shapes, enable_dynamic_shape=llm_config.model.enable_dynamic_shape, + save_exported_program=llm_config.export.export_only, calibration_tasks=llm_config.quantization.calibration_tasks, calibration_limit=llm_config.quantization.calibration_limit, calibration_seq_length=llm_config.quantization.calibration_seq_length, calibration_data=llm_config.quantization.calibration_data, tokenizer_path=llm_config.base.tokenizer_path, - save_exported_program=llm_config.export.export_only, verbose=llm_config.debug.verbose, - metadata=_load_llama_model_metadata( - llm_config.model.use_kv_cache, - llm_config.model.use_sdpa_with_kv_cache, - llm_config.model.enable_dynamic_shape, - # pyre-fixme[6]: For 5th argument expected `ModelArgs` but got - # `Union[Tensor, Module]`. - model.max_seq_len, - # pyre-fixme[6]: For 6th argument expected `ModelArgs` but got - # `Union[Tensor, Module]`. - model.max_context_len, - # pyre-fixme[6]: For 7th argument expected `int` but got `Union[Tensor, - # Module]`. - model.n_layers, - # pyre-fixme[6]: For 8th argument expected `int` but got `Union[Tensor, - # Module]`. - model.vocab_size, - llm_config.base.metadata, - ), + metadata=metadata, ) @@ -1359,6 +1418,9 @@ def _get_source_transforms( # noqa use_torchao_kernels_linear: bool = False, use_torchao_kernels_tied_embedding: bool = False, quantize_with_hqq: bool = True, + use_attention_sink: Optional[str] = None, + params_path: Optional[str] = None, + max_context_len: Optional[int] = None, ) -> List[Callable[[torch.nn.Module], torch.nn.Module]]: """ Return a list of functions that transform a graph. @@ -1470,7 +1532,6 @@ def _get_source_transforms( # noqa transforms.append(materialze_broadcast_of_rope_freq_cis) use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask - if use_sdpa_with_kv_cache: transforms.append(replace_kv_cache_with_custom_kv_cache) # todo: do this optionally @@ -1546,6 +1607,33 @@ def _get_source_transforms( # noqa ) ) + if use_attention_sink: + sink_params = [int(x) for x in use_attention_sink.split(",")] + + # Load ModelArgs for attention sink + if not params_path: + raise ValueError("params_path is required for attention sink") + with open(params_path, "r") as f: + params_dict = json.load(f) + + # Ensure use_kv_cache is propagated from config + params_dict["use_kv_cache"] = True # Attention Sink requires KV Cache + # ModelArgs might expect other fields usually handled by Llama2Model init + # We try to pass minimal set needed for Rope/Attention + + model_args = ModelArgs(**params_dict) + + transforms.append( + partial( + enable_attention_sink, + params=model_args, + sink_size=sink_params[0], + window_size=sink_params[1], + eviction_batch_size=sink_params[2], + max_context_len=max_context_len, + ) + ) + return transforms diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 1ec85936f7a..3be00d78711 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -218,7 +218,22 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): window_size = int(attention_sink_params[1]) eviction_batch_size = int(attention_sink_params[2]) - assert self.llm_config.export.max_context_length == sink_size + window_size + # max_context_length must be >= sink_size + window_size to have enough RoPE frequencies + # A larger max_context_length is allowed (and recommended) to support generation beyond + # the sliding window size. + assert self.llm_config.export.max_context_length >= sink_size + window_size, ( + f"max_context_length ({self.llm_config.export.max_context_length}) must be >= " + f"sink_size + window_size ({sink_size + window_size})" + ) + + # IMPORTANT: For attention sink, we need RoPE frequencies for all possible generation + # positions, not just the cache size. Override the model's max_context_len to use + # a larger value that supports extended generation. + # We use model_args.max_context_len which was set from export.max_context_length + # but for RoPE we need the full generation length capability. + # Use 131072 (128k) as default for Llama 3.2 models or the original model max if larger. + default_rope_length = max(131072, model_args.max_context_len) + model_args.max_context_len = default_rope_length self.model_ = enable_attention_sink( module=self.model_, diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index 22bd8a3e228..0f63b685f7b 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -7,12 +7,21 @@ # Components for supporting Attention Sink. See # https://arxiv.org/abs/2309.17453 for more details about Attention Sink. +# This implementation is torch.export compatible using a ring buffer approach +# for the sliding window portion while preserving the sink tokens. + import types -from typing import Optional +from typing import Optional, Tuple import torch - -from executorch.examples.models.llama.attention import AttentionMHA, KVCache +import torch.nn as nn +from executorch.examples.models.llama.attention import ( + _create_causal_mask_for_ring_buffer, + AttentionMHA, + CachePositionsManager, + KVCache, + RingKVCache, +) from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.rope import ( apply_rotary_emb_to_k, @@ -27,6 +36,13 @@ class RopeWithAttentionSink(Rope): Rope that helps adjust position encoding when tokens are shifted in KVCache. For AttentionSink, when tokens are shifted in KVCache, we need to use positions in KVCache instead of positions in the actual text. + + For torch.export compatibility, this just passes through the position - the + actual position adjustment is handled by the cache update logic. + + Note: This class uses the model's max_context_len (params.max_context_len) for + RoPE frequency table size, which should be large enough to support generation + beyond the sliding window. The actual KV cache size is sink_size + window_size * 2. """ def __init__( @@ -41,28 +57,22 @@ def __init__( self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k else: self.apply_rotary_emb_to_k = apply_rotary_emb_to_k - self.max_context_length = window_size + sink_size - assert self.max_context_length == self.params.max_context_len + # The KV cache size is sink_size + window_size * 2 (ring buffer needs 2x) + self.kv_cache_size = sink_size + window_size * 2 + self.window_size = window_size + self.sink_size = sink_size + # max_context_len from params is used for RoPE frequencies (should be large) + self.max_context_length = self.params.max_context_len self.eviction_batch_size = eviction_batch_size - self.position_shift = 0 def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): + """ + Get rotary embedding frequencies. + For attention sink, we use the original position - the sliding window + is handled by the cache index management, not by position shifting. + """ assert input_pos is not None - - input_pos_item = input_pos.item() - torch._check_is_size(input_pos_item) - if input_pos_item + self.position_shift + seq_len > self.max_context_length: - # There are not enough spaces in the cache to store the new tokens. - # We need to evict some old tokens and shift some recent tokens. - num_to_evict = max( - input_pos_item - + self.position_shift - - self.max_context_length - + seq_len, - self.eviction_batch_size, - ) - self.position_shift -= num_to_evict # pyre-ignore [8] - return super().get_freqs(input_pos + self.position_shift, seq_len) + return super().get_freqs(input_pos, seq_len) def rerotate_k( self, @@ -71,15 +81,8 @@ def rerotate_k( new_position: int, ): """ - Rerotate k from original_position to new_position. This is done by rerotating - k with (new_position * theta - original_position * theta) with the following matrix: - (cos(delta), -sin(delta) - sin(delta), cos(delta)) - where delta = new_position * theta - original_position * theta - - The shape of k is (batch_size, seq_len, n_local_heads, head_dim) - - Based on https://github.com/huggingface/transformers/blame/main/src/transformers/cache_utils.py#L961 + Rerotate k from original_position to new_position. + The shape of k is (batch_size, seq_len, n_local_heads, head_dim) """ seq_len = k.shape[1] original_freqs_cos = self.freqs_cos.narrow(0, original_position, seq_len) @@ -96,15 +99,113 @@ def rerotate_k( return self.apply_rotary_emb_to_k(k, rerotation_cos, rerotation_sin) +def _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len +): + """ + Create causal mask for attention sink. + + Unlike regular ring buffer mask, this mask: + 1. ALWAYS allows attending to sink tokens (positions 0 to sink_size-1) + 2. Uses sliding window for other tokens + + Args: + cache_positions: Tensor of actual positions stored at each cache index + window_size: Size of the sliding window + sink_size: Number of sink tokens to always attend to + start_pos: Starting position of the current query + seq_len: Length of the current query sequence + """ + pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1) + delta = pos_q - cache_positions + + # Valid if position is filled (>= 0) and causal (delta >= 0) + is_valid = (cache_positions >= 0) & (delta >= 0) + + # Sink tokens (original positions 0 to sink_size-1) are always visible + is_sink = cache_positions < sink_size + + # Window tokens must be within sliding window + # Use <= to include the boundary token. For window_size=124, we want to attend + # to the last 124 tokens BEFORE the current position (delta 1 to 124), plus + # position 4 (first non-sink token) which has delta exactly = window_size. + # This ensures sink_size + window_size tokens are visible when cache is full. + is_in_window = delta <= window_size + + # Final mask: valid AND (is_sink OR is_in_window) + attn_mask = is_valid & (is_sink | is_in_window) + attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712 + return attn_mask + + +class CachePositionsManagerWithSink(nn.Module): + """ + Manages cache positions for attention sink + sliding window. + + For sink_size=0: behaves exactly like original CachePositionsManager (simple ring buffer). + For sink_size>0: sink tokens (indices 0 to sink_size-1) are NEVER overwritten. + Ring buffer only cycles through indices sink_size to cache_size-1. + + IMPORTANT: cache_size should be the actual cache dimension size (sink_size + 2*window_size). + """ + + def __init__(self, cache_size: int, sink_size: int = 0): + super().__init__() + self.max_context_length = cache_size + self.sink_size = sink_size + # Ring buffer size = cache_size - sink_size + self.ring_size = cache_size - sink_size + # Initialize to -1 to mark unwritten positions + # The mask uses (cache_positions >= 0) to check if a position is valid + self.register_buffer( + "cache_positions", + torch.full((self.max_context_length,), -1, dtype=torch.long, device="cpu"), + ) + + def calculate_positions_and_update_indices( + self, input_pos: torch.Tensor, seq_len: int + ) -> torch.Tensor: + """ + Calculate indices into k_cache, v_cache for placing k_val, v_val. + + Index calculation: + - Position < sink_size: index = position (sink tokens at fixed indices) + - Position >= sink_size: index = sink_size + (position - sink_size) % ring_size + + This ensures sink tokens (indices 0 to sink_size-1) are NEVER overwritten. + """ + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + + # Original positions for the sequence + orig_positions = torch.arange(seq_len, dtype=torch.long) + start_pos + + if self.sink_size == 0: + # Simple ring buffer: just mod by cache size + indices = orig_positions % self.max_context_length + else: + # Shifted ring buffer: sink tokens at fixed indices, rest in ring buffer + # For position >= sink_size: index = sink_size + (position - sink_size) % ring_size + shifted = orig_positions - self.sink_size + ring_indices = self.sink_size + (shifted % self.ring_size) + # For position < sink_size: use position directly + indices = torch.where(orig_positions < self.sink_size, orig_positions, ring_indices) + + # Update cache_positions to track what position is at each index + # Only update the indices we're writing to + self.cache_positions.index_copy_(0, indices, orig_positions) + + return indices + + class KVCacheWithAttentionSink(KVCache): """ - KV cache that supports attention sink. It keeps the initial few tokens as attention sink. - For other tokens, it uses a sliding window to keep the most recent tokens. + KV cache that supports attention sink with torch.export compatibility. - Parameters: - window_size: the size of the sliding window - sink_size: the number of initial tokens to keep as attention sink - eviction_batch_size: the number of tokens to evict in batch when there is not enough space in the KV cache + Uses a ring buffer approach for the sliding window portion while keeping + the first sink_size tokens fixed. This avoids dynamic shape operations. + + Cache layout: [sink: 0 to sink_size-1] [ring_buffer: sink_size to sink_size + window_size*2 - 1] """ def __init__( @@ -117,11 +218,18 @@ def __init__( sink_size: int, eviction_batch_size: int, max_batch_size: int = 1, + max_context_len: Optional[int] = None, dtype=torch.float32, ): + # Total cache size is max_context_len if provided, else sink_size + window_size * 2 + if max_context_len is None: + total_cache_size = sink_size + window_size * 2 + else: + total_cache_size = max_context_len + super().__init__( max_batch_size=max_batch_size, - max_context_length=window_size + sink_size, + max_context_length=total_cache_size, n_heads=n_heads, head_dim=head_dim, enable_dynamic_shape=enable_dynamic_shape, @@ -131,78 +239,67 @@ def __init__( self.window_size = window_size self.sink_size = sink_size self.eviction_batch_size = eviction_batch_size - self.position_shift = 0 + self.is_ring_buffer = True - def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: - """ - Evict old tokens from the cache to make rooms for new tokens. - - Parameters: - input_pos: the start position of the incoming token in the actual sequence - seq_len: the length of the incoming sequence - rope: the rope object to use for rerotating k + # Cache positions manager for determining write locations + # Pass the total cache size (same as self.max_context_length after super().__init__) + self.cache_positions_manager = CachePositionsManagerWithSink(total_cache_size, sink_size) - Returns: - the number of tokens to evict from the cache which is also the number of - positions to shift for incoming tokens + def create_causal_mask_for_ring_buffer( + self, start_pos: torch.Tensor, seq_len: int + ): """ - input_pos_item = input_pos.item() - torch._check_is_size(input_pos_item) - if input_pos_item + self.position_shift + seq_len > self.max_context_length: - # There are not enough spaces in the cache to store the new tokens. - # We need to evict some old tokens and shift some recent tokens. - num_to_evict = max( - input_pos_item - + self.position_shift - - self.max_context_length - + seq_len, - self.eviction_batch_size, - ) - num_to_keep = ( - input_pos_item + self.position_shift - self.sink_size - num_to_evict - ) - num_empty_space = self.window_size - num_to_keep - dim_to_slice = 2 - k_to_keep = self.k_cache.narrow( - dim_to_slice, - self.sink_size + num_to_evict, # pyre-ignore [6] - num_to_keep, # pyre-ignore [6] + Create causal mask for the attention with attention sink. + Sink tokens are ALWAYS visible, plus recent tokens in the window. + """ + cache_positions = self.cache_positions_manager.cache_positions + if self.sink_size > 0: + # Use attention sink mask that always allows attending to sink tokens + return _create_causal_mask_for_attention_sink( + cache_positions, self.window_size, self.sink_size, start_pos, seq_len ) - k_to_keep = self.rope.rerotate_k( - k=k_to_keep.transpose(1, 2), - original_position=(self.sink_size + num_to_evict), # pyre-ignore [6] - new_position=self.sink_size, - ).transpose(1, 2) - self.k_cache = torch.cat( - [ - self.k_cache.narrow(dim_to_slice, 0, self.sink_size), - k_to_keep, - torch.zeros_like( - self.k_cache.narrow( - dim_to_slice, 0, num_empty_space # pyre-ignore [6] - ) - ), - ], - dim=dim_to_slice, + else: + # Pure ring buffer mode - use original mask with window_size = actual window + return _create_causal_mask_for_ring_buffer( + cache_positions, self.window_size, start_pos, seq_len ) - self.v_cache = torch.cat( - [ - self.v_cache.narrow(dim_to_slice, 0, self.sink_size), - self.v_cache.narrow( - dim_to_slice, - self.sink_size + num_to_evict, # pyre-ignore [6] - num_to_keep, # pyre-ignore [6] - ), - torch.zeros_like( - self.v_cache.narrow( - dim_to_slice, 0, num_empty_space # pyre-ignore [6] - ) - ), - ], - dim=dim_to_slice, + + def update( + self, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + indices: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Update KV cache with new key-value pairs. + Uses ring buffer indexing for positions >= sink_size. + """ + seq_len = k_val.size(2) + assert seq_len <= self.k_cache.size( + 2 + ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" + + if indices is None: + # Calculate write indices + indices = self.cache_positions_manager.calculate_positions_and_update_indices( + input_pos, seq_len ) - self.position_shift -= num_to_evict # pyre-ignore [8] - return self.position_shift + + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + self.k_cache.index_copy_(2, indices, k_val) + self.v_cache.index_copy_(2, indices, v_val) + + return self.k_cache, self.v_cache + + def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: + """ + For ring buffer implementation, no explicit eviction is needed. + The ring buffer automatically overwrites old values. + Returns 0 to indicate no position shift is needed. + """ + return 0 def attention_sink_forward( @@ -210,28 +307,49 @@ def attention_sink_forward( x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, - input_pos: Optional[torch.Tensor] = None, + **kwargs, ): + """ + Forward function for attention with attention sink KV cache. + Uses ring buffer masking for proper attention patterns. + """ assert self.use_kv_cache + + input_pos = kwargs.get("input_pos") assert input_pos is not None + + # Extract cache_indices if provided (injected by Transformer forward) + cache_indices = kwargs.get("cache_indices") bsz, seqlen, _ = x.shape # QKV q, k, v = self.wq(x), self.wk(x), self.wv(x) - # We need view_copy elimination q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - # Prepare for space in KV cache and get position shift - position_shift = self.kv_cache.evict_tokens(input_pos, seqlen) - - # RoPE relative positional embeddings with shifted position in KV cache + # RoPE relative positional embeddings q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) - output = self.SDPA(input_pos + position_shift, q, k, v, bsz, seqlen, self.mask) - return self.wo(output) + # Transpose for attention: [B, H, S, D] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Update KV cache + k, v = self.kv_cache.update(input_pos, k, v, cache_indices) + + # Use ring buffer mask since we have is_ring_buffer=True + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(start_pos, seqlen) + + # SDPA + output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) + + # Return tuple like original AttentionMHA.forward + return self.wo(output), None def _replace_rope( @@ -252,6 +370,7 @@ def _replace_attention( sink_size: int, window_size: int, eviction_batch_size: int, + max_context_len: int, ): for _, child_module in module._modules.items(): if len(list(child_module.children())) > 0: # pyre-ignore [16] @@ -261,10 +380,13 @@ def _replace_attention( sink_size=sink_size, window_size=window_size, eviction_batch_size=eviction_batch_size, + max_context_len=max_context_len, ) if isinstance(child_module, AttentionMHA): kv_cache = child_module.kv_cache + # Always use KVCacheWithAttentionSink, even for sink_size=0 + # This ensures we don't get replaced by CustomKVCache when use_sdpa_with_kv_cache=True kv_cache_with_attention_sink = KVCacheWithAttentionSink( n_heads=kv_cache.n_heads, head_dim=kv_cache.head_dim, @@ -274,10 +396,18 @@ def _replace_attention( window_size=window_size, sink_size=sink_size, eviction_batch_size=eviction_batch_size, + max_context_len=max_context_len, dtype=kv_cache.k_cache.dtype, ) child_module.kv_cache = kv_cache_with_attention_sink - child_module.forward = types.MethodType( # pyre-ignore + + # If using SDPACustom (fused SDPA op), enable attention mask support + # so it uses our ring buffer / attention sink mask instead of simple causal mask + if "SDPACustom" in child_module.SDPA.__class__.__name__: + child_module.SDPA.use_attention_mask = True + + # Replace forward with our custom forward that handles cache_indices + child_module.forward = types.MethodType( attention_sink_forward, child_module ) @@ -288,13 +418,22 @@ def enable_attention_sink( sink_size: int, window_size: int, eviction_batch_size: int, + max_context_len: Optional[int] = None, ) -> torch.nn.Module: """ Transform the model to be able to run inference with Attention Sink. There mainly three steps: - Replace Rope with RopeWithAttentionSink - - Replace Attention's KVCache with KVCacheWithAttentionSink, forward with attention_sink_forward + - Replace Attention's KVCache with KVCacheWithAttentionSink + - Replace Attention's forward with attention_sink_forward """ + if max_context_len is None: + max_context_len = sink_size + window_size * 2 + + # We update params.max_context_len to reflect the actual buffer size + # This ensures export captures the correct cache size in metadata + params.max_context_len = max_context_len + rope_with_attention_sink = RopeWithAttentionSink( params=params, window_size=window_size, @@ -308,5 +447,72 @@ def enable_attention_sink( sink_size=sink_size, window_size=window_size, eviction_batch_size=eviction_batch_size, + max_context_len=max_context_len, ) + + # Add metadata methods for IOManager + def get_sink_size(self): + return sink_size + + def get_window_size(self): + return window_size + + # Bind methods to module + # Note: For torch.export, we might need these to be part of the class or properly bound. + # Monkey patching instance methods works for some cases but let's verify. + # Ideally we subclass or mixin. But here we modify in place. + module.get_sink_size = types.MethodType(get_sink_size, module) + module.get_window_size = types.MethodType(get_window_size, module) + + # Monkey patch Transformer methods to handle cache_indices input + + # 1. New get_example_inputs that includes cache_indices + def get_example_inputs_with_sink(self): + # Create inputs manually to avoid relying on Llama2Model helper methods + # that might not be available on the Transformer module. + + # Use a small sequence length for example + seq_len = 3 + # Use simple tokens + tokens = torch.tensor([[2, 3, 4]], dtype=torch.long) + # Use corresponding input_pos matching seq_len + input_pos = torch.arange(seq_len, dtype=torch.long) + + # cache_indices matches input_pos/tokens length logic + # For export example, we can use simple indices + cache_indices = torch.arange(seq_len, dtype=torch.long) + + # Note: The original generic get_example_inputs usually returns ({tokens}, {input_pos}) + # input_pos shape depends on dynamic shape setting. + # But for export purposes, providing valid tensors is key. + # If dynamic shapes are enabled, input_pos might be expected to be 1D. + + return (tokens, {"input_pos": input_pos}, cache_indices) + + module.get_example_inputs = types.MethodType(get_example_inputs_with_sink, module) + + # 2. New forward that accepts cache_indices and passes it in attn_options + original_forward = module.forward + + def forward_with_sink( + self, + tokens: Optional[torch.LongTensor] = None, + attn_options: Optional[dict] = None, + cache_indices: Optional[torch.LongTensor] = None, + ): + # Allow cache_indices to be passed as positional or kwarg + # Note: top level export might pass inputs positionally if we aren't careful? + # LlamaTransformer.forward has (tokens, attn_options, h) + # We replace it. + + if attn_options is None: + attn_options = {} + + if cache_indices is not None: + attn_options["cache_indices"] = cache_indices + + return original_forward(tokens=tokens, attn_options=attn_options) + + module.forward = types.MethodType(forward_with_sink, module) + return module diff --git a/examples/models/llama/source_transformation/custom_kv_cache.py b/examples/models/llama/source_transformation/custom_kv_cache.py index 8d4d37e0e93..f8a268183b5 100644 --- a/examples/models/llama/source_transformation/custom_kv_cache.py +++ b/examples/models/llama/source_transformation/custom_kv_cache.py @@ -372,7 +372,16 @@ def replace_kv_cache_with_custom_kv_cache(module): def _replace_kv_cache_with_custom_kv_cache(module): for name, child in module.named_children(): - if isinstance(child, KVCache): + # Skip KVCacheWithAttentionSink as it has special evict_tokens logic + # that is not compatible with CustomKVCache. + # Check by class name because the class might come from different module paths + # (e.g., 'examples.models...' vs 'executorch.examples.models...') + child_class_name = type(child).__name__ + if child_class_name == "KVCacheWithAttentionSink": + logging.info(f"Skipping KVCacheWithAttentionSink at {name}") + _replace_kv_cache_with_custom_kv_cache(child) + elif isinstance(child, KVCache): + logging.info(f"Replacing KVCache at {name} (type={child_class_name})") cache_shape = child.k_cache.shape cache_dtype = child.k_cache.dtype max_batch_size, n_heads, max_context_length, head_dim = cache_shape diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index 59823b533a3..2885f9c12ea 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -50,11 +50,16 @@ def forward( v = v.to(dtype=torch.float) if self.use_attention_mask: + # When using attention mask, pass 0 as start_pos since: + # 1. The mask handles all masking logic (including ring buffer / attention sink) + # 2. is_causal=False so start_pos is not used for causal masking + # 3. This avoids validation errors when logical position >= cache size + # (e.g., with ring buffer where position 252 exceeds cache_size 252) output = torch.ops.llama.custom_sdpa( q, k, v, - input_pos[0].item(), + 0, # start_pos: not used when mask is provided mask, # Attention mask 0, # dropout probability. Ignored by the code False, # is_causal diff --git a/examples/models/llama/source_transformation/test_attention_sink.py b/examples/models/llama/source_transformation/test_attention_sink.py deleted file mode 100644 index fc882ebf4ab..00000000000 --- a/examples/models/llama/source_transformation/test_attention_sink.py +++ /dev/null @@ -1,514 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -import torch -from executorch.examples.models.llama.model_args import ModelArgs - -from executorch.examples.models.llama.source_transformation.attention_sink import ( - KVCacheWithAttentionSink, - RopeWithAttentionSink, -) -from parameterized import parameterized - - -class RopeWithAttentionSinkTest(unittest.TestCase): - - def _init_rope(self, params: ModelArgs, eviction_batch_size: int): - return RopeWithAttentionSink( - params=params, - window_size=252, - sink_size=4, - eviction_batch_size=eviction_batch_size, - ) - - def setUp(self): - torch.manual_seed(42) - self.params = ModelArgs( - use_kv_cache=True, enable_dynamic_shape=True, max_context_len=256 - ) - self.rope_with_attention_sink = self._init_rope( - params=self.params, eviction_batch_size=1 - ) - - @parameterized.expand( - [ - [0, 10, 1, 0], # No shift - [250, 10, 1, 246], # Some shift - [256, 10, 1, 246], # All shift - [0, 10, 30, 0], # No shift with batch eviction - [250, 10, 30, 220], # Some shift with batch eviction - [256, 10, 30, 226], # All shift with batch eviction - ] - ) - def test_get_freqs( - self, input_pos, seq_len, eviction_batch_size, expected_result_pos - ): - self.rope_with_attention_sink = self._init_rope( - params=self.params, eviction_batch_size=eviction_batch_size - ) - - freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( - input_pos=torch.tensor([input_pos], dtype=torch.int32), - seq_len=seq_len, - ) - - torch.testing.assert_close( - freqs_cos, - self.rope_with_attention_sink.freqs_cos.narrow( - 0, expected_result_pos, seq_len - ), - ) - torch.testing.assert_close( - freqs_sin, - self.rope_with_attention_sink.freqs_sin.narrow( - 0, expected_result_pos, seq_len - ), - ) - - @parameterized.expand( - [ - [128, 127], # Rotate left - [128, 128], # No rotation - [128, 129], # Rotate right - ] - ) - def test_rotate(self, original_position, new_position): - seq_len = 32 - - size = (1, seq_len, self.params.n_heads, self.params.head_dim) - q = torch.rand(*size, dtype=torch.float32) - k = torch.rand( - *size, - dtype=torch.float32, - ) - freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( - input_pos=torch.tensor([original_position], dtype=torch.int32), - seq_len=seq_len, - ) - _, pre_rotated_k = self.rope_with_attention_sink.forward( - q=q, - k=k, - freqs_cos=freqs_cos, - freqs_sin=freqs_sin, - ) - - rerotated_k = self.rope_with_attention_sink.rerotate_k( - k=pre_rotated_k, - original_position=original_position, - new_position=new_position, - ) - - freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( - input_pos=torch.tensor([new_position], dtype=torch.int32), - seq_len=seq_len, - ) - _, expected_k = self.rope_with_attention_sink.forward( - q=q, - k=k, - freqs_cos=freqs_cos, - freqs_sin=freqs_sin, - ) - - torch.testing.assert_close(rerotated_k, expected_k) - - -class KVCacheWithAttentionSinkTest(unittest.TestCase): - - _single_evict_test_cases = [ - [4, 1], - ] - - _batch_evict_test_cases = [ - [4, 8], - ] - - _sliding_window_test_cases = [ - [0, 1], - ] - - def _init_cache(self, sink_size, eviction_batch_size): - self.params = ModelArgs( - use_kv_cache=True, - enable_dynamic_shape=True, - max_context_len=self.window_size + sink_size, - ) - self.rope_with_attention_sink = RopeWithAttentionSink( - params=self.params, - window_size=self.window_size, - sink_size=sink_size, - eviction_batch_size=eviction_batch_size, - ) - self.kv_cache = KVCacheWithAttentionSink( - n_heads=self.params.n_heads, - head_dim=self.params.head_dim, - enable_dynamic_shape=self.params.enable_dynamic_shape, - rope=self.rope_with_attention_sink, - max_batch_size=self.max_batch_size, - window_size=self.window_size, - sink_size=sink_size, - eviction_batch_size=eviction_batch_size, - dtype=self.dtype, - ) - - def _rand_kv_with_length(self, seq_len): - size = ( - self.max_batch_size, - self.params.n_heads, - seq_len, - self.params.head_dim, - ) - k = torch.rand( - *size, - dtype=self.dtype, - ) - v = torch.rand( - *size, - dtype=self.dtype, - ) - return k, v - - def _zero_kv_with_length(self, seq_len): - size = ( - self.max_batch_size, - self.params.n_heads, - seq_len, - self.params.head_dim, - ) - k = torch.zeros( - *size, - dtype=self.dtype, - ) - v = torch.zeros( - *size, - dtype=self.dtype, - ) - return k, v - - def _get_dim_to_slice(self): - return 2 - - def _get_expected_rotated_k(self, k, original_position, new_position): - return self.rope_with_attention_sink.rerotate_k( - k=k.transpose(1, 2), - original_position=original_position, - new_position=new_position, - ).transpose(1, 2) - - def setUp(self): - torch.manual_seed(42) - self.max_batch_size = 1 - self.window_size = 28 - self.dtype = torch.float32 - - @parameterized.expand( - _single_evict_test_cases + _batch_evict_test_cases + _sliding_window_test_cases - ) - def test_evict_empty_cache(self, sink_size, eviction_batch_size): - self._init_cache(sink_size, eviction_batch_size) - - # KV cache is empty, evict does nothing - input_pos = torch.tensor([0], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 1) == 0 - - expected_k, expected_v = self._zero_kv_with_length(self.window_size + sink_size) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand( - _single_evict_test_cases + _batch_evict_test_cases + _sliding_window_test_cases - ) - def test_evict_without_shift(self, sink_size, eviction_batch_size): - dimension_to_slice = 2 - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has enough spaces for new tokens, no shift - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(10) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([10], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 1) == 0 - - zero_k, zero_v = self._zero_kv_with_length(self.window_size + sink_size - 10) - - expected_k = torch.cat( - [ - k, - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v, - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_single_evict_test_cases) - def test_evict_with_some_shift(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has some spaces for new tokens but not all, shift some tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([10], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 24) == -2 - - zero_k, zero_v = self._zero_kv_with_length(24) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k(k1.narrow(dimension_to_slice, 1, 4), 6, 4), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 1, 4), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_single_evict_test_cases) - def test_evict_with_all_shift(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has no spaces for new tokens, shift all tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(27) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([32], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 6) == -6 - - zero_k, zero_v = self._zero_kv_with_length(6) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 5, 22), 10, 4 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 5, 22), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_sliding_window_test_cases) - def test_evict_with_some_shift_for_sliding_window( - self, sink_size, eviction_batch_size - ): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has some spaces for new tokens but not all, shift some tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([10], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 20) == -2 - - zero_k, zero_v = self._zero_kv_with_length(20) - expected_k = torch.cat( - [ - self._get_expected_rotated_k(k.narrow(dimension_to_slice, 2, 3), 2, 0), - self._get_expected_rotated_k(k1, 5, 3), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 2, 3), - v1, - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_sliding_window_test_cases) - def test_evict_with_all_shift_for_sliding_window( - self, sink_size, eviction_batch_size - ): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has no spaces for new tokens, shift all tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(23) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([28], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 6) == -6 - - zero_k, zero_v = self._zero_kv_with_length(6) - expected_k = torch.cat( - [ - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 1, 22), 6, 0 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v1.narrow(dimension_to_slice, 1, 22), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_batch_evict_test_cases) - def test_batch_evict_with_seq_len(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has some spaces for new tokens but not all, shift some tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(25) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([30], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 12) == -10 - - zero_k, zero_v = self._zero_kv_with_length(12) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 9, 16), 14, 4 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 9, 16), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) - - @parameterized.expand(_batch_evict_test_cases) - def test_batch_evict_with_batch_size(self, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice() - - self._init_cache(sink_size, eviction_batch_size) - - # KV cache has no spaces for new tokens, shift all tokens - input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(5) - - self.kv_cache.update(input_pos, k, v) - - input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(25) - - self.kv_cache.update(input_pos, k1, v1) - - input_pos = torch.tensor([30], dtype=torch.int32) - assert self.kv_cache.evict_tokens(input_pos, 6) == -8 - - zero_k, zero_v = self._zero_kv_with_length(10) - expected_k = torch.cat( - [ - k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k( - k1.narrow(dimension_to_slice, 7, 18), 12, 4 - ), - zero_k, - ], - dim=dimension_to_slice, - ) - expected_v = torch.cat( - [ - v.narrow(dimension_to_slice, 0, sink_size), - v1.narrow(dimension_to_slice, 7, 18), - zero_v, - ], - dim=dimension_to_slice, - ) - - torch.testing.assert_close(self.kv_cache.k_cache, expected_k) - torch.testing.assert_close(self.kv_cache.v_cache, expected_v) diff --git a/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py new file mode 100644 index 00000000000..953033f5cd8 --- /dev/null +++ b/examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py @@ -0,0 +1,644 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for the ring-buffer based attention sink implementation. + +This tests the torch.export-compatible implementation that uses a ring buffer +for the sliding window rather than explicit token eviction. + +Usage: + # Run with pytest + python -m pytest examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py -v + + # Or run directly + python examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py +""" + +import unittest + +import torch +from executorch.examples.models.llama.model_args import ModelArgs + +from executorch.examples.models.llama.source_transformation.attention_sink import ( + CachePositionsManagerWithSink, + KVCacheWithAttentionSink, + RopeWithAttentionSink, + _create_causal_mask_for_attention_sink, +) + + +class CachePositionsManagerWithSinkTest(unittest.TestCase): + """Test the cache positions manager for ring buffer indexing.""" + + def setUp(self): + self.cache_size = 32 # Total cache size (e.g., sink_size + window_size * 2) + # Default: no sink (simple ring buffer) + self.manager = CachePositionsManagerWithSink(self.cache_size, sink_size=0) + + def test_initial_positions_are_minus_one(self): + """Cache positions should start as -1 (unwritten).""" + expected = torch.full((self.cache_size,), -1, dtype=torch.long) + torch.testing.assert_close(self.manager.cache_positions, expected) + + def test_simple_update(self): + """Test simple sequential update without wraparound.""" + input_pos = torch.tensor([0], dtype=torch.long) + seq_len = 5 + indices = self.manager.calculate_positions_and_update_indices(input_pos, seq_len) + + # Should return indices 0, 1, 2, 3, 4 + expected_indices = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long) + torch.testing.assert_close(indices, expected_indices) + + # Cache positions at those indices should be the original positions + for i in range(5): + self.assertEqual(self.manager.cache_positions[i].item(), i) + + def test_wraparound_no_sink(self): + """Test ring buffer wraparound with sink_size=0.""" + # Fill cache to position 30 + input_pos = torch.tensor([0], dtype=torch.long) + self.manager.calculate_positions_and_update_indices(input_pos, 30) + + # Add 5 more tokens at position 30 - should wrap around + input_pos = torch.tensor([30], dtype=torch.long) + indices = self.manager.calculate_positions_and_update_indices(input_pos, 5) + + # Indices should wrap: 30 % 32 = 30, 31, 0, 1, 2 + expected_indices = torch.tensor([30, 31, 0, 1, 2], dtype=torch.long) + torch.testing.assert_close(indices, expected_indices) + + def test_wraparound_with_sink(self): + """Test ring buffer wraparound with sink_size > 0.""" + sink_size = 4 + cache_size = 32 + manager = CachePositionsManagerWithSink(cache_size, sink_size) + + # Fill cache to position 30 + input_pos = torch.tensor([0], dtype=torch.long) + manager.calculate_positions_and_update_indices(input_pos, 30) + + # Add 5 more tokens at position 30 + input_pos = torch.tensor([30], dtype=torch.long) + indices = manager.calculate_positions_and_update_indices(input_pos, 5) + + # Ring size = 32 - 4 = 28 + # pos 30 -> idx = 4 + (30-4)%28 = 4 + 26 = 30 + # pos 31 -> idx = 4 + (31-4)%28 = 4 + 27 = 31 + # pos 32 -> idx = 4 + (32-4)%28 = 4 + 0 = 4 (WRAPS TO SINK_SIZE=4, not 0!) + # pos 33 -> idx = 4 + (33-4)%28 = 4 + 1 = 5 + # pos 34 -> idx = 4 + (34-4)%28 = 4 + 2 = 6 + expected_indices = torch.tensor([30, 31, 4, 5, 6], dtype=torch.long) + torch.testing.assert_close(indices, expected_indices) + + def test_cache_positions_track_original_positions_no_sink(self): + """Cache positions should track which original position is at each index (no sink).""" + # Fill with positions 0-31 + input_pos = torch.tensor([0], dtype=torch.long) + self.manager.calculate_positions_and_update_indices(input_pos, 32) + + # Now add position 32 which wraps to index 0 + input_pos = torch.tensor([32], dtype=torch.long) + self.manager.calculate_positions_and_update_indices(input_pos, 1) + + # Index 0 should now contain original position 32 + self.assertEqual(self.manager.cache_positions[0].item(), 32) + + def test_cache_positions_track_original_positions_with_sink(self): + """Cache positions should track positions, and sink tokens are never overwritten.""" + sink_size = 4 + cache_size = 32 + manager = CachePositionsManagerWithSink(cache_size, sink_size) + + # Fill with positions 0-31 + input_pos = torch.tensor([0], dtype=torch.long) + manager.calculate_positions_and_update_indices(input_pos, 32) + + # Indices 0-3 should have pos 0-3 (Sink tokens) + for i in range(4): + self.assertEqual(manager.cache_positions[i].item(), i) + + # Now add position 32. + # (32-4)%28 = 0. So index = 4 + 0 = 4. + input_pos = torch.tensor([32], dtype=torch.long) + manager.calculate_positions_and_update_indices(input_pos, 1) + + # Index 4 should now contain original position 32 + self.assertEqual(manager.cache_positions[4].item(), 32) + + # Index 0-3 (sink) should STILL contain positions 0-3 (unchanged) + for i in range(4): + self.assertEqual(manager.cache_positions[i].item(), i) + + +class CausalMaskTest(unittest.TestCase): + """Test the causal mask generation for attention sink.""" + + def test_mask_allows_sink_tokens(self): + """Sink tokens should always be visible (mask = 0).""" + cache_size = 32 + sink_size = 4 + window_size = 14 # cache_size = sink_size + window_size * 2 + + # Create cache positions where positions 0-3 are sink tokens + cache_positions = torch.arange(cache_size, dtype=torch.long) + + start_pos = 20 # Query at position 20 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Sink tokens (indices 0-3, original positions 0-3) should have mask = 0 + for i in range(sink_size): + self.assertEqual(mask[0, i].item(), 0.0, f"Sink token at index {i} should be visible") + + def test_mask_blocks_future_tokens(self): + """Future tokens should be masked (-inf).""" + cache_size = 32 + sink_size = 4 + window_size = 14 + + cache_positions = torch.arange(cache_size, dtype=torch.long) + + start_pos = 10 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Future tokens (positions > 10) should have mask = -inf + for i in range(11, cache_size): + self.assertEqual(mask[0, i].item(), float('-inf'), f"Future token at position {i} should be masked") + + def test_mask_respects_window(self): + """Tokens outside the window should be masked.""" + cache_size = 32 + sink_size = 4 + window_size = 5 # Only allow 5 recent tokens + + cache_positions = torch.arange(cache_size, dtype=torch.long) + + start_pos = 20 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Positions 16-20 should be visible (within window of 5) + for pos in range(16, 21): + self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible (in window)") + + # Position 15 should be masked (outside window, not a sink) + self.assertEqual(mask[0, 15].item(), float('-inf'), f"Position 15 should be masked (outside window)") + + +class KVCacheWithAttentionSinkTest(unittest.TestCase): + """Test the KV cache with attention sink.""" + + def setUp(self): + torch.manual_seed(42) + self.window_size = 14 + self.sink_size = 4 + self.n_heads = 8 + self.head_dim = 64 + self.max_batch_size = 1 + + # Create model args with enough context for RoPE + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, # Large enough for RoPE + n_heads=self.n_heads, + n_kv_heads=self.n_heads, + dim=self.n_heads * self.head_dim, + ) + + self.rope = RopeWithAttentionSink( + params=self.params, + window_size=self.window_size, + sink_size=self.sink_size, + eviction_batch_size=1, + ) + + self.kv_cache = KVCacheWithAttentionSink( + n_heads=self.n_heads, + head_dim=self.head_dim, + enable_dynamic_shape=True, + rope=self.rope, + max_batch_size=self.max_batch_size, + window_size=self.window_size, + sink_size=self.sink_size, + eviction_batch_size=1, + dtype=torch.float32, + ) + + def test_cache_size(self): + """Cache should be sink_size + window_size * 2.""" + expected_size = self.sink_size + self.window_size * 2 # 4 + 28 = 32 + self.assertEqual(self.kv_cache.k_cache.size(2), expected_size) + self.assertEqual(self.kv_cache.v_cache.size(2), expected_size) + + def test_is_ring_buffer(self): + """Cache should be marked as ring buffer.""" + self.assertTrue(self.kv_cache.is_ring_buffer) + + def test_update_stores_kv(self): + """Update should store key-value pairs.""" + k = torch.randn(1, self.n_heads, 5, self.head_dim) + v = torch.randn(1, self.n_heads, 5, self.head_dim) + input_pos = torch.tensor([0], dtype=torch.long) + + k_out, v_out = self.kv_cache.update(input_pos, k, v) + + # First 5 positions should contain our values + torch.testing.assert_close(k_out[:, :, :5, :], k) + torch.testing.assert_close(v_out[:, :, :5, :], v) + + def test_evict_tokens_returns_zero(self): + """Ring buffer implementation doesn't shift, so evict returns 0.""" + input_pos = torch.tensor([100], dtype=torch.long) + shift = self.kv_cache.evict_tokens(input_pos, 10) + self.assertEqual(shift, 0) + + def test_extended_generation(self): + """Test that cache works for positions beyond cache size.""" + cache_size = self.kv_cache.k_cache.size(2) + + # Fill cache with initial tokens + for pos in range(cache_size + 50): + k = torch.randn(1, self.n_heads, 1, self.head_dim) + v = torch.randn(1, self.n_heads, 1, self.head_dim) + input_pos = torch.tensor([pos], dtype=torch.long) + + k_out, v_out = self.kv_cache.update(input_pos, k, v) + + # Should not raise any errors + self.assertEqual(k_out.shape, self.kv_cache.k_cache.shape) + self.assertEqual(v_out.shape, self.kv_cache.v_cache.shape) + + +class RopeWithAttentionSinkTest(unittest.TestCase): + """Test RoPE for attention sink.""" + + def setUp(self): + torch.manual_seed(42) + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, + n_heads=8, + dim=512, + ) + + self.rope = RopeWithAttentionSink( + params=self.params, + window_size=100, + sink_size=4, + eviction_batch_size=1, + ) + + def test_get_freqs_uses_original_position(self): + """RoPE frequencies should use the original position.""" + input_pos = torch.tensor([50], dtype=torch.long) + seq_len = 5 + + freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seq_len) + + # Should get frequencies for positions 50-54 + expected_cos = self.rope.freqs_cos[50:55] + expected_sin = self.rope.freqs_sin[50:55] + + torch.testing.assert_close(freqs_cos, expected_cos) + torch.testing.assert_close(freqs_sin, expected_sin) + + def test_rerotate_k(self): + """Test re-rotation of k from one position to another.""" + batch_size = 1 + seq_len = 8 + n_heads = self.params.n_heads + head_dim = self.params.dim // n_heads + + k = torch.randn(batch_size, seq_len, n_heads, head_dim) + q = torch.randn(batch_size, seq_len, n_heads, head_dim) + + # Rotate k at position 100 + original_pos = 100 + freqs_cos, freqs_sin = self.rope.get_freqs( + torch.tensor([original_pos], dtype=torch.long), seq_len + ) + _, rotated_k = self.rope.forward(q, k, freqs_cos, freqs_sin) + + # Re-rotate to position 50 + new_pos = 50 + rerotated_k = self.rope.rerotate_k(rotated_k, original_pos, new_pos) + + # This should be equivalent to directly rotating k at position 50 + freqs_cos_new, freqs_sin_new = self.rope.get_freqs( + torch.tensor([new_pos], dtype=torch.long), seq_len + ) + _, expected_k = self.rope.forward(q, k, freqs_cos_new, freqs_sin_new) + + torch.testing.assert_close(rerotated_k, expected_k, rtol=1e-4, atol=1e-4) + + +class CausalMaskWithWraparoundTest(unittest.TestCase): + """Test causal mask with ring buffer wraparound.""" + + def test_mask_after_wraparound(self): + """Test mask after cache has wrapped around.""" + cache_size = 16 + sink_size = 4 + window_size = 6 # cache_size = sink_size + window_size * 2 + + # Simulate cache after generating beyond cache_size: + # The ring buffer wraps, so indices 0-15 contain positions that wrap + # At position 50, with cache_size=16, the cache contains: + # positions 50-15=35 to 49 at various indices + cache_positions = torch.zeros(cache_size, dtype=torch.long) + # Fill with positions that would exist after generating 50 tokens + # idx = pos % cache_size, so: + # pos 34-49 occupy indices 2-15 and 0-1 + for pos in range(34, 50): + idx = pos % cache_size + cache_positions[idx] = pos + + start_pos = 49 # Query at position 49 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Positions within window (49-6+1=44 to 49) should be visible + visible_count = 0 + for i in range(cache_size): + pos = cache_positions[i].item() + if pos >= 44 and pos <= 49: # In window + self.assertEqual(mask[0, i].item(), 0.0, + f"Position {pos} at idx {i} should be visible (in window)") + visible_count += 1 + + # Should have some visible tokens + self.assertGreater(visible_count, 0, "Should have visible tokens in window") + + def test_mask_with_sink_size_zero(self): + """Test pure sliding window (sink_size=0).""" + cache_size = 16 + sink_size = 0 + window_size = 8 + + cache_positions = torch.arange(cache_size, dtype=torch.long) + start_pos = 10 + seq_len = 1 + + mask = _create_causal_mask_for_attention_sink( + cache_positions, window_size, sink_size, start_pos, seq_len + ) + + # Positions 3-10 should be visible (within window of 8) + for pos in range(3, 11): + self.assertEqual(mask[0, pos].item(), 0.0, f"Position {pos} should be visible") + + # Positions 0-2 should be masked (outside window) + for pos in range(0, 3): + self.assertEqual(mask[0, pos].item(), float('-inf'), + f"Position {pos} should be masked (outside window)") + + +class PrefillTest(unittest.TestCase): + """Test prefill scenarios.""" + + def setUp(self): + torch.manual_seed(42) + self.window_size = 14 + self.sink_size = 4 + self.n_heads = 8 + self.head_dim = 64 + + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, + n_heads=self.n_heads, + n_kv_heads=self.n_heads, + dim=self.n_heads * self.head_dim, + ) + + self.rope = RopeWithAttentionSink( + params=self.params, + window_size=self.window_size, + sink_size=self.sink_size, + eviction_batch_size=1, + ) + + self.kv_cache = KVCacheWithAttentionSink( + n_heads=self.n_heads, + head_dim=self.head_dim, + enable_dynamic_shape=True, + rope=self.rope, + max_batch_size=1, + window_size=self.window_size, + sink_size=self.sink_size, + eviction_batch_size=1, + dtype=torch.float32, + ) + + def test_prefill_entire_cache(self): + """Test prefill that fills entire cache.""" + cache_size = self.kv_cache.k_cache.size(2) + + k = torch.randn(1, self.n_heads, cache_size, self.head_dim) + v = torch.randn(1, self.n_heads, cache_size, self.head_dim) + input_pos = torch.tensor([0], dtype=torch.long) + + k_out, v_out = self.kv_cache.update(input_pos, k, v) + + # All positions should be filled + torch.testing.assert_close(k_out, k) + torch.testing.assert_close(v_out, v) + + def test_prefill_larger_than_cache_raises_error(self): + """Test that prefill larger than cache size raises an assertion error.""" + cache_size = self.kv_cache.k_cache.size(2) + seq_len = cache_size + 10 + + k = torch.randn(1, self.n_heads, seq_len, self.head_dim) + v = torch.randn(1, self.n_heads, seq_len, self.head_dim) + input_pos = torch.tensor([0], dtype=torch.long) + + # This should raise an assertion error since seq_len > cache_size + with self.assertRaises(AssertionError): + self.kv_cache.update(input_pos, k, v) + + def test_prefill_followed_by_decode(self): + """Test prefill followed by decode steps.""" + cache_size = self.kv_cache.k_cache.size(2) + + # Prefill with 20 tokens + k_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) + v_prefill = torch.randn(1, self.n_heads, 20, self.head_dim) + input_pos = torch.tensor([0], dtype=torch.long) + self.kv_cache.update(input_pos, k_prefill, v_prefill) + + # Decode 5 more tokens + for i in range(5): + k_decode = torch.randn(1, self.n_heads, 1, self.head_dim) + v_decode = torch.randn(1, self.n_heads, 1, self.head_dim) + input_pos = torch.tensor([20 + i], dtype=torch.long) + k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode) + + # Verify cache positions are updated + expected_pos = 20 + i + cache_idx = expected_pos % cache_size + self.assertEqual( + self.kv_cache.cache_positions_manager.cache_positions[cache_idx].item(), + expected_pos + ) + + +class EnableAttentionSinkTest(unittest.TestCase): + """Test the enable_attention_sink transformation.""" + + def setUp(self): + torch.manual_seed(42) + self.params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, + n_heads=8, + n_kv_heads=8, + dim=512, + n_layers=2, + vocab_size=100, + ) + + def test_enable_attention_sink_transforms_model(self): + """Test that enable_attention_sink properly transforms the model.""" + from executorch.examples.models.llama.llama_transformer import construct_transformer + from executorch.examples.models.llama.source_transformation.attention_sink import ( + enable_attention_sink, + ) + + # Create a simple transformer + with torch.device("meta"): + model = construct_transformer(self.params) + model.to_empty(device="cpu") + + # Apply attention sink transformation + model = enable_attention_sink( + module=model, + params=self.params, + sink_size=4, + window_size=100, + eviction_batch_size=1, + ) + + # Check that KV caches are replaced + for layer in model.layers: + kv_cache = layer.attention.kv_cache + self.assertIsInstance(kv_cache, KVCacheWithAttentionSink) + self.assertEqual(kv_cache.sink_size, 4) + self.assertEqual(kv_cache.window_size, 100) + self.assertTrue(kv_cache.is_ring_buffer) + + def test_enable_attention_sink_replaces_rope(self): + """Test that RoPE is replaced with RopeWithAttentionSink.""" + from executorch.examples.models.llama.llama_transformer import construct_transformer + from executorch.examples.models.llama.source_transformation.attention_sink import ( + enable_attention_sink, + ) + + with torch.device("meta"): + model = construct_transformer(self.params) + model.to_empty(device="cpu") + + model = enable_attention_sink( + module=model, + params=self.params, + sink_size=4, + window_size=100, + eviction_batch_size=1, + ) + + # Check that rope is replaced + for layer in model.layers: + rope = layer.attention.rope + self.assertIsInstance(rope, RopeWithAttentionSink) + + +class IntegrationTest(unittest.TestCase): + """Integration tests for end-to-end scenarios.""" + + def setUp(self): + torch.manual_seed(42) + + def test_cache_positions_consistency(self): + """Test that cache positions remain consistent during generation.""" + cache_size = 32 + sink_size = 4 + window_size = 14 + n_heads = 8 + head_dim = 64 + + params = ModelArgs( + use_kv_cache=True, + enable_dynamic_shape=True, + max_context_len=1024, + n_heads=n_heads, + n_kv_heads=n_heads, + dim=n_heads * head_dim, + ) + + rope = RopeWithAttentionSink( + params=params, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=1, + ) + + kv_cache = KVCacheWithAttentionSink( + n_heads=n_heads, + head_dim=head_dim, + enable_dynamic_shape=True, + rope=rope, + max_batch_size=1, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=1, + dtype=torch.float32, + ) + + # Generate 100 tokens + for pos in range(100): + k = torch.randn(1, n_heads, 1, head_dim) + v = torch.randn(1, n_heads, 1, head_dim) + input_pos = torch.tensor([pos], dtype=torch.long) + + kv_cache.update(input_pos, k, v) + + # Create mask and verify it's valid + mask = kv_cache.create_causal_mask_for_ring_buffer(pos, 1) + + # Mask should not be all -inf (would mean no tokens to attend to) + non_inf_count = (mask != float('-inf')).sum().item() + self.assertGreater(non_inf_count, 0, f"At pos {pos}, mask should have visible tokens") + + # For positions >= sink_size, sinks should always be visible + if pos >= sink_size: + for i in range(sink_size): + cache_pos = kv_cache.cache_positions_manager.cache_positions[i].item() + if cache_pos < sink_size: # This is actually a sink token + self.assertEqual(mask[0, i].item(), 0.0, + f"Sink at idx {i} should be visible at pos {pos}") + + +if __name__ == '__main__': + unittest.main() diff --git a/extension/llm/runner/constants.h b/extension/llm/runner/constants.h index d7b36077757..5e06b05cb40 100644 --- a/extension/llm/runner/constants.h +++ b/extension/llm/runner/constants.h @@ -19,6 +19,11 @@ inline constexpr auto kVocabSize = "get_vocab_size"; inline constexpr auto kUseKVCache = "use_kv_cache"; inline constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache"; +// Attention sink configuration metadata keys +inline constexpr auto kUseAttentionSink = "use_attention_sink"; +inline constexpr auto kAttentionSinkSize = "attention_sink_size"; +inline constexpr auto kAttentionSinkWindowSize = "attention_sink_window_size"; + // Multimodal method name conventions inline constexpr auto kVisionEncoderMethod = "vision_encoder"; inline constexpr auto kAudioEncoderMethod = "audio_encoder"; diff --git a/extension/llm/runner/io_manager/attention_sink_io_manager.cpp b/extension/llm/runner/io_manager/attention_sink_io_manager.cpp new file mode 100644 index 00000000000..30d84887c7a --- /dev/null +++ b/extension/llm/runner/io_manager/attention_sink_io_manager.cpp @@ -0,0 +1,182 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace executorch { +namespace extension { +namespace llm { + +AttentionSinkIOManager::AttentionSinkIOManager( + ET_MODULE_NAMESPACE::Module& module, + int64_t max_context_len, + AttentionSinkConfig config) + : IOManager(module), + max_context_len_(max_context_len), + config_(config), + logical_pos_(0) { + ET_CHECK_MSG( + config_.sink_size >= 0, + "sink_size must be non-negative, got %" PRId64, + config_.sink_size); + ET_CHECK_MSG( + config_.window_size > 0, + "window_size must be positive, got %" PRId64, + config_.window_size); +} + +runtime::Error AttentionSinkIOManager::load( + const std::string& prefill_method, + const std::string& decode_method) { + (void)prefill_method; + (void)decode_method; + + ET_LOG( + Info, + "AttentionSinkIOManager loaded: sink_size=%" PRId64 + ", window_size=%" PRId64 ", max_context_len=%" PRId64, + config_.sink_size, + config_.window_size, + max_context_len_); + + return runtime::Error::Ok; +} + +runtime::Error AttentionSinkIOManager::reset( + const std::string& prefill_method, + const std::string& decode_method) { + (void)prefill_method; + (void)decode_method; + + logical_pos_ = 0; + + ET_LOG(Debug, "AttentionSinkIOManager reset"); + return runtime::Error::Ok; +} + +runtime::Result> +AttentionSinkIOManager::prepare_prefill( + const TensorPtr& input, + const TensorPtr& start_pos, + const std::string& prefill_method) { + int64_t logical_start = start_pos->data_ptr()[0]; + int64_t seq_len = input->numel(); + + logical_pos_ = logical_start + seq_len; + + ET_LOG( + Debug, + "AttentionSinkIOManager::prepare_prefill: logical_start=%" PRId64 + ", seq_len=%" PRId64 ", logical_pos_after=%" PRId64 + ", cache_full=%s", + logical_start, + seq_len, + logical_pos_, + is_cache_full() ? "true" : "false"); + + // Check if we need to provide cache_indices (3rd input) + auto method_meta = module_.method_meta(prefill_method); + if (method_meta.ok() && method_meta->num_inputs() == 3) { + update_indices_tensor(logical_start, seq_len); + return std::vector{input, start_pos, *indices_tensor_}; + } + + // Pass through to model as-is. + return std::vector{input, start_pos}; +} + +runtime::Result> +AttentionSinkIOManager::prepare_decode( + const TensorPtr& input, + const TensorPtr& start_pos, + const std::string& decode_method) { + int64_t logical_start = start_pos->data_ptr()[0]; + int64_t seq_len = input->numel(); + + logical_pos_ = logical_start + seq_len; + + ET_LOG( + Debug, + "AttentionSinkIOManager::prepare_decode: logical_start=%" PRId64 + ", logical_pos_after=%" PRId64 + ", cache_full=%s", + logical_start, + logical_pos_, + is_cache_full() ? "true" : "false"); + + // Check if we need to provide cache_indices (3rd input) + auto method_meta = module_.method_meta(decode_method); + if (method_meta.ok() && method_meta->num_inputs() == 3) { + update_indices_tensor(logical_start, seq_len); + return std::vector{input, start_pos, *indices_tensor_}; + } + + // Pass through to model as-is. + return std::vector{input, start_pos}; +} + +void AttentionSinkIOManager::update_indices_tensor( + int64_t logical_start, + int64_t seq_len) { + int64_t ring_size = max_context_len_ - config_.sink_size; + ET_CHECK_MSG(ring_size > 0, "ring_size must be positive, got %" PRId64, ring_size); + ET_CHECK_MSG( + ring_size >= config_.window_size, + "ring_size (%" PRId64 ") must be >= window_size (%" PRId64 ")", + ring_size, + config_.window_size); + indices_buffer_.resize(seq_len); + for (int64_t i = 0; i < seq_len; ++i) { + int64_t pos = logical_start + i; + if (pos < config_.sink_size) { + indices_buffer_[i] = pos; + } else { + indices_buffer_[i] = + config_.sink_size + (pos - config_.sink_size) % ring_size; + } + } + + // Wrap in tensor + if (!indices_tensor_impl_ || indices_tensor_impl_->size(0) != seq_len) { + sizes_vec_ = {static_cast(seq_len)}; + dim_order_vec_ = {0}; + strides_vec_ = {1}; + + indices_tensor_impl_ = std::make_unique( + exec_aten::ScalarType::Long, + 1, + sizes_vec_.data(), + static_cast(indices_buffer_.data()), + dim_order_vec_.data(), + strides_vec_.data(), + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND); + indices_tensor_ = + std::make_unique(indices_tensor_impl_.get()); + } else { + // Update logic if buffer moved (vector resize might reallocate) + // Just re-create to be safe as data ptr is used + sizes_vec_ = {static_cast(seq_len)}; + dim_order_vec_ = {0}; + strides_vec_ = {1}; + + indices_tensor_impl_ = std::make_unique( + exec_aten::ScalarType::Long, + 1, + sizes_vec_.data(), + static_cast(indices_buffer_.data()), + dim_order_vec_.data(), + strides_vec_.data(), + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND); + indices_tensor_ = + std::make_unique(indices_tensor_impl_.get()); + } +} + +} // namespace llm +} // namespace extension +} // namespace executorch diff --git a/extension/llm/runner/io_manager/attention_sink_io_manager.h b/extension/llm/runner/io_manager/attention_sink_io_manager.h new file mode 100644 index 00000000000..a8bafa6feb4 --- /dev/null +++ b/extension/llm/runner/io_manager/attention_sink_io_manager.h @@ -0,0 +1,170 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace executorch { +namespace extension { +namespace llm { +namespace exec_aten = ::executorch::aten; + +/** + * @brief Configuration for attention sink behavior. + */ +struct AttentionSinkConfig { + /// Number of initial tokens to always keep (sink tokens). + int64_t sink_size = 4; + + /// Size of the sliding window for non-sink tokens. + int64_t window_size = 508; + + /// When the cache is full, evict this many tokens at once. + int64_t eviction_batch_size = 256; +}; + +/** + * @brief IOManager that supports attention sink models for infinite context. + * + * This IOManager is designed to work with models that have been exported with + * attention sink support (KVCacheWithAttentionSink). The model internally + * manages: + * - Cache write indices (sink tokens at fixed positions, rest in ring buffer) + * - Attention mask creation (sink tokens always visible + sliding window) + * - Position-based RoPE embeddings + * + * The IOManager's role is to: + * 1. Pass through input_pos as-is (the model handles position mapping) + * 2. Track logical position for runner bookkeeping + * 3. Allow generation to continue past max_cache_size without errors + * + * This works with models exported using: + * - KVCacheWithAttentionSink (model-side attention sink) + * - RingKVCache (standard ring buffer - provides sliding window without sink) + * + * Cache layout (managed by model, not runner): + * [sink tokens: 0 to sink_size-1] [ring buffer: sink_size to cache_size-1] + * + * Usage: + * 1. Export model with attention sink config (use_ring_kv_cache=True, etc.) + * 2. Runner detects attention sink metadata and creates this IOManager + * 3. IOManager passes positions through; model handles cache management + */ +class ET_EXPERIMENTAL AttentionSinkIOManager : public IOManager { + public: + /** + * @brief Construct an AttentionSinkIOManager. + * + * @param module The Module used for querying method metadata and execution. + * @param max_cache_size The maximum size of the KV cache in the model. + * @param config Configuration for attention sink behavior. + */ + AttentionSinkIOManager( + ET_MODULE_NAMESPACE::Module& module, + int64_t max_context_len, + AttentionSinkConfig config = AttentionSinkConfig()); + + /** + * @brief Load the IO manager with method metadata. + */ + ET_NODISCARD runtime::Error load( + const std::string& prefill_method, + const std::string& decode_method) override; + + /** + * @brief Reset the IO manager state. + * + * Resets the logical position counter. + */ + ET_NODISCARD runtime::Error reset( + const std::string& prefill_method, + const std::string& decode_method) override; + + /** + * @brief Prepare inputs for the prefill phase. + * + * Passes through input_pos to the model. The model's internal + * KVCacheWithAttentionSink handles position-to-index mapping and masking. + */ + runtime::Result> prepare_prefill( + const TensorPtr& input, + const TensorPtr& start_pos, + const std::string& prefill_method) override; + + /** + * @brief Prepare inputs for the decode phase. + * + * Passes through input_pos to the model. The model's internal + * KVCacheWithAttentionSink handles position-to-index mapping and masking. + */ + runtime::Result> prepare_decode( + const TensorPtr& input, + const TensorPtr& start_pos, + const std::string& decode_method) override; + + /** + * @brief Get the current logical position. + * + * This is the position in the full context, which may exceed the cache size. + * The model handles wrapping internally via ring buffer. + */ + int64_t logical_position() const { + return logical_pos_; + } + + /** + * @brief Get the attention sink configuration. + */ + const AttentionSinkConfig& config() const { + return config_; + } + + /** + * @brief Check if the cache is in the "infinite context" regime. + * + * Returns true when the logical position exceeds the effective cache + * capacity, meaning the ring buffer has wrapped and old tokens are being + * overwritten. + */ + bool is_cache_full() const { + return logical_pos_ >= max_context_len_; + } + + private: + /// Maximum size of the KV cache in the model + int64_t max_context_len_; + + /// Attention sink configuration + AttentionSinkConfig config_; + + /// Current logical position (may exceed max_cache_size) + int64_t logical_pos_ = 0; + + /** + * @brief Update the internal indices buffer and tensor for a given position and length. + */ + void update_indices_tensor(int64_t logical_start, int64_t seq_len); + + // Buffer for cache indices + std::vector indices_buffer_; + + // Tensor wrapper for indices + std::unique_ptr indices_tensor_impl_; + std::unique_ptr indices_tensor_; + + // Metadata storage for TensorImpl + std::vector sizes_vec_; + std::vector dim_order_vec_; + std::vector strides_vec_; +}; + +} // namespace llm +} // namespace extension +} // namespace executorch diff --git a/extension/llm/runner/io_manager/targets.bzl b/extension/llm/runner/io_manager/targets.bzl index fe7fe9d56ae..4ea579423fd 100644 --- a/extension/llm/runner/io_manager/targets.bzl +++ b/extension/llm/runner/io_manager/targets.bzl @@ -17,3 +17,20 @@ def define_common_targets(): ], visibility = ["PUBLIC"], ) + + # Attention Sink IOManager for runner-side infinite context + runtime.cxx_library( + name = "attention_sink_io_manager" + aten_suffix, + srcs = [ + "attention_sink_io_manager.cpp", + ], + exported_headers = [ + "attention_sink_io_manager.h", + ], + exported_deps = [ + ":io_manager" + aten_suffix, + "//executorch/extension/tensor:tensor" + aten_suffix, + "//executorch/extension/module:module" + aten_suffix, + ], + visibility = ["PUBLIC"], + ) diff --git a/extension/llm/runner/llm_runner_helper.cpp b/extension/llm/runner/llm_runner_helper.cpp index 13f8d7a9db5..7ce2c4f7caa 100644 --- a/extension/llm/runner/llm_runner_helper.cpp +++ b/extension/llm/runner/llm_runner_helper.cpp @@ -9,6 +9,7 @@ // Implementation of helper utilities for creating and configuring LLM runners #include +#include #include #include #include @@ -233,8 +234,61 @@ std::unique_ptr create_text_llm_runner( auto eos_ids = std::make_unique>( llm::get_eos_ids(tokenizer.get(), module.get())); - // Create IOManager - std::unique_ptr io_manager = std::make_unique(*module); + // Create IOManager - use AttentionSinkIOManager if attention sink is enabled + std::unique_ptr io_manager; + + // Get method names to check for attention sink metadata + auto method_names_result = module->method_names(); + if (method_names_result.error() != Error::Ok) { + ET_LOG(Error, "Failed reading method names for IOManager selection"); + return nullptr; + } + const auto& method_names = method_names_result.get(); + + // Check if attention sink is enabled via metadata + bool use_attention_sink = false; + int64_t sink_size = 4; // Default values + int64_t window_size = -1; + + if (method_names.count(kUseAttentionSink)) { + auto get_result = module->get(kUseAttentionSink); + use_attention_sink = get_result.get().toScalar().to(); + } + + if (use_attention_sink) { + // Get attention sink configuration from metadata + if (method_names.count(kAttentionSinkSize)) { + auto get_result = module->get(kAttentionSinkSize); + sink_size = get_result.get().toScalar().to(); + } + if (method_names.count(kAttentionSinkWindowSize)) { + auto get_result = module->get(kAttentionSinkWindowSize); + window_size = get_result.get().toScalar().to(); + } + + int64_t max_cache_size = metadata.at(kMaxContextLen); + + // If window_size is not found in metadata, calculate from max_context_len + if (window_size == -1) { + window_size = max_cache_size - sink_size; + } + + AttentionSinkConfig config; + config.sink_size = sink_size; + config.window_size = window_size; + ET_LOG( + Info, + "Creating AttentionSinkIOManager with sink_size=%" PRId64 + ", window_size=%" PRId64 ", max_cache_size=%" PRId64, + sink_size, + window_size, + max_cache_size); + + io_manager = std::make_unique( + *module, max_cache_size, config); + } else { + io_manager = std::make_unique(*module); + } // Create text_decoder_runner. Use a shared_ptr so that it can be shared with // TextPrefiller and TextTokenGenerator diff --git a/extension/llm/runner/targets.bzl b/extension/llm/runner/targets.bzl index 2c9000d0137..8150647a733 100644 --- a/extension/llm/runner/targets.bzl +++ b/extension/llm/runner/targets.bzl @@ -132,6 +132,7 @@ def define_common_targets(): ":text_prefiller" + aten_suffix, ":text_token_generator" + aten_suffix, "//executorch/extension/llm/runner/io_manager:io_manager" + aten_suffix, + "//executorch/extension/llm/runner/io_manager:attention_sink_io_manager" + aten_suffix, "//pytorch/tokenizers:hf_tokenizer", "//pytorch/tokenizers:llama2c_tokenizer", "//pytorch/tokenizers:sentencepiece", diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index 373f57d3a8e..82e80b0a0ce 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -129,22 +129,27 @@ Error TextLLMRunner::generate( std::vector prompt_tokens = encode_res.get(); int num_prompt_tokens = prompt_tokens.size(); - // Reduce max_context_len by pos_ - int64_t max_context_len = metadata_.at(kMaxContextLen) - pos_; + // Get max_seq_len for single prefill chunk limit + int64_t max_seq_len = metadata_.at(kMaxSeqLen); + int64_t max_context_len = metadata_.at(kMaxContextLen); + ET_CHECK_OR_RETURN_ERROR( num_prompt_tokens >= 1, InvalidArgument, "Expected at least 1 prompt token"); + ET_CHECK_OR_RETURN_ERROR( - num_prompt_tokens < max_context_len, + num_prompt_tokens <= max_seq_len, InvalidArgument, - "num_prompt_tokens %d >= max_context_len %" PRId64 - ", Max seq length exceeded - please increase max seq len value in your export script", + "num_prompt_tokens %d > max_seq_len %" PRId64 + ", Single prefill chunk too large", num_prompt_tokens, - max_context_len); + max_seq_len); - // Determine max_new_tokens using the GenerationConfig's resolve method, - // then subtract pos_ for max_new_tokens. + // Determine max_new_tokens from GenerationConfig. + // For ring buffer / attention sink models, the model handles position + // wrapping internally so generation can continue past max_context_len. + // If user specified seq_len explicitly, use that as the overall token limit. int max_new_tokens = config.resolve_max_new_tokens(max_context_len, num_prompt_tokens); diff --git a/test_attention_sink.md b/test_attention_sink.md new file mode 100644 index 00000000000..0e70723cc20 --- /dev/null +++ b/test_attention_sink.md @@ -0,0 +1,37 @@ +# Build the runtime + +Do this after modifying runtime code (cpp) +```sh +cmake --workflow llm-debug +pushd examples/models/llama +cmake --workflow --preset llama-debug +popd +``` + +# Export model +Take a look at examples/models/llama/README.md + +Check point is in ~/executorch/ + +Make sure you are in conda executorch env + +# No quantization +# Set these paths to point to the downloaded files +LLAMA_CHECKPOINT=path/to/consolidated.00.pth +LLAMA_PARAMS=path/to/params.json + +python -m extension.llm.export.export_llm \ + --config examples/models/llama/config/llama_bf16.yaml \ + +base.model_class="llama3_2" \ + +base.checkpoint="consolidated.00.pth" \ + +base.params="params.json" +``` + +# Run + +Please also take a look at examples/models/llama/runner to make sure it can emit many tokens, exceeding context size. + +Please check whether the output makes sense or not +``` +cmake-out/examples/models/llama/llama_main --model_path= --tokenizer_path= --prompt= +```