Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmark/config/countdown-template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ algorithm:
optimizer:
lr: 1e-06
lr_warmup_steps_ratio: 0.0
warmup_style: constant
lr_scheduler_type: constant
advantage_fn: ppo
data_processor: {}
model:
Expand Down Expand Up @@ -78,7 +78,7 @@ trainer:
optim:
lr: 1e-5
lr_warmup_steps_ratio: 0.0
warmup_style: constant
lr_scheduler_type: constant
ppo_max_token_len_per_gpu: 12800
forward_max_token_len_per_gpu: 12800
cliprange_value: 0.5
Expand Down
2 changes: 1 addition & 1 deletion benchmark/config/gsm8k-template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ algorithm:
optimizer:
lr: 1e-5
lr_warmup_steps_ratio: 0.0
warmup_style: constant
lr_scheduler_type: constant
sample_strategy: default
policy_loss_fn: ppo
advantage_fn: grpo
Expand Down
2 changes: 1 addition & 1 deletion benchmark/config/guru_math-template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ algorithm:
lr: 1e-6
weight_decay: 0.1
lr_warmup_steps: 80
warmup_style: constant
lr_scheduler_type: constant
cluster:
node_num: 1
gpu_per_node: 8
Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ algorithm:
- `optimizer`: Optimizer configuration for actor.
- `lr`: Learning rate for actor.
- `warmup_style`: Deprecated, use `lr_scheduler_type` instead. We will remove this field in future versions.
- `lr_scheduler_type`: Learning rate scheduler type for actor model. Default is `constant`. Supported types: `constant`, `consine`.
- `lr_scheduler_type`: Learning rate scheduler type for actor model. Default is `constant`. Supported types: `constant`, `cosine`.
- `sample_strategy`: The sampling strategy used for loading experiences from experience buffer. Supported types: `default`, `staleness_control`, `mix`.
- `advantage_fn`: The advantage function used for computing advantages.
- `kl_penalty_fn`: The KL penalty function used for computing KL penalty applied in reward.
Expand Down
2 changes: 1 addition & 1 deletion examples/dpo_human_in_the_loop/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ algorithm:
lr: 5e-7
lr_warmup_steps_ratio: 0.03 # the total steps will be injected during runtime
min_lr_ratio: 0.1 # only useful for warmup with cosine
warmup_style: cosine # select from constant/cosine
lr_scheduler_type: cosine # select from constant/cosine
betas: [0.9, 0.95]
kl_loss_fn: k1
kl_loss_fn_args:
Expand Down
2 changes: 1 addition & 1 deletion examples/dpo_humanlike/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ algorithm:
lr: 5e-7
lr_warmup_steps_ratio: 0.03 # the total steps will be injected during runtime
min_lr_ratio: 0.1 # only useful for warmup with cosine
warmup_style: cosine # select from constant/cosine
lr_scheduler_type: cosine # select from constant/cosine
betas: [0.9, 0.95]
kl_loss_fn: k1
kl_loss_fn_args:
Expand Down
2 changes: 1 addition & 1 deletion examples/grpo_alfworld/alfworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ trainer:
# optimizer:
# lr: 5e-6
# lr_warmup_steps_ratio: 0.0
# warmup_style: constant
# lr_scheduler_type: constant
# buffer:
# total_epochs: 1
# train_batch_size: 32
Expand Down
2 changes: 1 addition & 1 deletion examples/learn_to_ask/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ algorithm:
optimizer:
lr: 5.0e-07
lr_warmup_steps_ratio: 0.0
warmup_style: constant
lr_scheduler_type: constant
data_processor: {}
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
Expand Down
2 changes: 1 addition & 1 deletion examples/tinker/tinker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ algorithm:
optimizer:
lr: 1.0e-05
lr_warmup_steps_ratio: 0.0
warmup_style: constant
lr_scheduler_type: constant
data_processor: {}
model:
model_path: Qwen/Qwen3-4B-Instruct-2507
Expand Down
2 changes: 1 addition & 1 deletion scripts/context_length_test/context_length.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ algorithm:
optimizer:
lr: 1.0e-05
lr_warmup_steps_ratio: 0.0
warmup_style: constant
lr_scheduler_type: constant
data_processor: {}
model:
model_path: ${oc.env:MODEL_PATH,Qwen/Qwen3-0.6B}
Expand Down
25 changes: 18 additions & 7 deletions tests/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import shutil
import unittest

import torch

from tests.tools import get_template_config, get_unittest_dataset_config
from trinity.common.config import InferenceModelConfig, load_config

