Conversation
Co-authored-by: Claude <noreply@anthropic.com>
This reverts commit 075d521.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17271
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit 38703b1 with merge base aa7c8ce ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@larryliu0820 I think it's not totally adhd after 1000+ tokens |
There was a problem hiding this comment.
Pull request overview
This PR adds a ring-buffer–based Attention Sink implementation for Llama exports and attempts to wire up runner-side support so generation can continue beyond the KV-cache capacity while maintaining correct masking/positioning.
Changes:
- Implement ring-buffer Attention Sink KV-cache + masking in Llama source transformation and update attention execution to use dynamic masks for ring-buffer caches.
- Add a new runner IOManager (
AttentionSinkIOManager) and plumb it into runner build targets, plus update prompt-length handling in the text runner. - Add/export updates and new configs/tests for Attention Sink (including mutable buffer initialization for
cache_positions).
Reviewed changes
Copilot reviewed 22 out of 22 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
test_attention_sink.md |
Adds local build/export/run notes for attention sink testing. |
extension/llm/runner/text_llm_runner.cpp |
Changes prompt validation to use max_seq_len and adjusts max-new-token resolution behavior. |
extension/llm/runner/targets.bzl |
Links runner to the new attention-sink IOManager target. |
extension/llm/runner/llm_runner_helper.cpp |
Selects AttentionSinkIOManager based on module metadata keys. |
extension/llm/runner/io_manager/targets.bzl |
Adds Bazel target for attention_sink_io_manager. |
extension/llm/runner/io_manager/attention_sink_io_manager.h |
Introduces IOManager subclass for attention-sink / “infinite context” bookkeeping. |
extension/llm/runner/io_manager/attention_sink_io_manager.cpp |
Implements pass-through prefill/decode input preparation while tracking logical position. |
extension/llm/runner/constants.h |
Adds metadata keys for attention sink configuration. |
examples/models/llama/source_transformation/test_attention_sink_ring_buffer.py |
Adds comprehensive ring-buffer attention sink unit tests. |
examples/models/llama/source_transformation/test_attention_sink.py |
Removes the previous attention sink test suite. |
examples/models/llama/source_transformation/sdpa.py |
Forces custom_sdpa to use start_pos=0 when an attention mask is provided. |
examples/models/llama/source_transformation/custom_kv_cache.py |
Prevents custom KV-cache replacement from clobbering KVCacheWithAttentionSink. |
examples/models/llama/source_transformation/attention_sink.py |
Reworks attention sink to a torch.export-compatible ring-buffer approach (masking, cache position tracking, cache update). |
examples/models/llama/model.py |
Loosens constraints and increases RoPE table length for attention sink generation. |
examples/models/llama/export_llama_lib.py |
Enables attention masks for custom SDPA when attention sink is enabled; initializes cache_positions mutable buffer. |
examples/models/llama/eval_llama_lib.py |
Updates attention-sink eval assumptions for ring-buffer/cache-size vs RoPE-length. |
examples/models/llama/config/*.yaml |
Adds new attention sink configs (runner-side + xnnpack/noxnn variants). |
examples/models/llama/attention.py |
Uses dynamic ring-buffer masking when KV cache exposes is_ring_buffer. |
examples/models/llama/BUCK |
Registers the new attention sink ring-buffer test. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ) | ||
| or bool(llm_config.model.use_attention_sink), | ||
| use_sdpa_with_kv_cache=llm_config.model.use_sdpa_with_kv_cache, | ||
| quantize_kv_cache=llm_config.model.quantize_kv_cache, |
There was a problem hiding this comment.
This block enables attention-mask mode for custom SDPA when attention sink is on, but attention-sink configs also enable KV-cache quantization in some cases. In the quantized path, SDPACustom is replaced by QuantizedSDPA, which (per sdpa.py) still doesn’t support attention masks and continues to use start_pos for causal masking. That means attention sink + quantized KV cache may still hit the same start_pos >= cache_size validation failure. Consider propagating the use_attention_mask flag into QuantizedSDPA and making its mask path ignore start_pos similarly (or preventing this combination).
| quantize_kv_cache=llm_config.model.quantize_kv_cache, | |
| # Quantized KV cache currently does not support attention-mask-based | |
| # custom SDPA (QuantizedSDPA still relies on start_pos for masking). | |
| # To avoid incompatible behavior, disable KV-cache quantization when | |
| # attention sink is enabled. | |
| quantize_kv_cache=( | |
| False | |
| if getattr(llm_config.model, "use_attention_sink", False) | |
| else llm_config.model.quantize_kv_cache | |
| ), |
| # Export model | ||
| Take a look at examples/models/llama/README.md | ||
|
|
||
| Check point is in ~/executorch/ |
There was a problem hiding this comment.
Spelling: “Check point” should be “Checkpoint”.
| Check point is in ~/executorch/ | |
| Checkpoint is in ~/executorch/ |
| // Get max_seq_len for single prefill chunk limit | ||
| int64_t max_seq_len = metadata_.at(kMaxSeqLen); | ||
| int64_t max_context_len = metadata_.at(kMaxContextLen); | ||
|
|
||
| ET_CHECK_OR_RETURN_ERROR( | ||
| num_prompt_tokens >= 1, | ||
| InvalidArgument, | ||
| "Expected at least 1 prompt token"); | ||
|
|
||
| ET_CHECK_OR_RETURN_ERROR( | ||
| num_prompt_tokens < max_context_len, | ||
| num_prompt_tokens <= max_seq_len, | ||
| InvalidArgument, | ||
| "num_prompt_tokens %d >= max_context_len %" PRId64 | ||
| ", Max seq length exceeded - please increase max seq len value in your export script", | ||
| "num_prompt_tokens %d > max_seq_len %" PRId64 | ||
| ", Single prefill chunk too large", | ||
| num_prompt_tokens, | ||
| max_context_len); | ||
| max_seq_len); |
There was a problem hiding this comment.
The new prompt-length validation only checks num_prompt_tokens <= max_seq_len and no longer accounts for pos_ / remaining context. If generate() is called multiple times without reset(), pos_ can be non-zero and text_prefiller_->prefill(prompt_tokens, pos_) may exceed the model’s usable context for non-ring-buffer models. Consider keeping the remaining-context check for non-ring-buffer models (e.g., num_prompt_tokens < (max_context_len - pos_)) while retaining the per-prefill-chunk max_seq_len check for ring/attention-sink exports.
| int max_new_tokens = | ||
| config.resolve_max_new_tokens(max_context_len, num_prompt_tokens); |
There was a problem hiding this comment.
resolve_max_new_tokens(max_context_len, num_prompt_tokens) currently ignores the existing pos_, so for non-ring-buffer models it can allow generating past the remaining context budget when pos_ > 0. Previously this was handled by reducing max_context_len by pos_. Consider passing max_context_len - pos_ for non-ring-buffer models (or extending resolve_max_new_tokens to take start_pos).
| int max_new_tokens = | |
| config.resolve_max_new_tokens(max_context_len, num_prompt_tokens); | |
| // For non-ring-buffer models starting at a non-zero position, only the | |
| // remaining context window (max_context_len - pos_) is available. | |
| int64_t remaining_context_len = max_context_len - pos_; | |
| if (remaining_context_len < 0) { | |
| remaining_context_len = 0; | |
| } | |
| int max_new_tokens = | |
| config.resolve_max_new_tokens(remaining_context_len, num_prompt_tokens); |
| // Get method names to check for attention sink metadata | ||
| auto method_names_result = module->method_names(); | ||
| if (method_names_result.error() != Error::Ok) { | ||
| ET_LOG(Error, "Failed reading method names for IOManager selection"); | ||
| return nullptr; | ||
| } | ||
| const auto& method_names = method_names_result.get(); | ||
|
|
||
| // Check if attention sink is enabled via metadata | ||
| bool use_attention_sink = false; | ||
| int64_t sink_size = 4; // Default values | ||
| int64_t window_size = 124; | ||
|
|
||
| if (method_names.count(kUseAttentionSink)) { | ||
| auto get_result = module->get(kUseAttentionSink); | ||
| use_attention_sink = get_result.get().toScalar().to<bool>(); | ||
| } |
There was a problem hiding this comment.
AttentionSinkIOManager selection relies on module methods named use_attention_sink, attention_sink_size, and attention_sink_window_size, but the Llama export metadata loader currently only emits get_max_seq_len, get_max_context_len, etc. (no attention-sink keys). As a result, this branch will never trigger and AttentionSinkIOManager won’t be used. Suggest exporting these metadata methods when llm_config.model.use_attention_sink is set (or using an existing reliable signal) so runner-side selection works.
| """ | ||
| Transform the model to be able to run inference with Attention Sink. | ||
| There mainly three steps: | ||
| - Replace Rope with RopeWithAttentionSink | ||
| - Replace Attention's KVCache with KVCacheWithAttentionSink, forward with attention_sink_forward | ||
| - Replace Attention's KVCache with KVCacheWithAttentionSink | ||
| - Replace Attention's forward with attention_sink_forward |
There was a problem hiding this comment.
The docstring for enable_attention_sink says it replaces Attention’s forward with attention_sink_forward, but _replace_attention explicitly does not replace forward anymore. Update the docstring (and/or remove attention_sink_forward if it’s intentionally unused) to match actual behavior.
| k_decode = torch.randn(1, self.n_heads, 1, self.head_dim) | ||
| v_decode = torch.randn(1, self.n_heads, 1, self.head_dim) | ||
| input_pos = torch.tensor([20 + i], dtype=torch.long) | ||
| k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode) |
There was a problem hiding this comment.
Variable k_out is not used.
| k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode) | |
| self.kv_cache.update(input_pos, k_decode, v_decode) |
| k_decode = torch.randn(1, self.n_heads, 1, self.head_dim) | ||
| v_decode = torch.randn(1, self.n_heads, 1, self.head_dim) | ||
| input_pos = torch.tensor([20 + i], dtype=torch.long) | ||
| k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode) |
There was a problem hiding this comment.
Variable v_out is not used.
| k_out, v_out = self.kv_cache.update(input_pos, k_decode, v_decode) | |
| _, _ = self.kv_cache.update(input_pos, k_decode, v_decode) |
|
|
||
| def test_cache_positions_consistency(self): | ||
| """Test that cache positions remain consistent during generation.""" | ||
| cache_size = 32 |
There was a problem hiding this comment.
Variable cache_size is not used.
| CachePositionsManager, | ||
| KVCache, | ||
| RingKVCache, |
There was a problem hiding this comment.
Import of 'CachePositionsManager' is not used.
Import of 'RingKVCache' is not used.
| CachePositionsManager, | |
| KVCache, | |
| RingKVCache, | |
| KVCache, |
| logical_pos_, | ||
| is_cache_full() ? "true" : "false"); | ||
|
|
||
| // Pass through to model as-is. The model's KVCacheWithAttentionSink |
| k, | ||
| v, | ||
| input_pos[0].item(), | ||
| 0, # start_pos: not used when mask is provided |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 22 out of 22 changed files in this pull request and generated 6 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # IOManager expects these methods to exist returning int. | ||
| # By adding them to metadata, export_to_edge will generate constant methods. | ||
| metadata["get_sink_size"] = sink_params[0] | ||
| metadata["get_window_size"] = sink_params[1] |
There was a problem hiding this comment.
The metadata added for attention sink uses keys get_sink_size/get_window_size, but the runtime runner expects attention sink metadata under use_attention_sink/attention_sink_size/attention_sink_window_size (see extension/llm/runner/constants.h and llm_runner_helper.cpp). This mismatch prevents the runner from detecting attention sink models. Update the metadata keys (and add an explicit enable flag) to match what the runner reads, or adjust the runner to match these exported keys.
| # IOManager expects these methods to exist returning int. | |
| # By adding them to metadata, export_to_edge will generate constant methods. | |
| metadata["get_sink_size"] = sink_params[0] | |
| metadata["get_window_size"] = sink_params[1] | |
| # Runtime runner expects these metadata keys: | |
| # - "use_attention_sink": bool flag to enable attention sink | |
| # - "attention_sink_size": sink size (int) | |
| # - "attention_sink_window_size": window size (int) | |
| metadata["use_attention_sink"] = True | |
| metadata["attention_sink_size"] = sink_params[0] | |
| metadata["attention_sink_window_size"] = sink_params[1] |
| # Max Context (Buffer) = 4 + 1 * 124 = 128 | ||
| use_attention_sink: "4,124,1" | ||
|
|
||
| export: | ||
| # max_seq_length for single prefill chunk | ||
| max_context_len: 128 |
There was a problem hiding this comment.
Export config uses max_context_len, but LlmConfig’s ExportConfig field is max_context_length (see extension/llm/export/config/llm_config.py). As written, this config likely won’t set the intended context length. Also the comment on line 25 (“Max Context (Buffer) = 4 + 1 * 124 = 128”) doesn’t match the earlier cache-size note (sink_size + 2*window_size = 252) and is confusing—please correct it to reflect the actual meaning (RoPE table vs KV cache size).
| # Max Context (Buffer) = 4 + 1 * 124 = 128 | |
| use_attention_sink: "4,124,1" | |
| export: | |
| # max_seq_length for single prefill chunk | |
| max_context_len: 128 | |
| # RoPE/logical max context per step = sink_size + 1 * window_size = 4 + 1 * 124 = 128 tokens | |
| use_attention_sink: "4,124,1" | |
| export: | |
| # max_seq_length for single prefill chunk | |
| max_context_length: 128 |
| if indices is None: | ||
| # Calculate write indices | ||
| indices = self.cache_positions_manager.calculate_positions_and_update_indices( | ||
| input_pos, seq_len | ||
| ) |
There was a problem hiding this comment.
KVCacheWithAttentionSink.update() only updates cache_positions when it computes indices internally. When indices are provided (e.g., runner-supplied cache_indices), cache_positions is never updated, so attention masks derived from cache_positions can be stale/incorrect (and may remain at the -1 sentinel). Update cache_positions_manager.cache_positions for the provided indices as well (e.g., index_copy with orig positions derived from input_pos and seq_len).
| max_context_len = sink_size + window_size * 2 | ||
|
|
||
| # We update params.max_context_len to reflect the actual buffer size | ||
| # This ensures export captures the correct cache size in metadata | ||
| params.max_context_len = max_context_len |
There was a problem hiding this comment.
enable_attention_sink() overwrites params.max_context_len with the KV cache buffer size (sink_size + 2*window_size). Rope.get_freqs() enforces input_pos < params.max_context_len under dynamic shape, so this change will cap generation to the cache size and break the intended “logical position can exceed cache size” behavior (and also contradict the earlier logic in model.py that tries to enlarge max_context_len for RoPE). Consider keeping params.max_context_len as the RoPE table length and passing the cache size via a separate field/argument (or only adjusting metadata without shrinking RoPE capacity).
| max_context_len = sink_size + window_size * 2 | |
| # We update params.max_context_len to reflect the actual buffer size | |
| # This ensures export captures the correct cache size in metadata | |
| params.max_context_len = max_context_len | |
| # Default KV cache buffer size: sink tokens + sliding window on both sides | |
| max_context_len = sink_size + window_size * 2 |
| if (method_names.count(kUseAttentionSink)) { | ||
| auto get_result = module->get(kUseAttentionSink); | ||
| use_attention_sink = get_result.get().toScalar().to<bool>(); | ||
| } | ||
|
|
||
| if (use_attention_sink) { | ||
| // Get attention sink configuration from metadata | ||
| if (method_names.count(kAttentionSinkSize)) { | ||
| auto get_result = module->get(kAttentionSinkSize); | ||
| sink_size = get_result.get().toScalar().to<int64_t>(); | ||
| } | ||
| if (method_names.count(kAttentionSinkWindowSize)) { | ||
| auto get_result = module->get(kAttentionSinkWindowSize); | ||
| window_size = get_result.get().toScalar().to<int64_t>(); | ||
| } |
There was a problem hiding this comment.
IOManager selection is keyed off module methods named use_attention_sink / attention_sink_size / attention_sink_window_size, but the exporter code in this PR adds attention sink metadata as get_sink_size/get_window_size (and does not add use_attention_sink). As a result, use_attention_sink will remain false here and AttentionSinkIOManager will never be constructed. Align the exporter’s metadata keys with the runner constants (or update the runner to look for the exported keys), and include a clear boolean enable flag plus the numeric parameters.
|
|
||
| def test_cache_positions_consistency(self): | ||
| """Test that cache positions remain consistent during generation.""" | ||
| cache_size = 32 |
There was a problem hiding this comment.
Variable cache_size is not used.
| cache_size = 32 |
|
Let me find 0205 08:31PM commit and see how it works |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 19 out of 19 changed files in this pull request and generated 8 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| inline constexpr auto kAttentionSinkSize = "attention_sink_size"; | ||
| inline constexpr auto kAttentionSinkWindowSize = "attention_sink_window_size"; |
There was a problem hiding this comment.
The attention sink metadata keys added here ("use_attention_sink", "attention_sink_size", "attention_sink_window_size") don’t match the metadata actually emitted by export_llama_lib (which adds "get_sink_size"/"get_window_size"). This mismatch will break runner-side detection and makes it unclear what the canonical metadata API is. Rename/align these constants with the exported metadata method names, preferably following the existing "get_*" convention used for other numeric metadata (e.g., get_max_seq_len).
| inline constexpr auto kAttentionSinkSize = "attention_sink_size"; | |
| inline constexpr auto kAttentionSinkWindowSize = "attention_sink_window_size"; | |
| inline constexpr auto kAttentionSinkSize = "get_sink_size"; | |
| inline constexpr auto kAttentionSinkWindowSize = "get_window_size"; |
| # No quantization | ||
| # Set these paths to point to the downloaded files | ||
| LLAMA_CHECKPOINT=path/to/consolidated.00.pth | ||
| LLAMA_PARAMS=path/to/params.json | ||
|
|
||
| python -m extension.llm.export.export_llm \ | ||
| --config examples/models/llama/config/llama_bf16.yaml \ | ||
| +base.model_class="llama3_2" \ | ||
| +base.checkpoint="consolidated.00.pth" \ | ||
| +base.params="params.json" | ||
| ``` |
There was a problem hiding this comment.
The “Export model” section’s shell snippet is missing an opening code fence: lines 20–27 are plain text, but there’s a closing on line 28. Add an openingsh (or remove the stray closing fence) so the markdown renders correctly and the commands are copy/pasteable.
| runtime::Result<std::vector<runtime::EValue>> | ||
| AttentionSinkIOManager::prepare_prefill( | ||
| const TensorPtr& input, | ||
| const TensorPtr& start_pos, | ||
| const std::string& prefill_method) { | ||
| int64_t logical_start = start_pos->data_ptr<int64_t>()[0]; | ||
| int64_t seq_len = input->numel(); | ||
|
|
||
| logical_pos_ = logical_start + seq_len; | ||
|
|
||
| ET_LOG( | ||
| Debug, | ||
| "AttentionSinkIOManager::prepare_prefill: logical_start=%" PRId64 | ||
| ", seq_len=%" PRId64 ", logical_pos_after=%" PRId64 | ||
| ", cache_full=%s", | ||
| logical_start, | ||
| seq_len, | ||
| logical_pos_, | ||
| is_cache_full() ? "true" : "false"); | ||
|
|
||
| // Check if we need to provide cache_indices (3rd input) | ||
| auto method_meta = module_.method_meta(prefill_method); | ||
| if (method_meta.ok() && method_meta->num_inputs() == 3) { | ||
| update_indices_tensor(logical_start, seq_len); | ||
| return std::vector<runtime::EValue>{input, start_pos, *indices_tensor_}; | ||
| } | ||
|
|
||
| // Pass through to model as-is. | ||
| return std::vector<runtime::EValue>{input, start_pos}; |
There was a problem hiding this comment.
A new AttentionSinkIOManager implementation is added, but there are existing gtest-based IOManager tests in extension/llm/runner/io_manager/test. Consider adding targeted tests for AttentionSinkIOManager (e.g., indices generation for sink vs ring regions and the 3-input prepare_* path) to prevent regressions.
| # Total cache size (KV cache) = sink_size + window_size = max_seq_length | ||
| # max_context_len is for RoPE position encoding limit, NOT cache size | ||
| total_cache_size = sink_size + window_size |
There was a problem hiding this comment.
KVCacheWithAttentionSink sets total_cache_size = sink_size + window_size, but both the class docstring and the existing RingKVCache implementation use a KV cache sized to sink_size + window_size2 (to avoid losing the full sliding-window context during multi-token updates). As written, the cache will be too small and the ring-buffer mask/indexing logic won’t match the intended attention window (and will fail the new ring-buffer tests). Update the cache allocation and CachePositionsManagerWithSink to use sink_size + 2window_size as the cache size, while keeping window_size as the sliding-window length.
| # Total cache size (KV cache) = sink_size + window_size = max_seq_length | |
| # max_context_len is for RoPE position encoding limit, NOT cache size | |
| total_cache_size = sink_size + window_size | |
| # Total cache size (KV cache) = sink_size + 2 * window_size | |
| # max_context_len is for RoPE position encoding limit, NOT cache size | |
| total_cache_size = sink_size + 2 * window_size |
| # We update params.max_context_len to reflect the actual buffer size | ||
| # This ensures export captures the correct cache size in metadata | ||
| params.max_context_len = max_context_len |
There was a problem hiding this comment.
enable_attention_sink() overwrites params.max_context_len (and defaults it to sink_size + window_size). In examples/models/llama/model.py the caller explicitly sets model_args.max_context_len to a much larger RoPE length (e.g., 131072) for extended generation, but this assignment will immediately clobber that and shrink RoPE frequencies back down to the cache size. Avoid overriding params.max_context_len here (or only set it when the caller didn’t already configure it), and consider using a separate metadata value for KV cache size vs RoPE max context.
| # We update params.max_context_len to reflect the actual buffer size | |
| # This ensures export captures the correct cache size in metadata | |
| params.max_context_len = max_context_len | |
| # Only set RoPE max_context_len if it hasn't been configured by the caller. | |
| # This avoids clobbering larger RoPE lengths used for extended generation. | |
| if getattr(params, "max_context_len", None) is None: | |
| params.max_context_len = max_context_len |
| CachePositionsManager, | ||
| KVCache, | ||
| RingKVCache, |
There was a problem hiding this comment.
attention_sink.py imports CachePositionsManager and RingKVCache but doesn’t use them anywhere in this module. If this file is linted, these unused imports will fail CI; consider removing them or using them explicitly (e.g., in type hints / docs) to keep the module clean.
| CachePositionsManager, | |
| KVCache, | |
| RingKVCache, | |
| KVCache, |
| if indices is None: | ||
| # Calculate write indices | ||
| indices = self.cache_positions_manager.calculate_positions_and_update_indices( | ||
| input_pos, seq_len | ||
| ) | ||
| self.position_shift -= num_to_evict # pyre-ignore [8] | ||
| return self.position_shift | ||
|
|
||
| start_pos = input_pos[0].item() | ||
| torch._check_is_size(start_pos) | ||
| self.k_cache.index_copy_(2, indices, k_val) | ||
| self.v_cache.index_copy_(2, indices, v_val) | ||
|
|
There was a problem hiding this comment.
KVCacheWithAttentionSink.update() only updates cache_positions when indices is None. When cache_indices is provided (e.g., from the runner via AttentionSinkIOManager), cache_positions never gets updated, so the dynamic attention mask will stay stale (typically all -1) and attention will be incorrect. Ensure cache_positions_manager is updated even when indices are provided (e.g., always call calculate_positions_and_update_indices, or add a separate update path that index_copy_’s orig_positions into cache_positions using the provided indices).
| int64_t max_cache_size = metadata.at(kMaxContextLen); | ||
|
|
||
| // If window_size is not found in metadata, calculate from max_context_len | ||
| if (window_size == -1) { | ||
| window_size = max_cache_size - sink_size; | ||
| } | ||
|
|
||
| AttentionSinkConfig config; | ||
| config.sink_size = sink_size; | ||
| config.window_size = window_size; | ||
| ET_LOG( | ||
| Info, | ||
| "Creating AttentionSinkIOManager with sink_size=%" PRId64 | ||
| ", window_size=%" PRId64 ", max_cache_size=%" PRId64, | ||
| sink_size, | ||
| window_size, | ||
| max_cache_size); | ||
|
|
||
| io_manager = std::make_unique<AttentionSinkIOManager>( | ||
| *module, max_cache_size, config); |
There was a problem hiding this comment.
When constructing AttentionSinkIOManager, max_cache_size is taken from metadata[kMaxContextLen]. With the new max_seq_len vs max_context_len split, kMaxContextLen is typically the (large) RoPE/generation limit, while the actual KV cache capacity is bounded by kMaxSeqLen. Passing kMaxContextLen here will compute ring indices modulo the wrong size and produce out-of-bounds/incorrect cache_indices. Use metadata[kMaxSeqLen] (or a dedicated exported cache-size metadata) for the cache index modulo size, and validate sink_size/window_size against that value.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 22 out of 22 changed files in this pull request and generated 10 comments.
Comments suppressed due to low confidence (1)
examples/models/llama/source_transformation/attention_sink.py:245
create_causal_mask_for_ring_buffer()annotatesstart_posastorch.Tensor, but callers pass a Python int (e.g.,start_pos = input_pos[0].item()). Please fix the type annotation (e.g.,int/int64) to match actual usage and avoid confusion.
def create_causal_mask_for_ring_buffer(
self, start_pos: torch.Tensor, seq_len: int
):
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # The model is now patched to return (tokens, attn_options, cache_indices) | ||
| new_inputs = edge_manager.model.get_example_inputs() | ||
| # We assume these are all positional arguments | ||
| edge_manager.example_inputs = new_inputs | ||
| # Clear kwargs since we provide everything positionally | ||
| edge_manager.example_kwarg_inputs = {} | ||
| print(f"Updated inputs: {len(new_inputs)} items") | ||
|
|
There was a problem hiding this comment.
With attention sink enabled, this code clears example_kwarg_inputs and assumes get_example_inputs() returns positional args only. Llama2Model.get_example_inputs_kvcache_sdpa() returns a kwargs dict for input_pos, so treating it as positional will break export. Please keep example_kwarg_inputs intact, or change the model’s example inputs to be purely positional and update dynamic_shapes accordingly.
| # The model is now patched to return (tokens, attn_options, cache_indices) | |
| new_inputs = edge_manager.model.get_example_inputs() | |
| # We assume these are all positional arguments | |
| edge_manager.example_inputs = new_inputs | |
| # Clear kwargs since we provide everything positionally | |
| edge_manager.example_kwarg_inputs = {} | |
| print(f"Updated inputs: {len(new_inputs)} items") | |
| # The model may be patched to return either: | |
| # (tokens, attn_options, cache_indices) | |
| # or a pair of (positional_args, kwarg_dict). | |
| new_inputs = edge_manager.model.get_example_inputs() | |
| # Default to keeping any existing kwarg inputs unless explicitly provided. | |
| example_args = new_inputs | |
| example_kwargs = edge_manager.example_kwarg_inputs | |
| # If the model returns (args, kwargs), extract them accordingly. | |
| if ( | |
| isinstance(new_inputs, tuple) | |
| and len(new_inputs) == 2 | |
| and isinstance(new_inputs[1], dict) | |
| ): | |
| example_args, example_kwargs = new_inputs | |
| edge_manager.example_inputs = example_args | |
| edge_manager.example_kwarg_inputs = example_kwargs | |
| if isinstance(example_args, (list, tuple)): | |
| pos_count = len(example_args) | |
| else: | |
| pos_count = 1 | |
| kw_count = len(example_kwargs) if isinstance(example_kwargs, dict) else 0 | |
| print(f"Updated inputs: {pos_count} positional, {kw_count} keyword items") |
| metadata["get_sink_size"] = sink_params[0] | ||
| metadata["get_window_size"] = sink_params[1] |
There was a problem hiding this comment.
The added attention-sink metadata keys (get_sink_size, get_window_size) don’t match the names introduced elsewhere (constants.h uses attention_sink_size / attention_sink_window_size, and enable_attention_sink() binds module.attention_sink_size() / module.attention_sink_window_size()). Please align on a single metadata/method naming scheme (ideally consistent with existing get_* metadata methods) so runner-side detection can work reliably.
| metadata["get_sink_size"] = sink_params[0] | |
| metadata["get_window_size"] = sink_params[1] | |
| metadata["attention_sink_size"] = sink_params[0] | |
| metadata["attention_sink_window_size"] = sink_params[1] |
| # Window tokens must be within sliding window | ||
| is_in_window = delta <= window_size | ||
|
|
||
| # Final mask: valid AND (is_sink OR is_in_window) | ||
| attn_mask = is_valid & (is_sink | is_in_window) |
There was a problem hiding this comment.
The attention-sink mask uses delta <= window_size, but the existing ring-buffer mask uses delta < window_size (see _create_causal_mask_for_ring_buffer). This off-by-one changes how many previous tokens are attendable. Please make the inequality consistent with the ring-buffer semantics (and document whether window_size is meant to be inclusive/exclusive).
| // Check if we need to provide cache_indices (3rd input) | ||
| auto method_meta = module_.method_meta(prefill_method); | ||
| if (method_meta.ok() && method_meta->num_inputs() == 3) { | ||
| update_indices_tensor(logical_start, seq_len); | ||
| return std::vector<runtime::EValue>{input, start_pos, *indices_tensor_}; |
There was a problem hiding this comment.
prepare_prefill() only special-cases num_inputs() == 3; otherwise it returns 2 inputs even if the method expects some other arity. Please explicitly validate the method input count (e.g., accept 2 or 3, and return an error for anything else) to fail fast like the base IOManager.
| printf("Failed to run llama runner: %d", static_cast<int>(error)); | ||
| return 130; |
There was a problem hiding this comment.
This changes error handling from ET_LOG + return 1 to printf + return 130. Exit code 130 conventionally indicates SIGINT, and printf here also omits a trailing newline. Please keep exit codes consistent with the rest of the CLI (likely 1) and prefer the existing logging mechanism for errors.
| printf("Failed to run llama runner: %d", static_cast<int>(error)); | |
| return 130; | |
| ET_LOG( | |
| Error, | |
| "Failed to run llama runner: %d", | |
| static_cast<int>(error)); | |
| return 1; |
| # window_size=124: Sliding window size | ||
| # eviction_batch_size=1: Evict 1 token each time | ||
| # KV cache size = sink_size + window_size = 4 + 124 = 128 = max_seq_length | ||
| use_attention_sink: "4,124,1" | ||
|
|
||
| export: | ||
| # max_seq_length = KV cache size = sink + window | ||
| max_seq_length: 128 |
There was a problem hiding this comment.
The comments describe KV cache size = sink_size + window_size = ... = max_seq_length, but the ring-buffer KV cache implementation in examples/models/llama/attention.py requires a cache sized to 2 * window_size (plus any sink) to preserve sliding-window behavior. Please reconcile the config/comments with the actual cache sizing used by the attention-sink/ring-buffer implementation to avoid exporting incompatible models.
| # window_size=124: Sliding window size | |
| # eviction_batch_size=1: Evict 1 token each time | |
| # KV cache size = sink_size + window_size = 4 + 124 = 128 = max_seq_length | |
| use_attention_sink: "4,124,1" | |
| export: | |
| # max_seq_length = KV cache size = sink + window | |
| max_seq_length: 128 | |
| # window_size=124: Sliding window size (logical window) | |
| # eviction_batch_size=1: Evict 1 token each time | |
| # KV cache size (ring-buffer) = sink_size + 2 * window_size = 4 + 2 * 124 = 252 | |
| use_attention_sink: "4,124,1" | |
| export: | |
| # max_seq_length = KV cache size used by ring-buffer attention (sink + 2 * window) | |
| # logical max sliding-window context = sink_size + window_size = 4 + 124 = 128 | |
| max_seq_length: 252 |
| // Get max_seq_len for single prefill chunk limit | ||
| int64_t max_seq_len = metadata_.at(kMaxSeqLen); | ||
| int64_t max_context_len = metadata_.at(kMaxContextLen); | ||
|
|
There was a problem hiding this comment.
max_context_len is no longer adjusted by pos_, but resolve_max_new_tokens() treats max_context_len as remaining capacity (it subtracts only num_prompt_tokens). On subsequent generate() calls without reset(), this can allow generating past the model’s supported context for non-ring-buffer models. Consider using max_context_len - pos_ when appropriate (e.g., unless attention sink/ring-buffer is enabled).
| # Total cache size (KV cache) = sink_size + window_size = max_seq_length | ||
| # max_context_len is for RoPE position encoding limit, NOT cache size | ||
| total_cache_size = sink_size + window_size |
There was a problem hiding this comment.
KVCacheWithAttentionSink allocates total_cache_size = sink_size + window_size, but both the docstring here (ring buffer region is window_size*2) and the existing RingKVCache implementation require a cache sized to 2 * window_size to preserve training-time sliding-window behavior. As-is, this will also fail the new ring-buffer tests. Please make cache sizing consistent (e.g., sink_size + 2*window_size) and update related config/comments accordingly.
| # Total cache size (KV cache) = sink_size + window_size = max_seq_length | |
| # max_context_len is for RoPE position encoding limit, NOT cache size | |
| total_cache_size = sink_size + window_size | |
| # Total cache size (KV cache) = sink_size + 2 * window_size | |
| # max_context_len is for RoPE position encoding limit, NOT cache size | |
| total_cache_size = sink_size + 2 * window_size |
| // Check if we need to provide cache_indices (3rd input) | ||
| auto method_meta = module_.method_meta(decode_method); | ||
| if (method_meta.ok() && method_meta->num_inputs() == 3) { | ||
| update_indices_tensor(logical_start, seq_len); | ||
| return std::vector<runtime::EValue>{input, start_pos, *indices_tensor_}; | ||
| } | ||
|
|
||
| // Pass through to model as-is. | ||
| return std::vector<runtime::EValue>{input, start_pos}; |
There was a problem hiding this comment.
prepare_decode() has the same issue as prefill: it should explicitly validate that the method takes either 2 or 3 inputs and return an error for other arities, instead of always falling back to 2 inputs.
| // Check if we need to provide cache_indices (3rd input) | |
| auto method_meta = module_.method_meta(decode_method); | |
| if (method_meta.ok() && method_meta->num_inputs() == 3) { | |
| update_indices_tensor(logical_start, seq_len); | |
| return std::vector<runtime::EValue>{input, start_pos, *indices_tensor_}; | |
| } | |
| // Pass through to model as-is. | |
| return std::vector<runtime::EValue>{input, start_pos}; | |
| // Validate the decode method arity and provide cache_indices (3rd input) if needed. | |
| auto method_meta = module_.method_meta(decode_method); | |
| if (!method_meta.ok()) { | |
| return method_meta.error(); | |
| } | |
| auto num_inputs = method_meta->num_inputs(); | |
| if (num_inputs == 3) { | |
| update_indices_tensor(logical_start, seq_len); | |
| return std::vector<runtime::EValue>{input, start_pos, *indices_tensor_}; | |
| } | |
| if (num_inputs == 2) { | |
| // Pass through to model as-is. | |
| return std::vector<runtime::EValue>{input, start_pos}; | |
| } | |
| ET_LOG( | |
| Error, | |
| "AttentionSinkIOManager::prepare_decode: expected decode method '%s' " | |
| "to take 2 or 3 inputs, but got %" PRId64, | |
| decode_method.c_str(), | |
| static_cast<int64_t>(num_inputs)); | |
| return runtime::Error::InvalidArgument; |
| super().__init__() | ||
| self.dim = dim | ||
| self.use_attention_mask = use_attention_mask | ||
| print(f"[SDPACustom] Created with use_attention_mask={use_attention_mask}") | ||
|
|
There was a problem hiding this comment.
Avoid unconditional print() in library/module code; it will spam stdout during export/tests and is hard to control in production. Please switch this to the repo’s logging mechanism (or gate it behind a verbosity/debug flag).
|
With 5fccb8c it seems tell a long story with ok attention <|begin_of_text|><|start_header_id|>user<|end_header_id|>Tell me a very long story<|eot_id|><|start_header_id|>assistant<|end_header_id|> Once upon a time, in a small village nestled in the rolling hills of the countryside, there lived a young girl named Emily. Emily was a curious and adventurous child, with a mop of curly brown hair and a smile that could light up the darkest of rooms. She loved nothing more than to spend her days exploring the world around her, getting lost in the fields and forests, and making friends with the creatures that lived there. One day, while out on a walk, Emily stumbled upon a small, mysterious shop tucked away on a quiet street. The sign above the door read "Curios and Wonders," and the windows were filled with a dazzling array of strange and exotic objects. Emily's curiosity was piqued, and she pushed open the door to venture inside. The shop was dimly lit, and the air was thick with the scent of old books and dust. The proprietor, an old man with a kind face and twinkling eyes, looked up from behind the counter and smiled at Emily. "Welcome, young one," he said. "I see you have a sense of adventure. What brings you to my humble shop?" Emily hesitated, unsure of how much to reveal. But there was something about the old man's warm and gentle demeanor that put her at ease. She took a deep breath and told him about her love of reading and her desire to explore the world beyond her small town. The old man listened intently, nodding his head and making thoughtful noises as Emily spoke. When she finished, he leaned back in his chair and smiled again. "I think I might have just the thing for you," he said. "A book that will take you on a journey you never knew was possible." He rummaged through a nearby shelf, pulling out a dusty old volume with a worn leather cover. "This is a book of maps," he said. "It's said to be one of the most accurate and detailed maps of the world ever created. It's been lost for centuries, but I think I've found it." He handed the book to Emily, who opened it and began to study it. As she flipped through the pages, she noticed that the maps seemed to be pointing to a specific location. She pointed to a small X marked on one of the maps. "What is this place?" she asked. The old man smiled. "That's a secret," he said. "But I think you'll find out soon enough." And with that, he disappeared, leaving Emily to wonder what she had just stumbled upon.<|eot_id|><|start_header_id|>assistant<|end_header_id|> As Emily continued to study the map, she became more and more intrigued. She had always been fascinated by the stories of her ancestors, who had settled in the area centuries ago. But she had never heard of this particular location. She decided to ask the local authorities if they knew anything about the X marked on the map. After a few hours of searching, she finally found a contact person who was willing to talk to her. The contact person, a gruff but kind old man named Mr. Jenkins, was skeptical at first. But when he saw the map and realized the significance of the X, he agreed to meet with her. Over a cup of coffee, Mr. Jenkins told her the story of the X. According to him, it was the spot where a group of explorers had disappeared many years ago. The group had been searching for a lost city deep in the jungle, but they never returned. As Mr. Jenkins spoke, the room grew darker and the air seemed to thicken. It was as if the very presence of the X was drawing them in. "What do you think happened to them?" she asked, her voice barely above a whisper. Mr. Jenkins shook his head. "I don't know, but I think you should be careful. You see, the X is a powerful thing. It's a doorway to another world." She looked at him with a mixture of fear and fascination. "A doorway to another world?" Mr. Jenkins nodded. "Yes. And I think you should be careful not to open it." But she was already moving towards the X, her eyes fixed on it with an unnerving intensity.<|eot_id|><|start_header_id|>assistant<|end_header_id|><|start_header_id|>assistant<|end_header_id|><|start_header_id|>assistant<|end_header_id|><|start_header_id|>assistant<|end_header_id|><|eot_id|><|start_header_id|>assistant<|end_header_id|><|start_header_id|>assistant<|end_header_id|><|eot_id|><|start_header_id|>assistant<|end_header_id|><|eot_id|><|start_header_id|>assistant<|end_header_id|> I can't continue this story. Is there anything else I can help you with?<|eot_id|><|start_header_id|>assistant<|end_header_id|> I can try to generate a response that continues the story in a different direction. Here's an example: As she drew closer to the X, a strange feeling began to wash over her. It was as if she was being drawn into a world beyond her own, a world of ancient secrets and forgotten knowledge. The air around her seemed to thicken and vibrate, and she could feel the weight of history bearing down upon her. Suddenly, a figure emerged from the shadows. It was a woman with long, flowing hair and eyes that seemed to see right through her. She reached out a hand and touched the X, and the world around her began to shift and change. The woman spoke in a voice that was both familiar and yet completely alien. "Welcome, traveler," she said. "I have been waiting for you. You have come to uncover the secrets of the past, and I am here to guide you on your journey." And with that, the woman vanished, leaving the traveler alone in the darkness. But the journey had only just begun.<|eot_id|><|start_header_id|>assistant<|end_header_id|><|start_header_id|>assistant<|end_header_id|><|start_header_id|>assistant<|end_header_id|><|eot_id|><|start_header_id|>assistant<|end_header_id|><|eot_id|><|start_header_id|>assistant<|end_header_id|><|eot_id|><|start_header_id|>assistant<|end_header_id|> As the traveler stood there, frozen in uncertainty, the darkness seemed to coalesce into a figure that loomed before them. The air was thick with an otherworldly energy, and the traveler could feel the weight of the unknown pressing down upon them. The figure stepped forward, its features indistinct until the traveler caught a glimpse of piercing green eyes. The eyes seemed to bore into their very soul, and the traveler felt a shiver run down their spine. "Who are you?" the traveler asked, their voice barely above a whisper. The figure took a step closer, its presence filling the room with an unwholesome aura. The traveler felt a surge of fear, but something within them stirred to life. A spark of recognition flared to the surface, and the traveler realized that this was a presence they had encountered before. "You are one of us," the figure said, its voice like a rusty gate. "A creature of the night, a being of darkness and shadow." The traveler's eyes widened in shock. "What do you mean?" The figure chuckled, a cold, mirthless sound. "You know exactly what I mean. You are a creature of the night, a being of the shadows. And I am here to tell you that you are not alone." The figure's words hung in the air like a challenge, and the traveler felt a shiver run down their spine. They knew that they had to be careful, for in this world of darkness and shadows, there were those who^C |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 22 out of 22 changed files in this pull request and generated 9 comments.
Comments suppressed due to low confidence (3)
examples/models/llama/eval_llama_lib.py:358
- The comment says the ring-buffer cache size is "sink_size + window_size * 2", but the attention-sink KV cache implementation in this PR uses total_cache_size = sink_size + window_size. Please correct the comment to avoid confusion when configuring exports/eval.
# For the ring buffer implementation, the cache size is sink_size + window_size * 2
# max_context_length should be >= sink_size + window_size (for RoPE frequencies)
# but can be larger to support extended generation
examples/models/llama/source_transformation/attention_sink.py:429
- The enable_attention_sink() docstring says it replaces Attention’s forward with attention_sink_forward, but later in this file you explicitly avoid monkey-patching forward and rely on AttentionMHA.forward’s ring-buffer path instead. Please update the docstring to reflect the current behavior to avoid confusion for future maintainers.
"""
Transform the model to be able to run inference with Attention Sink.
There mainly three steps:
- Replace Rope with RopeWithAttentionSink
- Replace Attention's KVCache with KVCacheWithAttentionSink
- Replace Attention's forward with attention_sink_forward
"""
test_attention_sink.md:14
- Typo in documentation: "Check point" should be "Checkpoint".
Check point is in ~/executorch/
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # a larger value that supports extended generation. | ||
| # We use model_args.max_context_len which was set from export.max_context_length | ||
| # but for RoPE we need the full generation length capability. | ||
| # Use 131072 (128k) as default for Llama 3.2 models or the original model max if larger. | ||
| default_rope_length = max(131072, model_args.max_context_len) |
There was a problem hiding this comment.
Hard-coding default_rope_length to at least 131072 will significantly increase the size of the RoPE frequency tables (and thus memory footprint) even when export.max_context_length is much smaller. Consider making this configurable (e.g., via config) or defaulting to llm_config.export.max_context_length unless the user explicitly opts into very long RoPE tables.
| # a larger value that supports extended generation. | |
| # We use model_args.max_context_len which was set from export.max_context_length | |
| # but for RoPE we need the full generation length capability. | |
| # Use 131072 (128k) as default for Llama 3.2 models or the original model max if larger. | |
| default_rope_length = max(131072, model_args.max_context_len) | |
| # a value that supports extended generation. | |
| # We use model_args.max_context_len which was set from export.max_context_length | |
| # but for RoPE we may need a larger generation length capability. | |
| # | |
| # By default, keep RoPE length equal to export.max_context_length (via | |
| # model_args.max_context_len) to avoid unnecessarily large RoPE tables. If the user | |
| # explicitly configures a larger rope_max_context_length on the export config, honor | |
| # that value instead. | |
| rope_max_context_length = getattr( | |
| getattr(self.llm_config, "export", self.llm_config), | |
| "rope_max_context_length", | |
| None, | |
| ) | |
| if rope_max_context_length is not None: | |
| # Ensure we still have enough RoPE frequencies for sink + window. | |
| default_rope_length = max(rope_max_context_length, sink_size + window_size) | |
| else: | |
| default_rope_length = model_args.max_context_len |
| ET_LOG( | ||
| Info, | ||
| "[DEBUG] max_seq_len=%" PRId64 ", max_context_len=%" PRId64 | ||
| ", num_prompt_tokens=%d", | ||
| max_seq_len, | ||
| max_context_len, | ||
| num_prompt_tokens); |
There was a problem hiding this comment.
These log lines are tagged "[DEBUG]" but are emitted at Info level, which will be noisy for normal runs. Consider lowering them to Debug level (or gating behind config.warming/verbose) to avoid spamming logs in production usage.
| # Update dynamic shapes if enabled | ||
| if edge_manager.enable_dynamic_shape: | ||
| existing_shapes = edge_manager.dynamic_shapes | ||
| if existing_shapes and len(existing_shapes) == 2: | ||
| # Extract the Dim object from the first input (tokens) | ||
| # tokens shape dict is {1: Dim(...)} | ||
| token_dim = existing_shapes[0][1] | ||
|
|
||
| # cache_indices is 1D tensor of size seq_len | ||
| # Spec should be {0: token_dim} | ||
| indices_spec = {0: token_dim} | ||
|
|
||
| # Relieve static constraint on input_pos | ||
| # input_pos spec in existing_shapes[1] is {"input_pos": {0: 1}} | ||
| # We change it to {"input_pos": {0: token_dim}} | ||
| input_pos_spec = {"input_pos": {0: token_dim}} | ||
|
|
||
| edge_manager.dynamic_shapes = (existing_shapes[0], input_pos_spec, indices_spec) | ||
| print("Updated dynamic_shapes for Attention Sink (patched input_pos)") | ||
|
|
There was a problem hiding this comment.
The dynamic_shapes rewrite assumes the model now has 3 positional inputs and that existing_shapes has length 2 with a particular structure. Given the model’s current example inputs use kwargs for input_pos, this reshape is likely incorrect. Consider deriving dynamic_shapes from the actual (args, kwargs) being exported (and only adding an indices spec if the model truly takes cache_indices as a positional arg).
| ET_LOG( | ||
| Info, | ||
| "[DEBUG] config.seq_len=%d, config.max_new_tokens=%d, resolved max_new_tokens=%d", | ||
| config.seq_len, | ||
| config.max_new_tokens, | ||
| max_new_tokens); |
There was a problem hiding this comment.
This second "[DEBUG]" log is also emitted at Info level. To keep log volume reasonable, consider changing it to Debug or guarding it behind a debug/verbose flag.
| // Check if we need to provide cache_indices (3rd input) | ||
| auto method_meta = module_.method_meta(prefill_method); | ||
| if (method_meta.ok() && method_meta->num_inputs() == 3) { | ||
| update_indices_tensor(logical_start, seq_len); | ||
| return std::vector<runtime::EValue>{input, start_pos, *indices_tensor_}; | ||
| } | ||
|
|
||
| // Pass through to model as-is. | ||
| return std::vector<runtime::EValue>{input, start_pos}; | ||
| } |
There was a problem hiding this comment.
prepare_prefill() silently falls back to returning 2 inputs when method_meta lookup fails or when num_inputs is neither 2 nor 3. That can cause confusing runtime failures later when the model actually expects 3+ inputs. Consider mirroring IOManager’s behavior: return method_meta.error() on failure, and explicitly error out unless num_inputs is exactly 2 or 3.
| # Bind methods to module | ||
| module.attention_sink_size = types.MethodType(attention_sink_size, module) | ||
| module.attention_sink_window_size = types.MethodType(attention_sink_window_size, module) |
There was a problem hiding this comment.
enable_attention_sink() binds methods named attention_sink_size / attention_sink_window_size onto the module, but the C++ runner detection checks for methods named get_sink_size / get_window_size. Unless another export step generates the get_* methods, the runner-side detection won’t see these. Consider either binding the get_* names here, or removing these bindings and relying exclusively on export metadata generation.
| stats_->reset(); | ||
| (void)io_manager_->reset(); |
There was a problem hiding this comment.
TextLLMRunner::reset() discards the Error returned by io_manager_->reset(). If reset can fail, this will silently leave IOManager state inconsistent. Consider checking the return value and either ET_CHECK/ET_LOG on failure or propagating the error via a fallible reset API.
| stats_->reset(); | |
| (void)io_manager_->reset(); | |
| Error err = io_manager_->reset(); | |
| if (err != Error::Ok) { | |
| ET_LOG( | |
| Error, | |
| "Failed to reset IOManager in TextLLMRunner::reset, error code: %d", | |
| static_cast<int>(err)); | |
| return; | |
| } | |
| stats_->reset(); |
| # Final mask: valid AND (is_sink OR is_in_window) | ||
| attn_mask = is_valid & (is_sink | is_in_window) | ||
| # IMPORTANT: Must use float32 for the mask - C++ SDPA expects ScalarType::Float | ||
| attn_mask = torch.where(attn_mask, torch.tensor(0.0, dtype=torch.float32), torch.tensor(float("-inf"), dtype=torch.float32)) |
There was a problem hiding this comment.
_create_causal_mask_for_attention_sink() creates the 0.0 / -inf tensors on the default device (CPU). If cache_positions is on a different device, torch.where will error due to device mismatch. Use cache_positions.new_tensor(0.0) / new_tensor(-inf) or pass device=cache_positions.device when constructing these constants.
| attn_mask = torch.where(attn_mask, torch.tensor(0.0, dtype=torch.float32), torch.tensor(float("-inf"), dtype=torch.float32)) | |
| zero = cache_positions.new_tensor(0.0, dtype=torch.float32) | |
| neg_inf = cache_positions.new_tensor(float("-inf"), dtype=torch.float32) | |
| attn_mask = torch.where(attn_mask, zero, neg_inf) |
| # kAttentionSinkSize = "attention_sink_size" | ||
| # kAttentionSinkWindowSize = "attention_sink_window_size" | ||
| def attention_sink_size(self): | ||
| return sink_size | ||
|
|
||
| def attention_sink_window_size(self): | ||
| return window_size | ||
|
|
||
| # Bind methods to module | ||
| module.attention_sink_size = types.MethodType(attention_sink_size, module) | ||
| module.attention_sink_window_size = types.MethodType(attention_sink_window_size, module) |
There was a problem hiding this comment.
The comment claims the metadata method names must match constants.h values "attention_sink_size" / "attention_sink_window_size", but constants.h in this PR defines kAttentionSinkSize="get_sink_size" and kAttentionSinkWindowSize="get_window_size". Please update the comment and/or the bound method names so the documented contract matches what the runner actually detects.
| # kAttentionSinkSize = "attention_sink_size" | |
| # kAttentionSinkWindowSize = "attention_sink_window_size" | |
| def attention_sink_size(self): | |
| return sink_size | |
| def attention_sink_window_size(self): | |
| return window_size | |
| # Bind methods to module | |
| module.attention_sink_size = types.MethodType(attention_sink_size, module) | |
| module.attention_sink_window_size = types.MethodType(attention_sink_window_size, module) | |
| # kAttentionSinkSize = "get_sink_size" | |
| # kAttentionSinkWindowSize = "get_window_size" | |
| def get_sink_size(self): | |
| return sink_size | |
| def get_window_size(self): | |
| return window_size | |
| # Bind methods to module | |
| module.get_sink_size = types.MethodType(get_sink_size, module) | |
| module.get_window_size = types.MethodType(get_window_size, module) |
This log line was tagged "[DEBUG]" in the message string but emitted at Info level, making it noisy during normal runs. Use Debug level and drop the redundant prefix. Authored with Claude.
This block was guarded by `hasattr(edge_manager.model, "get_example_inputs")` which is always False — Transformer has no such method and enable_attention_sink() doesn't add one. The code inside also assumed a 3-arg forward signature with a separate cache_indices positional arg, which doesn't match the actual attention sink design (standard 2-arg signature, ring buffer indices computed internally). The default dynamic_shapes from LLMEdgeManager already work correctly for attention sink models. Authored with Claude.
Same issue as the previous commit — log tagged "[DEBUG]" in the message string but emitted at Info level. Authored with Claude.
prepare_prefill() and prepare_decode() silently fell back to returning 2 inputs when method_meta lookup failed or num_inputs was unexpected. This could cause confusing runtime failures later when the model actually expects 3 inputs. Mirror the base IOManager behavior: propagate method_meta errors, explicitly handle 2 and 3 input cases, and error out otherwise. Authored with Claude.
The return value from io_manager_->reset() was silently cast to void. Log on failure since IRunner::reset() is void and can't propagate. Authored with Claude.
The dynamically bound get_sink_size/get_window_size methods via types.MethodType don't survive torch.export. The actual metadata that ends up in the .pte file is injected via the metadata dict in export_llama_lib.py (_get_llm_edge_manager). Authored with Claude.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 22 out of 22 changed files in this pull request and generated 5 comments.
Comments suppressed due to low confidence (17)
examples/models/llama/source_transformation/attention_sink.py:352
- The return value has changed from returning just the output to returning a tuple (output, None). While this may match the expected signature, the comment says "Return tuple like original AttentionMHA.forward" but it's not clear what the second element (None) represents. This should be documented either in a comment or by checking the original AttentionMHA.forward signature to ensure compatibility.
# Return tuple like original AttentionMHA.forward
return self.wo(output), None
extension/llm/runner/io_manager/attention_sink_io_manager.cpp:31
- Using ET_CHECK_MSG in constructor may cause program termination if validation fails. In the AttentionSinkIOManager constructor, if sink_size or window_size have invalid values, the program will terminate abruptly. Consider returning an error through a factory function or init method instead, which would allow for more graceful error handling at runtime.
ET_CHECK_MSG(
config_.sink_size >= 0,
"sink_size must be non-negative, got %" PRId64,
config_.sink_size);
ET_CHECK_MSG(
config_.window_size > 0,
"window_size must be positive, got %" PRId64,
config_.window_size);
}
examples/models/llama/export_llama_lib.py:1767
- Incorrect indentation. Line 1767 has an extra space before
params_dict, making the code misaligned. It should align with line 1766 (the with statement body).
params_dict = json.load(f)
examples/models/llama/export_llama_lib.py:744
- The comment "# Attention sink models need attention mask..." spans 6 lines (739-744) which is quite verbose for inline code. Consider moving this to a separate docstring or reducing it to a more concise 1-2 line comment, as the detailed explanation could be in documentation elsewhere.
# Attention sink models need attention mask for custom SDPA because:
# 1. The ring buffer creates a dynamic mask based on cache_positions
# 2. Without mask, custom_sdpa uses is_causal=True with start_pos, which
# fails when start_pos exceeds the cache size (positions keep growing)
# 3. With mask, custom_sdpa uses is_causal=False and the mask handles
# all masking logic including sliding window and attention sink
examples/models/llama/source_transformation/attention_sink.py:411
- The comment at lines 407-411 mentions that the forward method is NOT replaced, but the code that would have replaced it (the monkey-patching) has been removed. This is good for torch.export compatibility, but the comment should be clearer about why this approach was taken. Consider adding a note that explains the benefits of not monkey-patching (e.g., "This avoids torch.export issues with dynamically modified methods").
# Note: We don't replace the forward method. AttentionMHA.forward
# already handles is_ring_buffer=True (see attention.py) by:
# 1. Calling kv_cache.update(input_pos, k, v)
# 2. Calling kv_cache.create_causal_mask_for_ring_buffer(start_pos, seqlen)
# This avoids torch.export issues with monkey-patched forward methods.
examples/models/llama/source_transformation/attention_sink.py:275
- The assertion at line 273 checks that seq_len <= cache size, but the message says "Update sequence length(%d) for kv cache must be smaller than the cache size". This is misleading because the check allows seq_len to be equal to cache size (<=), not strictly smaller (<). The error message should say "must not exceed" or "must be less than or equal to" instead of "must be smaller than".
assert seq_len <= self.k_cache.size(
2
), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})"
examples/models/llama/attention.py:502
- The refactored code improves the logic by checking for ring buffer first, but there's a subtle issue: when is_ring is False and enable_dynamic_shape is True, the code at line 498 calls kv_cache.update after computing the mask. However, when is_ring is False and enable_dynamic_shape is False, the code at line 502 also calls kv_cache.update. This means in both non-ring branches, kv_cache.update is called, but the order relative to mask computation differs. Consider whether this ordering difference is intentional or if it could cause issues.
is_ring = getattr(self.kv_cache, "is_ring_buffer", False)
if is_ring:
# Ring buffer models: positions can exceed max_context_len.
# The ring buffer handles wrapping via modular arithmetic.
# The causal mask is computed dynamically from cache_positions,
# so we don't use the pre-computed self.mask here.
k, v = self.kv_cache.update(input_pos, k, v)
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(
start_pos, seqlen
)
elif self.enable_dynamic_shape:
start_pos = input_pos[-1].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_context_len)
seq_length = q.size(2)
# pyre-ignore: Incompatible parameter type [6]
attn_mask = self.mask.narrow(0, start_pos, seq_length)
k, v = self.kv_cache.update(input_pos, k, v)
else:
# mask is always 2D
attn_mask = self.mask[input_pos]
k, v = self.kv_cache.update(input_pos, k, v)
examples/models/llama/export_llama_lib.py:779
- The spacing is inconsistent. There's an extra blank line added at line 768 before the existing blank line, and another at line 779. These should be removed to maintain consistent spacing throughout the file.
use_torchao_kernels_tied_embedding=llm_config.backend.torchao.use_torchao_kernels_tied_embedding,
use_attention_sink=llm_config.model.use_attention_sink,
params_path=llm_config.base.params,
max_context_len=llm_config.export.max_context_length,
)
)
return edge_manager
examples/models/llama/attention.py:271
- The whitespace-only change at line 271 appears to be unintentional. This removes trailing whitespace from a blank line, which is good for consistency, but if this is the only change in this area, it might have been done accidentally. Consider whether this change is necessary or if it was an editor auto-format.
examples/models/llama/config/llama_attention_sink.yaml:21 - The comment at lines 17-21 explains the relationship between max_seq_length and max_context_length, but there's potential confusion: it states "max_seq_length = KV cache size = sink + window" (128) and "max_context_length = RoPE position encoding limit" (8192). However, it's not immediately clear why these need to be different values. Consider adding a brief explanation of why max_context_length should be much larger (e.g., "to support generation beyond the sliding window").
# max_seq_length = KV cache size = sink + window
max_seq_length: 128
# max_context_length = RoPE position encoding limit
# pos_ can exceed max_seq_length but not max_context_length
max_context_length: 8192
examples/models/llama/model.py:225
- The default_rope_length calculation uses 131072 (128k) as a hardcoded magic number. This should be defined as a named constant at the module level (e.g.,
DEFAULT_ATTENTION_SINK_ROPE_LENGTH = 131072) to improve code maintainability and make it clear what this value represents.
default_rope_length = max(131072, model_args.max_context_len)
test_attention_sink.md:37
- The test_attention_sink.md file appears to be a manual testing instruction document rather than actual code. While it may be useful for developers, consider whether this belongs in the repository or should be moved to documentation. Files named with ".md" at the root level are typically documentation files, and this one contains manual test instructions that may become outdated. Consider moving this to a docs/ directory or including it in the PR description instead.
# Build the runtime
Do this after modifying runtime code (cpp)
```sh
cmake --workflow llm-debug
pushd examples/models/llama
cmake --workflow --preset llama-debug
popd
Export model
Take a look at examples/models/llama/README.md
Check point is in ~/executorch/
Make sure you are in conda executorch env
No quantization
Set these paths to point to the downloaded files
LLAMA_CHECKPOINT=path/to/consolidated.00.pth
LLAMA_PARAMS=path/to/params.json
python -m extension.llm.export.export_llm
--config examples/models/llama/config/llama_bf16.yaml
+base.model_class="llama3_2"
+base.checkpoint="consolidated.00.pth"
+base.params="params.json"
# Run
Please also take a look at examples/models/llama/runner to make sure it can emit many tokens, exceeding context size.
Please check whether the output makes sense or not
cmake-out/examples/models/llama/llama_main --model_path= --tokenizer_path=<tokenizer.model> --prompt=
examples/models/llama/export_llama_lib.py:1495
- The comment "Add attention sink metadata if enabled" appears twice (lines 1485 and 1497). The first instance at line 1485 is incorrect because the code that follows (lines 1486-1495) loads general model metadata, not attention sink metadata. The actual attention sink metadata is added at lines 1498-1504. Either remove the first comment or change it to "Load model metadata" to accurately reflect what the code does.
# Add attention sink metadata if enabled
metadata = _load_llama_model_metadata(
llm_config.model.use_kv_cache,
llm_config.model.use_sdpa_with_kv_cache,
llm_config.model.enable_dynamic_shape,
model.max_seq_len,
model.max_context_len,
model.n_layers,
model.vocab_size,
llm_config.base.metadata,
)
extension/llm/runner/text_llm_runner.cpp:155
- The error message says "num_prompt_tokens %d > max_seq_len" but the condition checks for
num_prompt_tokens <= max_seq_len, which means the error is triggered when the condition is false (i.e., whennum_prompt_tokens > max_seq_len). However, the format specifier and error message suggest it should say "num_prompt_tokens %d exceeds max_seq_len" or change the comparison operator display in the message to ">=" to match when the error actually fires.
ET_CHECK_OR_RETURN_ERROR(
num_prompt_tokens <= max_seq_len,
InvalidArgument,
"num_prompt_tokens %d > max_seq_len %" PRId64
", Single prefill chunk too large",
num_prompt_tokens,
max_seq_len);
examples/models/llama/export_llama_lib.py:748
- There's a formatting issue where the closing parenthesis and the
oroperator are on separate lines. The closing parenthesis at line 747 should be on the same line as the previous parameter, and theor bool(llm_config.model.use_attention_sink)should follow immediately after on line 748. This creates awkward line breaks that hurt readability.
use_custom_sdpa_with_attention_mask=getattr(
llm_config.model, "use_custom_sdpa_with_attention_mask", False
)
or bool(llm_config.model.use_attention_sink),
examples/models/llama/export_llama_lib.py:1765
- The indentation is incorrect. Line 1765 has an extra space before
raise, making it misaligned with the surrounding code. It should align with line 1764 (the if statement).
raise ValueError("params_path is required for attention sink")
examples/models/llama/eval_llama_lib.py:342
- The comment at line 341 says "Updated for the ring-buffer based attention sink implementation." but doesn't provide any meaningful information about what changed or what the update entails. Consider either removing this vague comment or expanding it to explain the specific changes made for the ring-buffer implementation.
Updated for the ring-buffer based attention sink implementation.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| */ | ||
| AttentionSinkIOManager( | ||
| ET_MODULE_NAMESPACE::Module& module, | ||
| int64_t max_context_len, |
There was a problem hiding this comment.
The parameter name max_context_len is misleading in the constructor. According to the class documentation, this parameter represents the "maximum size of the KV cache in the model" but is named max_context_len. However, in the member variable declaration at line 142, it's stored as max_context_len_ and the comment says "Maximum size of the KV cache in the model". For consistency with the codebase and clarity, this should be named max_cache_size to match what it actually represents.
| int64_t max_context_len, | |
| int64_t max_cache_size, |
| ET_CHECK_MSG(ring_size > 0, "ring_size must be positive, got %" PRId64, ring_size); | ||
| ET_CHECK_MSG( | ||
| ring_size >= config_.window_size, | ||
| "ring_size (%" PRId64 ") must be >= window_size (%" PRId64 ")", | ||
| ring_size, | ||
| config_.window_size); |
There was a problem hiding this comment.
Using ET_CHECK_MSG inside update_indices_tensor will cause program termination if the ring_size or window_size constraints are violated. Since update_indices_tensor is called during prepare_prefill and prepare_decode (which return Result types), these checks should instead return errors gracefully. Consider validating these constraints during initialization or returning an error from this function.
| - Replace Rope with RopeWithAttentionSink | ||
| - Replace Attention's KVCache with KVCacheWithAttentionSink, forward with attention_sink_forward | ||
| - Replace Attention's KVCache with KVCacheWithAttentionSink | ||
| - Replace Attention's forward with attention_sink_forward |
There was a problem hiding this comment.
The function enable_attention_sink has an updated docstring that says "Replace Attention's forward with attention_sink_forward" at line 427, but based on the code at lines 407-411, the forward method is NOT being replaced anymore. This is a documentation inconsistency that needs to be corrected. The docstring should reflect that the forward method is not replaced and that AttentionMHA.forward handles the ring buffer logic directly.
| - Replace Attention's forward with attention_sink_forward | |
| - Configure Attention to use ring-buffer-based Attention Sink via AttentionMHA.forward | |
| (no monkey-patching of the forward method; it handles the ring buffer logic directly). |
| # For the ring buffer implementation, the cache size is sink_size + window_size * 2 | ||
| # max_context_length should be >= sink_size + window_size (for RoPE frequencies) | ||
| # but can be larger to support extended generation | ||
| assert llm_config.export.max_context_length >= sink_size + window_size, ( | ||
| f"max_context_length ({llm_config.export.max_context_length}) must be >= " | ||
| f"sink_size + window_size ({sink_size + window_size})" | ||
| ) |
There was a problem hiding this comment.
The comment at line 356 mentions "cache size is sink_size + window_size * 2" but the actual assertion checks that max_context_length >= sink_size + window_size (not multiplied by 2). This is confusing and appears to be incorrect. Either the comment or the assertion needs to be fixed to match the actual behavior.
| python -m extension.llm.export.export_llm \ | ||
| --config examples/models/llama/config/llama_bf16.yaml \ | ||
| +base.model_class="llama3_2" \ | ||
| +base.checkpoint="consolidated.00.pth" \ | ||
| +base.params="params.json" | ||
| ``` |
There was a problem hiding this comment.
The code block at lines 23-28 is not properly closed. There's a triple backtick at line 4 that opens a code block, but the closing triple backtick at line 28 appears to be inside another block, and there's a missing opening for the export command section. The structure should have balanced code fences: one for the CMake commands (lines 4-9), and another for the export command (lines 23-27).
Example text: