[bugfix]: preserve FSDP hooks for RMSNorm qk norms#1513
Conversation
Merge Protections🔴 1 of 1 protections blocking · waiting on 🤖 CI
🔴 PR merge requirementsWaiting for
This rule is failing.
|
There was a problem hiding this comment.
Code Review
This pull request replaces direct calls to .forward_native() on RMSNorm layers with standard module calls in causal_wanvideo.py and causal_model.py to prevent bypassing FSDP hooks. It also introduces a new test suite to verify this behavior under FSDP. The feedback suggests adding a timeout to the subprocess.run call when executing torchrun in the tests to prevent potential deadlocks or hangs in CI/CD pipelines.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
SolitaryThinker
left a comment
There was a problem hiding this comment.
Reviewed the production dispatch changes and regression coverage. One inline correctness finding; CI status intentionally excluded.
|
/merge |
Summary
Fixes #1379.
This replaces the remaining direct q/k RMSNorm
forward_native(...)calls with normal module dispatch so FSDP hooks run:self.norm_q.forward_native(query)->self.norm_q(query)self.norm_k.forward_native(key)->self.norm_k(key)The remaining production call sites were in causal Wan and MatrixGame2 causal transformer blocks. This keeps the selected RMSNorm implementation unchanged while avoiding bypassing
nn.Module.__call__under FSDP.Root Cause
Calling
RMSNorm.forward_native(...)directly bypasses FSDP2 module hooks. With sharded RMSNorm weights, this can leaveself.weightas aDTensor, and the native multiply fails with a mixed Tensor/DTensor error. CallingRMSNorm(x)lets FSDP unwrap/localize parameters as expected.Issue #1379 Investigation Checklist
fastvideo/tests/layers/test_rmsnorm_forward_dispatch.pyand reproduced on Modal 2x L40S with PyTorch2.11.0+cu128, FSDP2fully_shard, and both regular sharding andCPUOffloadPolicy.RMSNorm(x)works while directRMSNorm.forward_native(x)fails, so this appears to be FastVideo bypassing FSDP module hooks rather than an FSDP-side bug.direct_cpu_offload, wherenorm.weightremains a CPUDTensorand directforward_nativefails with the mixed Tensor/DTensor error; it also verifies module dispatch succeeds with CPU offload.RMSNorm(has_weight=True)called directly throughforward_native; it reproduces with and without CPU offload.Tests
uvx pre-commit==4.0.1 run --all-filespytest fastvideo/tests/layers/test_rmsnorm_forward_dispatch.py -vstorchruntimeout follow-up5 passed, 14 warnings in 38.28spytest fastvideo/tests/transformers/test_wanvideo.py -vs1 passed, 14 warnings in 50.05sFASTVIDEO_SSIM_MODEL_ID=Wan2.1-T2V-1.3B-Diffusers pytest fastvideo/tests/ssim/test_wan_t2v_similarity.py -vs75e1659310ae306d342b6bfcd0e8ef7ebb0b5caa2 passed, 18 warnings in 123.41sFLASH_ATTN-parametrized mean SSIM:0.976557461420695TORCH_SDPAmean SSIM:0.9821627881791857FLASH_ATTN-parametrized run fell back to Torch SDPA.Checklist
For model/pipeline changes, also check:
Documentation/support-matrix notes: no docs or support matrix update was needed; this does not add a model or user-facing API. GPU memory impact should be neutral because this only routes existing RMSNorm calls through module dispatch so FSDP hooks can run.