Skip to content

[bugfix]: preserve FSDP hooks for RMSNorm qk norms#1513

Open
macthecadillac wants to merge 2 commits into
hao-ai-lab:mainfrom
macthecadillac:investigate/rmsnorm-fsdp-forward-native
Open

[bugfix]: preserve FSDP hooks for RMSNorm qk norms#1513
macthecadillac wants to merge 2 commits into
hao-ai-lab:mainfrom
macthecadillac:investigate/rmsnorm-fsdp-forward-native

Conversation

@macthecadillac

@macthecadillac macthecadillac commented Jun 29, 2026

Copy link
Copy Markdown
Collaborator

Summary

Fixes #1379.

This replaces the remaining direct q/k RMSNorm forward_native(...) calls with normal module dispatch so FSDP hooks run:

  • self.norm_q.forward_native(query) -> self.norm_q(query)
  • self.norm_k.forward_native(key) -> self.norm_k(key)

The remaining production call sites were in causal Wan and MatrixGame2 causal transformer blocks. This keeps the selected RMSNorm implementation unchanged while avoiding bypassing nn.Module.__call__ under FSDP.

Root Cause

Calling RMSNorm.forward_native(...) directly bypasses FSDP2 module hooks. With sharded RMSNorm weights, this can leave self.weight as a DTensor, and the native multiply fails with a mixed Tensor/DTensor error. Calling RMSNorm(x) lets FSDP unwrap/localize parameters as expected.

Issue #1379 Investigation Checklist

  • Reliable reproducer: added fastvideo/tests/layers/test_rmsnorm_forward_dispatch.py and reproduced on Modal 2x L40S with PyTorch 2.11.0+cu128, FSDP2 fully_shard, and both regular sharding and CPUOffloadPolicy.
  • Upstream issue determination: no exact matching PyTorch issue was found. The repro shows RMSNorm(x) works while direct RMSNorm.forward_native(x) fails, so this appears to be FastVideo bypassing FSDP module hooks rather than an FSDP-side bug.
  • Unit test for CPU-offload + DTensor weight path: the new test covers direct_cpu_offload, where norm.weight remains a CPU DTensor and direct forward_native fails with the mixed Tensor/DTensor error; it also verifies module dispatch succeeds with CPU offload.
  • Behavioral spec: not final-layer-only. The failure can affect any FSDP-sharded RMSNorm(has_weight=True) called directly through forward_native; it reproduces with and without CPU offload.

Tests

  • uvx pre-commit==4.0.1 run --all-files
    • passed after the review follow-up commit
  • Modal L40S: pytest fastvideo/tests/layers/test_rmsnorm_forward_dispatch.py -vs
    • 2x L40S, latest branch includes the 120-second torchrun timeout follow-up
    • 5 passed, 14 warnings in 38.28s
  • Modal L40S: pytest fastvideo/tests/transformers/test_wanvideo.py -vs
    • 1x L40S, production fix applied
    • 1 passed, 14 warnings in 50.05s
  • Modal L40S targeted Wan T2V SSIM:
    • FASTVIDEO_SSIM_MODEL_ID=Wan2.1-T2V-1.3B-Diffusers pytest fastvideo/tests/ssim/test_wan_t2v_similarity.py -vs
    • 2x L40S, commit 75e1659310ae306d342b6bfcd0e8ef7ebb0b5caa
    • 2 passed, 18 warnings in 123.41s
    • FLASH_ATTN-parametrized mean SSIM: 0.976557461420695
    • TORCH_SDPA mean SSIM: 0.9821627881791857
    • Note: the Modal image's FlashAttention import failed and the FLASH_ATTN-parametrized run fell back to Torch SDPA.

Checklist

  • I ran pre-commit run --all-files and fixed all issues
  • I added or updated tests for my changes
  • I updated documentation if needed
  • I considered GPU memory impact of my changes

For model/pipeline changes, also check:

  • I verified targeted Wan T2V SSIM regression tests pass on L40S
  • I updated the support matrix if adding a new model

Documentation/support-matrix notes: no docs or support matrix update was needed; this does not add a model or user-facing API. GPU memory impact should be neutral because this only routes existing RMSNorm calls through module dispatch so FSDP hooks can run.

@mergify mergify Bot added type: bugfix Bug fix scope: infra CI, tests, Docker, build scope: model Model architecture (DiTs, encoders, VAEs) labels Jun 29, 2026
@mergify

mergify Bot commented Jun 29, 2026

Copy link
Copy Markdown
Contributor

Merge Protections

🔴 1 of 1 protections blocking · waiting on 🤖 CI

Protection Waiting on
🔴 PR merge requirements 🤖 CI

🔴 PR merge requirements

Waiting for

  • check-success=full-suite-passed
  • check-success~=pre-commit
This rule is failing.
  • check-success=full-suite-passed
  • check-success~=pre-commit
  • #approved-reviews-by>=1
  • check-success=fastcheck-passed
  • title~=(?i)^\[(feat|feature|bugfix|fix|refactor|perf|ci|doc|docs|misc|chore|kernel|new.?model|skill|skills|infra)\]

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request replaces direct calls to .forward_native() on RMSNorm layers with standard module calls in causal_wanvideo.py and causal_model.py to prevent bypassing FSDP hooks. It also introduces a new test suite to verify this behavior under FSDP. The feedback suggests adding a timeout to the subprocess.run call when executing torchrun in the tests to prevent potential deadlocks or hangs in CI/CD pipelines.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread fastvideo/tests/layers/test_rmsnorm_forward_dispatch.py Outdated
@macthecadillac macthecadillac marked this pull request as ready for review June 29, 2026 12:38

@SolitaryThinker SolitaryThinker left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed the production dispatch changes and regression coverage. One inline correctness finding; CI status intentionally excluded.

Comment thread fastvideo/tests/layers/test_rmsnorm_forward_dispatch.py
@SolitaryThinker

Copy link
Copy Markdown
Collaborator

/merge

@github-actions github-actions Bot added the ready PR is ready to merge label Jun 30, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready PR is ready to merge scope: infra CI, tests, Docker, build scope: model Model architecture (DiTs, encoders, VAEs) type: bugfix Bug fix

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RMSNorm.forward_native doesn't unwrap DTensor weight under FSDP CPU offload

2 participants