From 4a8bf02bf4bb7ec908ba27e77d27c1771829f95a Mon Sep 17 00:00:00 2001 From: Yanqi Dai Date: Sun, 15 Feb 2026 22:14:11 +0800 Subject: [PATCH 01/12] Add DGPO (Difficulty-Aware Group Policy Optimization, ICLR 2026) support to GRPO --- docs/source/grpo_trainer.md | 8 ++++ docs/source/paper_index.md | 19 +++++++++ tests/test_grpo_trainer.py | 25 ++++++++++++ trl/trainer/grpo_config.py | 49 +++++++++++++++++++++++ trl/trainer/grpo_trainer.py | 78 ++++++++++++++++++++++++++++++++++++- 5 files changed, 177 insertions(+), 2 deletions(-) diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 8216b34f2ef..aea7149656f 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -165,6 +165,14 @@ They recommend using asymmetric temperatures, \\( \tau_{\text{neg}} > \tau_{\te To use this formulation, set `loss_type="sapo"` in the [`GRPOConfig`]. +### DGPO (Difficulty-Aware Group Policy Optimization) + +DGPO extends GRPO with difficulty-aware mechanisms to improve training on tasks with varying question difficulty (e.g., math reasoning). It is introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614) (ICLR 2026) and is supported in [`GRPOTrainer`] via [`GRPOConfig`]. + +- **DGAE (Difficulty-balanced Group Advantage Estimation)**: When `use_dgpo_dgae=True`, advantages are scaled using Mean Absolute Deviation (MAD) instead of standard deviation, i.e. advantage = (reward - mean) / (MAD + eps), which can address the implicit imbalance +where the update magnitudes are suppressed for both easier and harder questions and peak for those of moderate difficulty. +- **DQW (Difficulty-aware Question-level Weighting)**: When `use_dgpo_dqw=True`, each question (prompt group) is assigned a weight based on its difficulty (e.g., mean accuracy reward). Harder questions get higher weight so the policy focuses more on them. Use `dgpo_dqw_temp` to control how sharp the weighting is (lower = more focus on hard questions) and `dgpo_dqw_acc_reward_index` to specify which reward in `reward_funcs` is used as the accuracy/difficulty signal. + ## Logged metrics While training and evaluating, we record the following reward metrics: diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index b4a73346074..c49bee59b89 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -222,6 +222,25 @@ training_args = GRPOConfig( ) ``` +### DGPO: Difficulty-Aware Group Policy Optimization + +**📜 Paper**: https://huggingface.co/papers/2601.20614 + +DGPO extends GRPO with difficulty-aware mechanisms: DGAE (difficulty-balanced group advantage estimation using MAD instead of std) and DQW (difficulty-aware question-level weighting). To use DGPO in TRL, enable the corresponding options in [`GRPOConfig`]: + +```python +from trl import GRPOConfig, GRPOTrainer + +training_args = GRPOConfig( + ..., + use_dgpo_dgae=True, + use_dgpo_dqw=True, + dgpo_dqw_temp=2.0, + dgpo_dqw_acc_reward_index=0, +) +trainer = GRPOTrainer(..., args=training_args, reward_funcs=[...], train_dataset=...) +``` + ### Part I: Tricks or Traps? A Deep Dive into RL for LLM Reasoning (Lite PPO) **📜 Paper**: https://huggingface.co/papers/2508.08221 diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 5cc5f498146..511ad9c13e3 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -231,6 +231,31 @@ def test_training_loss_types(self, loss_type): new_param = trainer.model.get_parameter(n) assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + def test_training_dgpo(self): + """Test DGPO (Difficulty-Aware Group Policy Optimization) runs without error.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=32, + report_to="none", + use_dgpo_dgae=True, + use_dgpo_dqw=True, + dgpo_dqw_temp=2.0, + dgpo_dqw_acc_reward_index=0, + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + trainer.train() + assert trainer.state.log_history[-1]["train_loss"] is not None + def test_training_with_eval(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 38dc4d5d115..3cb426ba24a 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -292,6 +292,24 @@ class GRPOConfig(TrainingArguments): Whether to use the unbiased KL divergence estimator with importance sampling correction. This corrects the KL divergence estimate by multiplying it with the importance sampling ratio. This is described in the [DeepSeek-V3.2 paper](https://huggingface.co/papers/2512.02556). + use_dgpo_dgae (`bool`, *optional*, defaults to `False`): + Whether to use difficulty-balanced group advantage estimation (DGAE). When `True`, the denominator when + scaling advantages uses the Mean Absolute Deviation (MAD) of rewards instead of the standard deviation, i.e. + advantage = (reward - mean) / (MAD + eps) with MAD = mean(|reward - mean|). Introduced in the [MathForge + paper](https://huggingface.co/papers/2601.20614) (ICLR 2026). + use_dgpo_dqw (`bool`, *optional*, defaults to `False`): + Whether to use difficulty-aware question-level weighting (DQW). When `True`, question weights (softmax over + negative mean accuracy reward at `dgpo_dqw_acc_reward_index`) are multiplied directly onto the advantages, + so harder questions get larger effective advantages. Introduced in the [MathForge + paper](https://huggingface.co/papers/2601.20614) (ICLR 2026). + dgpo_dqw_temp (`float`, *optional*, defaults to `2.0`): + Temperature for the DQW softmax over negative mean (accuracy) reward. Higher values make the weighting more + uniform; lower values concentrate weight on harder questions. Introduced in the [MathForge + paper](https://huggingface.co/papers/2601.20614) (ICLR 2026). + dgpo_dqw_acc_reward_index (`int`, *optional*, defaults to `0`): + Index of the accuracy reward in `reward_funcs` used by DQW for difficulty measure. The mean reward at this + index (per question) is used to compute question weights: lower mean accuracy means harder question. + Introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614) (ICLR 2026). > Parameters that control the logging @@ -805,6 +823,37 @@ class GRPOConfig(TrainingArguments): }, ) + use_dgpo_dgae: bool = field( + default=False, + metadata={ + "help": "Whether to use difficulty-balanced group advantage estimation (DGAE). When True, the denominator " + "when scaling advantages uses the Mean Absolute Deviation (MAD) of rewards instead of the standard " + "deviation. Introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614) (ICLR 2026)." + }, + ) + use_dgpo_dqw: bool = field( + default=False, + metadata={ + "help": "Whether to use difficulty-aware question-level weighting (DQW). When True, question weights are " + "multiplied directly onto the advantages. Introduced in the [MathForge paper](https://huggingface.co/" + "papers/2601.20614) (ICLR 2026)." + }, + ) + dgpo_dqw_temp: float = field( + default=2.0, + metadata={ + "help": "Temperature for the DQW softmax over negative mean (accuracy) reward. Introduced in the " + "[MathForge paper](https://huggingface.co/papers/2601.20614) (ICLR 2026)." + }, + ) + dgpo_dqw_acc_reward_index: int = field( + default=0, + metadata={ + "help": "Index of the accuracy reward in reward_funcs used by DQW for difficulty measure. Introduced in " + "the [MathForge paper](https://huggingface.co/papers/2601.20614) (ICLR 2026)." + }, + ) + # Parameters that control the logging log_completions: bool = field( default=False, diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index e2a6d41a142..64f3da8fc6f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -491,6 +491,16 @@ def __init__( raise NotImplementedError( "Liger Kernels don't currently support masking token positions based on entropy." ) + # DGPO (Difficulty-Aware Group Policy Optimization) + self.use_dgpo_dgae = getattr(args, "use_dgpo_dgae", False) + self.use_dgpo_dqw = getattr(args, "use_dgpo_dqw", False) + self.dgpo_dqw_temp = getattr(args, "dgpo_dqw_temp", 2.0) + self.dgpo_dqw_acc_reward_index = getattr(args, "dgpo_dqw_acc_reward_index", 0) + if self.use_dgpo_dqw and (self.dgpo_dqw_acc_reward_index < 0 or self.dgpo_dqw_acc_reward_index >= len(self.reward_funcs)): + raise ValueError( + f"dgpo_dqw_acc_reward_index must be in [0, {len(self.reward_funcs)}), got " + f"{self.dgpo_dqw_acc_reward_index}." + ) if self.use_liger_kernel and not self.importance_sampling_level == "token": raise NotImplementedError( "Liger Kernels currently only support token-level importance sampling. Please set" @@ -1751,7 +1761,23 @@ def _generate_and_score_completions( advantages = rewards - mean_grouped_rewards if self.scale_rewards != "none": - advantages = advantages / (std_rewards + 1e-4) + if self.use_dgpo_dgae: + # DGAE: use Mean Absolute Deviation (MAD) as denominator instead of std + # MAD = mean(|rewards - mean|), per group when scale_rewards=="group" + if self.scale_rewards == "group" and num_generations > 1: + mad_rewards = ( + (rewards - mean_grouped_rewards) + .abs() + .view(-1, num_generations) + .mean(dim=1) + .repeat_interleave(num_generations, dim=0) + ) + else: + # batch scaling or single sample per group: use global MAD + mad_rewards = (rewards - mean_grouped_rewards).abs().mean().expand_as(rewards) + advantages = advantages / (mad_rewards + 1e-4) + else: + advantages = advantages / (std_rewards + 1e-4) is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging elif self.multi_objective_aggregation == "normalize_then_sum": @@ -1762,7 +1788,12 @@ def _generate_and_score_completions( reward_k = reward_k.view(-1, len(self.reward_funcs)) rewards = (reward_k * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) std_rewards = rewards.std().expand_as(rewards) if rewards.numel() > 1 else torch.zeros_like(rewards) - advantages = (rewards - rewards.mean()) / (std_rewards + 1e-4) + if self.use_dgpo_dgae: + # DGAE: use MAD as denominator instead of std + mad_rewards = (rewards - rewards.mean()).abs().mean().expand_as(rewards) + advantages = (rewards - rewards.mean()) / (mad_rewards + 1e-4) + else: + advantages = (rewards - rewards.mean()) / (std_rewards + 1e-4) is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging else: @@ -1771,6 +1802,46 @@ def _generate_and_score_completions( "'sum_then_normalize' or 'normalize_then_sum'." ) + # Valid token-level loss averaging: zero_mask_ratio before slice, global_balancing_ratio after slice + # Valid sample = std_rewards not zero (is_std_zero is already defined above) + if self.use_dgpo_dgae or self.use_dgpo_dqw: + completion_length = completion_mask.sum(dim=1) # (N_local,) valid tokens per sample + gathered_completion_length = self.accelerator.gather(completion_length) + global_completion_length_sum = gathered_completion_length.sum().clamp(min=1e-8) + local_completion_length_sum = completion_length.sum() + global_balancing_ratio = ( + self.accelerator.num_processes * local_completion_length_sum / global_completion_length_sum + ) + if (~is_std_zero).any(): + valid_completion_length_sum = gathered_completion_length[~is_std_zero].sum().clamp(min=1e-8) + zero_mask_ratio = global_completion_length_sum / valid_completion_length_sum + else: + zero_mask_ratio = torch.tensor(1.0, device=advantages.device, dtype=advantages.dtype) + advantages = advantages * zero_mask_ratio + + # DQW: multiply advantages by question-level weights; weights sum to num_questions, zero-variance questions get 1 + if self.use_dgpo_dqw: + num_questions = rewards.size(0) // num_generations + acc_rewards = rewards_per_func[:, self.dgpo_dqw_acc_reward_index] # (N,) + mean_per_q_acc = acc_rewards.view(-1, num_generations).nanmean(dim=1) # (num_questions,) + std_per_q_acc = acc_rewards.view(-1, num_generations).std(dim=1) # (num_questions,) + is_std_zero_q = std_per_q_acc < 1e-8 + num_zero_variance_questions = is_std_zero_q.sum().item() + difficulty_balancing_weights = torch.ones( + num_questions, device=advantages.device, dtype=advantages.dtype + ) + if num_zero_variance_questions < num_questions: + # For mean accuracy 0 (all wrong) or NaN, set difficulty to -1 so they get less weight + mean_per_q_acc_modified = mean_per_q_acc.clone() + mean_per_q_acc_modified[(mean_per_q_acc == 0) | torch.isnan(mean_per_q_acc)] = 1.0 + difficulty_balancing_weights[~is_std_zero_q] = ( + num_questions - num_zero_variance_questions + ) * torch.nn.functional.softmax( + -mean_per_q_acc_modified[~is_std_zero_q] / self.dgpo_dqw_temp, dim=0 + ) + question_weights_expanded = difficulty_balancing_weights.repeat_interleave(num_generations) + advantages = advantages * question_weights_expanded + # Slice to keep only the local part of the data process_slice = slice( self.accelerator.process_index * len(prompts), @@ -1779,6 +1850,9 @@ def _generate_and_score_completions( all_process_advantages = advantages.clone() # keep the aggregated advantages for logging advantages = advantages[process_slice] + if self.use_dgpo_dgae or self.use_dgpo_dqw: + advantages = advantages * global_balancing_ratio + # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) for i, reward_func_name in enumerate(self.reward_func_names): mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() From b0f72efff62b282572f469723ab0af81a910b3e2 Mon Sep 17 00:00:00 2001 From: Yanqi Dai Date: Fri, 20 Feb 2026 14:51:07 +0800 Subject: [PATCH 02/12] Revise DGPO description and usage instructions Updated the DGPO section to clarify its mechanisms and usage in TRL, including details on DGAE and DQW. --- docs/source/paper_index.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index c49bee59b89..075a9e557e5 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -226,7 +226,12 @@ training_args = GRPOConfig( **📜 Paper**: https://huggingface.co/papers/2601.20614 -DGPO extends GRPO with difficulty-aware mechanisms: DGAE (difficulty-balanced group advantage estimation using MAD instead of std) and DQW (difficulty-aware question-level weighting). To use DGPO in TRL, enable the corresponding options in [`GRPOConfig`]: +DGPO extends GRPO with difficulty-aware mechanisms to improve training on tasks with varying question difficulty (e.g., math reasoning). It is introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614) (ICLR 2026) and is supported in [`GRPOTrainer`] via [`GRPOConfig`]. + +- **DGAE (Difficulty-balanced Group Advantage Estimation)**: When `use_dgpo_dgae=True`, advantages are scaled using Mean Absolute Deviation (MAD) instead of standard deviation, i.e. advantage = (reward - mean) / (MAD + eps), which can address the implicit imbalance where the update magnitudes are suppressed for both easier and harder questions and peak for those of moderate difficulty. +- **DQW (Difficulty-aware Question-level Weighting)**: When `use_dgpo_dqw=True`, each question (prompt group) is assigned a weight based on its difficulty (e.g., mean accuracy reward). Harder questions get higher weight, so the policy focuses more on them. Use `dgpo_dqw_temp` to control how sharp the weighting is (lower = more focus on hard questions) and `dgpo_dqw_acc_reward_index` to specify which reward in `reward_funcs` is used as the accuracy/difficulty signal. + +To use DGPO in TRL, enable the corresponding options in [`GRPOConfig`]: ```python from trl import GRPOConfig, GRPOTrainer From 90fc3f5f69a07f939a46dedc2d999c0882223212 Mon Sep 17 00:00:00 2001 From: Yanqi Dai Date: Fri, 20 Feb 2026 14:54:49 +0800 Subject: [PATCH 03/12] Remove DGPO section from grpo_trainer.md Removed DGPO section and its related details from the documentation. --- docs/source/grpo_trainer.md | 8 -------- 1 file changed, 8 deletions(-) diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index aea7149656f..8216b34f2ef 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -165,14 +165,6 @@ They recommend using asymmetric temperatures, \\( \tau_{\text{neg}} > \tau_{\te To use this formulation, set `loss_type="sapo"` in the [`GRPOConfig`]. -### DGPO (Difficulty-Aware Group Policy Optimization) - -DGPO extends GRPO with difficulty-aware mechanisms to improve training on tasks with varying question difficulty (e.g., math reasoning). It is introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614) (ICLR 2026) and is supported in [`GRPOTrainer`] via [`GRPOConfig`]. - -- **DGAE (Difficulty-balanced Group Advantage Estimation)**: When `use_dgpo_dgae=True`, advantages are scaled using Mean Absolute Deviation (MAD) instead of standard deviation, i.e. advantage = (reward - mean) / (MAD + eps), which can address the implicit imbalance -where the update magnitudes are suppressed for both easier and harder questions and peak for those of moderate difficulty. -- **DQW (Difficulty-aware Question-level Weighting)**: When `use_dgpo_dqw=True`, each question (prompt group) is assigned a weight based on its difficulty (e.g., mean accuracy reward). Harder questions get higher weight so the policy focuses more on them. Use `dgpo_dqw_temp` to control how sharp the weighting is (lower = more focus on hard questions) and `dgpo_dqw_acc_reward_index` to specify which reward in `reward_funcs` is used as the accuracy/difficulty signal. - ## Logged metrics While training and evaluating, we record the following reward metrics: From 34a69ebbdcda8ba7cf9d3bff77433c5ae1afa862 Mon Sep 17 00:00:00 2001 From: Yanqi Dai Date: Fri, 20 Feb 2026 14:56:45 +0800 Subject: [PATCH 04/12] Remove ICLR 2026 --- trl/trainer/grpo_config.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 3cb426ba24a..ffc4714d8bc 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -296,20 +296,20 @@ class GRPOConfig(TrainingArguments): Whether to use difficulty-balanced group advantage estimation (DGAE). When `True`, the denominator when scaling advantages uses the Mean Absolute Deviation (MAD) of rewards instead of the standard deviation, i.e. advantage = (reward - mean) / (MAD + eps) with MAD = mean(|reward - mean|). Introduced in the [MathForge - paper](https://huggingface.co/papers/2601.20614) (ICLR 2026). + paper](https://huggingface.co/papers/2601.20614). use_dgpo_dqw (`bool`, *optional*, defaults to `False`): Whether to use difficulty-aware question-level weighting (DQW). When `True`, question weights (softmax over negative mean accuracy reward at `dgpo_dqw_acc_reward_index`) are multiplied directly onto the advantages, so harder questions get larger effective advantages. Introduced in the [MathForge - paper](https://huggingface.co/papers/2601.20614) (ICLR 2026). + paper](https://huggingface.co/papers/2601.20614). dgpo_dqw_temp (`float`, *optional*, defaults to `2.0`): Temperature for the DQW softmax over negative mean (accuracy) reward. Higher values make the weighting more uniform; lower values concentrate weight on harder questions. Introduced in the [MathForge - paper](https://huggingface.co/papers/2601.20614) (ICLR 2026). + paper](https://huggingface.co/papers/2601.20614). dgpo_dqw_acc_reward_index (`int`, *optional*, defaults to `0`): Index of the accuracy reward in `reward_funcs` used by DQW for difficulty measure. The mean reward at this index (per question) is used to compute question weights: lower mean accuracy means harder question. - Introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614) (ICLR 2026). + Introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614). > Parameters that control the logging From da8c4454d491ed0840cf082b82cc037a90589328 Mon Sep 17 00:00:00 2001 From: Yanqi Dai Date: Fri, 20 Feb 2026 15:21:06 +0800 Subject: [PATCH 05/12] Apply all other suggestions from code review Co-authored-by: LeonEricsson <70749762+LeonEricsson@users.noreply.github.com> --- trl/trainer/grpo_trainer.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index bbe8ec08a52..7269fed195f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -491,11 +491,10 @@ def __init__( raise NotImplementedError( "Liger Kernels don't currently support masking token positions based on entropy." ) - # DGPO (Difficulty-Aware Group Policy Optimization) - self.use_dgpo_dgae = getattr(args, "use_dgpo_dgae", False) - self.use_dgpo_dqw = getattr(args, "use_dgpo_dqw", False) - self.dgpo_dqw_temp = getattr(args, "dgpo_dqw_temp", 2.0) - self.dgpo_dqw_acc_reward_index = getattr(args, "dgpo_dqw_acc_reward_index", 0) + self.use_dgpo_dgae = args.use_dgpo_dgae + self.use_dgpo_dqw = args.use_dgpo_dqw + self.dgpo_dqw_temp = args.dgpo_dqw_temp + self.dgpo_dqw_acc_reward_index = args.dgpo_dqw_acc_reward_index if self.use_dgpo_dqw and (self.dgpo_dqw_acc_reward_index < 0 or self.dgpo_dqw_acc_reward_index >= len(self.reward_funcs)): raise ValueError( f"dgpo_dqw_acc_reward_index must be in [0, {len(self.reward_funcs)}), got " @@ -1812,18 +1811,24 @@ def _generate_and_score_completions( # Valid token-level loss averaging: zero_mask_ratio before slice, global_balancing_ratio after slice # Valid sample = std_rewards not zero (is_std_zero is already defined above) if self.use_dgpo_dgae or self.use_dgpo_dqw: - completion_length = completion_mask.sum(dim=1) # (N_local,) valid tokens per sample - gathered_completion_length = self.accelerator.gather(completion_length) - global_completion_length_sum = gathered_completion_length.sum().clamp(min=1e-8) - local_completion_length_sum = completion_length.sum() + completion_length_local = completion_mask.sum(dim=1) + completion_length_global = self.accelerator.gather(completion_length_local) + valid_mask_global = ~is_std_zero + + global_completion_length_sum = completion_length_global.sum().clamp(min=1e-8) + local_completion_length_sum = completion_length_local.sum() + + # Rebalance each process by its share of valid tokens. global_balancing_ratio = ( self.accelerator.num_processes * local_completion_length_sum / global_completion_length_sum ) - if (~is_std_zero).any(): - valid_completion_length_sum = gathered_completion_length[~is_std_zero].sum().clamp(min=1e-8) + + if valid_mask_global.any(): + valid_completion_length_sum = completion_length_global[valid_mask_global].sum().clamp(min=1e-8) zero_mask_ratio = global_completion_length_sum / valid_completion_length_sum else: zero_mask_ratio = torch.tensor(1.0, device=advantages.device, dtype=advantages.dtype) + advantages = advantages * zero_mask_ratio # DQW: multiply advantages by question-level weights; weights sum to num_questions, zero-variance questions get 1 From 25def334d29e9ec4e405762c33b092ec42a13851 Mon Sep 17 00:00:00 2001 From: Yanqi Dai Date: Fri, 20 Feb 2026 15:28:25 +0800 Subject: [PATCH 06/12] Polish the description of accuracy handling logic in DQW Updated comment to clarify handling of mean accuracy for questions. --- trl/trainer/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 7269fed195f..320e4be47fb 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1843,7 +1843,7 @@ def _generate_and_score_completions( num_questions, device=advantages.device, dtype=advantages.dtype ) if num_zero_variance_questions < num_questions: - # For mean accuracy 0 (all wrong) or NaN, set difficulty to -1 so they get less weight + # mean accuracy == 0 (all wrong) or NaN are remapped to 1.0 before softmax so they get less weight``` mean_per_q_acc_modified = mean_per_q_acc.clone() mean_per_q_acc_modified[(mean_per_q_acc == 0) | torch.isnan(mean_per_q_acc)] = 1.0 difficulty_balancing_weights[~is_std_zero_q] = ( From 1ecdae6fd7216970a3e094bd787f14a500b35071 Mon Sep 17 00:00:00 2001 From: Yanqi Dai Date: Fri, 20 Feb 2026 17:19:46 +0800 Subject: [PATCH 07/12] Rewrite the DGPO code --- trl/trainer/grpo_trainer.py | 160 ++++++++++++++++++++++-------------- 1 file changed, 100 insertions(+), 60 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 320e4be47fb..caea6c1c481 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1531,6 +1531,90 @@ def _generate(self, prompts: list): extra_fields, ) + def _compute_advantages_with_dgae( + self, + rewards: torch.Tensor, + num_generations: int, + *, + use_group_mad: bool | None = None, + ) -> torch.Tensor: + """Compute advantages using MAD (DGAE) as denominator. Call only when use_dgpo_dgae is True.""" + advantages = rewards - rewards.mean() + if self.scale_rewards != "none": + if use_group_mad is None: + use_group_mad = self.scale_rewards == "group" and num_generations > 1 + if use_group_mad: + mad_rewards = ( + advantages.abs() + .view(-1, num_generations) + .mean(dim=1) + .repeat_interleave(num_generations, dim=0) + ) + else: + mad_rewards = advantages.abs().mean().expand_as(rewards) + advantages = advantages / (mad_rewards + 1e-4) + return advantages + + def _compute_valid_token_balancing_ratios( + self, + completion_mask: torch.Tensor, + is_std_zero: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute valid token-level balancing ratios (zero_mask_ratio and global_balancing_ratio). + Returns (zero_mask_ratio, global_balancing_ratio). Apply zero_mask_ratio to advantages before slice, + global_balancing_ratio after slice. Call only when use_dgpo_dgae or use_dgpo_dqw is True. + """ + completion_length_local = completion_mask.sum(dim=1) + completion_length_global = gather(completion_length_local) + + global_completion_length_sum = completion_length_global.sum().clamp(min=1e-8) + local_completion_length_sum = completion_length_local.sum() + + global_balancing_ratio = ( + self.accelerator.num_processes * local_completion_length_sum / global_completion_length_sum + ) + + valid_mask_global = ~gather(is_std_zero) + if valid_mask_global.any(): + valid_completion_length_sum = completion_length_global[valid_mask_global].sum().clamp(min=1e-8) + zero_mask_ratio = global_completion_length_sum / valid_completion_length_sum + else: + zero_mask_ratio = torch.tensor(1.0, device=completion_mask.device, dtype=completion_mask.dtype) + + return zero_mask_ratio, global_balancing_ratio + + def _compute_dqw_weights( + self, + rewards: torch.Tensor, + rewards_per_func: torch.Tensor, + num_generations: int, + ) -> torch.Tensor: + """ + Compute question-level difficulty balancing weights (DQW). + Returns difficulty_balancing_weights (num_questions,); expand with repeat_interleave at call site. + Weights sum to num_questions; zero-variance questions get weight 1. + Call only when use_dgpo_dqw is True. + """ + num_questions = rewards.size(0) // num_generations + acc_rewards = rewards_per_func[:, self.dgpo_dqw_acc_reward_index] # (N,) + mean_per_q_acc = acc_rewards.view(-1, num_generations).nanmean(dim=1) # (num_questions,) + std_per_q_acc = acc_rewards.view(-1, num_generations).std(dim=1) # (num_questions,) + is_std_zero_q = std_per_q_acc < 1e-8 + num_zero_variance_questions = is_std_zero_q.sum().item() + difficulty_balancing_weights = torch.ones( + num_questions, device=rewards.device, dtype=rewards.dtype + ) + if num_zero_variance_questions < num_questions: + mean_per_q_acc_modified = mean_per_q_acc.clone() + mean_per_q_acc_modified[(mean_per_q_acc == 0) | torch.isnan(mean_per_q_acc)] = 1.0 + difficulty_balancing_weights[~is_std_zero_q] = ( + num_questions - num_zero_variance_questions + ) * torch.nn.functional.softmax( + -mean_per_q_acc_modified[~is_std_zero_q] / self.dgpo_dqw_temp, dim=0 + ) + return difficulty_balancing_weights + def _generate_and_score_completions( self, inputs: list[dict[str, torch.Tensor | Any]] ) -> dict[str, torch.Tensor | Any]: @@ -1765,26 +1849,14 @@ def _generate_and_score_completions( f"Invalid value for scale_rewards: {self.scale_rewards}. Must be one of 'batch', 'group', or 'none'." ) - advantages = rewards - mean_grouped_rewards - if self.scale_rewards != "none": - if self.use_dgpo_dgae: - # DGAE: use Mean Absolute Deviation (MAD) as denominator instead of std - # MAD = mean(|rewards - mean|), per group when scale_rewards=="group" - if self.scale_rewards == "group" and num_generations > 1: - mad_rewards = ( - (rewards - mean_grouped_rewards) - .abs() - .view(-1, num_generations) - .mean(dim=1) - .repeat_interleave(num_generations, dim=0) - ) - else: - # batch scaling or single sample per group: use global MAD - mad_rewards = (rewards - mean_grouped_rewards).abs().mean().expand_as(rewards) - advantages = advantages / (mad_rewards + 1e-4) - else: + if self.use_dgpo_dgae: + advantages = self._compute_advantages_with_dgae( + rewards, num_generations + ) + else: + advantages = rewards - mean_grouped_rewards + if self.scale_rewards != "none": advantages = advantages / (std_rewards + 1e-4) - is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging elif self.multi_objective_aggregation == "normalize_then_sum": grouped = rewards_per_func.view(-1, num_generations, len(self.reward_funcs)) @@ -1795,12 +1867,11 @@ def _generate_and_score_completions( rewards = (reward_k * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) std_rewards = rewards.std().expand_as(rewards) if rewards.numel() > 1 else torch.zeros_like(rewards) if self.use_dgpo_dgae: - # DGAE: use MAD as denominator instead of std - mad_rewards = (rewards - rewards.mean()).abs().mean().expand_as(rewards) - advantages = (rewards - rewards.mean()) / (mad_rewards + 1e-4) + advantages = self._compute_advantages_with_dgae( + rewards, num_generations, use_group_mad=False + ) else: advantages = (rewards - rewards.mean()) / (std_rewards + 1e-4) - is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging else: raise ValueError( @@ -1809,50 +1880,19 @@ def _generate_and_score_completions( ) # Valid token-level loss averaging: zero_mask_ratio before slice, global_balancing_ratio after slice - # Valid sample = std_rewards not zero (is_std_zero is already defined above) + is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging if self.use_dgpo_dgae or self.use_dgpo_dqw: - completion_length_local = completion_mask.sum(dim=1) - completion_length_global = self.accelerator.gather(completion_length_local) - valid_mask_global = ~is_std_zero - - global_completion_length_sum = completion_length_global.sum().clamp(min=1e-8) - local_completion_length_sum = completion_length_local.sum() - - # Rebalance each process by its share of valid tokens. - global_balancing_ratio = ( - self.accelerator.num_processes * local_completion_length_sum / global_completion_length_sum + zero_mask_ratio, global_balancing_ratio = self._compute_valid_token_balancing_ratios( + completion_mask, is_std_zero ) - - if valid_mask_global.any(): - valid_completion_length_sum = completion_length_global[valid_mask_global].sum().clamp(min=1e-8) - zero_mask_ratio = global_completion_length_sum / valid_completion_length_sum - else: - zero_mask_ratio = torch.tensor(1.0, device=advantages.device, dtype=advantages.dtype) - advantages = advantages * zero_mask_ratio # DQW: multiply advantages by question-level weights; weights sum to num_questions, zero-variance questions get 1 if self.use_dgpo_dqw: - num_questions = rewards.size(0) // num_generations - acc_rewards = rewards_per_func[:, self.dgpo_dqw_acc_reward_index] # (N,) - mean_per_q_acc = acc_rewards.view(-1, num_generations).nanmean(dim=1) # (num_questions,) - std_per_q_acc = acc_rewards.view(-1, num_generations).std(dim=1) # (num_questions,) - is_std_zero_q = std_per_q_acc < 1e-8 - num_zero_variance_questions = is_std_zero_q.sum().item() - difficulty_balancing_weights = torch.ones( - num_questions, device=advantages.device, dtype=advantages.dtype + difficulty_balancing_weights = self._compute_dqw_weights( + rewards, rewards_per_func, num_generations ) - if num_zero_variance_questions < num_questions: - # mean accuracy == 0 (all wrong) or NaN are remapped to 1.0 before softmax so they get less weight``` - mean_per_q_acc_modified = mean_per_q_acc.clone() - mean_per_q_acc_modified[(mean_per_q_acc == 0) | torch.isnan(mean_per_q_acc)] = 1.0 - difficulty_balancing_weights[~is_std_zero_q] = ( - num_questions - num_zero_variance_questions - ) * torch.nn.functional.softmax( - -mean_per_q_acc_modified[~is_std_zero_q] / self.dgpo_dqw_temp, dim=0 - ) - question_weights_expanded = difficulty_balancing_weights.repeat_interleave(num_generations) - advantages = advantages * question_weights_expanded + advantages = advantages * difficulty_balancing_weights.repeat_interleave(num_generations) # Slice to keep only the local part of the data process_slice = slice( From 6291ae35fd207fb4eefe42d115d0dfc40f10a0f6 Mon Sep 17 00:00:00 2001 From: Yanqi Dai Date: Fri, 20 Feb 2026 17:30:49 +0800 Subject: [PATCH 08/12] Remove ICLR 2026 --- trl/trainer/grpo_config.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index ffc4714d8bc..05049dc0090 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -828,7 +828,7 @@ class GRPOConfig(TrainingArguments): metadata={ "help": "Whether to use difficulty-balanced group advantage estimation (DGAE). When True, the denominator " "when scaling advantages uses the Mean Absolute Deviation (MAD) of rewards instead of the standard " - "deviation. Introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614) (ICLR 2026)." + "deviation. Introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614)." }, ) use_dgpo_dqw: bool = field( @@ -836,21 +836,21 @@ class GRPOConfig(TrainingArguments): metadata={ "help": "Whether to use difficulty-aware question-level weighting (DQW). When True, question weights are " "multiplied directly onto the advantages. Introduced in the [MathForge paper](https://huggingface.co/" - "papers/2601.20614) (ICLR 2026)." + "papers/2601.20614)." }, ) dgpo_dqw_temp: float = field( default=2.0, metadata={ "help": "Temperature for the DQW softmax over negative mean (accuracy) reward. Introduced in the " - "[MathForge paper](https://huggingface.co/papers/2601.20614) (ICLR 2026)." + "[MathForge paper](https://huggingface.co/papers/2601.20614)." }, ) dgpo_dqw_acc_reward_index: int = field( default=0, metadata={ "help": "Index of the accuracy reward in reward_funcs used by DQW for difficulty measure. Introduced in " - "the [MathForge paper](https://huggingface.co/papers/2601.20614) (ICLR 2026)." + "the [MathForge paper](https://huggingface.co/papers/2601.20614)." }, ) From df1fb48bc8a03afbf26f72ae3451a42d60889a62 Mon Sep 17 00:00:00 2001 From: Yanqi Dai Date: Fri, 20 Feb 2026 17:39:43 +0800 Subject: [PATCH 09/12] Recover the code position of is_std_zero --- trl/trainer/grpo_trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index caea6c1c481..59d7db8e202 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1857,6 +1857,7 @@ def _generate_and_score_completions( advantages = rewards - mean_grouped_rewards if self.scale_rewards != "none": advantages = advantages / (std_rewards + 1e-4) + is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging elif self.multi_objective_aggregation == "normalize_then_sum": grouped = rewards_per_func.view(-1, num_generations, len(self.reward_funcs)) @@ -1872,6 +1873,7 @@ def _generate_and_score_completions( ) else: advantages = (rewards - rewards.mean()) / (std_rewards + 1e-4) + is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging else: raise ValueError( @@ -1880,7 +1882,6 @@ def _generate_and_score_completions( ) # Valid token-level loss averaging: zero_mask_ratio before slice, global_balancing_ratio after slice - is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging if self.use_dgpo_dgae or self.use_dgpo_dqw: zero_mask_ratio, global_balancing_ratio = self._compute_valid_token_balancing_ratios( completion_mask, is_std_zero From 89e1384506c3dc451cdb8a6cf7d220e62f112e4f Mon Sep 17 00:00:00 2001 From: Yanqi Dai Date: Wed, 25 Feb 2026 21:45:16 +0800 Subject: [PATCH 10/12] Apply suggestions from code review Co-authored-by: LeonEricsson <70749762+LeonEricsson@users.noreply.github.com> --- trl/trainer/grpo_config.py | 61 ++++++++++++++++++++++++------------- trl/trainer/grpo_trainer.py | 37 ++++++++++++++++------ 2 files changed, 67 insertions(+), 31 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index fe0e0a03a3a..aae1548a332 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -295,22 +295,24 @@ class GRPOConfig(BaseConfig): KL divergence estimate by multiplying it with the importance sampling ratio. This is described in the [DeepSeek-V3.2 paper](https://huggingface.co/papers/2512.02556). use_dgpo_dgae (`bool`, *optional*, defaults to `False`): - Whether to use difficulty-balanced group advantage estimation (DGAE). When `True`, the denominator when - scaling advantages uses the Mean Absolute Deviation (MAD) of rewards instead of the standard deviation, i.e. - advantage = (reward - mean) / (MAD + eps) with MAD = mean(|reward - mean|). Introduced in the [MathForge - paper](https://huggingface.co/papers/2601.20614). + Whether to use difficulty-balanced group advantage estimation (DGAE). When `True`, group-relative + advantages are normalized by the mean absolute deviation (MAD) of rewards (instead of the standard + deviation): `advantage = (reward - mean) / (MAD + eps)`, where `MAD = mean(|reward - mean|)`. Introduced + in the [MathForge paper](https://huggingface.co/papers/2601.20614). use_dgpo_dqw (`bool`, *optional*, defaults to `False`): - Whether to use difficulty-aware question-level weighting (DQW). When `True`, question weights (softmax over - negative mean accuracy reward at `dgpo_dqw_acc_reward_index`) are multiplied directly onto the advantages, - so harder questions get larger effective advantages. Introduced in the [MathForge + Whether to use difficulty-aware question-level weighting (DQW). When `True`, each question gets a weight + based on its estimated difficulty, and that weight is multiplied into the advantages—so harder questions + produce larger effective updates. Difficulty is computed as a softmax over the negative per-question mean + accuracy reward from `reward_funcs[dgpo_dqw_acc_reward_index]`. Introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614). dgpo_dqw_temp (`float`, *optional*, defaults to `2.0`): - Temperature for the DQW softmax over negative mean (accuracy) reward. Higher values make the weighting more - uniform; lower values concentrate weight on harder questions. Introduced in the [MathForge - paper](https://huggingface.co/papers/2601.20614). + Temperature for the DQW softmax over difficulty scores (negative mean accuracy reward). Higher values make + weights more uniform across questions; lower values concentrate weight on the hardest questions. + Introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614). dgpo_dqw_acc_reward_index (`int`, *optional*, defaults to `0`): - Index of the accuracy reward in `reward_funcs` used by DQW for difficulty measure. The mean reward at this - index (per question) is used to compute question weights: lower mean accuracy means harder question. + Index into `reward_funcs` selecting the reward used as the "accuracy" signal for DQW's difficulty + estimate. For each question, DQW uses the mean reward at this index across the group to compute the + difficulty score (lower mean ⇒ harder), which is then turned into a weight via the DQW softmax. Introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614). > Parameters that control the logging @@ -794,34 +796,49 @@ class GRPOConfig(BaseConfig): }, ) + use_bias_correction_kl: bool = field( + default=False, + metadata={ + "help": "Whether to use the unbiased KL divergence estimator with importance sampling correction. This " + "corrects the KL divergence estimate by multiplying it with the importance sampling ratio. " + "This is described in the [DeepSeek-V3.2 paper](https://huggingface.co/papers/2512.02556)." + }, + ) + use_dgpo_dgae: bool = field( default=False, metadata={ - "help": "Whether to use difficulty-balanced group advantage estimation (DGAE). When True, the denominator " - "when scaling advantages uses the Mean Absolute Deviation (MAD) of rewards instead of the standard " - "deviation. Introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614)." + "help": "Whether to use difficulty-balanced group advantage estimation (DGAE). When `True`, group-relative " + "advantages are normalized by the mean absolute deviation (MAD) of rewards (instead of the standard " + "deviation): `advantage = (reward - mean) / (MAD + eps)`, where `MAD = mean(|reward - mean|)`. Introduced " + "in the [MathForge paper](https://huggingface.co/papers/2601.20614)." }, ) use_dgpo_dqw: bool = field( default=False, metadata={ - "help": "Whether to use difficulty-aware question-level weighting (DQW). When True, question weights are " - "multiplied directly onto the advantages. Introduced in the [MathForge paper](https://huggingface.co/" - "papers/2601.20614)." + "help": "Whether to use difficulty-aware question-level weighting (DQW). When `True`, each question gets a " + "weight based on its estimated difficulty, and that weight is multiplied into the advantages—so harder " + "questions produce larger effective updates. Difficulty is computed as a softmax over the negative " + "per-question mean accuracy reward from `reward_funcs[dgpo_dqw_acc_reward_index]`. Introduced in the " + "[MathForge paper](https://huggingface.co/papers/2601.20614)." }, ) dgpo_dqw_temp: float = field( default=2.0, metadata={ - "help": "Temperature for the DQW softmax over negative mean (accuracy) reward. Introduced in the " - "[MathForge paper](https://huggingface.co/papers/2601.20614)." + "help": "Temperature for the DQW softmax over difficulty scores (negative mean accuracy reward). Higher " + "values make weights more uniform across questions; lower values concentrate weight on the hardest " + "questions. Introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614)." }, ) dgpo_dqw_acc_reward_index: int = field( default=0, metadata={ - "help": "Index of the accuracy reward in reward_funcs used by DQW for difficulty measure. Introduced in " - "the [MathForge paper](https://huggingface.co/papers/2601.20614)." + "help": "Index into `reward_funcs` selecting the reward used as the \"accuracy\" signal for DQW's " + "difficulty estimate. For each question, DQW uses the mean reward at this index across the group to " + "compute the difficulty score (lower mean ⇒ harder), which is then turned into a weight via the DQW " + "softmax. Introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614)." }, ) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index ccf46bc8cdb..5a48b9f2739 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1624,9 +1624,19 @@ def _compute_valid_token_balancing_ratios( is_std_zero: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """ - Compute valid token-level balancing ratios (zero_mask_ratio and global_balancing_ratio). - Returns (zero_mask_ratio, global_balancing_ratio). Apply zero_mask_ratio to advantages before slice, - global_balancing_ratio after slice. Call only when use_dgpo_dgae or use_dgpo_dqw is True. + Compute token-level balancing ratios for distributed training with filtered questions. + + When zero-variance questions are masked out, the effective number of valid tokens changes across + processes. This method produces two correction factors: `zero_mask_ratio` compensates for the tokens + lost by masking zero-variance questions, and `global_balancing_ratio` corrects for uneven token counts + across processes. + + Args: + completion_mask: Boolean tensor of shape `(batch_size, seq_len)` indicating valid completion tokens. + is_std_zero: Boolean tensor of shape `(batch_size,)` indicating zero-variance questions. + + Returns: + A tuple `(zero_mask_ratio, global_balancing_ratio)` of scalar tensors. """ completion_length_local = completion_mask.sum(dim=1) completion_length_global = gather(completion_length_local) @@ -1655,12 +1665,22 @@ def _compute_dqw_weights( ) -> torch.Tensor: """ Compute question-level difficulty balancing weights (DQW). - Returns difficulty_balancing_weights (num_questions,); expand with repeat_interleave at call site. - Weights sum to num_questions; zero-variance questions get weight 1. - Call only when use_dgpo_dqw is True. + + Assigns higher weight to harder questions (lower mean accuracy) using a temperature-scaled softmax over + per-question accuracy means. Zero-variance questions receive a neutral weight of 1. + The returned weights sum to `num_questions`. + + Args: + rewards: Tensor of shape `(num_questions * num_generations,)` with per-generation rewards. + rewards_per_func: Tensor of shape `(num_questions * num_generations, num_reward_funcs)` with + per-reward-function scores; the column at `dgpo_dqw_acc_reward_index` is used as accuracy. + num_generations: Number of generations per question. + + Returns: + Tensor of shape `(num_questions,)` with difficulty balancing weights. """ num_questions = rewards.size(0) // num_generations - acc_rewards = rewards_per_func[:, self.dgpo_dqw_acc_reward_index] # (N,) + acc_rewards = rewards_per_func[:, self.dgpo_dqw_acc_reward_index] mean_per_q_acc = acc_rewards.view(-1, num_generations).nanmean(dim=1) # (num_questions,) std_per_q_acc = acc_rewards.view(-1, num_generations).std(dim=1) # (num_questions,) is_std_zero_q = std_per_q_acc < 1e-8 @@ -1951,14 +1971,13 @@ def _generate_and_score_completions( "'sum_then_normalize' or 'normalize_then_sum'." ) - # Valid token-level loss averaging: zero_mask_ratio before slice, global_balancing_ratio after slice + # zero_mask_ratio must be applied before the process slice; global_balancing_ratio after if self.use_dgpo_dgae or self.use_dgpo_dqw: zero_mask_ratio, global_balancing_ratio = self._compute_valid_token_balancing_ratios( completion_mask, is_std_zero ) advantages = advantages * zero_mask_ratio - # DQW: multiply advantages by question-level weights; weights sum to num_questions, zero-variance questions get 1 if self.use_dgpo_dqw: difficulty_balancing_weights = self._compute_dqw_weights( rewards, rewards_per_func, num_generations From e07db903fa04c5c5378ab80196ff4ba0d09e3eda Mon Sep 17 00:00:00 2001 From: Yanqi Dai Date: Wed, 25 Feb 2026 21:48:59 +0800 Subject: [PATCH 11/12] Remove repeated use_bias_correction_kl in suggestions from grpo_config Removed the repeated use_bias_correction_kl configuration option. --- trl/trainer/grpo_config.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index aae1548a332..38fcb71ce59 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -796,15 +796,6 @@ class GRPOConfig(BaseConfig): }, ) - use_bias_correction_kl: bool = field( - default=False, - metadata={ - "help": "Whether to use the unbiased KL divergence estimator with importance sampling correction. This " - "corrects the KL divergence estimate by multiplying it with the importance sampling ratio. " - "This is described in the [DeepSeek-V3.2 paper](https://huggingface.co/papers/2512.02556)." - }, - ) - use_dgpo_dgae: bool = field( default=False, metadata={ From f34d3f16e37d4d7ff9e85f9175436f81e0d524ac Mon Sep 17 00:00:00 2001 From: Yanqi Dai Date: Wed, 25 Feb 2026 22:47:43 +0800 Subject: [PATCH 12/12] Modify _compute_advantages_with_dgae and implement it directly in _generate_and_score_completions --- trl/trainer/grpo_trainer.py | 47 ++++++++++--------------------------- 1 file changed, 12 insertions(+), 35 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 5a48b9f2739..12f39986432 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1594,30 +1594,6 @@ def _generate(self, prompts: list): extra_fields, ) - def _compute_advantages_with_dgae( - self, - rewards: torch.Tensor, - num_generations: int, - *, - use_group_mad: bool | None = None, - ) -> torch.Tensor: - """Compute advantages using MAD (DGAE) as denominator. Call only when use_dgpo_dgae is True.""" - advantages = rewards - rewards.mean() - if self.scale_rewards != "none": - if use_group_mad is None: - use_group_mad = self.scale_rewards == "group" and num_generations > 1 - if use_group_mad: - mad_rewards = ( - advantages.abs() - .view(-1, num_generations) - .mean(dim=1) - .repeat_interleave(num_generations, dim=0) - ) - else: - mad_rewards = advantages.abs().mean().expand_as(rewards) - advantages = advantages / (mad_rewards + 1e-4) - return advantages - def _compute_valid_token_balancing_ratios( self, completion_mask: torch.Tensor, @@ -1939,13 +1915,13 @@ def _generate_and_score_completions( f"Invalid value for scale_rewards: {self.scale_rewards}. Must be one of 'batch', 'group', or 'none'." ) - if self.use_dgpo_dgae: - advantages = self._compute_advantages_with_dgae( - rewards, num_generations - ) - else: - advantages = rewards - mean_grouped_rewards - if self.scale_rewards != "none": + advantages = rewards - mean_grouped_rewards + if self.scale_rewards != "none": + if self.use_dgpo_dgae: + mad_rewards = advantages.abs().view(-1, num_generations).mean(dim=1) + mad_rewards = mad_rewards.repeat_interleave(num_generations, dim=0) + advantages = advantages / (mad_rewards + 1e-4) + else: advantages = advantages / (std_rewards + 1e-4) is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging @@ -1957,12 +1933,13 @@ def _generate_and_score_completions( reward_k = reward_k.view(-1, len(self.reward_funcs)) rewards = (reward_k * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) std_rewards = rewards.std().expand_as(rewards) if rewards.numel() > 1 else torch.zeros_like(rewards) + advantages = rewards - rewards.mean() if self.use_dgpo_dgae: - advantages = self._compute_advantages_with_dgae( - rewards, num_generations, use_group_mad=False - ) + mad_rewards = advantages.abs().view(-1, num_generations).mean(dim=1) + mad_rewards = mad_rewards.repeat_interleave(num_generations, dim=0) + advantages = advantages / (mad_rewards + 1e-4) else: - advantages = (rewards - rewards.mean()) / (std_rewards + 1e-4) + advantages = advantages / (std_rewards + 1e-4) is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging else: