Skip to content

Conversation

@fy817
Copy link

@fy817 fy817 commented Jan 12, 2026

Summary

This PR fixes a critical logic error in how gradient clipping interacts with the Muon optimizer within the ZeRO Stage 1 and 2 implementations.

Currently, the muon_update (orthogonalization) is applied before gradient clipping. Since Muon's orthogonalization naturally changes the gradient magnitude (often resulting in norms significantly larger than the default clip_grad of 1.0), the subsequent standard clipping mechanism incorrectly identifies these valid updates as exploding gradients. This results in aggressive, erroneous clipping that severely degrades or breaks training.

This fix ensures that gradient clipping is applied to the raw gradients before the Muon update occurs, ensuring the orthogonalization step receives properly clipped inputs, while preventing double-clipping during the optimizer step.

Root Cause

In the current implementation, muon_update is invoked inside get_flat_partition to leverage the shaped gradients required by the algorithm. However, get_flat_partition is called before the standard global gradient clipping phase.

  1. Muon Characteristics: The Muon optimizer performs an orthogonalization step (via Newton-Schulz iteration), which typically results in gradient tensors with Frobenius norms much larger than 1.0.
  2. Clipping Conflict: DeepSpeed defaults to a clip_grad value of 1.0. Because the gradients have already been processed by Muon by the time clipping occurs, their norms are naturally high.
  3. Result: The global clipper perceives these high norms as gradient explosion and scales the gradients down by a massive factor (e.g., dividing by 100+), effectively neutralizing the Muon update and causing training failure.

Fix

The fix rearranges the clipping logic for parameters utilizing the Muon optimizer:

  1. Pre-computation: In independent_gradient_partition_epilogue, we now pre-calculate the unscaled global gradient norm (or the specific norm required for clipping) before the flattening process begins.
  2. Pre-Update Clipping: This calculated norm is passed into get_flat_partition. Inside this method, we calculate the clip_factor and apply it (grad = grad / clip_factor) explicitly to the gradients before passing them to muon_update. This ensures Muon operates on clipped, stable gradients.
  3. Skip Redundant Clipping: In the step method, we modified the call to unscale_and_clip_grads to accept a skip_clipping=True flag. For Muon parameter groups, we set this flag to true, ensuring that we only perform unscaling (handling loss scale) without re-applying gradient clipping, as it was already handled in the previous step. Standard AdamW parameters retain the original behavior.

@fy817 fy817 force-pushed the fix-muon-grad-clip-conflict branch from c63f84f to 1a693b5 Compare January 12, 2026 11:55
@sfc-gh-truwase
Copy link
Collaborator

@PKUWZP, please help review

@sfc-gh-truwase sfc-gh-truwase requested a review from PKUWZP January 13, 2026 16:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants