Conversation
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis pull request introduces a new Triton-based sparse attention backend supporting 2:4 structured sparsity for efficient LLM inference. It adds Triton kernels, a sparse attention method, configuration options, HuggingFace integration, and comprehensive GPU tests to enable high-performance sparse prefill and decode paths. Changes
Sequence Diagram(s)sequenceDiagram
participant User as User/CLI
participant Config as Config/Conversion
participant Kernel as Kernel Registration
participant Model as HF Model
participant Forward as Forward Pass
User->>Config: select sparse24_triton backend
activate Config
Config->>Config: validate backend="triton"
Config->>Kernel: _register_triton_backend_if_needed()
deactivate Config
activate Kernel
Kernel->>Kernel: register_triton_attention()
Kernel->>Model: set attn_implementation="modelopt_triton"
Kernel->>Model: patch attention interface
deactivate Kernel
User->>Model: forward pass (prefill/decode)
activate Model
Model->>Forward: dispatch via AttentionInterface
deactivate Model
activate Forward
Forward->>Forward: get_sparse_context() from Sparse24Triton
Forward->>Forward: apply 2:4 sparsity mask
Forward->>Forward: unified_attention (2D/3D path)
Forward-->>Forward: return sparse attention output
deactivate Forward
Forward-->>User: sparse attention result
sequenceDiagram
participant Scores as Attention Scores
participant Method as Sparse24Triton Method
participant Mask as Sparsity Mask
participant Triton as Triton Kernel
Scores->>Method: calculate_sparsity(scores)
activate Method
Method->>Mask: _sparse24_mask_along_last_dim()
activate Mask
Mask->>Mask: select top-2 per 4-group
Mask-->>Method: binary sparsity mask
deactivate Mask
Method-->>Scores: stats dict (phase, counts)
deactivate Method
Scores->>Method: apply_sparsity(scores, mask)
activate Method
Method->>Scores: mask * scores
Method-->>Scores: sparse scores
deactivate Method
Scores->>Triton: context_attention_fwd(q, k, v, mask)
activate Triton
Triton->>Triton: detect packed vs. unpacked layout
alt packed layout
Triton->>Triton: derive seq_lens, pack tensors
Triton->>Triton: unified_attention (2D kernel)
Triton->>Triton: unpack output
else unpacked layout
Triton->>Triton: reshape for segment layout
Triton->>Triton: unified_attention (3D kernel + buffers)
Triton->>Triton: reshape back
end
Triton-->>Scores: output attention
deactivate Triton
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/llm_sparsity/attention_sparsity/hf_sa.py (1)
249-255:⚠️ Potential issue | 🟡 Minor
--backendargument is parsed but never used.
args.backendis not referenced anywhere inmain()or the helper functions; the backend is already embedded in eachSPARSE_ATTN_CFG_CHOICESentry. Either wire it up to override thebackendkey in the resolved config, or remove the argument.🛡️ Minimal fix to either wire it up or remove dead code
Option A — wire it up to override the config backend:
sparse_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn] + # Override backend from CLI if explicitly specified + sparse_config = copy.deepcopy(sparse_config) + for key, cfg in sparse_config.get("sparse_cfg", {}).items(): + if isinstance(cfg, dict) and "backend" in cfg: + cfg["backend"] = args.backend + # Override calibration options if provided via CLIOption B — remove the unused argument:
- parser.add_argument( - "--backend", - type=str, - default="pytorch", - choices=["pytorch", "triton"], - help="Backend for sparse attention (default: pytorch). Use 'triton' with sparse24_triton.", - )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/llm_sparsity/attention_sparsity/hf_sa.py` around lines 249 - 255, The --backend arg added via parser.add_argument is parsed into args.backend but never used; either remove that parser entry or propagate args.backend into the sparse-attention config before it is used. To fix, in main() (where args is available and the config is selected from SPARSE_ATTN_CFG_CHOICES and/or built into a resolved_cfg) overwrite the backend key with args.backend (e.g., set resolved_cfg["backend"] = args.backend) so the chosen backend actually takes effect, or remove the parser.add_argument("--backend", ...) line to eliminate the dead argument.
🧹 Nitpick comments (2)
pyproject.toml (1)
84-84: Consider addingE731for consistency with sibling Triton entries.The entries for
modelopt/torch/quantization/triton/*(line 83) andexamples/deepseek/ds_kernel.py(line 85) both includeE731in their suppression rules, while the new entry omits it. While no lambda expressions currently exist in the kernel files, addingE731maintains consistency across the pattern established for Triton-based code and future-proofs against potential additions.🔧 Proposed fix for consistency
-"modelopt/torch/sparsity/attention_sparsity/kernels/*" = ["N803", "N806"] # triton kernel style +"modelopt/torch/sparsity/attention_sparsity/kernels/*" = ["N803", "N806", "E731"] # triton kernel style🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pyproject.toml` at line 84, Add the missing E731 suppression to the Triton kernel entry "modelopt/torch/sparsity/attention_sparsity/kernels/*" so it matches the sibling Triton entries; update the rule that currently lists ["N803", "N806"] to include "E731" as well to maintain consistency and future-proof against lambda usage in the kernels.modelopt/torch/sparsity/attention_sparsity/methods/registry.py (1)
73-83:get_sparse_contextshould be@abstractmethod.
SparseAttentionModule.forward()callsget_sparse_context()unconditionally — every registered method must implement it. All peer methods (calculate_sparsity,apply_sparsity,name) are@abstractmethod; this should be too. Without the decorator, a subclass can be instantiated without implementing it, silently deferring theNotImplementedErrorto runtime.♻️ Proposed fix
- def get_sparse_context(self, module: torch.nn.Module): - """Return a context manager that activates this method's sparsity during forward. - - Each method subclass implements its own activation mechanism: - - Softmax-patching methods replace F.softmax during the forward pass. - - Kernel-fused methods set flags on ``module`` that the kernel reads. - - Args: - module: The SparseAttentionModule wrapping the attention layer. - """ - raise NotImplementedError(f"{type(self).__name__} must implement get_sparse_context()") + `@abstractmethod` + def get_sparse_context(self, module: torch.nn.Module): + """Return a context manager that activates this method's sparsity during forward. + + Each method subclass implements its own activation mechanism: + - Softmax-patching methods replace F.softmax during the forward pass. + - Kernel-fused methods set flags on ``module`` that the kernel reads. + + Args: + module: The SparseAttentionModule wrapping the attention layer. + """🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/methods/registry.py` around lines 73 - 83, Mark get_sparse_context as an abstract method like the other abstract APIs so subclasses cannot be instantiated without implementing it: add the `@abstractmethod` decorator above the get_sparse_context method in the registry (matching how calculate_sparsity, apply_sparsity, and name are decorated) and ensure abstractmethod is imported if not already; this enforces that any subclass used by SparseAttentionModule.forward() must implement get_sparse_context.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/llm_sparsity/attention_sparsity/hf_sa.py`:
- Around line 149-158: The before/after comparison is invalid because the model
is loaded with the default attention backend (e.g., "sdpa") via
AutoModelForCausalLM.from_pretrained but after mtsa.sparsify()
validate_eager_attention forces "eager", so differences mix backend changes with
sparsity effects; fix by either (1) explicitly passing
attn_implementation="eager" into AutoModelForCausalLM.from_pretrained when the
sparse config indicates a PyTorch backend/flash_skip_softmax path (detect via
the sparsity config or args), or (2) add a clear comment in the block around
AutoModelForCausalLM.from_pretrained / mtsa.sparsify() documenting this
limitation and that comparisons for flash_skip_softmax should load with
attn_implementation set to "eager" to ensure a fair baseline.
In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Line 429: Update the inaccurate comment string "# 2:4 structured sparsity via
Triton prefill kernel (prefill-only)" to indicate the Triton kernel supports
both prefill (2D) and decode (3D) paths with the paged KV cache; locate the
comment in the attention sparsity config where "# 2:4 structured sparsity via
Triton prefill kernel (prefill-only)" appears and change it to something like "#
2:4 structured sparsity via unified Triton kernel (supports prefill 2D and
decode 3D with paged KV cache)" so it correctly documents the kernel
capabilities.
In `@modelopt/torch/sparsity/attention_sparsity/methods/sparse24_triton.py`:
- Around line 142-150: get_sparse_context currently sets module._apply_sparse24
and module._skip_diagonal_blocks but only resets _apply_sparse24 on exit;
preserve and restore the original _skip_diagonal_blocks value to avoid mutating
module state. Fix get_sparse_context by reading the original value (use getattr
to handle missing attribute), set module._skip_diagonal_blocks =
self.skip_diagonal_blocks on entry, and in the finally block restore the
original value (use setattr or delattr if the attribute did not exist
originally). Ensure you still clear _apply_sparse24 as before and reference
get_sparse_context, module._apply_sparse24, module._skip_diagonal_blocks, and
self.skip_diagonal_blocks when implementing the change.
---
Outside diff comments:
In `@examples/llm_sparsity/attention_sparsity/hf_sa.py`:
- Around line 249-255: The --backend arg added via parser.add_argument is parsed
into args.backend but never used; either remove that parser entry or propagate
args.backend into the sparse-attention config before it is used. To fix, in
main() (where args is available and the config is selected from
SPARSE_ATTN_CFG_CHOICES and/or built into a resolved_cfg) overwrite the backend
key with args.backend (e.g., set resolved_cfg["backend"] = args.backend) so the
chosen backend actually takes effect, or remove the
parser.add_argument("--backend", ...) line to eliminate the dead argument.
---
Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/methods/registry.py`:
- Around line 73-83: Mark get_sparse_context as an abstract method like the
other abstract APIs so subclasses cannot be instantiated without implementing
it: add the `@abstractmethod` decorator above the get_sparse_context method in the
registry (matching how calculate_sparsity, apply_sparsity, and name are
decorated) and ensure abstractmethod is imported if not already; this enforces
that any subclass used by SparseAttentionModule.forward() must implement
get_sparse_context.
In `@pyproject.toml`:
- Line 84: Add the missing E731 suppression to the Triton kernel entry
"modelopt/torch/sparsity/attention_sparsity/kernels/*" so it matches the sibling
Triton entries; update the rule that currently lists ["N803", "N806"] to include
"E731" as well to maintain consistency and future-proof against lambda usage in
the kernels.
| # Select attn_implementation based on sparse method: | ||
| # - skip_softmax methods require "eager" (softmax patching bypassed by flash/sdpa) | ||
| # - sparse24_triton requires "modelopt_triton" (fused Triton kernel) | ||
| # No need to specify attn_implementation here — mtsa.sparsify() handles it | ||
| # automatically based on the sparse config (sets "modelopt_triton" for triton | ||
| # backend, keeps "eager" for pytorch backend). | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| args.pyt_ckpt_path, | ||
| attn_implementation="eager", | ||
| torch_dtype=torch.bfloat16, | ||
| ) |
There was a problem hiding this comment.
Before/after comparison uses different attention backends for flash_skip_softmax.
Before sparsify() the model runs with whatever attn_implementation was selected at load time (likely "sdpa"); after sparsify() validate_eager_attention forces "eager". Any output difference now conflates sparsity effects with the SDPA → eager backend switch. For the sparse24_triton path this is less of a concern, but the skip_softmax path should still load with a consistent backend for a meaningful comparison.
Consider documenting this limitation in the comment block at lines 149-154, or conditionally set attn_implementation="eager" when the config uses a pytorch backend:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/llm_sparsity/attention_sparsity/hf_sa.py` around lines 149 - 158,
The before/after comparison is invalid because the model is loaded with the
default attention backend (e.g., "sdpa") via
AutoModelForCausalLM.from_pretrained but after mtsa.sparsify()
validate_eager_attention forces "eager", so differences mix backend changes with
sparsity effects; fix by either (1) explicitly passing
attn_implementation="eager" into AutoModelForCausalLM.from_pretrained when the
sparse config indicates a PyTorch backend/flash_skip_softmax path (detect via
the sparsity config or args), or (2) add a clear comment in the block around
AutoModelForCausalLM.from_pretrained / mtsa.sparsify() documenting this
limitation and that comparisons for flash_skip_softmax should load with
attn_implementation set to "eager" to ensure a fair baseline.
| }, | ||
| } | ||
|
|
||
| # 2:4 structured sparsity via Triton prefill kernel (prefill-only) |
There was a problem hiding this comment.
Comment says "prefill-only" but the kernel supports both prefill and decode.
The PR description explicitly states the unified Triton kernel supports both prefill (2D kernel) and decode (3D kernel) paths with paged KV cache. The comment at line 429 is inaccurate and should be corrected to avoid misleading users.
📝 Proposed fix
-# 2:4 structured sparsity via Triton prefill kernel (prefill-only)
+# 2:4 structured sparsity via Triton unified attention kernel (prefill + decode)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # 2:4 structured sparsity via Triton prefill kernel (prefill-only) | |
| # 2:4 structured sparsity via Triton unified attention kernel (prefill + decode) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/config.py` at line 429, Update the
inaccurate comment string "# 2:4 structured sparsity via Triton prefill kernel
(prefill-only)" to indicate the Triton kernel supports both prefill (2D) and
decode (3D) paths with the paged KV cache; locate the comment in the attention
sparsity config where "# 2:4 structured sparsity via Triton prefill kernel
(prefill-only)" appears and change it to something like "# 2:4 structured
sparsity via unified Triton kernel (supports prefill 2D and decode 3D with paged
KV cache)" so it correctly documents the kernel capabilities.
| @contextlib.contextmanager | ||
| def get_sparse_context(self, module: torch.nn.Module): | ||
| """Set _apply_sparse24 and _skip_diagonal_blocks on module for the Triton kernel.""" | ||
| module._apply_sparse24 = True | ||
| module._skip_diagonal_blocks = self.skip_diagonal_blocks | ||
| try: | ||
| yield | ||
| finally: | ||
| module._apply_sparse24 = False |
There was a problem hiding this comment.
_skip_diagonal_blocks is not restored in the finally block.
get_sparse_context sets both _apply_sparse24 and _skip_diagonal_blocks on the module but only resets _apply_sparse24 on exit. If a module had a different _skip_diagonal_blocks value before entering the context, it will be silently overwritten.
Proposed fix
`@contextlib.contextmanager`
def get_sparse_context(self, module: torch.nn.Module):
"""Set _apply_sparse24 and _skip_diagonal_blocks on module for the Triton kernel."""
+ prev_sparse24 = getattr(module, "_apply_sparse24", False)
+ prev_skip_diag = getattr(module, "_skip_diagonal_blocks", True)
module._apply_sparse24 = True
module._skip_diagonal_blocks = self.skip_diagonal_blocks
try:
yield
finally:
- module._apply_sparse24 = False
+ module._apply_sparse24 = prev_sparse24
+ module._skip_diagonal_blocks = prev_skip_diag📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @contextlib.contextmanager | |
| def get_sparse_context(self, module: torch.nn.Module): | |
| """Set _apply_sparse24 and _skip_diagonal_blocks on module for the Triton kernel.""" | |
| module._apply_sparse24 = True | |
| module._skip_diagonal_blocks = self.skip_diagonal_blocks | |
| try: | |
| yield | |
| finally: | |
| module._apply_sparse24 = False | |
| `@contextlib.contextmanager` | |
| def get_sparse_context(self, module: torch.nn.Module): | |
| """Set _apply_sparse24 and _skip_diagonal_blocks on module for the Triton kernel.""" | |
| prev_sparse24 = getattr(module, "_apply_sparse24", False) | |
| prev_skip_diag = getattr(module, "_skip_diagonal_blocks", True) | |
| module._apply_sparse24 = True | |
| module._skip_diagonal_blocks = self.skip_diagonal_blocks | |
| try: | |
| yield | |
| finally: | |
| module._apply_sparse24 = prev_sparse24 | |
| module._skip_diagonal_blocks = prev_skip_diag |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/methods/sparse24_triton.py` around
lines 142 - 150, get_sparse_context currently sets module._apply_sparse24 and
module._skip_diagonal_blocks but only resets _apply_sparse24 on exit; preserve
and restore the original _skip_diagonal_blocks value to avoid mutating module
state. Fix get_sparse_context by reading the original value (use getattr to
handle missing attribute), set module._skip_diagonal_blocks =
self.skip_diagonal_blocks on entry, and in the finally block restore the
original value (use setattr or delattr if the attribute did not exist
originally). Ensure you still clear _apply_sparse24 as before and reference
get_sparse_context, module._apply_sparse24, module._skip_diagonal_blocks, and
self.skip_diagonal_blocks when implementing the change.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #916 +/- ##
==========================================
- Coverage 73.09% 72.06% -1.03%
==========================================
Files 205 208 +3
Lines 22301 22722 +421
==========================================
+ Hits 16300 16375 +75
- Misses 6001 6347 +346 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Do we need to update anything in example readme or changelog?
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
05da7ea to
8f3c0ec
Compare
Signed-off-by: Kai Xu <kaix@nvidia.com>
There was a problem hiding this comment.
Pull request overview
Adds a Triton-based unified attention backend with optional 2:4 structured sparsity, and integrates it into the existing sparse-attention framework (including HuggingFace attn_implementation="modelopt_triton" registration and mtsa.sparsify() flow).
Changes:
- Introduces a Triton unified attention kernel (prefill + decode, forward + backward) with optional 2:4 sparsification and diagonal-tile skipping.
- Refactors sparse-attention “activation” to be method-driven via
SparseAttentionMethod.get_sparse_context(...), enabling both softmax-patching and kernel-flag approaches. - Adds configs/examples/tests for the new
sparse24_tritonmethod and HF integration.
Reviewed changes
Copilot reviewed 15 out of 15 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
modelopt/torch/sparsity/attention_sparsity/kernels/triton_unified_attention.py |
New Triton unified attention kernel with optional 2:4 sparsity and autograd support. |
modelopt/torch/sparsity/attention_sparsity/kernels/hf_triton_attention.py |
HF AttentionInterface backend (attn_implementation="modelopt_triton") + packing/unpacking utilities. |
modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py |
Kernel availability detection + optional HF registration on import. |
modelopt/torch/sparsity/attention_sparsity/methods/sparse24_triton.py |
New sparse-attention method that activates 2:4 sparsity via kernel flags. |
modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py |
Moves softmax patching into the method via get_sparse_context. |
modelopt/torch/sparsity/attention_sparsity/methods/registry.py |
Adds get_sparse_context(...) API to the method base class. |
modelopt/torch/sparsity/attention_sparsity/methods/__init__.py |
Ensures new method module is imported to trigger registration. |
modelopt/torch/sparsity/attention_sparsity/sparse_attention.py |
Delegates sparse activation to method-provided context manager. |
modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py |
Switches prints to logging; relaxes/changes attention-impl validation logic. |
modelopt/torch/sparsity/attention_sparsity/conversion.py |
Auto-registers Triton backend + sets HF _attn_implementation during sparsify when needed. |
modelopt/torch/sparsity/attention_sparsity/config.py |
Adds triton backend support, skip_diagonal_blocks, and SPARSE24_TRITON preset. |
tests/gpu/torch/sparsity/attention_sparsity/test_triton_unified_attention.py |
New GPU tests for kernel vs SDPA, sparse24 correctness, HF integration, and backward. |
examples/llm_sparsity/attention_sparsity/hf_sa.py |
Adds sparse24 config choice; adjusts CLI/options messaging around backends. |
examples/llm_sparsity/attention_sparsity/README.md |
Documents sparse24 Triton method, usage, and differences vs skip-softmax. |
pyproject.toml |
Adds ruff per-file ignores for kernel directory naming conventions. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| batch, num_heads, seq_len, head_dim = query.shape | ||
| seq_k = key.shape[2] | ||
| is_cross_attention = seq_len != seq_k | ||
|
|
||
| # Decode: one query token per sequence, full context in K/V | ||
| if seq_len <= 1: | ||
| attn_output = _decode_attention(module, query, key, value, attention_mask, scaling) | ||
| return (attn_output, None) | ||
|
|
||
| device = query.device | ||
| num_kv_heads = key.shape[1] | ||
| is_causal = not is_cross_attention | ||
| apply_sparse24 = kwargs.get("apply_sparse24", getattr(module, "_apply_sparse24", False)) | ||
| skip_diagonal_blocks = kwargs.get( | ||
| "skip_diagonal_blocks", getattr(module, "_skip_diagonal_blocks", True) | ||
| ) | ||
|
|
||
| needs_grad = torch.is_grad_enabled() and ( | ||
| query.requires_grad or key.requires_grad or value.requires_grad | ||
| ) | ||
|
|
||
| use_packed = attention_mask is not None and _attention_mask_supported_for_triton(attention_mask) | ||
| if use_packed: | ||
| q_packed, k_packed, v_packed, b_start_loc, b_seq_len, max_input_len = ( | ||
| _derive_seq_lens_and_pack(query, key, value, attention_mask) | ||
| ) |
There was a problem hiding this comment.
use_packed is enabled solely based on attention_mask shape, but _derive_seq_lens_and_pack() assumes the same per-row lengths apply to Q/K/V. For cross-attention (seq_len != seq_k), HF masks typically describe only the K/V side (shape [batch, seq_k]), so packing query[..., token_indices, :] can index past seq_len or silently pack the wrong tokens. Consider disabling the packed path for cross-attention (require seq_len == seq_k) or implementing separate packing metadata for Q vs K/V (e.g., derive b_seq_len from query.shape[2] and b_seq_len_k from the mask).
| def register_triton_attention() -> bool: | ||
| """Register the Triton backend with HF AttentionInterface. | ||
|
|
||
| Call after importing this module so that attn_implementation="modelopt_triton" | ||
| is available when loading models. | ||
|
|
||
| Returns: | ||
| True if registration succeeded. | ||
| """ | ||
| try: | ||
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS | ||
|
|
||
| ALL_ATTENTION_FUNCTIONS.register("modelopt_triton", triton_attention_forward) | ||
| return True | ||
| except Exception: | ||
| return False |
There was a problem hiding this comment.
register_triton_attention() catches all exceptions and returns False without any logging, which can make failed registration (e.g., API mismatch across transformers versions) very hard to debug—especially since other code may still set model.config._attn_implementation = "modelopt_triton". Consider catching narrower exceptions or at least logging the exception (and/or warning) when registration fails.
| parser.add_argument( | ||
| "--backend", | ||
| type=str, | ||
| default="pytorch", | ||
| choices=["pytorch"], | ||
| help="Backend for sparse attention (default: pytorch). More backends coming soon.", | ||
| choices=["pytorch", "triton"], | ||
| help="Backend for sparse attention (default: pytorch). Use 'triton' with sparse24_triton.", | ||
| ) |
There was a problem hiding this comment.
The --backend CLI option is now advertised as selecting between pytorch and triton, but args.backend is never read—backend selection currently comes only from --sparse_attn (the chosen config already hard-codes its backend). Either remove --backend (and update docs) or use it to override sparse_config["sparse_cfg"]["*attn*"]["backend"] / validate consistency with --sparse_attn so the flag actually affects behavior.
| --sparse_attn sparse24_triton \ | ||
| --backend triton |
There was a problem hiding this comment.
The README instructs running hf_sa.py with --backend triton, but the script currently does not use args.backend at all (backend is implicitly determined by --sparse_attn via the selected config). This makes the documented CLI invocation misleading unless the script is updated to honor --backend.
| --sparse_attn sparse24_triton \ | |
| --backend triton | |
| --sparse_attn sparse24_triton |
| @contextlib.contextmanager | ||
| def get_sparse_context(self, module: torch.nn.Module): | ||
| """Set _apply_sparse24 and _skip_diagonal_blocks on module for the Triton kernel.""" | ||
| module._apply_sparse24 = True | ||
| # Diagonal skip only applies to causal self-attention; for cross-attention | ||
| # there is no diagonal relationship between Q and K positions. | ||
| module._skip_diagonal_blocks = self.skip_diagonal_blocks and self.is_causal | ||
| try: | ||
| yield | ||
| finally: | ||
| module._apply_sparse24 = False | ||
|
|
There was a problem hiding this comment.
SparseAttentionModule collects stats only when module._last_stats is populated during the forward context. For sparse24_triton, get_sparse_context() just sets kernel flags and never sets _last_stats, so enabling collect_stats=True in the config will silently collect nothing. Consider explicitly disallowing collect_stats for this method/backend (with a clear error/warning), or wiring stats collection in a way that actually produces per-forward stats for this method.
| attn_impl = getattr(model.config, "_attn_implementation", None) | ||
| if attn_impl and attn_impl != "eager": | ||
| if attn_impl and attn_impl not in ("eager", "modelopt_triton"): | ||
| warnings.warn( | ||
| f"Sparse attention requires attn_implementation='eager', but model uses '{attn_impl}'. " | ||
| "Forcing eager attention implementation." | ||
| f"Sparse attention expects attn_implementation='eager' or 'modelopt_triton', " | ||
| f"but model uses '{attn_impl}'. Forcing eager attention implementation." | ||
| ) | ||
| model.config._attn_implementation = "eager" |
There was a problem hiding this comment.
Allowing attn_implementation="modelopt_triton" here can break softmax-patching methods (e.g. flash_skip_softmax with backend="pytorch"): in that setup the model must be forced to eager, otherwise torch.nn.functional.softmax is never called and sparsity/stats won’t apply. Consider making the enforcement conditional on the configured sparse-attention backend/method (e.g. force eager when any enabled sparse cfg uses backend="pytorch"; set/allow modelopt_triton only for the Triton kernel path).
What does this PR do?
Type of change: ?
New feature
Overview: ?
attn_implementation="modelopt_triton"that automatically set bymtsa.sparsify().Usage
Testing
tests/gpu/torch/sparsity/attention_sparsity/test_triton_unified_attention.py
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
Release Notes
New Features
Tests