Fix Muon optimizer conflict with gradient clipping in ZeRO 1/2 #7776
+50
−6
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
This PR fixes a critical logic error in how gradient clipping interacts with the Muon optimizer within the ZeRO Stage 1 and 2 implementations.
Currently, the
muon_update(orthogonalization) is applied before gradient clipping. Since Muon's orthogonalization naturally changes the gradient magnitude (often resulting in norms significantly larger than the defaultclip_gradof 1.0), the subsequent standard clipping mechanism incorrectly identifies these valid updates as exploding gradients. This results in aggressive, erroneous clipping that severely degrades or breaks training.This fix ensures that gradient clipping is applied to the raw gradients before the Muon update occurs, ensuring the orthogonalization step receives properly clipped inputs, while preventing double-clipping during the optimizer step.
Root Cause
In the current implementation,
muon_updateis invoked insideget_flat_partitionto leverage the shaped gradients required by the algorithm. However,get_flat_partitionis called before the standard global gradient clipping phase.clip_gradvalue of 1.0. Because the gradients have already been processed by Muon by the time clipping occurs, their norms are naturally high.Fix
The fix rearranges the clipping logic for parameters utilizing the Muon optimizer:
independent_gradient_partition_epilogue, we now pre-calculate the unscaled global gradient norm (or the specific norm required for clipping) before the flattening process begins.get_flat_partition. Inside this method, we calculate theclip_factorand apply it (grad = grad / clip_factor) explicitly to the gradients before passing them tomuon_update. This ensures Muon operates on clipped, stable gradients.stepmethod, we modified the call tounscale_and_clip_gradsto accept askip_clipping=Trueflag. For Muon parameter groups, we set this flag to true, ensuring that we only perform unscaling (handling loss scale) without re-applying gradient clipping, as it was already handled in the previous step. Standard AdamW parameters retain the original behavior.