diff --git a/benchmark/config/countdown-template.yaml b/benchmark/config/countdown-template.yaml index d81cfb6759..1003fb1f77 100644 --- a/benchmark/config/countdown-template.yaml +++ b/benchmark/config/countdown-template.yaml @@ -9,7 +9,7 @@ algorithm: optimizer: lr: 1e-06 lr_warmup_steps_ratio: 0.0 - warmup_style: constant + lr_scheduler_type: constant advantage_fn: ppo data_processor: {} model: @@ -78,7 +78,7 @@ trainer: optim: lr: 1e-5 lr_warmup_steps_ratio: 0.0 - warmup_style: constant + lr_scheduler_type: constant ppo_max_token_len_per_gpu: 12800 forward_max_token_len_per_gpu: 12800 cliprange_value: 0.5 diff --git a/benchmark/config/gsm8k-template.yaml b/benchmark/config/gsm8k-template.yaml index 35db1831e4..92a1c8b5ff 100644 --- a/benchmark/config/gsm8k-template.yaml +++ b/benchmark/config/gsm8k-template.yaml @@ -9,7 +9,7 @@ algorithm: optimizer: lr: 1e-5 lr_warmup_steps_ratio: 0.0 - warmup_style: constant + lr_scheduler_type: constant sample_strategy: default policy_loss_fn: ppo advantage_fn: grpo diff --git a/benchmark/config/guru_math-template.yaml b/benchmark/config/guru_math-template.yaml index b0b32ff164..3abccf78fc 100644 --- a/benchmark/config/guru_math-template.yaml +++ b/benchmark/config/guru_math-template.yaml @@ -16,7 +16,7 @@ algorithm: lr: 1e-6 weight_decay: 0.1 lr_warmup_steps: 80 - warmup_style: constant + lr_scheduler_type: constant cluster: node_num: 1 gpu_per_node: 8 diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index c0e16d0645..8f7f1d813a 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -112,7 +112,7 @@ algorithm: - `optimizer`: Optimizer configuration for actor. - `lr`: Learning rate for actor. - `warmup_style`: Deprecated, use `lr_scheduler_type` instead. We will remove this field in future versions. - - `lr_scheduler_type`: Learning rate scheduler type for actor model. Default is `constant`. Supported types: `constant`, `consine`. + - `lr_scheduler_type`: Learning rate scheduler type for actor model. Default is `constant`. Supported types: `constant`, `cosine`. - `sample_strategy`: The sampling strategy used for loading experiences from experience buffer. Supported types: `default`, `staleness_control`, `mix`. - `advantage_fn`: The advantage function used for computing advantages. - `kl_penalty_fn`: The KL penalty function used for computing KL penalty applied in reward. diff --git a/examples/dpo_human_in_the_loop/dpo.yaml b/examples/dpo_human_in_the_loop/dpo.yaml index f13dfc539e..3893913e4c 100644 --- a/examples/dpo_human_in_the_loop/dpo.yaml +++ b/examples/dpo_human_in_the_loop/dpo.yaml @@ -42,7 +42,7 @@ algorithm: lr: 5e-7 lr_warmup_steps_ratio: 0.03 # the total steps will be injected during runtime min_lr_ratio: 0.1 # only useful for warmup with cosine - warmup_style: cosine # select from constant/cosine + lr_scheduler_type: cosine # select from constant/cosine betas: [0.9, 0.95] kl_loss_fn: k1 kl_loss_fn_args: diff --git a/examples/dpo_humanlike/dpo.yaml b/examples/dpo_humanlike/dpo.yaml index f9948becca..a05785780a 100644 --- a/examples/dpo_humanlike/dpo.yaml +++ b/examples/dpo_humanlike/dpo.yaml @@ -7,7 +7,7 @@ algorithm: lr: 5e-7 lr_warmup_steps_ratio: 0.03 # the total steps will be injected during runtime min_lr_ratio: 0.1 # only useful for warmup with cosine - warmup_style: cosine # select from constant/cosine + lr_scheduler_type: cosine # select from constant/cosine betas: [0.9, 0.95] kl_loss_fn: k1 kl_loss_fn_args: diff --git a/examples/grpo_alfworld/alfworld.yaml b/examples/grpo_alfworld/alfworld.yaml index dcfd21d3db..259742a09e 100644 --- a/examples/grpo_alfworld/alfworld.yaml +++ b/examples/grpo_alfworld/alfworld.yaml @@ -65,7 +65,7 @@ trainer: # optimizer: # lr: 5e-6 # lr_warmup_steps_ratio: 0.0 -# warmup_style: constant +# lr_scheduler_type: constant # buffer: # total_epochs: 1 # train_batch_size: 32 diff --git a/examples/learn_to_ask/train.yaml b/examples/learn_to_ask/train.yaml index 9aee01b4a9..1e4d3972df 100644 --- a/examples/learn_to_ask/train.yaml +++ b/examples/learn_to_ask/train.yaml @@ -13,7 +13,7 @@ algorithm: optimizer: lr: 5.0e-07 lr_warmup_steps_ratio: 0.0 - warmup_style: constant + lr_scheduler_type: constant data_processor: {} model: model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct} diff --git a/examples/tinker/tinker.yaml b/examples/tinker/tinker.yaml index 93812e9f79..d6edbad669 100644 --- a/examples/tinker/tinker.yaml +++ b/examples/tinker/tinker.yaml @@ -11,7 +11,7 @@ algorithm: optimizer: lr: 1.0e-05 lr_warmup_steps_ratio: 0.0 - warmup_style: constant + lr_scheduler_type: constant data_processor: {} model: model_path: Qwen/Qwen3-4B-Instruct-2507 diff --git a/scripts/context_length_test/context_length.yaml b/scripts/context_length_test/context_length.yaml index 27611dc5e6..dd209bba29 100644 --- a/scripts/context_length_test/context_length.yaml +++ b/scripts/context_length_test/context_length.yaml @@ -16,7 +16,7 @@ algorithm: optimizer: lr: 1.0e-05 lr_warmup_steps_ratio: 0.0 - warmup_style: constant + lr_scheduler_type: constant data_processor: {} model: model_path: ${oc.env:MODEL_PATH,Qwen/Qwen3-0.6B} diff --git a/tests/common/config_test.py b/tests/common/config_test.py index 4abe8b7774..d31e09cf42 100644 --- a/tests/common/config_test.py +++ b/tests/common/config_test.py @@ -6,6 +6,8 @@ import shutil import unittest +import torch + from tests.tools import get_template_config, get_unittest_dataset_config from trinity.common.config import InferenceModelConfig, load_config @@ -143,10 +145,9 @@ def test_optimizer_config_propagation(self): config.algorithm.optimizer.lr = 1e-4 config.algorithm.optimizer.weight_decay = 0.05 config.algorithm.optimizer.clip_grad = 2.0 - config.algorithm.optimizer.lr_decay_steps = 1000 - config.algorithm.optimizer.lr_decay_style = "cosine" - config.algorithm.optimizer.lr_warmup_init = 1e-7 - config.algorithm.optimizer.min_lr = 1e-6 + config.trainer.total_steps = 1000 + config.algorithm.optimizer.lr_scheduler_type = "cosine" + config.algorithm.optimizer.min_lr_ratio = 1e-2 config.check_and_update() self.assertEqual(config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr, 1e-4) self.assertEqual( @@ -159,10 +160,20 @@ def test_optimizer_config_propagation(self): self.assertEqual( config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr_decay_style, "cosine" ) - self.assertEqual( - config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr_warmup_init, 1e-7 + self.assertTrue( + torch.allclose( + torch.tensor( + config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr_warmup_init + ), + torch.tensor(1e-6), + ) + ) + self.assertTrue( + torch.allclose( + torch.tensor(config.trainer.trainer_config.actor_rollout_ref.actor.optim.min_lr), + torch.tensor(1e-6), + ) ) - self.assertEqual(config.trainer.trainer_config.actor_rollout_ref.actor.optim.min_lr, 1e-6) # critic optimizer should not be affected self.assertEqual(config.trainer.trainer_config.critic.optim.lr, 1e-5) self.assertEqual(config.trainer.trainer_config.critic.optim.weight_decay, 0.01) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 3a40988a07..9ffe3e7c3c 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -798,6 +798,9 @@ def __init__(self, config: Config): self.config = config self.synchronizer = Synchronizer.get_actor(config) + async def is_alive(self): + return True + fake_trainer = FakeTrainer.remote(self.config) await fake_trainer.__ray_ready__.remote() await super()._setup_engines() diff --git a/tests/template/config.yaml b/tests/template/config.yaml index 0bf1c2eaf8..deb4bfd9b1 100644 --- a/tests/template/config.yaml +++ b/tests/template/config.yaml @@ -8,7 +8,7 @@ algorithm: optimizer: lr: 1e-6 lr_warmup_steps_ratio: 0. - warmup_style: constant # select from constant/cosine + lr_scheduler_type: constant # select from constant/cosine policy_loss_fn: ppo policy_loss_fn_args: clip_range: 0.2 diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 136d95c366..99eb711975 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -2,6 +2,7 @@ import asyncio import json +import math import multiprocessing import os import shutil @@ -48,6 +49,8 @@ from trinity.common.models.utils import get_checkpoint_dir_with_step_num from trinity.explorer.proxy.client import TrinityClient from trinity.manager.state_manager import StateManager +from trinity.manager.synchronizer import Synchronizer +from trinity.trainer.tinker_trainer import TinkerTrainerWrapper class BaseTrainerCase(RayUnittestBase): @@ -1434,8 +1437,8 @@ def tearDown(self): shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True) +@unittest.skipIf("TINKER_API_KEY" not in os.environ, "TINKER_API_KEY is not set") class TestTinkerTrainer(BaseTrainerCase): - @unittest.skipIf("TINKER_API_KEY" not in os.environ, "TINKER_API_KEY is not set") def test_trainer(self): """Test GSM8K on tinker.""" # test both mode @@ -1448,7 +1451,7 @@ def test_trainer(self): self.config.buffer.total_epochs = 1 self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") self.config.model.tinker.enable = True - self.config.model.tinker.base_model = "Qwen/Qwen3-4B-Instruct-2507" + self.config.model.model_path = "Qwen/Qwen3-4B-Instruct-2507" self.config.check_and_update() both(self.config) parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) @@ -1464,6 +1467,43 @@ def test_trainer(self): self.assertGreater(len(response_metrics), 0) self.assertEqual(parser.metric_max_step(response_metrics[0]), 4) + def test_trainer_class(self): + total_steps = 100 + lr_warmup_steps = 10 + self.config.algorithm.algorithm_type = "grpo" + self.config.model.tinker.enable = True + self.config.model.model_path = "Qwen/Qwen3-4B-Instruct-2507" + self.config.trainer.total_steps = total_steps + self.config.algorithm.optimizer.lr_warmup_steps = lr_warmup_steps + self.config.algorithm.optimizer.lr_scheduler_type = "cosine" + self.config.check_and_update() + lr = self.config.algorithm.optimizer.lr + + @ray.remote + class FakeExplorer: + def __init__(self, config: Config): + self.config = config + self.synchronizer = Synchronizer.get_actor(config) + + async def is_alive(self): + return True + + fake_explorer = FakeExplorer.remote(self.config) + ray.get(fake_explorer.__ray_ready__.remote()) + + tinker_trainer = TinkerTrainerWrapper(self.config) + tinker_trainer._train_step_num = 5 + self.assertEqual(tinker_trainer.current_learning_rate, lr * 0.5) + tinker_trainer._train_step_num = 50 + self.assertEqual( + tinker_trainer.current_learning_rate, + lr + * ( + 0.5 + * (1 + math.cos((50 - lr_warmup_steps) / (total_steps - lr_warmup_steps) * math.pi)) + ), + ) + def tearDown(self): # remove dir only when the test passed shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True) diff --git a/trinity/common/config.py b/trinity/common/config.py index ef7fcd8611..581e07d6a7 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -93,17 +93,13 @@ class OptimizerConfig: lr: float = 1e-6 lr_warmup_steps: int = -1 lr_warmup_steps_ratio: float = 0.0 - min_lr_ratio: Optional[float] = 0.0 + min_lr_ratio: float = 0.0 warmup_style: Optional[str] = None # deprecated ! lr_scheduler_type: str = "constant" optimizer_type: str = "adam" betas: List[float] = field(default_factory=lambda: [0.9, 0.999]) weight_decay: float = 0.01 clip_grad: float = 1.0 - lr_warmup_init: float = 0.0 - lr_decay_steps: Optional[int] = None - lr_decay_style: str = "constant" # duplicated with lr_scheduler_type in veRL - min_lr: float = 0.0 @dataclass diff --git a/trinity/common/models/tinker_model.py b/trinity/common/models/tinker_model.py index 0be19b855f..631795bf74 100644 --- a/trinity/common/models/tinker_model.py +++ b/trinity/common/models/tinker_model.py @@ -95,7 +95,8 @@ async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]: if with_chat_completion: create_time = int(time.time()) output = await self._generate_internal(prompt={"prompt_token_ids": token_ids}, **kwargs) - return_logprobs = kwargs.get("logprobs", self.config.logprobs is not None) + logprobs = kwargs.get("logprobs", self.config.logprobs) + return_logprobs = logprobs is not None and logprobs is not False experiences = [ Experience( tokens=torch.tensor(token_ids + sequence.tokens, dtype=torch.int32), diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 02a38cddae..2d823390cc 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -61,14 +61,15 @@ class Optim: lr_warmup_steps: int = -1 lr_warmup_steps_ratio: float = 0.0 min_lr_ratio: Optional[float] = 0.0 + warmup_style: Optional[str] = None # deprecated ! lr_scheduler_type: str = "constant" total_training_steps: int = -1 # ! DO NOT SET, use trainer.total_steps betas: List[float] = field(default_factory=lambda: [0.9, 0.999]) clip_grad: float = 1.0 - lr_warmup_init: float = 0.0 + lr_warmup_init: Optional[float] = None # 0.0 lr_decay_steps: Optional[int] = None - lr_decay_style: str = "constant" - min_lr: float = 0.0 + lr_decay_style: Optional[str] = None # "constant" + min_lr: Optional[float] = None # 0.0 weight_decay: float = 0.01 weight_decay_incr_style: str = "constant" lr_wsd_decay_style: str = "exponential" @@ -606,12 +607,24 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 self.critic.strategy = "fsdp" # Algorithm related config - for field_name in config.algorithm.optimizer.__dataclass_fields__: - field_value = getattr(config.algorithm.optimizer, field_name) + actor_optim = self.actor_rollout_ref.actor.optim + critic_optim = self.critic.optim + optim_config = config.algorithm.optimizer + for field_name in optim_config.__dataclass_fields__: + field_value = getattr(optim_config, field_name) if field_name == "optimizer_type": - setattr(self.actor_rollout_ref.actor.optim, "optimizer", field_value) - elif hasattr(self.actor_rollout_ref.actor.optim, field_name): - setattr(self.actor_rollout_ref.actor.optim, field_name, field_value) + setattr(actor_optim, "optimizer", field_value) + elif hasattr(actor_optim, field_name): + setattr(actor_optim, field_name, field_value) + # ensure megatron optimizer config compatibility + set_if_none(actor_optim, "lr_warmup_init", optim_config.min_lr_ratio * optim_config.lr) + set_if_none(actor_optim, "lr_decay_steps", self.trainer.total_training_steps) + set_if_none(actor_optim, "lr_decay_style", optim_config.lr_scheduler_type) + set_if_none(actor_optim, "min_lr", optim_config.min_lr_ratio * optim_config.lr) + set_if_none(critic_optim, "lr_warmup_init", 0.0) + set_if_none(critic_optim, "lr_decay_steps", self.trainer.total_training_steps) + set_if_none(critic_optim, "lr_decay_style", "constant") + set_if_none(critic_optim, "min_lr", 0.0) # fix optimizer type for fsdp if config.trainer.trainer_strategy.startswith("fsdp"): optim_map = { @@ -619,9 +632,7 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 "adamw": "AdamW", "sgd": "SGD", } - actor_optim = self.actor_rollout_ref.actor.optim actor_optim.optimizer = optim_map.get(actor_optim.optimizer, actor_optim.optimizer) - critic_optim = self.critic.optim critic_optim.optimizer = optim_map.get(critic_optim.optimizer, critic_optim.optimizer) self.actor_rollout_ref.actor.use_kl_loss = config.algorithm.kl_loss_fn != "none" self.algorithm.use_kl_in_reward = config.algorithm.kl_penalty_fn != "none" diff --git a/trinity/trainer/tinker_trainer.py b/trinity/trainer/tinker_trainer.py index 1265348491..c58c953045 100644 --- a/trinity/trainer/tinker_trainer.py +++ b/trinity/trainer/tinker_trainer.py @@ -1,4 +1,6 @@ +import math import os +import sys from typing import Dict, List import ray @@ -36,7 +38,7 @@ def __init__(self, config: Config): def _init_algorithm(self): self.algorithm = ALGORITHM_TYPE.get(self.config.algorithm.algorithm_type) - algorithm_config = self.config.algorithm + self.algorithm_config = algorithm_config = self.config.algorithm if self.algorithm.compute_advantage_in_trainer: self.advantage_fn = ADVANTAGE_FN.get(algorithm_config.advantage_fn)( **algorithm_config.advantage_fn_args @@ -63,12 +65,60 @@ def _init_algorithm(self): and (self.loss_agg_mode == "token-mean") ) - self.adam_params = types.AdamParams( - learning_rate=algorithm_config.optimizer.lr, - beta1=algorithm_config.optimizer.betas[0], - beta2=algorithm_config.optimizer.betas[1], + self.lr_scheduler_type = algorithm_config.optimizer.lr_scheduler_type + self.total_steps = self.config.trainer.total_steps or sys.maxsize + self.num_warmup_steps = algorithm_config.optimizer.lr_warmup_steps + if self.num_warmup_steps < 0: + self.num_warmup_steps = int( + algorithm_config.optimizer.lr_warmup_steps_ratio * self.total_steps + ) + self.min_lr_ratio = algorithm_config.optimizer.min_lr_ratio + assert 0.0 <= self.min_lr_ratio <= 1.0 + self.logger.info( + f"Total steps: {self.total_steps}, num_warmup_steps: {self.num_warmup_steps}" + ) + + if self.lr_scheduler_type not in {"constant", "cosine"}: + raise NotImplementedError( + f"LR scheduler type {self.lr_scheduler_type} is not supported" + ) + + @property + def _current_lr_factor(self): + train_step_num = self._train_step_num + # warmup + if train_step_num < self.num_warmup_steps: + factor = float(train_step_num) / float(max(1.0, self.num_warmup_steps)) + factor = self.min_lr_ratio + (1.0 - self.min_lr_ratio) * factor + return factor + + # decay + if train_step_num >= self.total_steps: + progress = 1.0 + else: + progress = float(train_step_num - self.num_warmup_steps) / float( + max(1.0, self.total_steps - self.num_warmup_steps) + ) + if self.lr_scheduler_type == "constant": + factor = 1.0 + elif self.lr_scheduler_type == "cosine": + num_cycles = 0.5 # TODO: may add to config + factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) + factor = self.min_lr_ratio + (1.0 - self.min_lr_ratio) * factor + return max(self.min_lr_ratio, factor) + + @property + def current_learning_rate(self): + return self._current_lr_factor * self.algorithm_config.optimizer.lr + + @property + def adam_params(self): + return types.AdamParams( + learning_rate=self.current_learning_rate, + beta1=self.algorithm_config.optimizer.betas[0], + beta2=self.algorithm_config.optimizer.betas[1], # eps is currently not in config - weight_decay=algorithm_config.optimizer.weight_decay, + weight_decay=self.algorithm_config.optimizer.weight_decay, grad_clip_norm=self.config.trainer.grad_clip, )