Add support for DGPO (ICLR 2026) to GRPO#5102
Add support for DGPO (ICLR 2026) to GRPO#5102YanqiDai wants to merge 20 commits intohuggingface:mainfrom
Conversation
There was a problem hiding this comment.
a few comments.
also: _generate_and_score is getting too dense with this PR. DGAE/DQW + valid-token balancing logic and the existing multi-objective aggregation both add substantial branching/state. It's becoming hard to follow and validate each transformation in isolation.
I think it makes sense to pull most of these out into separate helpers?
trl/trainer/grpo_trainer.py
Outdated
| num_questions, device=advantages.device, dtype=advantages.dtype | ||
| ) | ||
| if num_zero_variance_questions < num_questions: | ||
| # For mean accuracy 0 (all wrong) or NaN, set difficulty to -1 so they get less weight |
There was a problem hiding this comment.
| # For mean accuracy 0 (all wrong) or NaN, set difficulty to -1 so they get less weight | |
| # mean accuracy == 0 (all wrong) or NaN are remapped to 1.0 before softmax so they get less weight``` |
There was a problem hiding this comment.
style nit
but also, doesn't this imply rewards have to be >0?
There was a problem hiding this comment.
Thanks. This is only the judgment and operation for the accuracy reward. We default the accuracy reward range to [0,1].
Updated the DGPO section to clarify its mechanisms and usage in TRL, including details on DGAE and DQW.
Removed DGPO section and its related details from the documentation.
Co-authored-by: LeonEricsson <70749762+LeonEricsson@users.noreply.github.com>
Updated comment to clarify handling of mean accuracy for questions.
We have adopted and modified all the suggestions. In particular, we have rewritten the code logic and separated the main functional implementations into helper functions: _compute_advantages_with_dgae, _compute_valid_token_balancing_ratios, and _compute_dqw_weights. The final code also passed our test program. |
trl/trainer/grpo_trainer.py
Outdated
| if self.use_dgpo_dgae: | ||
| advantages = self._compute_advantages_with_dgae( | ||
| rewards, num_generations | ||
| ) |
There was a problem hiding this comment.
this split feels a bit asymmetric: the DGAE path goes into _compute_advantages_with_dgae while the standard advantage computation (center by mean, divide by std) stays inline. When reading the code you now have to jump to a separate method for one path but not the other, even though they're doing the same conceptual thing.
I think there's motivation for a larger refactoring of the advantage calculations LOC 1844-1890, but I'd like a maintainers thoughts on this. My suggestion would be to at least pull out std_rewards calculation into a helper, but open to a larger refactor as well.
There was a problem hiding this comment.
Thank you for your suggestion. We have found that the implementation of _compute_advantages_with_dgae can be simplified and requires some minor adjustments. After making these corrections, we implemented it directly in _generate_and_score_completions (it only takes 3 lines of code, just as concise as using std_rewards).
Co-authored-by: LeonEricsson <70749762+LeonEricsson@users.noreply.github.com>
Removed the repeated use_bias_correction_kl configuration option.
…nerate_and_score_completions
LeonEricsson
left a comment
There was a problem hiding this comment.
lgtm now. the only thing I would consider is some refactoring of the advantage calculation, as discussed here
needs a maintainers approval before merging
What does this PR do?
Add DGPO (Difficulty-Aware Group Policy Optimization) support to GRPO.
References: MathForge (ICLR 2026) GitHub, Paper.
Before submitting
Pull Request section? Yes
to it if that's the case. No
docs/source/grpo_trainer.mdanddocs/source/paper_index.md.test_training_dgpointests/test_grpo_trainer.py.Motivation
This PR integrates DGPO (Difficulty-Aware Group Policy Optimization) from MathForge (ICLR 2026, paper) into the GRPO trainer. DGPO improves group-based RL by:
std_rewards != 0) contribute to the effective normalizer, and (b) multi-GPU training is balanced by accounting for per-process valid token counts. This yields a proper token-level average loss across valid data and devices.These options are useful for mathematical reasoning and other settings where question difficulty is heterogeneous and reward variance can be zero for some groups.
Changes
Configuration (
grpo_config.py)use_dgpo_dgae(bool, defaultFalse): When True andscale_rewards != "none", advantages are normalized by MAD instead of standard deviation (DGAE).use_dgpo_dqw(bool, defaultFalse): When True, advantages are multiplied by per-question difficulty weights (DQW). Zero-variance questions get weight 1; others are weighted by a softmax over negative mean accuracy so that harder questions get higher weight; weights sum tonum_questions.dgpo_dqw_temp(float, default2.0): Temperature for the DQW softmax.dgpo_dqw_acc_reward_index(int, default0): Index of the accuracy reward inreward_funcsused to compute per-question mean accuracy for DQW.All new parameters are documented with a reference to the MathForge paper (ICLR 2026).
Trainer (
grpo_trainer.py)DGAE
In both
sum_then_normalizeandnormalize_then_sumbranches, whenuse_dgpo_dgaeis True and rewards are scaled, the advantage denominator uses MAD instead of std:advantage = (reward - mean) / (MAD + eps).DQW
After advantage computation (and after the valid-token scaling described below), when
use_dgpo_dqwis True:dgpo_dqw_acc_reward_indexare computed.(num_questions - num_zero_variance) * softmax(-mean_acc / dgpo_dqw_temp); questions with mean accuracy 0 or NaN are treated as “easiest” (mean set to 1 in the softmax) so they receive less weight.Valid token-level loss averaging (when
use_dgpo_dgaeoruse_dgpo_dqwis True)std_rewards != 0(i.e.~is_std_zero).completion_length = completion_mask.sum(dim=1)is gathered across processes to getgathered_completion_length.global_balancing_ratio = num_processes * local_completion_length_sum / global_completion_length_sumis computed (used later).zero_mask_ratio = global_completion_length_sum / valid_completion_length_sum, wherevalid_completion_length_sumis the sum of completion lengths over valid samples only; otherwisezero_mask_ratio = 1.0.zero_mask_ratioso that the effective normalizer ignores invalid (zero-variance) samples.global_balancing_ratioso that the loss is balanced across processes by valid token count (valid token-level averaging across devices).Tests
test_training_dgpointests/test_grpo_trainer.py: Runs a short training withuse_dgpo_dgae=True,use_dgpo_dqw=True,dgpo_dqw_temp=2.0, anddgpo_dqw_acc_reward_index=0, and checks thattrain_lossis recorded.