Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ trainer:
use_dynamic_bsz: true
max_token_len_per_gpu: 16384
ulysses_sequence_parallel_size: 1
max_checkpoints_to_keep: 5
trainer_config: null
```

Expand All @@ -499,6 +500,7 @@ trainer:
- `use_dynamic_bsz`: Whether to use dynamic batch size.
- `max_token_len_per_gpu`: The maximum number of tokens to be processed in forward and backward when updating the policy. Effective when `use_dynamic_bsz=true`.
- `ulysses_sequence_parallel_size`: Sequence parallel size.
- `max_checkpoints_to_keep`: Maximum number of checkpoints to keep. Older checkpoints will be deleted. If not specified, all checkpoints will be kept.
- `trainer_config`: The trainer configuration provided inline.
---

Expand Down
2 changes: 2 additions & 0 deletions docs/sphinx_doc/source_zh/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ trainer:
use_dynamic_bsz: true
max_token_len_per_gpu: 16384
ulysses_sequence_parallel_size: 1
max_checkpoints_to_keep: 5
trainer_config: null
```

Expand All @@ -496,6 +497,7 @@ trainer:
- `use_dynamic_bsz`: 是否使用动态批量大小。
- `max_token_len_per_gpu`: 训练过程中,每个 GPU 最大 token 长度; 当 `use_dynamic_bsz=true` 时生效。
- `ulysses_sequence_parallel_size`: 序列并行的并行度,即用于分割单个序列的 GPU 数量。
- `max_checkpoints_to_keep`: 保留的最大检查点数量。超过此数量后,最旧的检查点将被删除。如果未指定,则将保留所有检查点。
- `trainer_config`: 内联提供的 trainer 配置。

---
Expand Down
47 changes: 40 additions & 7 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,10 @@ def test_trainer(self, mock_load):

