Skip to content

Fix double prepare_grads / loss-scaler-double-update in train_one_step#1842

Open
jthomy wants to merge 1 commit intoTHUDM:mainfrom
jthomy:fix_double_unscale_fp16
Open

Fix double prepare_grads / loss-scaler-double-update in train_one_step#1842
jthomy wants to merge 1 commit intoTHUDM:mainfrom
jthomy:fix_double_unscale_fp16

Conversation

@jthomy
Copy link
Copy Markdown

@jthomy jthomy commented Apr 17, 2026

When training with fp16, and therefore args.check_for_nan_in_loss_and_grad=False, train_one_step calls optimizer.prepare_grads() and then optimizer.step(). Megatron's MixedPrecisionOptimizer.step() calls prepare_grads() internally, so prepare_grads runs twice per step causing a doubly unscaled unscaling, causing extremely small gradients for fp16.

Since optimizer.step() automatically skips the update when inf is detected, we can first run optimizer.step and check for if it was a valid step afterwards. The behavior for non-fp16 is unchanged.

Results with a GLM 4.7 30B Flash:
BF16 (for comparison)
image
FP16 without fix (zero means the grad scaler was too high and is reducing):
image
FP16 with fix:
image

When args.check_for_nan_in_loss_and_grad=False, train_one_step called
optimizer.prepare_grads() and then optimizer.step(). Megatron's
MixedPrecisionOptimizer.step() calls prepare_grads() internally, so
prepare_grads ran twice per step. With fp16 + a grad scaler, that:

  - advanced grad_scaler.update() twice per step, breaking dynamic loss
    scaling cadence;
  - on configurations where model_param.main_grad persists (typical DDP),
    re-copied scaled grads into main grads and unscaled them again.

Additionally, when optimizer.step() returned (False, None, None) on a
grad-scaler overflow, the subsequent assert update_successful fired even
though that's a legitimate skipped-step signal, not a programming error.

Call optimizer.step() exactly once and derive the overflow signal from
its return values, skipping the LR scheduler advance and emitting a
warning when an overflow occurs. Move the MTP CI gradient check before
optimizer.step() since step() modifies gradients.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant