Skip to content
43 changes: 35 additions & 8 deletions docs/source/paper_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,24 @@ trainer = GRPOTrainer(
)
```

### INTELLECT-2: A Reasoning Model Trained Through Globally Decentralized Reinforcement Learning

**📜 Paper**: https://huggingface.co/papers/2505.07291

INTELLECT-2 is the first globally distributed reinforcement learning training run of a 32 billion parameter language model using fully asynchronous RL across a dynamic, heterogeneous swarm of permissionless compute contributors. The authors propose modifications to the standard GRPO training recipe, including two-sided GRPO clipping for increased training stability. To reproduce the paper's setting, use this configuration:

```python
from trl import GRPOConfig

training_args = GRPOConfig(
delta=4, # δ in section 4.1 of the paper
epsilon=0.2, # ε in section 4.1 of the paper
beta=0.001, # KL divergence coefficient in section 4.1 of the paper
num_generations=16, # responses per prompt in section 4.1 of the paper
learning_rate=3e-7, # section 4.1 of the paper
)
```

### Beyond the 80/20 Rule: High-Entropy Minority Tokens Drive Effective Reinforcement Learning for LLM Reasoning

**📜 Paper**: https://huggingface.co/papers/2506.01939
Expand Down Expand Up @@ -573,21 +591,30 @@ training_args = GRPOConfig(
)
```

### INTELLECT-2: A Reasoning Model Trained Through Globally Decentralized Reinforcement Learning
### VESPO: Variational Sequence-Level Soft Policy Optimization for Stable Off-Policy LLM Training

**📜 Paper**: https://huggingface.co/papers/2505.07291
**📜 Paper**: https://huggingface.co/papers/2602.10693

INTELLECT-2 is the first globally distributed reinforcement learning training run of a 32 billion parameter language model using fully asynchronous RL across a dynamic, heterogeneous swarm of permissionless compute contributors. The authors propose modifications to the standard GRPO training recipe, including two-sided GRPO clipping for increased training stability. To reproduce the paper's setting, use this configuration:
VESPO addresses training instability in off-policy RL caused by policy staleness, asynchronous updates, and train-inference mismatches. Rather than relying on heuristic token-level clipping (GRPO) or sequence-length normalization (GSPO), VESPO derives a principled reshaping kernel from a variational framework. In practice, this yields a smooth, asymmetric Gamma weighting function that gracefully suppresses extreme sequence-level importance weights without introducing length bias.

$$
\mathcal{L}_{\text{VESPO}}(\theta) = - \mathbb{E}_{\tau \sim \mu} \left[ \underbrace{W(\tau)^{k} \cdot \exp\left(\lambda
(1 - W(\tau))\right)}_{\phi(W) \text{ detached }} \cdot \mathcal{A}(\tau) \cdot \log \pi_\theta(\tau) \right]
$$

with \\( W(\tau) = \frac{\pi_\theta(\tau)}{\mu(\tau)} \\) the sequence level importance ratio, and \\( \phi(W) \\) is detached from the computation graph to serve as a gradient scaling coefficient.

```python
from trl import GRPOConfig

training_args = GRPOConfig(
delta=4, # δ in section 4.1 of the paper
epsilon=0.2, # ε in section 4.1 of the paper
beta=0.001, # KL divergence coefficient in section 4.1 of the paper
num_generations=16, # responses per prompt in section 4.1 of the paper
learning_rate=3e-7, # section 4.1 of the paper
loss_type="vespo",
use_vllm=True, # or False if not using any token-level `vllm_importance_sampling_correction` methods
vllm_importance_sampling_mode="token_truncate", # default correction mode for VESPO, `token_mask` also supported
vespo_k_pos=2.0, # power exponent (c1 in paper Section 3.4) for positive advantages
vespo_lambda_pos=3.0, # decay factor (c2 in paper Section 3.4) for positive advantages
vespo_k_neg=3.0, # power exponent (c1 in paper Section 3.4) for negative advantages
vespo_lambda_neg=2.0, # decay factor (c2 in paper Section 3.4) for negative advantages
)
```

Expand Down
2 changes: 1 addition & 1 deletion tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def test_training(self, config_name):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

@pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo"])
@pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo", "vespo"])
def test_training_loss_types(self, loss_type):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

Expand Down
49 changes: 49 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,19 @@ class GRPOConfig(_BaseConfig):
sapo_temperature_pos (`float`, *optional*, defaults to `1.0`):
Temperature for tokens with positive advantage scores used in the `sapo` loss function. This parameter is
introduced in the [Soft Adaptive Policy Optimization paper](https://huggingface.co/papers/2511.20347).
vespo_k_pos (`float`, *optional*, defaults to `2.0`):
k parameter for positive advantages, it is the power exponent in the VESPO loss. Controls how aggressively
we down-weight samples with low importance weights (when the importance sampling ratio < 1).
vespo_lambda_pos (`float`, *optional*, defaults to `3.0`):
lambda parameter for positive advantages, it is the decay factor in the VESPO loss. Controls how
aggressively we down-weight samples with high importance weights (when the importance sampling ratio > 1).
vespo_k_neg (`float`, *optional*, defaults to `3.0`):
k parameter for negative advantages, it is the power exponent in the VESPO loss. Controls how aggressively
we down-weight samples with low importance weights (when the importance sampling ratio < 1).
vespo_lambda_neg (`float`, *optional*, defaults to `2.0`):
lambda parameter for negative advantages, it is the exponential decay factor in the VESPO loss. Controls
how aggressively we down-weight samples with high importance weights (when the importance sampling ratio >
1).
importance_sampling_level (`str`, *optional*, defaults to `"token"`):
Controls whether importance sampling ratios are computed at the `"token"` or `"sequence"` level. `"token"`
keeps the raw per-token log-probability ratios (one weight per token). `"sequence"` averages the
Expand Down Expand Up @@ -243,6 +256,9 @@ class GRPOConfig(_BaseConfig):
sequence's loss by its length. This is a modification of GSPO and requires
`importance_sampling_level="sequence"`. Introduced in the [LUSPO
paper](https://huggingface.co/papers/2602.05261).
- `"vespo"`: Variational Sequence-Level Soft Policy Optimization. Replaces hard clipping with a smooth,
asymmetric Gamma weighting function applied directly to sequence-level importance weights. Introduced in
the [VESPO paper](https://huggingface.co/papers/2602.10693).
mask_truncated_completions (`bool`, *optional*, defaults to `False`):
When enabled, truncated completions are excluded from the loss calculation, preventing them from being
incorrectly penalized and introducing noise during training. According to the
Expand Down Expand Up @@ -619,6 +635,36 @@ class GRPOConfig(_BaseConfig):
"paper](https://huggingface.co/papers/2511.20347)."
},
)
vespo_k_pos: float = field(
default=2.0,
metadata={
"help": "k parameter for positive advantages, it is the power exponent in the VESPO loss. Controls how "
"aggressively we down-weight samples with low importance weights (when the importance sampling ratio < 1)."
},
)
vespo_lambda_pos: float = field(
default=3.0,
metadata={
"help": "lambda parameter for positive advantages, it is the decay factor in the VESPO loss. Controls "
"how aggressively we down-weight samples with high importance weights (when the importance sampling ratio "
"> 1)."
},
)
vespo_k_neg: float = field(
default=3.0,
metadata={
"help": "k parameter for negative advantages, it is the power exponent in the VESPO loss. Controls how "
"aggressively we down-weight samples with low importance weights (when the importance sampling ratio < 1)."
},
)
vespo_lambda_neg: float = field(
default=2.0,
metadata={
"help": "lambda parameter for negative advantages, it is the exponential decay factor in the VESPO loss. "
"Controls how aggressively we down-weight samples with high importance weights (when the importance "
"sampling ratio > 1)."
},
)
importance_sampling_level: str = field(
default="token",
metadata={
Expand Down Expand Up @@ -690,6 +736,9 @@ class GRPOConfig(_BaseConfig):
"sequence's loss by its length. This is a modification of GSPO and requires "
"`importance_sampling_level='sequence'`. Introduced in the [LUSPO "
"paper](https://huggingface.co/papers/2602.05261)."
"'vespo': Variational Sequence-Level Soft Policy Optimization. Replaces hard clipping with a smooth, "
"asymmetric Gamma weighting function applied directly to sequence-level importance weights. Introduced in "
"the [VESPO paper](https://huggingface.co/papers/2602.10693)."
},
)
mask_truncated_completions: bool = field(
Expand Down
83 changes: 81 additions & 2 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import copy
import importlib.resources as pkg_resources
import inspect
import math
import os
import sys
import textwrap
Expand Down Expand Up @@ -578,6 +579,19 @@ def __init__(
"paper's setup."
)

if args.loss_type == "vespo" and args.importance_sampling_level != "token":
logger.warning(
"VESPO computes sequence-level importance weights internally. `importance_sampling_level` should be "
"set to `'token'` (the default)."
)

if self.loss_type == "vespo" and self.use_vllm and self.vllm_importance_sampling_correction:
if self.vllm_importance_sampling_mode not in ["token_truncate", "token_mask"]:
raise ValueError(
f"VESPO loss requires `vllm_importance_sampling_mode` to be either 'token_truncate' or "
f"'token_mask'. Got: {self.vllm_importance_sampling_mode}."
)

# Multi-step
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
self.epsilon_low = args.epsilon
Expand Down Expand Up @@ -2098,6 +2112,56 @@ def get_off_policy_mask(
is_low_kl = avg_seq_kl <= off_policy_threshold
return (is_pos_adv | is_low_kl).to(dtype=mask.dtype) # (B, 1)

@staticmethod
@torch.no_grad()
def get_gamma_weights(
advantages: torch.Tensor,
log_ratio_per_token: torch.Tensor,
mask: torch.Tensor,
importance_sampling_ratio: torch.Tensor | None, # (B, T)
k_pos: float = 2.0,
lambda_pos: float = 3.0,
k_neg: float = 3.0,
lambda_neg: float = 2.0,
) -> torch.Tensor:
"""
Computes the Gamma weights for the VESPO loss. For reference:
φ(w) = e^λ × w^k × e^{-λw} is the gamma weighting (normalized so φ(1)=1)
with w = sequence-level importance sampling ratio
note: we will compute φ(w) in log space

φ(w) is detached via @torch.no_grad(), only acts as gradient scaling coefficient

VESPO loss = -φ(w) × A × log_prob, gradient naturally gives φ(w) × A × ∇log π
"""
# reducing clamp range directly to log(1e-8) ~ -18.42, to avoid recomputing log_w=log(w.clamp(min=1e-8)) later
# This is solely for matching truthfully the original implementation, otherwise keeping -20 could be fine.
lower_clamp = math.log(1e-8)

# Sequence-level log ratio Σ log(π_θ/π_old) (not a mean like for `log_importance_weights`)
log_ratio_clamped = torch.clamp(log_ratio_per_token, -20.0, 20.0)
seq_log_ratio = torch.sum(log_ratio_clamped * mask, dim=-1, keepdim=True) # (B, 1)

# Apply token-level TIS or MIS correction (in log space)
if importance_sampling_ratio is not None:
log_is_ratio = torch.clamp(torch.log(importance_sampling_ratio), lower_clamp, 20.0)
# log(w) = log(π_θ/π_old) + log(π_old/π_sampler)
seq_log_ratio += torch.sum(log_is_ratio, dim=-1, keepdim=True)

log_w_seq = torch.clamp(seq_log_ratio, lower_clamp, 20.0)
w_seq = torch.exp(log_w_seq)

# compute k and lambda based on advantage sign
is_nonneg_adv = advantages >= 0
k_seq = torch.where(is_nonneg_adv, k_pos, k_neg)
lambda_seq = torch.where(is_nonneg_adv, lambda_pos, lambda_neg).clamp(min=1e-4)

# log(φ(w)) = λ + k × log(w) - λ × w
log_phi = lambda_seq + k_seq * log_w_seq - lambda_seq * w_seq
phi_seq = torch.exp(log_phi).nan_to_num(nan=0.0, posinf=0.0, neginf=0.0)

return phi_seq # (B, 1)

def _compute_loss(self, model, inputs):
# Compute the per-token log probabilities for the model
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
Expand Down Expand Up @@ -2199,6 +2263,18 @@ def _compute_loss(self, model, inputs):
temperatures = torch.where(advantages > 0, self.args.sapo_temperature_pos, self.args.sapo_temperature_neg)
soft_coef_1 = torch.sigmoid(temperatures * (coef_1 - 1)) * 4 / temperatures
per_token_loss = -soft_coef_1 * advantages
elif self.loss_type == "vespo":
phi_seq = self.get_gamma_weights(
advantages=advantages,
log_ratio_per_token=log_ratio,
mask=mask,
importance_sampling_ratio=inputs.get("importance_sampling_ratio"),
k_pos=self.args.vespo_k_pos,
lambda_pos=self.args.vespo_lambda_pos,
k_neg=self.args.vespo_k_neg,
lambda_neg=self.args.vespo_lambda_neg,
)
per_token_loss = -phi_seq * advantages * per_token_logps
else:
raise ValueError(f"Unknown loss type: {self.loss_type}")

Expand All @@ -2208,7 +2284,7 @@ def _compute_loss(self, model, inputs):
if entropy_mask is not None:
per_token_loss = per_token_loss * entropy_mask

if self.use_vllm and self.vllm_importance_sampling_correction:
if self.use_vllm and self.vllm_importance_sampling_correction and self.loss_type != "vespo":
per_token_loss = per_token_loss * inputs["importance_sampling_ratio"]

if self.beta != 0.0:
Expand All @@ -2227,7 +2303,7 @@ def _compute_loss(self, model, inputs):
loss = (per_token_loss * mask).sum() / (per_token_loss.size(0) * self.max_completion_length)
normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval
loss = loss / normalizer
elif self.loss_type in ["cispo", "dapo"]:
elif self.loss_type in ["cispo", "dapo", "vespo"]:
normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes
loss = (per_token_loss * mask).sum() / normalizer
elif self.loss_type == "luspo":
Expand Down Expand Up @@ -2277,6 +2353,9 @@ def masked_batch_mean(x):
cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float())
gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio)
self._metrics[mode]["cispo_clip_ratio"].append(gathered_cispo_clip_ratio.nanmean().item())
elif self.loss_type == "vespo":
gathered_phi_seq = self.accelerator.gather(phi_seq)
self._metrics[mode]["vespo/phi_seq_mean"].append(gathered_phi_seq.nanmean().item())

return loss

Expand Down
Loading