# sft warmup stage
sft_config = stage_configs[0]
self.assertEqual(
sft_config.synchronizer.sync_interval,
sft_config.trainer.save_interval,
)
parser = TensorBoardParser(os.path.join(sft_config.monitor.cache_dir, "tensorboard"))
rollout_metrics = parser.metric_list("rollout")
self.assertEqual(len(rollout_metrics), 0)
Expand Down Expand Up @@ -374,11 +378,15 @@ def test_trainer(self, mock_load):
self.assertEqual(parser.metric_min_step(response_metrics[0]), 1)
self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
# test save checkpoint when sft finish
for i in range(3):
self.assertFalse(
os.path.exists(os.path.join(sft_config.checkpoint_job_dir, f"global_step_{i}"))
)
self.assertEqual(
get_checkpoint_dir_with_step_num(
checkpoint_root_path=sft_config.checkpoint_job_dir, trainer_type="verl", step_num=2
checkpoint_root_path=sft_config.checkpoint_job_dir, trainer_type="verl", step_num=3
)[1],
2,
3,
)
# test save checkpoint at last step
checkpoint_dir, step_num = get_checkpoint_dir_with_step_num(
Expand Down Expand Up @@ -749,7 +757,7 @@ def setUp(self):
if multiprocessing.get_start_method(allow_none=True) != "spawn":
multiprocessing.set_start_method("spawn", force=True)
self.config = get_template_config()
self.config.buffer.total_epochs = 1
self.config.buffer.total_steps = 6
self.config.buffer.batch_size = 4
self.config.model.model_path = get_model_path()
self.config.explorer.rollout_model.engine_type = "vllm_async"
Expand All @@ -762,21 +770,20 @@ def setUp(self):
self.config.synchronizer.sync_method = SyncMethod.CHECKPOINT
self.config.explorer.eval_interval = 4
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
self.config.trainer.save_interval = 4
self.config.trainer.save_interval = 2
self.config.trainer.save_hf_checkpoint = "last"
self.config.trainer.trainer_strategy = self.strategy
self.config.trainer.max_checkpoints_to_keep = 2
self.config.check_and_update()
self.process_list = []

def test_trainer(self):
def test_trainer(self): # noqa: C901
"""Test the checkpoint saving."""
_trainer_config = self.config.trainer.trainer_config
if self.strategy == "megatron":
_trainer_config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size = 2
_trainer_config.actor_rollout_ref.ref.megatron.tensor_model_parallel_size = 2
_trainer_config.critic.megatron.tensor_model_parallel_size = 2
_trainer_config.trainer.max_actor_ckpt_to_keep = 2
_trainer_config.trainer.max_critic_ckpt_to_keep = 2

stop_event = multiprocessing.Event()
trainer_process = multiprocessing.Process(target=run_both, args=(self.config, stop_event))
Expand Down Expand Up @@ -839,6 +846,10 @@ def test_trainer(self):
# print(f"State dict check at {state_dict_iteration} iteration passed.") # for debug

if checkpoint_iteration > 0:
flag_file = os.path.join(
default_local_dir, f"global_step_{checkpoint_iteration}", ".full_checkpoint"
)
self.assertTrue(os.path.exists(flag_file))
for sub_dir_name in ["critic", "actor"]:
iteration_dir = os.path.join(
default_local_dir, f"global_step_{checkpoint_iteration}", sub_dir_name
Expand Down Expand Up @@ -882,6 +893,28 @@ def test_trainer(self):
# print(f"Checkpoint check at {checkpoint_iteration} iteration passed.") # for debug
if not stop_event.is_set():
self.fail("Training process failed to stop.")
# check only full checkpoint dirs are kept
for sync_step in [1, 3, 5]:
state_dict_dir = os.path.join(default_local_dir, f"global_step_{sync_step}")
self.assertFalse(
os.path.exists(state_dict_dir),
f"Found unexpected state dict dir at step {sync_step}",
)
for checkpoint_step in [4, 6]:
checkpoint_dir = os.path.join(default_local_dir, f"global_step_{checkpoint_step}")
self.assertTrue(
os.path.exists(checkpoint_dir),
f"Missing expected checkpoint dir at step {checkpoint_step}",
)
actor_checkpoint_dir = os.path.join(checkpoint_dir, "actor")
self.assertTrue(os.path.exists(actor_checkpoint_dir))
# check step 2 should have no checkpoint
checkpoint_dir = os.path.join(default_local_dir, "global_step_2")
self.assertTrue(os.path.exists(checkpoint_dir))
actor_checkpoint_dir = os.path.join(checkpoint_dir, "actor")
self.assertFalse(os.path.exists(actor_checkpoint_dir))
critic_checkpoint_dir = os.path.join(checkpoint_dir, "critic")
self.assertFalse(os.path.exists(critic_checkpoint_dir))
trainer_process.join(timeout=10)
self.assertIn("model.safetensors", huggingface_dir_files)

Expand Down
22 changes: 22 additions & 0 deletions trinity/algorithm/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,27 @@ def default_config(cls) -> Dict:
"entropy_loss_fn": "none",
}

@classmethod
def check_config(cls, config: Config) -> None:
if config.mode == "train":
if (
config.buffer.trainer_input.experience_buffer is None
or not config.buffer.trainer_input.experience_buffer.path
):
raise ValueError(
"`buffer.trainer_input.experience_buffer.path` is required when `algorithm.algorithm_type == sft`"
)
elif config.mode in ["both", "explore"]:
raise ValueError(f"SFT does not support `{config.mode}` mode")

if config.synchronizer.sync_method != SyncMethod.CHECKPOINT:
config.synchronizer.sync_method = SyncMethod.CHECKPOINT
logger.warning(
"SFT only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`."
)

config.synchronizer.sync_interval = config.trainer.save_interval


class PPOAlgorithm(AlgorithmType):
"""PPO Algorithm."""
Expand Down Expand Up @@ -232,6 +253,7 @@ def check_config(cls, config: Config) -> None:
logger.warning(
"DPO only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`."
)
config.synchronizer.sync_interval = config.trainer.save_interval
if config.algorithm.repeat_times != 2:
config.algorithm.repeat_times = 2 # Fake repeat times
if config.algorithm.kl_loss_fn in {"none", None}:
Expand Down
1 change: 1 addition & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,7 @@ class TrainerConfig:
# TODO: extract more train-related params from underlying trainer engine

save_strategy: SaveStrategy = SaveStrategy.UNRESTRICTED
max_checkpoints_to_keep: Optional[int] = None

trainer_config: Any = field(default_factory=dict)
trainer_config_path: str = "" # deprecated, use `trainer_config` instead
Expand Down
4 changes: 4 additions & 0 deletions trinity/common/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ def load_state_dict(checkpoint_dir: str, config: TrainerConfig) -> Union[dict, T
Args:
checkpoint_dir (str): The checkpoint directory.
trainer_type (str): The trainer type. Only support "verl" for now.

Returns:
Union[dict, Tuple[str, str]]: The state dict. If the checkpoint uses
megatron dist checkpointing, return a tuple of (method, checkpoint_dir).
"""
if config.trainer_type == "verl":
strategy = config.trainer_strategy
Expand Down
1 change: 1 addition & 0 deletions trinity/common/models/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def update_weight(self):
if self._weight_update_rank == 0:
state_dict, model_version = ray.get(self.synchronizer.get_model_state_dict.remote())
if isinstance(state_dict, tuple):
# currently only megatron return a tuple
method, checkpoint_dir = state_dict
if method == "megatron":
if self._checkpoint_converter is None:
Expand Down
3 changes: 3 additions & 0 deletions trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,9 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
self.trainer.group_name = config.group
self.trainer.experiment_name = config.name
self.trainer.default_local_dir = config.checkpoint_job_dir
if config.trainer.max_checkpoints_to_keep is not None:
self.trainer.max_actor_ckpt_to_keep = config.trainer.max_checkpoints_to_keep
self.trainer.max_critic_ckpt_to_keep = config.trainer.max_checkpoints_to_keep
if not config.continue_from_checkpoint:
self.trainer.resume_mode = "disable"
else:
Expand Down
24 changes: 19 additions & 5 deletions trinity/manager/synchronizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import os
import shutil
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -95,13 +96,14 @@ async def _find_verl_latest_state_dict(self) -> None:
)
while True:
if os.path.exists(local_latest_state_dict_iteration):
current_model_version = self.model_version
try:
with open(local_latest_state_dict_iteration, "r") as f:
latest_model_version = int(f.read().strip())
except (IOError, ValueError) as e:
self.logger.warning(f"Failed to read or parse state dict iteration file: {e}")
continue
if latest_model_version > self.model_version:
if latest_model_version > current_model_version:
self.logger.info(
f"Synchronizer has found a new model state dict at step {latest_model_version}."
)
Expand All @@ -119,8 +121,22 @@ async def _find_verl_latest_state_dict(self) -> None:
f"Synchronizer has loaded model state dict from checkpoint {latest_model_version}."
)
await self.set_model_state_dict(model_state_dict, latest_model_version)
# remove the previous checkpoints to save disk space
await self._remove_previous_state_dict(current_model_version)
await asyncio.sleep(1)

async def _remove_previous_state_dict(self, previous_model_version: int) -> None:
previous_state_dict_dir = os.path.join(
self.config.checkpoint_job_dir, f"global_step_{previous_model_version}"
)
if os.path.exists(previous_state_dict_dir):
# check if it's a full checkpoint, only remove checkpoints for sync
if not os.path.exists(os.path.join(previous_state_dict_dir, ".full_checkpoint")):
self.logger.info(
f"Removing previous checkpoint for sync at step {previous_model_version}."
)
shutil.rmtree(previous_state_dict_dir, ignore_errors=True)

async def _find_tinker_latest_state_dict(self) -> None:
default_local_dir = self.config.checkpoint_job_dir
local_latest_state_dict_iteration = os.path.join(
Expand Down Expand Up @@ -320,17 +336,15 @@ async def get_latest_model_version(self) -> int:
async with self._ready_condition:
return self.model_version

async def ready_to_nccl_sync(
self, module: str, trainer_step: Optional[int] = None
) -> Union[int, None]:
async def ready_to_nccl_sync(self, module: str, trainer_step: int) -> Union[int, None]:
"""
Prepare for NCCL-based synchronization between modules.

Only supports one explorer currently.

Args:
module: Either 'trainer' or 'explorer'.
trainer_step: Optional step number from the trainer.
trainer_step: Step number from the trainer.

Returns:
The model version if both sides are ready; otherwise None.
Expand Down
9 changes: 9 additions & 0 deletions trinity/trainer/tinker_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,15 @@ def save_checkpoint(self, block_until_saved: bool = False, save_as_hf: bool = Fa
f"global_step_{self.train_step_num}",
)
os.makedirs(local_path, exist_ok=True)

# save a flag to indicate this is a full checkpoint dir
# make sure this flag is created before notifying the synchronizer
# to avoid the synchronizer recognizing it as a state_dict-only checkpoint
# TODO: use a better way to indicate full checkpoint
flag_path = os.path.join(local_path, ".full_checkpoint")
with open(flag_path, "w") as f:
f.write("")

remote_checkpoint_path = os.path.join(local_path, "remote_checkpoint_path.txt")
with open(remote_checkpoint_path, "w") as f:
f.write(self.latest_remote_checkpoint_path)
Expand Down
6 changes: 6 additions & 0 deletions trinity/trainer/verl/fsdp_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@

from trinity.manager.synchronizer import Synchronizer
from trinity.trainer.verl_trainer import CheckpointMonitor
from trinity.utils.log import get_logger


class FSDPCheckpointManager(OldFSDPCheckpointManager):
Expand All @@ -62,6 +63,7 @@ class FSDPCheckpointManager(OldFSDPCheckpointManager):

def __init__(self, *args, ray_namespace: str = "", **kwargs):
super().__init__(*args, **kwargs)
self.logger = get_logger()
self.synchronizer = Synchronizer.get_actor(namespace=ray_namespace)
self.checkpoint_monitor = CheckpointMonitor.get_actor(
namespace=ray_namespace,
Expand Down Expand Up @@ -439,6 +441,10 @@ def save_checkpoint(
and local_path != self.previous_saved_paths[-1] # type: ignore
): # last step may save twice
keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1 # type: ignore
self.logger.info(
"Checkpoint manager is removing previous checkpoints at "
+ str(self.previous_saved_paths[:keep_start]) # type: ignore
)
self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start]) # type: ignore
self.previous_saved_paths = self.previous_saved_paths[keep_start:] # type: ignore

Expand Down
6 changes: 6 additions & 0 deletions trinity/trainer/verl/megatron_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

from trinity.manager.synchronizer import Synchronizer
from trinity.trainer.verl_trainer import CheckpointMonitor
from trinity.utils.log import get_logger


class MegatronCheckpointManager(OldMegatronCheckpointManager):
Expand All @@ -59,6 +60,7 @@ def __init__(
*args,
**kwargs,
)
self.logger = get_logger()
self.synchronizer = Synchronizer.get_actor(namespace=ray_namespace)
self.checkpoint_monitor = CheckpointMonitor.get_actor(
namespace=ray_namespace,
Expand Down Expand Up @@ -340,6 +342,10 @@ def save_checkpoint(
and local_path != self.previous_saved_paths[-1] # type: ignore
): # last step may save twice
keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1 # type: ignore
self.logger.info(
"Checkpoint manager is removing previous checkpoints at "
+ str(self.previous_saved_paths[:keep_start]) # type: ignore
)
self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start]) # type: ignore
self.previous_saved_paths = self.previous_saved_paths[keep_start:] # type: ignore

Expand Down
9 changes: 9 additions & 0 deletions trinity/trainer/verl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,15 @@ def _save_checkpoint(self, save_as_hf: bool = False):
self.config.trainer.default_local_dir, f"global_step_{self.global_steps}"
)

# save a flag to indicate this is a full checkpoint dir
# make sure this flag is created before notifying the synchronizer
# to avoid the synchronizer recognizing it as a state_dict-only checkpoint
# TODO: use a better way to indicate full checkpoint
os.makedirs(local_global_step_folder, exist_ok=True)
flag_path = os.path.join(local_global_step_folder, ".full_checkpoint")
with open(flag_path, "w") as f:
f.write("")

self.logger.info(f"local_global_step_folder: {local_global_step_folder}")
actor_local_path = os.path.join(local_global_step_folder, "actor")

Expand Down