Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions docs/source/paper_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,30 @@ training_args = GRPOConfig(
)
```

### DGPO: Difficulty-Aware Group Policy Optimization

**📜 Paper**: https://huggingface.co/papers/2601.20614

DGPO extends GRPO with difficulty-aware mechanisms to improve training on tasks with varying question difficulty (e.g., math reasoning). It is introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614) (ICLR 2026) and is supported in [`GRPOTrainer`] via [`GRPOConfig`].

- **DGAE (Difficulty-balanced Group Advantage Estimation)**: When `use_dgpo_dgae=True`, advantages are scaled using Mean Absolute Deviation (MAD) instead of standard deviation, i.e. advantage = (reward - mean) / (MAD + eps), which can address the implicit imbalance where the update magnitudes are suppressed for both easier and harder questions and peak for those of moderate difficulty.
- **DQW (Difficulty-aware Question-level Weighting)**: When `use_dgpo_dqw=True`, each question (prompt group) is assigned a weight based on its difficulty (e.g., mean accuracy reward). Harder questions get higher weight, so the policy focuses more on them. Use `dgpo_dqw_temp` to control how sharp the weighting is (lower = more focus on hard questions) and `dgpo_dqw_acc_reward_index` to specify which reward in `reward_funcs` is used as the accuracy/difficulty signal.

To use DGPO in TRL, enable the corresponding options in [`GRPOConfig`]:

```python
from trl import GRPOConfig, GRPOTrainer