Expand Down Expand Up @@ -143,10 +145,9 @@ def test_optimizer_config_propagation(self):
config.algorithm.optimizer.lr = 1e-4
config.algorithm.optimizer.weight_decay = 0.05
config.algorithm.optimizer.clip_grad = 2.0
config.algorithm.optimizer.lr_decay_steps = 1000
config.algorithm.optimizer.lr_decay_style = "cosine"
config.algorithm.optimizer.lr_warmup_init = 1e-7
config.algorithm.optimizer.min_lr = 1e-6
config.trainer.total_steps = 1000
config.algorithm.optimizer.lr_scheduler_type = "cosine"
config.algorithm.optimizer.min_lr_ratio = 1e-2
config.check_and_update()
self.assertEqual(config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr, 1e-4)
self.assertEqual(
Expand All @@ -159,10 +160,20 @@ def test_optimizer_config_propagation(self):
self.assertEqual(
config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr_decay_style, "cosine"
)
self.assertEqual(
config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr_warmup_init, 1e-7
self.assertTrue(
torch.allclose(
torch.tensor(
config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr_warmup_init
),
torch.tensor(1e-6),
)
)
self.assertTrue(
torch.allclose(
torch.tensor(config.trainer.trainer_config.actor_rollout_ref.actor.optim.min_lr),
torch.tensor(1e-6),
)
)
self.assertEqual(config.trainer.trainer_config.actor_rollout_ref.actor.optim.min_lr, 1e-6)
# critic optimizer should not be affected
self.assertEqual(config.trainer.trainer_config.critic.optim.lr, 1e-5)
self.assertEqual(config.trainer.trainer_config.critic.optim.weight_decay, 0.01)
Expand Down
3 changes: 3 additions & 0 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,9 @@ def __init__(self, config: Config):
self.config = config
self.synchronizer = Synchronizer.get_actor(config)

async def is_alive(self):
return True

fake_trainer = FakeTrainer.remote(self.config)
await fake_trainer.__ray_ready__.remote()
await super()._setup_engines()
Expand Down
2 changes: 1 addition & 1 deletion tests/template/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ algorithm:
optimizer:
lr: 1e-6
lr_warmup_steps_ratio: 0.
warmup_style: constant # select from constant/cosine
lr_scheduler_type: constant # select from constant/cosine
policy_loss_fn: ppo
policy_loss_fn_args:
clip_range: 0.2
Expand Down
44 changes: 42 additions & 2 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import json
import math
import multiprocessing
import os
import shutil
Expand Down Expand Up @@ -48,6 +49,8 @@
from trinity.common.models.utils import get_checkpoint_dir_with_step_num
from trinity.explorer.proxy.client import TrinityClient
from trinity.manager.state_manager import StateManager
from trinity.manager.synchronizer import Synchronizer
from trinity.trainer.tinker_trainer import TinkerTrainerWrapper


class BaseTrainerCase(RayUnittestBase):
Expand Down Expand Up @@ -1434,8 +1437,8 @@ def tearDown(self):
shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True)


@unittest.skipIf("TINKER_API_KEY" not in os.environ, "TINKER_API_KEY is not set")
class TestTinkerTrainer(BaseTrainerCase):
@unittest.skipIf("TINKER_API_KEY" not in os.environ, "TINKER_API_KEY is not set")
def test_trainer(self):
"""Test GSM8K on tinker."""
# test both mode
Expand All @@ -1448,7 +1451,7 @@ def test_trainer(self):
self.config.buffer.total_epochs = 1
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
self.config.model.tinker.enable = True
self.config.model.tinker.base_model = "Qwen/Qwen3-4B-Instruct-2507"
self.config.model.model_path = "Qwen/Qwen3-4B-Instruct-2507"
self.config.check_and_update()
both(self.config)
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
Expand All @@ -1464,6 +1467,43 @@ def test_trainer(self):
self.assertGreater(len(response_metrics), 0)
self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)

def test_trainer_class(self):
total_steps = 100
lr_warmup_steps = 10
self.config.algorithm.algorithm_type = "grpo"
self.config.model.tinker.enable = True
self.config.model.model_path = "Qwen/Qwen3-4B-Instruct-2507"
self.config.trainer.total_steps = total_steps
self.config.algorithm.optimizer.lr_warmup_steps = lr_warmup_steps
self.config.algorithm.optimizer.lr_scheduler_type = "cosine"
self.config.check_and_update()
lr = self.config.algorithm.optimizer.lr

@ray.remote
class FakeExplorer:
def __init__(self, config: Config):
self.config = config
self.synchronizer = Synchronizer.get_actor(config)

async def is_alive(self):
return True

fake_explorer = FakeExplorer.remote(self.config)
ray.get(fake_explorer.__ray_ready__.remote())

tinker_trainer = TinkerTrainerWrapper(self.config)
tinker_trainer._train_step_num = 5
self.assertEqual(tinker_trainer.current_learning_rate, lr * 0.5)
tinker_trainer._train_step_num = 50
self.assertEqual(
tinker_trainer.current_learning_rate,
lr
* (
0.5
* (1 + math.cos((50 - lr_warmup_steps) / (total_steps - lr_warmup_steps) * math.pi))
),
)

