Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/docker/docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
services:
trinity-node-1:
image: trinity-rft-unittest:20251225
image: trinity-rft-unittest:20260115
pull_policy: never
command: bash -c "source /opt/venv/bin/activate && uv pip install -e .[dev] && ray start --head --dashboard-host 0.0.0.0 --include-dashboard true --block"
environment:
Expand Down Expand Up @@ -30,7 +30,7 @@ services:
capabilities: [gpu]

trinity-node-2:
image: trinity-rft-unittest:20251225
image: trinity-rft-unittest:20260115
pull_policy: never
command: bash -c "source /opt/venv/bin/activate && uv pip install -e .[dev] && ray start --address=trinity-node-1:6379 --block"
environment:
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ classifiers = [
]
requires-python = ">=3.10,<3.13"
dependencies = [
"verl==0.5.0",
"verl==0.7.0",
"ray[default]>=2.50.0",
"tensordict",
"wandb",
Expand Down Expand Up @@ -79,7 +79,9 @@ megatron = [
# if you found "undefined symbol" error in transformer engine
# reinstall it with --no-build-isolation and `--no-cache-dir` flag
# "transformer_engine[pytorch]==2.8.0",
"mbridge>=0.13.0",

# Install mbridge from main branch (unreleased version)
"mbridge @ git+https://github.com/ISEEKYAN/mbridge.git@20e9ffbbe72ae7b1df83bfe1bc3c11f7382f2612",
]
tinker = [
"tinker; python_version >= '3.11'",
Expand Down
4 changes: 4 additions & 0 deletions tests/explorer/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,12 +731,16 @@ async def mock_get_api_server_url_remote():
async def mock_get_model_version_remote():
return 1

async def mock_get_api_key_remote():
return "dummy_api_key"

async def mock_get_model_config_remote():
return InferenceModelConfig(model_path="dummy_model")

model = MagicMock()
model.get_api_server_url.remote = MagicMock(side_effect=mock_get_api_server_url_remote)
model.get_model_version.remote = MagicMock(side_effect=mock_get_model_version_remote)
model.get_api_key.remote = MagicMock(side_effect=mock_get_api_key_remote)
model.get_model_config.remote = MagicMock(side_effect=mock_get_model_config_remote)

runner = WorkflowRunner(
Expand Down
28 changes: 21 additions & 7 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,15 @@ class OptimizerConfig:
lr_warmup_steps: int = -1
lr_warmup_steps_ratio: float = 0.0
min_lr_ratio: Optional[float] = 0.0
warmup_style: str = "constant"
warmup_style: Optional[str] = None # deprecated !
lr_scheduler_type: str = "constant"
optimizer_type: str = "adam"
betas: List[float] = field(default_factory=lambda: [0.9, 0.999])
weight_decay: float = 0.01
clip_grad: float = 1.0
lr_warmup_init: float = 0.0
lr_decay_steps: Optional[int] = None
lr_decay_style: str = "constant"
lr_decay_style: str = "constant" # duplicated with lr_scheduler_type in veRL
min_lr: float = 0.0


Expand All @@ -116,6 +117,8 @@ class LoRAConfig:
lora_alpha: int = 32
lora_dtype: str = "auto"
target_modules: str = "all-linear"
exclude_modules: Optional[str] = None
is_dummy: bool = False # DO NOT SET, automatically set


@Experimental
Expand Down Expand Up @@ -1167,6 +1170,14 @@ def check_and_set(name, registry, args_attr):
# override loss_agg_mode in policy_loss_fn_args
self.algorithm.policy_loss_fn_args["loss_agg_mode"] = self.algorithm.loss_agg_mode # type: ignore [index]

optim_config = self.algorithm.optimizer
if optim_config.warmup_style is not None:
optim_config.lr_scheduler_type = optim_config.warmup_style
logger.warning(
"`warmup_style` is deprecated. Please use `lr_scheduler_type` instead. "
f"And `lr_scheduler_type` is set to {optim_config.lr_scheduler_type}."
)

def _check_model(self) -> None:
model = self.model
if not model.critic_model_path:
Expand Down Expand Up @@ -1363,16 +1374,19 @@ def _check_explorer(self) -> None:
self.explorer.rollout_model.enable_lora = True
if len(self.model.lora_configs) > 1:
raise ValueError("Only one lora adapter is supported for now.")
if self.model.lora_configs[0].path is None:
lora_config = self.model.lora_configs[0]
if lora_config.path is None:
logger.info("Creating dummy lora, since no lora_path is provided.")
lora_path = create_dummy_lora(
model_path=self.model.model_path,
checkpoint_job_dir=self.checkpoint_job_dir,
lora_rank=self.model.lora_configs[0].lora_rank,
lora_alpha=self.model.lora_configs[0].lora_alpha,
target_modules=self.model.lora_configs[0].target_modules,
lora_rank=lora_config.lora_rank,
lora_alpha=lora_config.lora_alpha,
target_modules=lora_config.target_modules,
exclude_modules=lora_config.exclude_modules,
)
self.model.lora_configs[0].path = lora_path
lora_config.path = lora_path
lora_config.is_dummy = True
self.explorer.rollout_model.lora_modules = [
{
"lora_int_id": i + 1,
Expand Down
2 changes: 1 addition & 1 deletion trinity/common/models/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class WorkerExtension:
def apply_patches(self):
"""Apply necessary patches to vLLM."""
from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader
from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader

patch_vllm_moe_model_weight_loader(self.model_runner.model)
patch_vllm_prompt_logprobs(self.model_runner)
Expand Down
80 changes: 69 additions & 11 deletions trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from typing import Any, Dict, List, Optional

from omegaconf import OmegaConf
from verl.workers.config import PolicyLossConfig, RouterReplayConfig

from trinity.algorithm import ALGORITHM_TYPE
from trinity.common.config import Config, SynchronizerConfig, set_if_none
from trinity.common.constants import EXPLORER_NAME
from trinity.utils.log import get_logger
Expand Down Expand Up @@ -41,6 +43,8 @@ class ActorModel:
lora_rank: int = 0 # The rank of the LoRA model, default to 0. If lora_rank > 0, LoRA module is enabled in trainer
lora_alpha: int = 32
target_modules: Optional[str] = "all-linear"
exclude_modules: Optional[str] = None
lora_adapter_path: Optional[str] = None

# rope configs
rope_scaling: Optional[dict] = None
Expand All @@ -51,14 +55,15 @@ class ActorModel:
class Optim:
# For actor, most fields are set in algorithm.optimizer
# For critic, you can set trainer_config.critic.optim
optimizer: str = "adam"
optimizer_impl: str = "torch.optim"
lr: float = 1e-6
lr_warmup_steps: int = -1
lr_warmup_steps_ratio: float = 0.0
min_lr_ratio: Optional[float] = 0.0
warmup_style: str = "constant"
lr_scheduler_type: str = "constant"
total_training_steps: int = -1 # ! DO NOT SET, use trainer.total_steps
betas: List[float] = field(default_factory=lambda: [0.9, 0.999])
optimizer: str = "adam"
clip_grad: float = 1.0
lr_warmup_init: float = 0.0
lr_decay_steps: Optional[int] = None
Expand All @@ -69,6 +74,7 @@ class Optim:
lr_wsd_decay_style: str = "exponential"
lr_wsd_decay_steps: Optional[int] = None
use_checkpoint_opt_param_scheduler: bool = False
override_optimizer_config: Optional[dict] = None


@dataclass
Expand All @@ -78,6 +84,7 @@ class WrapPolicy:

@dataclass
class FSDPConfig:
_target_: str = "verl.workers.config.FSDPEngineConfig" # DO NOT SET
param_offload: bool = False
optimizer_offload: bool = False
offload_policy: bool = False
Expand All @@ -92,15 +99,15 @@ class FSDPConfig:
class Checkpoint:
load_contents: List[str] = field(default_factory=lambda: ["model", "optimizer", "extra"])
save_contents: List[str] = field(default_factory=lambda: ["model", "optimizer", "extra"])
async_save: bool = False # do not set, async save has bug in verl megatron training
async_save: bool = False # TODO: testing async save


@dataclass
class OverrideTransformerConfig:
recompute_granularity: Optional[str] = None
recompute_granularity: Optional[str] = "full"
recompute_modules: List[str] = field(default_factory=lambda: ["core_attn"])
recompute_method: Optional[str] = None
recompute_num_layers: Optional[int] = None
recompute_method: Optional[str] = "uniform"
recompute_num_layers: Optional[int] = 1


@dataclass
Expand All @@ -124,6 +131,8 @@ class MegatronConfig:
default_factory=OverrideTransformerConfig
)
use_mbridge: bool = False
dtype: str = "bfloat16"
use_remove_padding: bool = True


@dataclass
Expand Down Expand Up @@ -157,6 +166,9 @@ class Actor:
profile: ProfileConfig = field(default_factory=ProfileConfig)
data_loader_seed: Optional[int] = None
load_weight: bool = True
policy_loss: PolicyLossConfig = field(default_factory=PolicyLossConfig)
profiler: dict = field(default_factory=dict)
router_replay: RouterReplayConfig = field(default_factory=RouterReplayConfig)
# do not set
loss_agg_mode: str = "token-mean"
clip_ratio: float = 0.2
Expand All @@ -182,6 +194,8 @@ class Ref:
megatron: MegatronConfig = field(default_factory=MegatronConfig)
profile: ProfileConfig = field(default_factory=ProfileConfig)
load_weight: bool = True
profiler: dict = field(default_factory=dict)
router_replay: RouterReplayConfig = field(default_factory=RouterReplayConfig)


@dataclass
Expand Down Expand Up @@ -214,6 +228,7 @@ class ActorRolloutRef:
actor: Actor = field(default_factory=Actor)
ref: Ref = field(default_factory=Ref)
rollout: Rollout = field(default_factory=Rollout)
nccl_timeout: float = 600 # ! DO NOT SET, it will be set by `config.synchronizer.sync_timeout`
synchronizer: Optional[SynchronizerConfig] = None
explorer_name: str = EXPLORER_NAME

Expand All @@ -229,9 +244,14 @@ class CriticModel:
use_remove_padding: bool = True
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)

# rope configs
rope_scaling: Optional[dict] = None
rope_theta: Optional[float] = None


@dataclass
class Critic:
enable: bool = False
strategy: Optional[str] = None
optim: Optim = field(default_factory=Optim)
model: CriticModel = field(default_factory=CriticModel)
Expand All @@ -255,7 +275,9 @@ class Critic:
profile: ProfileConfig = field(default_factory=ProfileConfig)
data_loader_seed: Optional[int] = None
load_weight: bool = True
nccl_timeout: float = 600 # ! DO NOT SET, it will be set by `config.synchronizer.sync_timeout`
ray_namespace: str = "" # automatically generated
profiler: dict = field(default_factory=dict)


@dataclass
Expand All @@ -278,6 +300,7 @@ class RewardModel:
use_dynamic_bsz: bool = False
forward_max_token_len_per_gpu: int = 0
reward_manager: str = "naive"
use_reward_loop: bool = True


@dataclass
Expand All @@ -294,8 +317,24 @@ class KL_Ctrl:
target_kl: float = 0.1


@dataclass
class RolloutCorrection:
rollout_is: Optional[str] = None
rollout_is_threshold: float = 2.0
rollout_rs: Optional[str] = None
rollout_rs_threshold: Optional[float] = None
rollout_rs_threshold_lower: Optional[float] = None
rollout_token_veto_threshold: Optional[float] = None
# Because rollout and training in Trinity runs separately,
# rollout_is_batch_normalize is default to True
bypass_mode: bool = True
loss_type: str = "ppo_clip"
rollout_is_batch_normalize: bool = False


@dataclass
class Algorithm:
rollout_correction: RolloutCorrection = field(default_factory=RolloutCorrection)
# ! DO NOT SET gamma or lam below; they are kept here merely for compatibility with verl,
# and their values will be overwritten by those in AlgorithmConfig.advantage_fn_args
# if they are really needed (e.g., for GAE advantage/returns computation)
Expand Down Expand Up @@ -349,6 +388,7 @@ class veRLConfig:
custom_reward_function: CustomRewardFunction = field(default_factory=CustomRewardFunction)
algorithm: Algorithm = field(default_factory=Algorithm)
trainer: Trainer = field(default_factory=Trainer)
global_profiler: dict = field(default_factory=dict)
synchronizer: Optional[SynchronizerConfig] = None
enable_preview: bool = True

Expand Down Expand Up @@ -426,8 +466,12 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
) # kept to pass RayPPOTrainer._validate_config

self.synchronizer = config.synchronizer
self.actor_rollout_ref.nccl_timeout = config.synchronizer.sync_timeout
self.actor_rollout_ref.synchronizer = config.synchronizer
self.actor_rollout_ref.explorer_name = config.explorer.name
algorithm = ALGORITHM_TYPE.get(config.algorithm.algorithm_type)
self.critic.enable = algorithm.use_critic
self.critic.nccl_timeout = config.synchronizer.sync_timeout
self.critic.ray_namespace = config.synchronizer.ray_namespace

# Actor / Rollout Config
Expand Down Expand Up @@ -507,6 +551,8 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
set_if_none(self.critic, "strategy", config.trainer.trainer_strategy)
self.critic.model.path = config.model.critic_model_path
self.critic.model.tokenizer_path = config.model.critic_model_path
self.critic.model.rope_scaling = config.model.rope_scaling
self.critic.model.rope_theta = config.model.rope_theta
self.critic.ppo_mini_batch_size = config.buffer.train_batch_size
self.critic.rollout_n = self.actor_rollout_ref.rollout.n
self.critic.optim.total_training_steps = self.trainer.total_training_steps
Expand Down Expand Up @@ -542,11 +588,12 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901

# LoRA related config
if config.model.lora_configs is not None:
self.actor_rollout_ref.model.lora_rank = config.model.lora_configs[0].lora_rank
self.actor_rollout_ref.model.lora_alpha = config.model.lora_configs[0].lora_alpha
self.actor_rollout_ref.model.target_modules = config.model.lora_configs[
0
].target_modules
lora_config = config.model.lora_configs[0]
actor_model_config = self.actor_rollout_ref.model
for attr in ["lora_rank", "lora_alpha", "target_modules", "exclude_modules"]:
setattr(actor_model_config, attr, getattr(lora_config, attr))
if not lora_config.is_dummy:
actor_model_config.lora_adapter_path = lora_config.path
if self.actor_rollout_ref.actor.strategy not in ["fsdp", "fsdp2"]:
logger.warning(
f"Lora is only supported for fsdp and fsdp2, but got {self.actor_rollout_ref.actor.strategy} instead, changed to fsdp."
Expand All @@ -565,6 +612,17 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
setattr(self.actor_rollout_ref.actor.optim, "optimizer", field_value)
elif hasattr(self.actor_rollout_ref.actor.optim, field_name):
setattr(self.actor_rollout_ref.actor.optim, field_name, field_value)
# fix optimizer type for fsdp
if config.trainer.trainer_strategy.startswith("fsdp"):
optim_map = {
"adam": "AdamW",
"adamw": "AdamW",
"sgd": "SGD",
}
actor_optim = self.actor_rollout_ref.actor.optim
actor_optim.optimizer = optim_map.get(actor_optim.optimizer, actor_optim.optimizer)
critic_optim = self.critic.optim
critic_optim.optimizer = optim_map.get(critic_optim.optimizer, critic_optim.optimizer)
self.actor_rollout_ref.actor.use_kl_loss = config.algorithm.kl_loss_fn != "none"
self.algorithm.use_kl_in_reward = config.algorithm.kl_penalty_fn != "none"
# TODO (yanxi): it seems that adv_estimator now only affects whether use_critic is set to
Expand Down
8 changes: 4 additions & 4 deletions trinity/manager/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def _expert_verl_training_part(self):
def _expert_verl_actor_part(self):
st.subheader("Actor Model Config")

self.get_configs("actor_lr", "actor_warmup_style", "actor_lr_warmup_steps_ratio")
self.get_configs("actor_lr", "actor_lr_scheduler_type", "actor_lr_warmup_steps_ratio")

self.get_configs("actor_grad_clip", "actor_ulysses_sequence_parallel_size")

Expand All @@ -324,7 +324,7 @@ def _expert_verl_critic_part(self):
"critic_ppo_micro_batch_size_per_gpu", "critic_ulysses_sequence_parallel_size"
)

self.get_configs("critic_lr", "critic_warmup_style", "critic_lr_warmup_steps_ratio")
self.get_configs("critic_lr", "critic_lr_scheduler_type", "critic_lr_warmup_steps_ratio")

self.get_configs("critic_grad_clip", "critic_cliprange_value")
self.get_configs("critic_load_checkpoint", "critic_save_checkpoint")
Expand Down Expand Up @@ -490,7 +490,7 @@ def _generate_verl_config(self):
"optim": {
"lr": st.session_state["critic_lr"],
"lr_warmup_steps_ratio": st.session_state["critic_lr_warmup_steps_ratio"],
"warmup_style": st.session_state["critic_warmup_style"],
"lr_scheduler_type": st.session_state["critic_lr_scheduler_type"],
},
"model": {
"override_config": {},
Expand Down Expand Up @@ -550,7 +550,7 @@ def _gen_algorithm_config(self):
optimizer_config = {
"lr": st.session_state["actor_lr"],
"lr_warmup_steps_ratio": st.session_state["actor_lr_warmup_steps_ratio"],
"warmup_style": st.session_state["actor_warmup_style"],
"lr_scheduler_type": st.session_state["actor_lr_scheduler_type"],
}
algorithm_config["optimizer"] = optimizer_config
return algorithm_config
Expand Down
Loading