training_args = GRPOConfig(
...,
use_dgpo_dgae=True,
use_dgpo_dqw=True,
dgpo_dqw_temp=2.0,
dgpo_dqw_acc_reward_index=0,
)
trainer = GRPOTrainer(..., args=training_args, reward_funcs=[...], train_dataset=...)
```

### Part I: Tricks or Traps? A Deep Dive into RL for LLM Reasoning (Lite PPO)

**📜 Paper**: https://huggingface.co/papers/2508.08221
Expand Down
25 changes: 25 additions & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,31 @@ def test_training_loss_types(self, loss_type):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

def test_training_dgpo(self):
"""Test DGPO (Difficulty-Aware Group Policy Optimization) runs without error."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

training_args = GRPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1,
per_device_train_batch_size=3,
num_generations=3,
max_completion_length=32,
report_to="none",
use_dgpo_dgae=True,
use_dgpo_dqw=True,
dgpo_dqw_temp=2.0,
dgpo_dqw_acc_reward_index=0,
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)
trainer.train()
assert trainer.state.log_history[-1]["train_loss"] is not None

def test_training_with_eval(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")

Expand Down
57 changes: 57 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,26 @@ class GRPOConfig(_BaseConfig):
Whether to use the unbiased KL divergence estimator with importance sampling correction. This corrects the
KL divergence estimate by multiplying it with the importance sampling ratio. This is described in the
[DeepSeek-V3.2 paper](https://huggingface.co/papers/2512.02556).
use_dgpo_dgae (`bool`, *optional*, defaults to `False`):
Whether to use difficulty-balanced group advantage estimation (DGAE). When `True`, group-relative
advantages are normalized by the mean absolute deviation (MAD) of rewards (instead of the standard
deviation): `advantage = (reward - mean) / (MAD + eps)`, where `MAD = mean(|reward - mean|)`. Introduced
in the [MathForge paper](https://huggingface.co/papers/2601.20614).
use_dgpo_dqw (`bool`, *optional*, defaults to `False`):
Whether to use difficulty-aware question-level weighting (DQW). When `True`, each question gets a weight
based on its estimated difficulty, and that weight is multiplied into the advantages—so harder questions
produce larger effective updates. Difficulty is computed as a softmax over the negative per-question mean
accuracy reward from `reward_funcs[dgpo_dqw_acc_reward_index]`. Introduced in the [MathForge
paper](https://huggingface.co/papers/2601.20614).
dgpo_dqw_temp (`float`, *optional*, defaults to `2.0`):
Temperature for the DQW softmax over difficulty scores (negative mean accuracy reward). Higher values make
weights more uniform across questions; lower values concentrate weight on the hardest questions.
Introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614).
dgpo_dqw_acc_reward_index (`int`, *optional*, defaults to `0`):
Index into `reward_funcs` selecting the reward used as the "accuracy" signal for DQW's difficulty
estimate. For each question, DQW uses the mean reward at this index across the group to compute the
difficulty score (lower mean ⇒ harder), which is then turned into a weight via the DQW softmax.
Introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614).

> Parameters that control the logging

Expand Down Expand Up @@ -784,6 +804,43 @@ class GRPOConfig(_BaseConfig):
},
)

use_dgpo_dgae: bool = field(
default=False,
metadata={
"help": "Whether to use difficulty-balanced group advantage estimation (DGAE). When `True`, group-relative "
"advantages are normalized by the mean absolute deviation (MAD) of rewards (instead of the standard "
"deviation): `advantage = (reward - mean) / (MAD + eps)`, where `MAD = mean(|reward - mean|)`. Introduced "
"in the [MathForge paper](https://huggingface.co/papers/2601.20614)."
},
)
use_dgpo_dqw: bool = field(
default=False,
metadata={
"help": "Whether to use difficulty-aware question-level weighting (DQW). When `True`, each question gets a "
"weight based on its estimated difficulty, and that weight is multiplied into the advantages—so harder "
"questions produce larger effective updates. Difficulty is computed as a softmax over the negative "
"per-question mean accuracy reward from `reward_funcs[dgpo_dqw_acc_reward_index]`. Introduced in the "
"[MathForge paper](https://huggingface.co/papers/2601.20614)."
},
)
dgpo_dqw_temp: float = field(
default=2.0,
metadata={
"help": "Temperature for the DQW softmax over difficulty scores (negative mean accuracy reward). Higher "
"values make weights more uniform across questions; lower values concentrate weight on the hardest "
"questions. Introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614)."
},
)
dgpo_dqw_acc_reward_index: int = field(
default=0,
metadata={
"help": "Index into `reward_funcs` selecting the reward used as the \"accuracy\" signal for DQW's "
"difficulty estimate. For each question, DQW uses the mean reward at this index across the group to "
"compute the difficulty score (lower mean ⇒ harder), which is then turned into a weight via the DQW "
"softmax. Introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614)."
},
)

# Parameters that control the logging
log_completions: bool = field(
default=False,
Expand Down
120 changes: 118 additions & 2 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,15 @@ def __init__(
raise NotImplementedError(
"Liger Kernels don't currently support masking token positions based on entropy."
)
self.use_dgpo_dgae = args.use_dgpo_dgae
self.use_dgpo_dqw = args.use_dgpo_dqw
self.dgpo_dqw_temp = args.dgpo_dqw_temp
self.dgpo_dqw_acc_reward_index = args.dgpo_dqw_acc_reward_index
if self.use_dgpo_dqw and (self.dgpo_dqw_acc_reward_index < 0 or self.dgpo_dqw_acc_reward_index >= len(self.reward_funcs)):
raise ValueError(
f"dgpo_dqw_acc_reward_index must be in [0, {len(self.reward_funcs)}), got "
f"{self.dgpo_dqw_acc_reward_index}."
)
if self.use_liger_kernel and not self.importance_sampling_level == "token":
raise NotImplementedError(
"Liger Kernels currently only support token-level importance sampling. Please set"
Expand Down Expand Up @@ -1591,6 +1600,86 @@ def _generate(self, prompts: list):
extra_fields,
)

def _compute_valid_token_balancing_ratios(
self,
completion_mask: torch.Tensor,
is_std_zero: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute token-level balancing ratios for distributed training with filtered questions.

When zero-variance questions are masked out, the effective number of valid tokens changes across
processes. This method produces two correction factors: `zero_mask_ratio` compensates for the tokens
lost by masking zero-variance questions, and `global_balancing_ratio` corrects for uneven token counts
across processes.

Args:
completion_mask: Boolean tensor of shape `(batch_size, seq_len)` indicating valid completion tokens.
is_std_zero: Boolean tensor of shape `(batch_size,)` indicating zero-variance questions.

Returns:
A tuple `(zero_mask_ratio, global_balancing_ratio)` of scalar tensors.
"""
completion_length_local = completion_mask.sum(dim=1)
completion_length_global = gather(completion_length_local)

global_completion_length_sum = completion_length_global.sum().clamp(min=1e-8)
local_completion_length_sum = completion_length_local.sum()

global_balancing_ratio = (
self.accelerator.num_processes * local_completion_length_sum / global_completion_length_sum
)

valid_mask_global = ~gather(is_std_zero)
if valid_mask_global.any():
valid_completion_length_sum = completion_length_global[valid_mask_global].sum().clamp(min=1e-8)
zero_mask_ratio = global_completion_length_sum / valid_completion_length_sum
else:
zero_mask_ratio = torch.tensor(1.0, device=completion_mask.device, dtype=completion_mask.dtype)

return zero_mask_ratio, global_balancing_ratio

def _compute_dqw_weights(
self,
rewards: torch.Tensor,
rewards_per_func: torch.Tensor,
num_generations: int,
) -> torch.Tensor:
"""
Compute question-level difficulty balancing weights (DQW).

Assigns higher weight to harder questions (lower mean accuracy) using a temperature-scaled softmax over
per-question accuracy means. Zero-variance questions receive a neutral weight of 1.
The returned weights sum to `num_questions`.

Args:
rewards: Tensor of shape `(num_questions * num_generations,)` with per-generation rewards.
rewards_per_func: Tensor of shape `(num_questions * num_generations, num_reward_funcs)` with
per-reward-function scores; the column at `dgpo_dqw_acc_reward_index` is used as accuracy.
num_generations: Number of generations per question.

Returns:
Tensor of shape `(num_questions,)` with difficulty balancing weights.
"""
num_questions = rewards.size(0) // num_generations
acc_rewards = rewards_per_func[:, self.dgpo_dqw_acc_reward_index]
mean_per_q_acc = acc_rewards.view(-1, num_generations).nanmean(dim=1) # (num_questions,)
std_per_q_acc = acc_rewards.view(-1, num_generations).std(dim=1) # (num_questions,)
is_std_zero_q = std_per_q_acc < 1e-8
num_zero_variance_questions = is_std_zero_q.sum().item()
difficulty_balancing_weights = torch.ones(
num_questions, device=rewards.device, dtype=rewards.dtype
)
if num_zero_variance_questions < num_questions:
mean_per_q_acc_modified = mean_per_q_acc.clone()
mean_per_q_acc_modified[(mean_per_q_acc == 0) | torch.isnan(mean_per_q_acc)] = 1.0
difficulty_balancing_weights[~is_std_zero_q] = (
num_questions - num_zero_variance_questions
) * torch.nn.functional.softmax(
-mean_per_q_acc_modified[~is_std_zero_q] / self.dgpo_dqw_temp, dim=0
)
return difficulty_balancing_weights

def _generate_and_score_completions(
self, inputs: list[dict[str, torch.Tensor | Any]]
) -> dict[str, torch.Tensor | Any]:
Expand Down Expand Up @@ -1840,7 +1929,12 @@ def _generate_and_score_completions(

advantages = rewards - mean_grouped_rewards
if self.scale_rewards != "none":
advantages = advantages / (std_rewards + 1e-4)
if self.use_dgpo_dgae:
mad_rewards = advantages.abs().view(-1, num_generations).mean(dim=1)
mad_rewards = mad_rewards.repeat_interleave(num_generations, dim=0)
advantages = advantages / (mad_rewards + 1e-4)
else:
advantages = advantages / (std_rewards + 1e-4)
is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging

elif self.multi_objective_aggregation == "normalize_then_sum":
Expand All @@ -1851,7 +1945,13 @@ def _generate_and_score_completions(
reward_k = reward_k.view(-1, len(self.reward_funcs))
rewards = (reward_k * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
std_rewards = rewards.std().expand_as(rewards) if rewards.numel() > 1 else torch.zeros_like(rewards)
advantages = (rewards - rewards.mean()) / (std_rewards + 1e-4)
advantages = rewards - rewards.mean()
if self.use_dgpo_dgae:
mad_rewards = advantages.abs().view(-1, num_generations).mean(dim=1)
mad_rewards = mad_rewards.repeat_interleave(num_generations, dim=0)
advantages = advantages / (mad_rewards + 1e-4)
else:
advantages = advantages / (std_rewards + 1e-4)
is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging

else:
Expand All @@ -1860,6 +1960,19 @@ def _generate_and_score_completions(
"'sum_then_normalize' or 'normalize_then_sum'."
)

# zero_mask_ratio must be applied before the process slice; global_balancing_ratio after
if self.use_dgpo_dgae or self.use_dgpo_dqw:
zero_mask_ratio, global_balancing_ratio = self._compute_valid_token_balancing_ratios(
completion_mask, is_std_zero
)
advantages = advantages * zero_mask_ratio

if self.use_dgpo_dqw:
difficulty_balancing_weights = self._compute_dqw_weights(
rewards, rewards_per_func, num_generations
)
advantages = advantages * difficulty_balancing_weights.repeat_interleave(num_generations)

# Slice to keep only the local part of the data
process_slice = slice(
self.accelerator.process_index * len(prompts),
Expand All @@ -1868,6 +1981,9 @@ def _generate_and_score_completions(
all_process_advantages = advantages.clone() # keep the aggregated advantages for logging
advantages = advantages[process_slice]

if self.use_dgpo_dgae or self.use_dgpo_dqw:
advantages = advantages * global_balancing_ratio

# Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
for i, reward_func_name in enumerate(self.reward_func_names):
mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
Expand Down