def tearDown(self):
# remove dir only when the test passed
shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True)
Expand Down
6 changes: 1 addition & 5 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,13 @@ class OptimizerConfig:
lr: float = 1e-6
lr_warmup_steps: int = -1
lr_warmup_steps_ratio: float = 0.0
min_lr_ratio: Optional[float] = 0.0
min_lr_ratio: float = 0.0
warmup_style: Optional[str] = None # deprecated !
lr_scheduler_type: str = "constant"
optimizer_type: str = "adam"
betas: List[float] = field(default_factory=lambda: [0.9, 0.999])
weight_decay: float = 0.01
clip_grad: float = 1.0
lr_warmup_init: float = 0.0
lr_decay_steps: Optional[int] = None
lr_decay_style: str = "constant" # duplicated with lr_scheduler_type in veRL
min_lr: float = 0.0


@dataclass
Expand Down
3 changes: 2 additions & 1 deletion trinity/common/models/tinker_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]:
if with_chat_completion:
create_time = int(time.time())
output = await self._generate_internal(prompt={"prompt_token_ids": token_ids}, **kwargs)
return_logprobs = kwargs.get("logprobs", self.config.logprobs is not None)
logprobs = kwargs.get("logprobs", self.config.logprobs)
return_logprobs = logprobs is not None and logprobs is not False
experiences = [
Experience(
tokens=torch.tensor(token_ids + sequence.tokens, dtype=torch.int32),
Expand Down
31 changes: 21 additions & 10 deletions trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,15 @@ class Optim:
lr_warmup_steps: int = -1
lr_warmup_steps_ratio: float = 0.0
min_lr_ratio: Optional[float] = 0.0
warmup_style: Optional[str] = None # deprecated !
lr_scheduler_type: str = "constant"
total_training_steps: int = -1 # ! DO NOT SET, use trainer.total_steps
betas: List[float] = field(default_factory=lambda: [0.9, 0.999])
clip_grad: float = 1.0
lr_warmup_init: float = 0.0
lr_warmup_init: Optional[float] = None # 0.0
lr_decay_steps: Optional[int] = None
lr_decay_style: str = "constant"
min_lr: float = 0.0
lr_decay_style: Optional[str] = None # "constant"
min_lr: Optional[float] = None # 0.0
weight_decay: float = 0.01
weight_decay_incr_style: str = "constant"
lr_wsd_decay_style: str = "exponential"
Expand Down Expand Up @@ -606,22 +607,32 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
self.critic.strategy = "fsdp"

# Algorithm related config
for field_name in config.algorithm.optimizer.__dataclass_fields__:
field_value = getattr(config.algorithm.optimizer, field_name)
actor_optim = self.actor_rollout_ref.actor.optim
critic_optim = self.critic.optim
optim_config = config.algorithm.optimizer
for field_name in optim_config.__dataclass_fields__:
field_value = getattr(optim_config, field_name)
if field_name == "optimizer_type":
setattr(self.actor_rollout_ref.actor.optim, "optimizer", field_value)
elif hasattr(self.actor_rollout_ref.actor.optim, field_name):
setattr(self.actor_rollout_ref.actor.optim, field_name, field_value)
setattr(actor_optim, "optimizer", field_value)
elif hasattr(actor_optim, field_name):
setattr(actor_optim, field_name, field_value)
# ensure megatron optimizer config compatibility
set_if_none(actor_optim, "lr_warmup_init", optim_config.min_lr_ratio * optim_config.lr)
set_if_none(actor_optim, "lr_decay_steps", self.trainer.total_training_steps)
set_if_none(actor_optim, "lr_decay_style", optim_config.lr_scheduler_type)
set_if_none(actor_optim, "min_lr", optim_config.min_lr_ratio * optim_config.lr)
set_if_none(critic_optim, "lr_warmup_init", 0.0)
set_if_none(critic_optim, "lr_decay_steps", self.trainer.total_training_steps)
set_if_none(critic_optim, "lr_decay_style", "constant")
set_if_none(critic_optim, "min_lr", 0.0)
# fix optimizer type for fsdp
if config.trainer.trainer_strategy.startswith("fsdp"):
optim_map = {
"adam": "AdamW",
"adamw": "AdamW",
"sgd": "SGD",
}
actor_optim = self.actor_rollout_ref.actor.optim
actor_optim.optimizer = optim_map.get(actor_optim.optimizer, actor_optim.optimizer)
critic_optim = self.critic.optim
critic_optim.optimizer = optim_map.get(critic_optim.optimizer, critic_optim.optimizer)
self.actor_rollout_ref.actor.use_kl_loss = config.algorithm.kl_loss_fn != "none"
self.algorithm.use_kl_in_reward = config.algorithm.kl_penalty_fn != "none"
Expand Down
Loading