diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index a8734524a00..43f4b8bea9a 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -221,6 +221,30 @@ training_args = GRPOConfig( ) ``` +### DGPO: Difficulty-Aware Group Policy Optimization + +**📜 Paper**: https://huggingface.co/papers/2601.20614 + +DGPO extends GRPO with difficulty-aware mechanisms to improve training on tasks with varying question difficulty (e.g., math reasoning). It is introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614) (ICLR 2026) and is supported in [`GRPOTrainer`] via [`GRPOConfig`]. + +- **DGAE (Difficulty-balanced Group Advantage Estimation)**: When `use_dgpo_dgae=True`, advantages are scaled using Mean Absolute Deviation (MAD) instead of standard deviation, i.e. advantage = (reward - mean) / (MAD + eps), which can address the implicit imbalance where the update magnitudes are suppressed for both easier and harder questions and peak for those of moderate difficulty. +- **DQW (Difficulty-aware Question-level Weighting)**: When `use_dgpo_dqw=True`, each question (prompt group) is assigned a weight based on its difficulty (e.g., mean accuracy reward). Harder questions get higher weight, so the policy focuses more on them. Use `dgpo_dqw_temp` to control how sharp the weighting is (lower = more focus on hard questions) and `dgpo_dqw_acc_reward_index` to specify which reward in `reward_funcs` is used as the accuracy/difficulty signal. + +To use DGPO in TRL, enable the corresponding options in [`GRPOConfig`]: + +```python +from trl import GRPOConfig, GRPOTrainer + +training_args = GRPOConfig( + ..., + use_dgpo_dgae=True, + use_dgpo_dqw=True, + dgpo_dqw_temp=2.0, + dgpo_dqw_acc_reward_index=0, +) +trainer = GRPOTrainer(..., args=training_args, reward_funcs=[...], train_dataset=...) +``` + ### Part I: Tricks or Traps? A Deep Dive into RL for LLM Reasoning (Lite PPO) **📜 Paper**: https://huggingface.co/papers/2508.08221 diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index b32506c2ca0..76f1136b2d0 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -230,6 +230,31 @@ def test_training_loss_types(self, loss_type): new_param = trainer.model.get_parameter(n) assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + def test_training_dgpo(self): + """Test DGPO (Difficulty-Aware Group Policy Optimization) runs without error.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=32, + report_to="none", + use_dgpo_dgae=True, + use_dgpo_dqw=True, + dgpo_dqw_temp=2.0, + dgpo_dqw_acc_reward_index=0, + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + trainer.train() + assert trainer.state.log_history[-1]["train_loss"] is not None + def test_training_with_eval(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 809ac441165..5ccbcd01080 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -295,6 +295,26 @@ class GRPOConfig(_BaseConfig): Whether to use the unbiased KL divergence estimator with importance sampling correction. This corrects the KL divergence estimate by multiplying it with the importance sampling ratio. This is described in the [DeepSeek-V3.2 paper](https://huggingface.co/papers/2512.02556). + use_dgpo_dgae (`bool`, *optional*, defaults to `False`): + Whether to use difficulty-balanced group advantage estimation (DGAE). When `True`, group-relative + advantages are normalized by the mean absolute deviation (MAD) of rewards (instead of the standard + deviation): `advantage = (reward - mean) / (MAD + eps)`, where `MAD = mean(|reward - mean|)`. Introduced + in the [MathForge paper](https://huggingface.co/papers/2601.20614). + use_dgpo_dqw (`bool`, *optional*, defaults to `False`): + Whether to use difficulty-aware question-level weighting (DQW). When `True`, each question gets a weight + based on its estimated difficulty, and that weight is multiplied into the advantages—so harder questions + produce larger effective updates. Difficulty is computed as a softmax over the negative per-question mean + accuracy reward from `reward_funcs[dgpo_dqw_acc_reward_index]`. Introduced in the [MathForge + paper](https://huggingface.co/papers/2601.20614). + dgpo_dqw_temp (`float`, *optional*, defaults to `2.0`): + Temperature for the DQW softmax over difficulty scores (negative mean accuracy reward). Higher values make + weights more uniform across questions; lower values concentrate weight on the hardest questions. + Introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614). + dgpo_dqw_acc_reward_index (`int`, *optional*, defaults to `0`): + Index into `reward_funcs` selecting the reward used as the "accuracy" signal for DQW's difficulty + estimate. For each question, DQW uses the mean reward at this index across the group to compute the + difficulty score (lower mean ⇒ harder), which is then turned into a weight via the DQW softmax. + Introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614). > Parameters that control the logging @@ -784,6 +804,43 @@ class GRPOConfig(_BaseConfig): }, ) + use_dgpo_dgae: bool = field( + default=False, + metadata={ + "help": "Whether to use difficulty-balanced group advantage estimation (DGAE). When `True`, group-relative " + "advantages are normalized by the mean absolute deviation (MAD) of rewards (instead of the standard " + "deviation): `advantage = (reward - mean) / (MAD + eps)`, where `MAD = mean(|reward - mean|)`. Introduced " + "in the [MathForge paper](https://huggingface.co/papers/2601.20614)." + }, + ) + use_dgpo_dqw: bool = field( + default=False, + metadata={ + "help": "Whether to use difficulty-aware question-level weighting (DQW). When `True`, each question gets a " + "weight based on its estimated difficulty, and that weight is multiplied into the advantages—so harder " + "questions produce larger effective updates. Difficulty is computed as a softmax over the negative " + "per-question mean accuracy reward from `reward_funcs[dgpo_dqw_acc_reward_index]`. Introduced in the " + "[MathForge paper](https://huggingface.co/papers/2601.20614)." + }, + ) + dgpo_dqw_temp: float = field( + default=2.0, + metadata={ + "help": "Temperature for the DQW softmax over difficulty scores (negative mean accuracy reward). Higher " + "values make weights more uniform across questions; lower values concentrate weight on the hardest " + "questions. Introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614)." + }, + ) + dgpo_dqw_acc_reward_index: int = field( + default=0, + metadata={ + "help": "Index into `reward_funcs` selecting the reward used as the \"accuracy\" signal for DQW's " + "difficulty estimate. For each question, DQW uses the mean reward at this index across the group to " + "compute the difficulty score (lower mean ⇒ harder), which is then turned into a weight via the DQW " + "softmax. Introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614)." + }, + ) + # Parameters that control the logging log_completions: bool = field( default=False, diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 5e44cee18d2..52601c00fa7 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -546,6 +546,15 @@ def __init__( raise NotImplementedError( "Liger Kernels don't currently support masking token positions based on entropy." ) + self.use_dgpo_dgae = args.use_dgpo_dgae + self.use_dgpo_dqw = args.use_dgpo_dqw + self.dgpo_dqw_temp = args.dgpo_dqw_temp + self.dgpo_dqw_acc_reward_index = args.dgpo_dqw_acc_reward_index + if self.use_dgpo_dqw and (self.dgpo_dqw_acc_reward_index < 0 or self.dgpo_dqw_acc_reward_index >= len(self.reward_funcs)): + raise ValueError( + f"dgpo_dqw_acc_reward_index must be in [0, {len(self.reward_funcs)}), got " + f"{self.dgpo_dqw_acc_reward_index}." + ) if self.use_liger_kernel and not self.importance_sampling_level == "token": raise NotImplementedError( "Liger Kernels currently only support token-level importance sampling. Please set" @@ -1591,6 +1600,86 @@ def _generate(self, prompts: list): extra_fields, ) + def _compute_valid_token_balancing_ratios( + self, + completion_mask: torch.Tensor, + is_std_zero: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute token-level balancing ratios for distributed training with filtered questions. + + When zero-variance questions are masked out, the effective number of valid tokens changes across + processes. This method produces two correction factors: `zero_mask_ratio` compensates for the tokens + lost by masking zero-variance questions, and `global_balancing_ratio` corrects for uneven token counts + across processes. + + Args: + completion_mask: Boolean tensor of shape `(batch_size, seq_len)` indicating valid completion tokens. + is_std_zero: Boolean tensor of shape `(batch_size,)` indicating zero-variance questions. + + Returns: + A tuple `(zero_mask_ratio, global_balancing_ratio)` of scalar tensors. + """ + completion_length_local = completion_mask.sum(dim=1) + completion_length_global = gather(completion_length_local) + + global_completion_length_sum = completion_length_global.sum().clamp(min=1e-8) + local_completion_length_sum = completion_length_local.sum() + + global_balancing_ratio = ( + self.accelerator.num_processes * local_completion_length_sum / global_completion_length_sum + ) + + valid_mask_global = ~gather(is_std_zero) + if valid_mask_global.any(): + valid_completion_length_sum = completion_length_global[valid_mask_global].sum().clamp(min=1e-8) + zero_mask_ratio = global_completion_length_sum / valid_completion_length_sum + else: + zero_mask_ratio = torch.tensor(1.0, device=completion_mask.device, dtype=completion_mask.dtype) + + return zero_mask_ratio, global_balancing_ratio + + def _compute_dqw_weights( + self, + rewards: torch.Tensor, + rewards_per_func: torch.Tensor, + num_generations: int, + ) -> torch.Tensor: + """ + Compute question-level difficulty balancing weights (DQW). + + Assigns higher weight to harder questions (lower mean accuracy) using a temperature-scaled softmax over + per-question accuracy means. Zero-variance questions receive a neutral weight of 1. + The returned weights sum to `num_questions`. + + Args: + rewards: Tensor of shape `(num_questions * num_generations,)` with per-generation rewards. + rewards_per_func: Tensor of shape `(num_questions * num_generations, num_reward_funcs)` with + per-reward-function scores; the column at `dgpo_dqw_acc_reward_index` is used as accuracy. + num_generations: Number of generations per question. + + Returns: + Tensor of shape `(num_questions,)` with difficulty balancing weights. + """ + num_questions = rewards.size(0) // num_generations + acc_rewards = rewards_per_func[:, self.dgpo_dqw_acc_reward_index] + mean_per_q_acc = acc_rewards.view(-1, num_generations).nanmean(dim=1) # (num_questions,) + std_per_q_acc = acc_rewards.view(-1, num_generations).std(dim=1) # (num_questions,) + is_std_zero_q = std_per_q_acc < 1e-8 + num_zero_variance_questions = is_std_zero_q.sum().item() + difficulty_balancing_weights = torch.ones( + num_questions, device=rewards.device, dtype=rewards.dtype + ) + if num_zero_variance_questions < num_questions: + mean_per_q_acc_modified = mean_per_q_acc.clone() + mean_per_q_acc_modified[(mean_per_q_acc == 0) | torch.isnan(mean_per_q_acc)] = 1.0 + difficulty_balancing_weights[~is_std_zero_q] = ( + num_questions - num_zero_variance_questions + ) * torch.nn.functional.softmax( + -mean_per_q_acc_modified[~is_std_zero_q] / self.dgpo_dqw_temp, dim=0 + ) + return difficulty_balancing_weights + def _generate_and_score_completions( self, inputs: list[dict[str, torch.Tensor | Any]] ) -> dict[str, torch.Tensor | Any]: @@ -1840,7 +1929,12 @@ def _generate_and_score_completions( advantages = rewards - mean_grouped_rewards if self.scale_rewards != "none": - advantages = advantages / (std_rewards + 1e-4) + if self.use_dgpo_dgae: + mad_rewards = advantages.abs().view(-1, num_generations).mean(dim=1) + mad_rewards = mad_rewards.repeat_interleave(num_generations, dim=0) + advantages = advantages / (mad_rewards + 1e-4) + else: + advantages = advantages / (std_rewards + 1e-4) is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging elif self.multi_objective_aggregation == "normalize_then_sum": @@ -1851,7 +1945,13 @@ def _generate_and_score_completions( reward_k = reward_k.view(-1, len(self.reward_funcs)) rewards = (reward_k * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) std_rewards = rewards.std().expand_as(rewards) if rewards.numel() > 1 else torch.zeros_like(rewards) - advantages = (rewards - rewards.mean()) / (std_rewards + 1e-4) + advantages = rewards - rewards.mean() + if self.use_dgpo_dgae: + mad_rewards = advantages.abs().view(-1, num_generations).mean(dim=1) + mad_rewards = mad_rewards.repeat_interleave(num_generations, dim=0) + advantages = advantages / (mad_rewards + 1e-4) + else: + advantages = advantages / (std_rewards + 1e-4) is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging else: @@ -1860,6 +1960,19 @@ def _generate_and_score_completions( "'sum_then_normalize' or 'normalize_then_sum'." ) + # zero_mask_ratio must be applied before the process slice; global_balancing_ratio after + if self.use_dgpo_dgae or self.use_dgpo_dqw: + zero_mask_ratio, global_balancing_ratio = self._compute_valid_token_balancing_ratios( + completion_mask, is_std_zero + ) + advantages = advantages * zero_mask_ratio + + if self.use_dgpo_dqw: + difficulty_balancing_weights = self._compute_dqw_weights( + rewards, rewards_per_func, num_generations + ) + advantages = advantages * difficulty_balancing_weights.repeat_interleave(num_generations) + # Slice to keep only the local part of the data process_slice = slice( self.accelerator.process_index * len(prompts), @@ -1868,6 +1981,9 @@ def _generate_and_score_completions( all_process_advantages = advantages.clone() # keep the aggregated advantages for logging advantages = advantages[process_slice] + if self.use_dgpo_dgae or self.use_dgpo_dqw: + advantages = advantages * global_balancing_ratio + # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) for i, reward_func_name in enumerate(self.reward_func_names): mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()