Skip to content

Updated test to include CK/AITER V2/V3 test in single backend case#454

Merged
Micky774 merged 9 commits intodevfrom
zain/ck-pytorch-test
Feb 27, 2026
Merged

Updated test to include CK/AITER V2/V3 test in single backend case#454
Micky774 merged 9 commits intodevfrom
zain/ck-pytorch-test

Conversation

@Micky774
Copy link
Contributor

@Micky774 Micky774 commented Feb 20, 2026

Description

Modifies pytorch FA tests to essentially treat the CK backend as two virtual backends, thus directly comparing V2/V3 implementations even when no other backend is available.

Note that sometimes V3 will fallback to V2, so it's not a rigorous test and doesn't substitute having another backend to compare against, but it enables more tests that would otherwise be skipped.

Fixes https://github.com/ROCm/frameworks-internal/issues/15114

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

# FusedAttention backend
if fused_attn_supported:
if len(fused_attn_backends) == 1:
if len(fused_attn_backends) == 1 and FusedAttnBackend["CK"] not in fused_attn_backends:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FusedAnntBackend["CK"] should be guarded - it is invalid key on NV platform.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

)
if len(fused_attn_backends) == 2:
# We can consider the CK backend as being two, since we have V2/V3 kernels
elif len(fused_attn_backends) == 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid code duplication it should rather be outside of len(fused_atn_backends) but under separate 'if IS_HIP_EXTENSION and FusedAnndBackends["CK"] in fused_attention_backends'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't seem so. The code is still duplicated for len(fused_attn_backends) of 1 and 2

@ipanfilo
Copy link
Collaborator

How many previously skipped test cases are run with that change? It only allows running configs that are not supported by AOTriton but supported by Unfused or FA.
If the goal is to make CK V3/V2 comparison in all configs, skip condition for total number of supported backends should also be updated

@Micky774
Copy link
Contributor Author

How many previously skipped test cases are run with that change? It only allows running configs that are not supported by AOTriton but supported by Unfused or FA. If the goal is to make CK V3/V2 comparison in all configs, skip condition for total number of supported backends should also be updated

Good point, I've updated the skip conditions to double-count CK directly.

is_training,
)
if len(fused_attn_backends) == 2:
# We can consider the CK backend as being two, since we have V2/V3 kernels
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code duplicates the same addition under len(fused_attn_backend) == 2:


Running of V2 is independent from len(fised_attn_backend), but is controlled by FusedAttnBackend["CK"] presence so it should be outside of that if.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've reorganized the test a bit now, thanks.

# Double-count the CK backend since we want to compare V2/V3 kernels
if (
len(fused_attn_backends) +
int(IS_HIP_EXTENSION and FusedAttnBackend["CK"] in fused_attn_backends) +
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why using int? Because FusedAttnBackend gives two variants its presence is enough to continue so add 'not (IS_HIP_EXTENSION and FusedAttnBackend["CK"] in fused_attn_backends)' to original condition

Copy link
Contributor Author

@Micky774 Micky774 Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The int expression double-counts CK as a backend, so if fused_attn_backends=FusedAttnBackend["CK"] then the int(...) == 1. It seemed like applying this to the condition explained why the skip was being avoided in a more literal way (i.e. adjusting the count of the number of backends directly, rather than an excluding condition).

I'm happy to adjust to your approach though.

torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_2, **tols)
for i, _ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_2[i], **tols)
if (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just "if has_ck_backend" and you can remove V/2/V3 comparison from len(fused_attn_backend) == 2 section

RoPE,
is_training,
)
if has_ck_backend:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apply the same logic as in test_dot_product_attention

@Micky774 Micky774 merged commit ca2965d into dev Feb 27, 2026
2 checks passed
@Micky774 Micky774 deleted the zain/ck-pytorch-test branch February 27, 2026 23:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants