diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index 33bfb02542..51fe27c8e1 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -165,6 +165,24 @@ trainer = GRPOTrainer( ) ``` +### INTELLECT-2: A Reasoning Model Trained Through Globally Decentralized Reinforcement Learning + +**📜 Paper**: https://huggingface.co/papers/2505.07291 + +INTELLECT-2 is the first globally distributed reinforcement learning training run of a 32 billion parameter language model using fully asynchronous RL across a dynamic, heterogeneous swarm of permissionless compute contributors. The authors propose modifications to the standard GRPO training recipe, including two-sided GRPO clipping for increased training stability. To reproduce the paper's setting, use this configuration: + +```python +from trl import GRPOConfig + +training_args = GRPOConfig( + delta=4, # δ in section 4.1 of the paper + epsilon=0.2, # ε in section 4.1 of the paper + beta=0.001, # KL divergence coefficient in section 4.1 of the paper + num_generations=16, # responses per prompt in section 4.1 of the paper + learning_rate=3e-7, # section 4.1 of the paper +) +``` + ### Beyond the 80/20 Rule: High-Entropy Minority Tokens Drive Effective Reinforcement Learning for LLM Reasoning **📜 Paper**: https://huggingface.co/papers/2506.01939 @@ -573,21 +591,30 @@ training_args = GRPOConfig( ) ``` -### INTELLECT-2: A Reasoning Model Trained Through Globally Decentralized Reinforcement Learning +### VESPO: Variational Sequence-Level Soft Policy Optimization for Stable Off-Policy LLM Training -**📜 Paper**: https://huggingface.co/papers/2505.07291 +**📜 Paper**: https://huggingface.co/papers/2602.10693 -INTELLECT-2 is the first globally distributed reinforcement learning training run of a 32 billion parameter language model using fully asynchronous RL across a dynamic, heterogeneous swarm of permissionless compute contributors. The authors propose modifications to the standard GRPO training recipe, including two-sided GRPO clipping for increased training stability. To reproduce the paper's setting, use this configuration: +VESPO addresses training instability in off-policy RL caused by policy staleness, asynchronous updates, and train-inference mismatches. Rather than relying on heuristic token-level clipping (GRPO) or sequence-length normalization (GSPO), VESPO derives a principled reshaping kernel from a variational framework. In practice, this yields a smooth, asymmetric Gamma weighting function that gracefully suppresses extreme sequence-level importance weights without introducing length bias. + +$$ +\mathcal{L}_{\text{VESPO}}(\theta) = - \mathbb{E}_{\tau \sim \mu} \left[ \underbrace{W(\tau)^{k} \cdot \exp\left(\lambda +(1 - W(\tau))\right)}_{\phi(W) \text{ detached }} \cdot \mathcal{A}(\tau) \cdot \log \pi_\theta(\tau) \right] +$$ + +with \\( W(\tau) = \frac{\pi_\theta(\tau)}{\mu(\tau)} \\) the sequence level importance ratio, and \\( \phi(W) \\) is detached from the computation graph to serve as a gradient scaling coefficient. ```python from trl import GRPOConfig training_args = GRPOConfig( - delta=4, # δ in section 4.1 of the paper - epsilon=0.2, # ε in section 4.1 of the paper - beta=0.001, # KL divergence coefficient in section 4.1 of the paper - num_generations=16, # responses per prompt in section 4.1 of the paper - learning_rate=3e-7, # section 4.1 of the paper + loss_type="vespo", + use_vllm=True, # or False if not using any token-level `vllm_importance_sampling_correction` methods + vllm_importance_sampling_mode="token_truncate", # default correction mode for VESPO, `token_mask` also supported + vespo_k_pos=2.0, # power exponent (c1 in paper Section 3.4) for positive advantages + vespo_lambda_pos=3.0, # decay factor (c2 in paper Section 3.4) for positive advantages + vespo_k_neg=3.0, # power exponent (c1 in paper Section 3.4) for negative advantages + vespo_lambda_neg=2.0, # decay factor (c2 in paper Section 3.4) for negative advantages ) ``` diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 2b275bae81..988596701d 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -278,7 +278,7 @@ def test_training(self, config_name): new_param = trainer.model.get_parameter(n) assert not torch.equal(param, new_param), f"Parameter {n} has not changed." - @pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo"]) + @pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo", "vespo"]) def test_training_loss_types(self, loss_type): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 36147251f1..57c529b299 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -186,6 +186,19 @@ class GRPOConfig(_BaseConfig): sapo_temperature_pos (`float`, *optional*, defaults to `1.0`): Temperature for tokens with positive advantage scores used in the `sapo` loss function. This parameter is introduced in the [Soft Adaptive Policy Optimization paper](https://huggingface.co/papers/2511.20347). + vespo_k_pos (`float`, *optional*, defaults to `2.0`): + k parameter for positive advantages, it is the power exponent in the VESPO loss. Controls how aggressively + we down-weight samples with low importance weights (when the importance sampling ratio < 1). + vespo_lambda_pos (`float`, *optional*, defaults to `3.0`): + lambda parameter for positive advantages, it is the decay factor in the VESPO loss. Controls how + aggressively we down-weight samples with high importance weights (when the importance sampling ratio > 1). + vespo_k_neg (`float`, *optional*, defaults to `3.0`): + k parameter for negative advantages, it is the power exponent in the VESPO loss. Controls how aggressively + we down-weight samples with low importance weights (when the importance sampling ratio < 1). + vespo_lambda_neg (`float`, *optional*, defaults to `2.0`): + lambda parameter for negative advantages, it is the exponential decay factor in the VESPO loss. Controls + how aggressively we down-weight samples with high importance weights (when the importance sampling ratio > + 1). importance_sampling_level (`str`, *optional*, defaults to `"token"`): Controls whether importance sampling ratios are computed at the `"token"` or `"sequence"` level. `"token"` keeps the raw per-token log-probability ratios (one weight per token). `"sequence"` averages the @@ -243,6 +256,9 @@ class GRPOConfig(_BaseConfig): sequence's loss by its length. This is a modification of GSPO and requires `importance_sampling_level="sequence"`. Introduced in the [LUSPO paper](https://huggingface.co/papers/2602.05261). + - `"vespo"`: Variational Sequence-Level Soft Policy Optimization. Replaces hard clipping with a smooth, + asymmetric Gamma weighting function applied directly to sequence-level importance weights. Introduced in + the [VESPO paper](https://huggingface.co/papers/2602.10693). mask_truncated_completions (`bool`, *optional*, defaults to `False`): When enabled, truncated completions are excluded from the loss calculation, preventing them from being incorrectly penalized and introducing noise during training. According to the @@ -619,6 +635,36 @@ class GRPOConfig(_BaseConfig): "paper](https://huggingface.co/papers/2511.20347)." }, ) + vespo_k_pos: float = field( + default=2.0, + metadata={ + "help": "k parameter for positive advantages, it is the power exponent in the VESPO loss. Controls how " + "aggressively we down-weight samples with low importance weights (when the importance sampling ratio < 1)." + }, + ) + vespo_lambda_pos: float = field( + default=3.0, + metadata={ + "help": "lambda parameter for positive advantages, it is the decay factor in the VESPO loss. Controls " + "how aggressively we down-weight samples with high importance weights (when the importance sampling ratio " + "> 1)." + }, + ) + vespo_k_neg: float = field( + default=3.0, + metadata={ + "help": "k parameter for negative advantages, it is the power exponent in the VESPO loss. Controls how " + "aggressively we down-weight samples with low importance weights (when the importance sampling ratio < 1)." + }, + ) + vespo_lambda_neg: float = field( + default=2.0, + metadata={ + "help": "lambda parameter for negative advantages, it is the exponential decay factor in the VESPO loss. " + "Controls how aggressively we down-weight samples with high importance weights (when the importance " + "sampling ratio > 1)." + }, + ) importance_sampling_level: str = field( default="token", metadata={ @@ -690,6 +736,9 @@ class GRPOConfig(_BaseConfig): "sequence's loss by its length. This is a modification of GSPO and requires " "`importance_sampling_level='sequence'`. Introduced in the [LUSPO " "paper](https://huggingface.co/papers/2602.05261)." + "'vespo': Variational Sequence-Level Soft Policy Optimization. Replaces hard clipping with a smooth, " + "asymmetric Gamma weighting function applied directly to sequence-level importance weights. Introduced in " + "the [VESPO paper](https://huggingface.co/papers/2602.10693)." }, ) mask_truncated_completions: bool = field( diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index e5b3d21e96..564d55c1ad 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -17,6 +17,7 @@ import copy import importlib.resources as pkg_resources import inspect +import math import os import sys import textwrap @@ -578,6 +579,19 @@ def __init__( "paper's setup." ) + if args.loss_type == "vespo" and args.importance_sampling_level != "token": + logger.warning( + "VESPO computes sequence-level importance weights internally. `importance_sampling_level` should be " + "set to `'token'` (the default)." + ) + + if self.loss_type == "vespo" and self.use_vllm and self.vllm_importance_sampling_correction: + if self.vllm_importance_sampling_mode not in ["token_truncate", "token_mask"]: + raise ValueError( + f"VESPO loss requires `vllm_importance_sampling_mode` to be either 'token_truncate' or " + f"'token_mask'. Got: {self.vllm_importance_sampling_mode}." + ) + # Multi-step self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper self.epsilon_low = args.epsilon @@ -2098,6 +2112,56 @@ def get_off_policy_mask( is_low_kl = avg_seq_kl <= off_policy_threshold return (is_pos_adv | is_low_kl).to(dtype=mask.dtype) # (B, 1) + @staticmethod + @torch.no_grad() + def get_gamma_weights( + advantages: torch.Tensor, + log_ratio_per_token: torch.Tensor, + mask: torch.Tensor, + importance_sampling_ratio: torch.Tensor | None, # (B, T) + k_pos: float = 2.0, + lambda_pos: float = 3.0, + k_neg: float = 3.0, + lambda_neg: float = 2.0, + ) -> torch.Tensor: + """ + Computes the Gamma weights for the VESPO loss. For reference: + φ(w) = e^λ × w^k × e^{-λw} is the gamma weighting (normalized so φ(1)=1) + with w = sequence-level importance sampling ratio + note: we will compute φ(w) in log space + + φ(w) is detached via @torch.no_grad(), only acts as gradient scaling coefficient + + VESPO loss = -φ(w) × A × log_prob, gradient naturally gives φ(w) × A × ∇log π + """ + # reducing clamp range directly to log(1e-8) ~ -18.42, to avoid recomputing log_w=log(w.clamp(min=1e-8)) later + # This is solely for matching truthfully the original implementation, otherwise keeping -20 could be fine. + lower_clamp = math.log(1e-8) + + # Sequence-level log ratio Σ log(π_θ/π_old) (not a mean like for `log_importance_weights`) + log_ratio_clamped = torch.clamp(log_ratio_per_token, -20.0, 20.0) + seq_log_ratio = torch.sum(log_ratio_clamped * mask, dim=-1, keepdim=True) # (B, 1) + + # Apply token-level TIS or MIS correction (in log space) + if importance_sampling_ratio is not None: + log_is_ratio = torch.clamp(torch.log(importance_sampling_ratio), lower_clamp, 20.0) + # log(w) = log(π_θ/π_old) + log(π_old/π_sampler) + seq_log_ratio += torch.sum(log_is_ratio, dim=-1, keepdim=True) + + log_w_seq = torch.clamp(seq_log_ratio, lower_clamp, 20.0) + w_seq = torch.exp(log_w_seq) + + # compute k and lambda based on advantage sign + is_nonneg_adv = advantages >= 0 + k_seq = torch.where(is_nonneg_adv, k_pos, k_neg) + lambda_seq = torch.where(is_nonneg_adv, lambda_pos, lambda_neg).clamp(min=1e-4) + + # log(φ(w)) = λ + k × log(w) - λ × w + log_phi = lambda_seq + k_seq * log_w_seq - lambda_seq * w_seq + phi_seq = torch.exp(log_phi).nan_to_num(nan=0.0, posinf=0.0, neginf=0.0) + + return phi_seq # (B, 1) + def _compute_loss(self, model, inputs): # Compute the per-token log probabilities for the model prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] @@ -2199,6 +2263,18 @@ def _compute_loss(self, model, inputs): temperatures = torch.where(advantages > 0, self.args.sapo_temperature_pos, self.args.sapo_temperature_neg) soft_coef_1 = torch.sigmoid(temperatures * (coef_1 - 1)) * 4 / temperatures per_token_loss = -soft_coef_1 * advantages + elif self.loss_type == "vespo": + phi_seq = self.get_gamma_weights( + advantages=advantages, + log_ratio_per_token=log_ratio, + mask=mask, + importance_sampling_ratio=inputs.get("importance_sampling_ratio"), + k_pos=self.args.vespo_k_pos, + lambda_pos=self.args.vespo_lambda_pos, + k_neg=self.args.vespo_k_neg, + lambda_neg=self.args.vespo_lambda_neg, + ) + per_token_loss = -phi_seq * advantages * per_token_logps else: raise ValueError(f"Unknown loss type: {self.loss_type}") @@ -2208,7 +2284,7 @@ def _compute_loss(self, model, inputs): if entropy_mask is not None: per_token_loss = per_token_loss * entropy_mask - if self.use_vllm and self.vllm_importance_sampling_correction: + if self.use_vllm and self.vllm_importance_sampling_correction and self.loss_type != "vespo": per_token_loss = per_token_loss * inputs["importance_sampling_ratio"] if self.beta != 0.0: @@ -2227,7 +2303,7 @@ def _compute_loss(self, model, inputs): loss = (per_token_loss * mask).sum() / (per_token_loss.size(0) * self.max_completion_length) normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval loss = loss / normalizer - elif self.loss_type in ["cispo", "dapo"]: + elif self.loss_type in ["cispo", "dapo", "vespo"]: normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes loss = (per_token_loss * mask).sum() / normalizer elif self.loss_type == "luspo": @@ -2277,6 +2353,9 @@ def masked_batch_mean(x): cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float()) gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio) self._metrics[mode]["cispo_clip_ratio"].append(gathered_cispo_clip_ratio.nanmean().item()) + elif self.loss_type == "vespo": + gathered_phi_seq = self.accelerator.gather(phi_seq) + self._metrics[mode]["vespo/phi_seq_mean"].append(gathered_phi_seq.nanmean().item()) return loss