diff --git a/client_tools/client_generator.py b/client_tools/client_generator.py index 8927315d..c9f43fab 100644 --- a/client_tools/client_generator.py +++ b/client_tools/client_generator.py @@ -448,6 +448,7 @@ def generate_models(): GetStateDictResponse, GetTrainConfigsResponse, SaveResponse, + TrainingProgressResponse, ) @@ -617,6 +618,24 @@ def load(self, name: str, **kwargs) -> None: ) response.raise_for_status() + def load_training_state(self, name: str, **kwargs) -> Dict[str, Any]: + """Load optimizer, scheduler, scaler, RNG, and progress metadata from a checkpoint.""" + response = http_post( + url=f'{self.server_url}/load_training_state', + json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} + ) + response.raise_for_status() + return TrainingProgressResponse(**response.json()).result + + def read_training_progress(self, name: str, **kwargs) -> Dict[str, Any]: + """Read progress-only checkpoint metadata for resume-only-model flows.""" + response = http_post( + url=f'{self.server_url}/read_training_progress', + json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} + ) + response.raise_for_status() + return TrainingProgressResponse(**response.json()).result + def apply_patch(self, patch_cls: str, **kwargs) -> None: """Apply a patch to the model.""" response = http_post( diff --git a/cookbook/client/twinkle/self_host/self_congnition.py b/cookbook/client/twinkle/self_host/self_congnition.py index e31daaba..967cddbe 100644 --- a/cookbook/client/twinkle/self_host/self_congnition.py +++ b/cookbook/client/twinkle/self_host/self_congnition.py @@ -99,21 +99,29 @@ def train(): # model.set_lr_scheduler('LinearLR') # Step 6: Optionally resume from a previous checkpoint + consumed_train_samples = 0 + global_step = 0 if resume_path: - logger.info(f'Resuming training from {resume_path}') - model.load(resume_path, load_optimizer=True) + logger.info(f'Resuming model weights from {resume_path}') + model.load(resume_path) + trainer_state = model.load_training_state(resume_path) + dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) + consumed_train_samples = int(trainer_state['consumed_train_samples']) + global_step = int(trainer_state['cur_step']) # Step 7: Run the training loop logger.info(model.get_train_configs().model_dump()) for epoch in range(3): logger.info(f'Starting epoch {epoch}') - for step, batch in enumerate(dataloader): + for _, batch in enumerate(dataloader): # Forward pass + backward pass (computes gradients) model.forward_backward(inputs=batch) # Step model.clip_grad_and_step() + consumed_train_samples += len(batch) + global_step += 1 # Equal to the following steps: # # Clip gradients to prevent exploding gradients (max norm = 1.0) # model.clip_grad_norm(1.0) @@ -125,13 +133,17 @@ def train(): # model.lr_step() # Log the loss every 2 steps (aligned with gradient accumulation) - if step % 2 == 0: + if global_step % 2 == 0: # Print metric metric = model.calculate_metric(is_training=True) - logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric.result}') + logger.info(f'Current is step {global_step} of {len(dataloader)}, metric: {metric.result}') # Step 8: Save the trained checkpoint - twinkle_path = model.save(name=f'twinkle-epoch-{epoch}', save_optimizer=True) + twinkle_path = model.save( + name=f'twinkle-epoch-{epoch}', + save_optimizer=True, + consumed_train_samples=consumed_train_samples, + ) logger.info(f'Saved checkpoint: {twinkle_path}') # Step 9: Upload the checkpoint to ModelScope Hub diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index 7b6bd2a8..45dd8ac1 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -1,3 +1,5 @@ +from pathlib import Path + from peft import LoraConfig from tqdm import tqdm @@ -8,77 +10,121 @@ from twinkle.model import TransformersModel from twinkle.preprocessor import SelfCognitionProcessor -# Construct a device_mesh, fsdp_size=2, dp=4 -device_mesh = DeviceMesh.from_sizes(fsdp_size=2, dp_size=4) -# use torchrun mode -twinkle.initialize(mode='local', global_device_mesh=device_mesh) +from resume_utils import resume_from_checkpoint logger = get_logger() +MODEL_ID = 'ms://Qwen/Qwen3.5-4B' +DATASET_ID = 'ms://swift/self-cognition' +TEMPLATE_NAME = 'Qwen3_5Template' +MODEL_NAME = 'twinkle大模型' +MODEL_AUTHOR = 'ModelScope社区' +FSDP_SIZE = 2 +DP_SIZE = 4 +BATCH_SIZE = 8 +LEARNING_RATE = 1e-4 +GRADIENT_ACCUMULATION_STEPS = 2 +LOG_INTERVAL = 20 +EVAL_INTERVAL = 40 +EVAL_SAMPLES = 100 +TRAIN_SAMPLES = 1000 + +OUTPUT_DIR = './output/fsdp2' +RESUME_FROM_CHECKPOINT = None +RESUME_ONLY_MODEL = False +IGNORE_DATA_SKIP = False +ADAPTER_NAME = 'default' + +# Construct a device_mesh +device_mesh = DeviceMesh.from_sizes(fsdp_size=FSDP_SIZE, dp_size=DP_SIZE) +# use torchrun mode +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + -def eval(model): - # 100 Samples - dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100))) - dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') - dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) +def build_dataset(num_samples: int) -> Dataset: + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(num_samples))) + dataset.set_template(TEMPLATE_NAME, model_id=MODEL_ID) + dataset.map(SelfCognitionProcessor(MODEL_NAME, MODEL_AUTHOR)) dataset.encode() - dataloader = DataLoader(dataset=dataset, batch_size=8) - for step, batch in tqdm(enumerate(dataloader)): + return dataset + + +def save_checkpoint(model: TransformersModel, checkpoint_name: str, consumed_train_samples: int): + model.save( + checkpoint_name, + output_dir=OUTPUT_DIR, + adapter_name=ADAPTER_NAME, + save_optimizer=True, + consumed_train_samples=consumed_train_samples, + ) + + +def evaluate(model): + dataloader = DataLoader(dataset=build_dataset(EVAL_SAMPLES), batch_size=BATCH_SIZE) + for batch in tqdm(dataloader): model.forward_only(inputs=batch) model.calculate_loss() - metrics = model.calculate_metric(is_training=False) - return metrics + return model.calculate_metric(is_training=False) def train(): - # 1000 samples - dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) - # Set template to prepare encoding - dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') - # Preprocess the dataset to standard format - dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) - # Encode dataset - dataset.encode() + dataset = build_dataset(TRAIN_SAMPLES) # Global batch size = 8, for GPUs, so 1 sample per GPU - dataloader = DataLoader(dataset=dataset, batch_size=8) + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) # Use a TransformersModel - model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B') + model = TransformersModel(model_id=MODEL_ID) model.model._no_split_modules = {'Qwen3_5DecoderLayer'} lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') # Add a lora to model, with name `default` - # Comment this to use full-parameter training - model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2) + model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) # Add Optimizer for lora `default` - model.set_optimizer(optimizer_cls='AdamW', lr=1e-4) + model.set_optimizer(optimizer_cls='AdamW', lr=LEARNING_RATE) # Add LRScheduler for lora `default` model.set_lr_scheduler( scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader)) + + consumed_train_samples = 0 + if RESUME_FROM_CHECKPOINT: + checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() + consumed_train_samples = resume_from_checkpoint( + model=model, + dataloader=dataloader, + checkpoint_path=checkpoint_path, + resume_only_model=RESUME_ONLY_MODEL, + ignore_data_skip=IGNORE_DATA_SKIP, + adapter_name=ADAPTER_NAME, + ) + logger.info(get_device_placement()) # Print the training config logger.info(model.get_train_configs()) logger.info(f'Total steps: {len(dataloader)}') - loss_metric = 99.0 + optimizer_group = model.optimizer_group[ADAPTER_NAME] + best_loss = float('inf') # lora: 8G * 8 # full: 18G * 8 - for step, batch in enumerate(dataloader): + for batch in dataloader: # Do forward and backward model.forward_backward(inputs=batch) # Step model.clip_grad_and_step() - if step % 20 == 0: + consumed_train_samples += BATCH_SIZE + cur_step = optimizer_group.cur_step + if cur_step % LOG_INTERVAL == 0: # Print metric metric = model.calculate_metric(is_training=True) - logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') - if step > 0 and step % 40 == 0: - metrics = eval(model) + logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}') + if cur_step > 0 and cur_step % EVAL_INTERVAL == 0: + metrics = evaluate(model) logger.info(f'Eval metric: {metrics}') - metrics['step'] = step - if loss_metric > float(metrics['loss']): - model.save(f'checkpoint-{step}') - loss_metric = float(metrics['loss']) - model.save(f'last-checkpoint') + metrics['step'] = cur_step + current_loss = float(metrics['loss']) + if current_loss < best_loss: + save_checkpoint(model, f'checkpoint-{cur_step}', consumed_train_samples) + best_loss = current_loss + save_checkpoint(model, 'last-checkpoint', consumed_train_samples) if __name__ == '__main__': diff --git a/cookbook/transformers/resume_utils.py b/cookbook/transformers/resume_utils.py new file mode 100644 index 00000000..fdacf075 --- /dev/null +++ b/cookbook/transformers/resume_utils.py @@ -0,0 +1,55 @@ +from pathlib import Path +from typing import Any, Optional + +from twinkle import get_logger + + +logger = get_logger() + + +def _build_model_kwargs(adapter_name: str) -> dict: + if not adapter_name: + return {} + return {'adapter_name': adapter_name} + + +def resume_from_checkpoint( + model: Any, + dataloader: Any, + checkpoint_path: Path, + *, + resume_only_model: bool, + ignore_data_skip: bool, + adapter_name: Optional[str] = None) -> int: + adapter_name = adapter_name or '' + checkpoint_dir = str(checkpoint_path) + model_kwargs = _build_model_kwargs(adapter_name) + if model_kwargs: + # Load adapter checkpoint. + model.load( + name=checkpoint_path.name, + output_dir=str(checkpoint_path.parent), + **model_kwargs, + ) + + if resume_only_model: + # Only load model weights, optionally skip data. + if ignore_data_skip: + logger.info('Resumed weights only and restarted progress from step 0.') + return 0 + progress = model.read_training_progress(checkpoint_dir, **model_kwargs) + # Skip consumed samples in dataloader and move optimizer to the right step. + consumed_train_samples = int(progress['consumed_train_samples']) + dataloader.skip_consumed_samples(consumed_train_samples) + optimizer_group = model.optimizer_group[adapter_name] + optimizer_group.cur_step = progress['cur_step'] + optimizer_group.gradient_accumulation_steps = progress['gradient_accumulation_steps'] + logger.info(f'Skipped {consumed_train_samples} consumed samples.') + return consumed_train_samples + + # Load full training state, including model weights, optimizer states, and training progress. + trainer_state = model.load_training_state(checkpoint_dir, **model_kwargs) + consumed_train_samples = int(trainer_state['consumed_train_samples']) + dataloader.skip_consumed_samples(consumed_train_samples) + logger.info(f'Restored full training state from step {trainer_state["cur_step"]}.') + return consumed_train_samples diff --git a/docs/source_en/Components/Model/TransformersModel.md b/docs/source_en/Components/Model/TransformersModel.md index f10b0351..ba4ba1b4 100644 --- a/docs/source_en/Components/Model/TransformersModel.md +++ b/docs/source_en/Components/Model/TransformersModel.md @@ -50,3 +50,16 @@ for data in dataloader: model.forward_backward(...) model.clip_grad_and_step(..., gradient_accumulation_steps=16) ``` + +## Checkpoint and Resume + +`TransformersModel.save()` can save either weights only or a resumable training checkpoint. + +- `model.save(name, save_optimizer=True, consumed_train_samples=...)` saves weights together with optimizer, scheduler, scaler, RNG, and `trainer_state.json`. +- `model.load(name, output_dir=..., adapter_name=...)` restores LoRA / adapter model weights. +- `model.read_training_progress(checkpoint_dir, ...)` reads checkpoint metadata such as `cur_step`, `gradient_accumulation_steps`, and `consumed_train_samples`. +- `model.load_training_state(checkpoint_dir, ...)` restores optimizer-related state and returns the training progress dictionary. + +For full-parameter training, restore model weights by constructing `TransformersModel` with the checkpoint path as `model_id`, for example `TransformersModel(model_id='./output/fsdp2/last-checkpoint')`, and then call `load_training_state(...)` to restore optimizer state and training progress. + +For end-to-end resume logic, including dataloader skipping, refer to `cookbook/transformers/fsdp2.py` and `cookbook/transformers/resume_utils.py`. diff --git a/docs/source_en/Usage Guide/Quick-Start.md b/docs/source_en/Usage Guide/Quick-Start.md index 6a05a53f..d8b72ace 100644 --- a/docs/source_en/Usage Guide/Quick-Start.md +++ b/docs/source_en/Usage Guide/Quick-Start.md @@ -230,6 +230,71 @@ When running, you need to launch training like this: CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 train.py ``` +### Resume from Checkpoint + +The local and `torchrun` training loops above can be extended to support checkpoint resumption. For a complete example, refer to `cookbook/transformers/fsdp2.py` together with `cookbook/transformers/resume_utils.py`. + +When saving a checkpoint intended for resumption, save both model weights and training progress: + +```python +consumed_train_samples = 0 + +def save_checkpoint(model, checkpoint_name): + model.save( + checkpoint_name, + output_dir='./output/fsdp2', + adapter_name='default', + save_optimizer=True, + consumed_train_samples=consumed_train_samples, + ) +``` + +`save_optimizer=True` stores optimizer-related state, and `consumed_train_samples` is written into `trainer_state.json` so the dataloader can skip samples that have already been consumed. + +To resume training, restore the checkpoint before entering the main loop: + +```python +from pathlib import Path + +from resume_utils import resume_from_checkpoint + +RESUME_FROM_CHECKPOINT = './output/fsdp2/last-checkpoint' +RESUME_ONLY_MODEL = False +IGNORE_DATA_SKIP = False + +consumed_train_samples = 0 +if RESUME_FROM_CHECKPOINT: + consumed_train_samples = resume_from_checkpoint( + model=model, + dataloader=dataloader, + checkpoint_path=Path(RESUME_FROM_CHECKPOINT).expanduser().resolve(), + resume_only_model=RESUME_ONLY_MODEL, + ignore_data_skip=IGNORE_DATA_SKIP, + adapter_name='default', + ) +``` + +This helper provides two common resume modes: + +- Full resume: restore weights, optimizer, scheduler, scaler, RNG state, and training progress, then skip consumed samples in the dataloader. +- Weights-only resume: restore only model weights. This is useful when you want to continue with fresh optimizer state or intentionally restart the schedule. + +When `RESUME_ONLY_MODEL=True`, `IGNORE_DATA_SKIP=False` still skips already consumed samples based on `trainer_state.json`. If you want to reload weights but restart the dataset from the beginning, set `IGNORE_DATA_SKIP=True`. + +The flow above is intended for LoRA / adapter training. For full-parameter training, restore model weights by passing the checkpoint path as `model_id` when constructing `TransformersModel`, instead of calling `model.load(...)`. For example: + +```python +resume_path = './output/fsdp2/last-checkpoint' +model = TransformersModel(model_id=resume_path) +trainer_state = model.load_training_state(resume_path) +dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) +``` + +In other words: + +- LoRA / adapter resume: create `TransformersModel` from the original base model, then restore via `model.load(...)` or `resume_from_checkpoint(...)`. +- Full-parameter resume: construct `TransformersModel(...)` with the checkpoint path as `model_id`, then call `load_training_state(...)` to restore optimizer state and training progress. + ### Ray Training [Ray](https://github.com/ray-project/ray) is a commonly used scheduling middleware framework for multi-machine model training and inference scenarios. It provides additional optimizations for multi-model, multi-device execution and resource management, and supports integration with Kubernetes systems for production deployment. These characteristics make it particularly suitable for complex training scenarios such as RL and GKD. @@ -413,6 +478,8 @@ python train.py A major feature of Twinkle is support for multi-tenant mixed training. Specifically, multiple users can use a single base model for LoRA training, which can greatly reduce server-side deployment costs. +Checkpoint resumption is also supported in client-server training. The recommended flow is to restore weights with `model.load(resume_path)`, then restore optimizer and progress metadata with `model.load_training_state(resume_path)`, and finally call `dataloader.skip_consumed_samples(...)`. See `docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md` and `cookbook/client/twinkle/self_host/self_congnition.py`. + Suppose we start a service using eight GPUs. First, we need to start the Ray cluster: ```shell diff --git a/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md b/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md index 66d98eec..e373e30c 100644 --- a/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md +++ b/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md @@ -122,32 +122,36 @@ model.set_optimizer('AdamW', lr=1e-4) model.set_lr_scheduler('LinearLR') # Step 5: Resume training (optional) +consumed_train_samples = 0 +global_step = 0 if resume_path: - logger.info(f'Resuming training from {resume_path}') - model.load(resume_path, load_optimizer=True) + logger.info(f'Resuming model weights from {resume_path}') + model.load(resume_path) + trainer_state = model.load_training_state(resume_path) + dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) + consumed_train_samples = int(trainer_state['consumed_train_samples']) + global_step = int(trainer_state['cur_step']) # Step 6: Training loop -for step, batch in enumerate(dataloader): +for _, batch in enumerate(dataloader): # Forward propagation + backward propagation - output = model.forward_backward(inputs=batch) + model.forward_backward(inputs=batch) - if step % 2 == 0: - logger.info(f'Step {step // 2}, loss: {output}') + # Step + model.clip_grad_and_step() + consumed_train_samples += len(batch) + global_step += 1 - # Gradient clipping - model.clip_grad_norm(1.0) - - # Optimizer update - model.step() - - # Zero gradients - model.zero_grad() - - # Learning rate scheduling - model.lr_step() + if global_step % 2 == 0: + metric = model.calculate_metric(is_training=True) + logger.info(f'Current is step {global_step} of {len(dataloader)}, metric: {metric.result}') # Step 7: Save checkpoint -twinkle_path = model.save(name=f'step-{step}', save_optimizer=True) +twinkle_path = model.save( + name=f'step-{global_step}', + save_optimizer=True, + consumed_train_samples=consumed_train_samples, +) logger.info(f"Saved checkpoint: {twinkle_path}") # Step 8: Upload to ModelScope Hub (optional) @@ -158,6 +162,15 @@ model.upload_to_hub( ) ``` +For checkpoint resumption, the recommended client-side flow is: + +1. Query the server for an existing checkpoint path with `client.list_checkpoints(...)` or `client.get_latest_checkpoint_path(...)`. +2. Call `model.load(resume_path)` to restore adapter weights. +3. Call `model.load_training_state(resume_path)` to restore optimizer, scheduler, RNG, and progress metadata. +4. Call `dataloader.skip_consumed_samples(...)` with `consumed_train_samples` from the returned trainer state. + +This matches the end-to-end example in `cookbook/client/twinkle/self_host/self_congnition.py`. + ## Differences with Megatron Backend When using the Megatron backend, the main differences in client code: diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" index 0b8e386a..a7d98732 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" @@ -231,6 +231,71 @@ if __name__ == '__main__': CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 train.py ``` +### 断点续训 + +上面的本地训练和 `torchrun` 训练循环,都可以扩展为支持断点续训。完整示例可以直接参考 `cookbook/transformers/fsdp2.py` 和 `cookbook/transformers/resume_utils.py`。 + +如果希望保存出来的 checkpoint 可以用于续训,保存时除了模型权重,还需要把训练进度一并落盘: + +```python +consumed_train_samples = 0 + +def save_checkpoint(model, checkpoint_name): + model.save( + checkpoint_name, + output_dir='./output/fsdp2', + adapter_name='default', + save_optimizer=True, + consumed_train_samples=consumed_train_samples, + ) +``` + +其中,`save_optimizer=True` 会保存优化器相关状态,`consumed_train_samples` 会写入 `trainer_state.json`,用于恢复时让 dataloader 跳过已经消费过的数据。 + +恢复训练时,建议在进入主训练循环之前先加载 checkpoint: + +```python +from pathlib import Path + +from resume_utils import resume_from_checkpoint + +RESUME_FROM_CHECKPOINT = './output/fsdp2/last-checkpoint' +RESUME_ONLY_MODEL = False +IGNORE_DATA_SKIP = False + +consumed_train_samples = 0 +if RESUME_FROM_CHECKPOINT: + consumed_train_samples = resume_from_checkpoint( + model=model, + dataloader=dataloader, + checkpoint_path=Path(RESUME_FROM_CHECKPOINT).expanduser().resolve(), + resume_only_model=RESUME_ONLY_MODEL, + ignore_data_skip=IGNORE_DATA_SKIP, + adapter_name='default', + ) +``` + +这个辅助函数覆盖了两种常见恢复模式: + +- 完整续训:恢复权重、优化器、学习率调度器、梯度缩放器、随机数状态和训练进度,并让 dataloader 跳过已消费样本。 +- 仅恢复权重:只加载模型权重,不恢复优化器等训练状态。适合希望沿用参数初始化、但重新开始优化过程的场景。 + +当 `RESUME_ONLY_MODEL=True` 且 `IGNORE_DATA_SKIP=False` 时,仍会根据 `trainer_state.json` 跳过已训练过的数据;如果你只想加载权重、但从数据集开头重新训练,可以把 `IGNORE_DATA_SKIP=True`。 + +上面的恢复流程默认针对 LoRA / adapter 训练。对于全参训练,恢复模型权重时需要在创建 `TransformersModel` 时直接把 `model_id` 设为 checkpoint 路径,而不是再调用 `model.load(...)`。例如: + +```python +resume_path = './output/fsdp2/last-checkpoint' +model = TransformersModel(model_id=resume_path) +trainer_state = model.load_training_state(resume_path) +dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) +``` + +也就是说: + +- LoRA / adapter 续训:先按原始 base model 创建 `TransformersModel`,再通过 `model.load(...)` 或 `resume_from_checkpoint(...)` 恢复。 +- 全参续训:在 `TransformersModel(...)` 初始化时直接传入 checkpoint 路径作为 `model_id`,随后再调用 `load_training_state(...)` 恢复优化器和训练进度。 + ### Ray训练 [Ray](https://github.com/ray-project/ray)是多机模型训练和推理场景中常用的调度中间件框架。它针对多模型、多设备的执行和资源管理进行了额外优化, @@ -412,6 +477,7 @@ python train.py ``` ### 远程训练 +client-server 训练场景同样支持断点续训。推荐流程是先通过 `model.load(resume_path)` 恢复权重,再通过 `model.load_training_state(resume_path)` 恢复优化器和训练进度元数据,最后调用 `dataloader.skip_consumed_samples(...)` 跳过已消费数据。详细示例可参考 `docs/source_zh/使用指引/服务端和客户端/Twinkle客户端.md` 和 `cookbook/client/twinkle/self_host/self_congnition.py`。 Twinkle的一大特色是支持多租户用户混合训练。具体来说,多个用户可以使用一个基模进行lora训练,这样可以极大减小服务端部署成本。 diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" index fd81ac1b..5d4bafe7 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" @@ -122,32 +122,36 @@ model.set_optimizer('AdamW', lr=1e-4) model.set_lr_scheduler('LinearLR') # Step 5: 恢复训练(可选) +consumed_train_samples = 0 +global_step = 0 if resume_path: - logger.info(f'Resuming training from {resume_path}') - model.load(resume_path, load_optimizer=True) + logger.info(f'Resuming model weights from {resume_path}') + model.load(resume_path) + trainer_state = model.load_training_state(resume_path) + dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) + consumed_train_samples = int(trainer_state['consumed_train_samples']) + global_step = int(trainer_state['cur_step']) # Step 6: 训练循环 -for step, batch in enumerate(dataloader): +for _, batch in enumerate(dataloader): # 前向传播 + 反向传播 - output = model.forward_backward(inputs=batch) + model.forward_backward(inputs=batch) - if step % 2 == 0: - logger.info(f'Step {step // 2}, loss: {output}') + # Step + model.clip_grad_and_step() + consumed_train_samples += len(batch) + global_step += 1 - # 梯度裁剪 - model.clip_grad_norm(1.0) - - # 优化器更新 - model.step() - - # 梯度清零 - model.zero_grad() - - # 学习率调度 - model.lr_step() + if global_step % 2 == 0: + metric = model.calculate_metric(is_training=True) + logger.info(f'Current is step {global_step} of {len(dataloader)}, metric: {metric.result}') # Step 7: 保存检查点 -twinkle_path = model.save(name=f'step-{step}', save_optimizer=True) +twinkle_path = model.save( + name=f'step-{global_step}', + save_optimizer=True, + consumed_train_samples=consumed_train_samples, +) logger.info(f"Saved checkpoint: {twinkle_path}") # Step 8: 上传到 ModelScope Hub(可选) @@ -158,6 +162,15 @@ model.upload_to_hub( ) ``` +Twinkle Client 场景下,推荐的断点续训流程是: + +1. 先通过 `client.list_checkpoints(...)` 或 `client.get_latest_checkpoint_path(...)` 获取已有 checkpoint 路径。 +2. 调用 `model.load(resume_path)` 恢复 adapter 权重。 +3. 调用 `model.load_training_state(resume_path)` 恢复优化器、调度器、随机数状态和训练进度元数据。 +4. 使用返回结果中的 `consumed_train_samples` 调用 `dataloader.skip_consumed_samples(...)`,跳过已经训练过的数据。 + +完整示例可直接参考 `cookbook/client/twinkle/self_host/self_congnition.py`。 + ## Megatron 后端的差异 使用 Megatron 后端时,客户端代码的主要差异: diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" index b7b9cf0f..bb494131 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" @@ -50,3 +50,16 @@ for data in dataloader: model.forward_backward(...) model.clip_grad_and_step(..., gradient_accumulation_steps=16) ``` + +## 检查点保存与续训 + +`TransformersModel.save()` 既可以只保存权重,也可以保存可续训的训练检查点。 + +- `model.save(name, save_optimizer=True, consumed_train_samples=...)` 会在保存权重的同时,保存优化器、学习率调度器、梯度缩放器、随机数状态以及 `trainer_state.json`。 +- `model.load(name, output_dir=..., adapter_name=...)` 用于恢复 LoRA / adapter 模型权重。 +- `model.read_training_progress(checkpoint_dir, ...)` 用于读取 checkpoint 中的训练进度元数据,例如 `cur_step`、`gradient_accumulation_steps` 和 `consumed_train_samples`。 +- `model.load_training_state(checkpoint_dir, ...)` 用于恢复优化器等训练状态,并返回训练进度字典。 + +对于全参训练,恢复模型权重时需要在创建 `TransformersModel` 时直接把 checkpoint 路径传给 `model_id`,例如 `TransformersModel(model_id='./output/fsdp2/last-checkpoint')`,随后再调用 `load_training_state(...)` 恢复优化器和训练进度。 + +如果需要完整的断点续训流程,包括 dataloader 跳过已消费数据的逻辑,建议直接参考 `cookbook/transformers/fsdp2.py` 和 `cookbook/transformers/resume_utils.py`。 diff --git a/src/twinkle/dataloader/dataloader.py b/src/twinkle/dataloader/dataloader.py index b3ce4f0f..0725ee47 100644 --- a/src/twinkle/dataloader/dataloader.py +++ b/src/twinkle/dataloader/dataloader.py @@ -1,4 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import copy +import os +import warnings from functools import partial from typing import Callable, Optional, Type, Union @@ -51,6 +54,10 @@ def __init__(self, self.dataloader_params['batch_size'] = batch_size self.device_mesh = device_mesh self.processor: Optional[InputProcessor] = None + self._skip_samples = 0 + self._base_batch_sampler = None + self._base_sampler = None + self._retry_sampler_seed = self._resolve_retry_sampler_seed() self._set_work_init_fn() def _set_work_init_fn(self): @@ -60,6 +67,17 @@ def _set_work_init_fn(self): num_workers=num_workers, rank=self.device_mesh.data_rank if self.device_mesh else 0) + @staticmethod + def _resolve_retry_sampler_seed() -> int: + env_seed = os.environ.get('TWINKLE_SEED') + if env_seed is not None: + return int(env_seed) + try: + from twinkle.infra import _seed + return int(_seed) + except Exception: + return 42 + @remote_function() def __len__(self): self._lazy_init_dataloader() @@ -97,7 +115,9 @@ def _lazy_init_dataloader(self): if not isinstance(self.dataset, IterableDataset): self.dataloader.__initialized = False - self._repeat_sample_and_shard() + self._base_batch_sampler = self.dataloader.batch_sampler + self._base_sampler = self.dataloader.sampler + self._rebuild_sampler_stack() self.dataloader.__initialized = True @remote_function() @@ -116,11 +136,41 @@ def __iter__(self): max_retries=self.max_retries) return _iter - def _repeat_sample_and_shard(self): - if self.dataloader.batch_sampler is not None and hasattr(self.dataloader.batch_sampler, 'sampler'): - self.dataloader.batch_sampler.sampler = RetrySampler( - self.dataloader.batch_sampler.sampler, self.dataset, max_retries=self.max_retries) - self.dataloader.batch_sampler = DeviceMeshSampler(self.dataloader.batch_sampler, self.device_mesh, - self.min_batch_size) - elif self.dataloader.sampler is not None: - self.dataloader.sampler = RetrySampler(self.dataloader.sampler, self.dataset, max_retries=self.max_retries) + @remote_function() + def skip_consumed_samples(self, consumed_train_samples: int) -> None: + from torch.utils.data import IterableDataset + + if isinstance(self.dataset, IterableDataset): + warnings.warn('IterableDataset does not support consumed-data skipping; continuing without skipping.') + self._skip_samples = 0 + return + + self._skip_samples = max(int(consumed_train_samples), 0) + if self.dataloader is not None: + self.dataloader.__initialized = False + self._rebuild_sampler_stack() + self.dataloader.__initialized = True + + def _rebuild_sampler_stack(self): + if self._base_batch_sampler is not None and hasattr(self._base_batch_sampler, 'sampler'): + batch_sampler = copy.copy(self._base_batch_sampler) + batch_sampler.sampler = RetrySampler( + self._base_sampler, + self.dataset, + max_retries=self.max_retries, + seed=self._retry_sampler_seed, + ) + self.dataloader.batch_sampler = DeviceMeshSampler( + batch_sampler, + self.device_mesh, + self.min_batch_size, + skip_samples=self._skip_samples, + ) + elif self._base_sampler is not None: + self.dataloader.sampler = RetrySampler( + self._base_sampler, + self.dataset, + max_retries=self.max_retries, + skip_samples=self._skip_samples, + seed=self._retry_sampler_seed, + ) diff --git a/src/twinkle/dataloader/device_mesh_sampler.py b/src/twinkle/dataloader/device_mesh_sampler.py index 955b85cd..1f649de3 100644 --- a/src/twinkle/dataloader/device_mesh_sampler.py +++ b/src/twinkle/dataloader/device_mesh_sampler.py @@ -12,15 +12,28 @@ class DeviceMeshSampler(BatchSampler): device_mesh: The device mesh. """ - def __init__(self, original_sampler: BatchSampler, device_mesh: DeviceMesh, min_batch_size: int = None): + def __init__(self, + original_sampler: BatchSampler, + device_mesh: DeviceMesh, + min_batch_size: int = None, + skip_samples: int = 0): self.original_sampler = original_sampler self.device_mesh = device_mesh self.min_batch_size = min_batch_size + self.skip_samples = skip_samples if self.min_batch_size is None and self.device_mesh is not None: self.min_batch_size = self.device_mesh.data_world_size def __iter__(self): + skipped = 0 for batch in self.original_sampler: + if skipped < self.skip_samples: + if skipped + len(batch) <= self.skip_samples: + skipped += len(batch) + continue + batch = batch[self.skip_samples - skipped:] + skipped = self.skip_samples + if not self.device_mesh: yield batch else: diff --git a/src/twinkle/dataloader/retry_sampler.py b/src/twinkle/dataloader/retry_sampler.py index 62f05660..4d8c92e0 100644 --- a/src/twinkle/dataloader/retry_sampler.py +++ b/src/twinkle/dataloader/retry_sampler.py @@ -14,13 +14,22 @@ class RetrySampler(Sampler): max_retries: The maximum number of retries. """ - def __init__(self, original_sampler: Sampler, dataset: Dataset, max_retries=20): + def __init__(self, + original_sampler: Sampler, + dataset: Dataset, + max_retries=20, + skip_samples: int = 0, + seed: int = 42): self.original_sampler = original_sampler self.dataset = dataset self.max_retries = max_retries + self.skip_samples = skip_samples + self.seed = int(seed) def __iter__(self): - total = 0 + emitted = 0 + seen_valid = 0 + target_total = max(len(self.dataset) - self.skip_samples, 0) for idx in self.original_sampler: for _ in range(self.max_retries): try: @@ -29,23 +38,25 @@ def __iter__(self): data = self.dataset[idx] if not data: continue + seen_valid += 1 + if seen_valid <= self.skip_samples: + break yield idx - total += 1 + emitted += 1 break except Exception: # noqa import traceback traceback.print_exc() continue else: - raise StopIteration(f'Max retries exceeded: {self.max_retries}, no valid data found.') + raise RuntimeError(f'Max retries exceeded: {self.max_retries}, no valid data found.') - origin_dataset_len = len(self.dataset) - if total >= origin_dataset_len: + if emitted >= target_total: return - for idx in np.random.RandomState().permutation(len(self.dataset)).tolist(): - if total >= origin_dataset_len: - raise StopIteration + for idx in np.random.RandomState(self.seed).permutation(len(self.dataset)).tolist(): + if emitted >= target_total: + return for _ in range(self.max_retries): try: # Skip None values and raises @@ -53,7 +64,8 @@ def __iter__(self): if not data: continue yield idx - total += 1 + emitted += 1 + break except Exception: # noqa import traceback traceback.print_exc() diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index 77a25b18..7339c1f0 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -240,7 +240,7 @@ def load(self, name: Optional[str] = None, output_dir: Optional[str] = None, **k self.multi_adapter.set_state_dict(adapter_name, adapter_weights) if load_optimizer: - self._load_optimizer(checkpoint_dir, adapter_name=adapter_name) + self.load_training_state(checkpoint_dir, adapter_name=adapter_name) @remote_function() def set_grad_scaler(self, **kwargs): diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py index 89b497e2..b9b81f59 100644 --- a/src/twinkle/model/transformers/strategy/accelerate.py +++ b/src/twinkle/model/transformers/strategy/accelerate.py @@ -124,6 +124,59 @@ def wrap_model(self, model, *args): def unwrap_model(self, model): return self.accelerator.unwrap_model(model, keep_torch_compile=False) + def _get_fsdp_plugin(self): + state = self.accelerator.state + return state.fsdp_plugin if hasattr(state, 'fsdp_plugin') else None + + def _prepare_fsdp2_sd_options(self): + fsdp_plugin = self._get_fsdp_plugin() + if fsdp_plugin is None or fsdp_plugin.fsdp_version != 2: + return None + + from torch.distributed.checkpoint.state_dict import StateDictOptions + from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType + + return StateDictOptions( + full_state_dict=fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT, + cpu_offload=getattr(fsdp_plugin.state_dict_config, 'offload_to_cpu', False), + broadcast_from_rank0=getattr(fsdp_plugin.state_dict_config, 'rank0_only', False), + ) + + def needs_wrapped_optimizer_state(self) -> bool: + fsdp_plugin = self._get_fsdp_plugin() + return fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2 + + def save_optimizer_checkpoint(self, model, optimizer, output_path: str): + fsdp_plugin = self._get_fsdp_plugin() + if fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2: + import torch + from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict + + optim_state = get_optimizer_state_dict(model, optimizer, options=self._prepare_fsdp2_sd_options()) + if self.accelerator.process_index == 0: + torch.save(optim_state, output_path) + return + + import torch + if self.accelerator.process_index == 0: + torch.save(optimizer.state_dict(), output_path) + + def load_optimizer_checkpoint(self, model, optimizer, input_path: str): + fsdp_plugin = self._get_fsdp_plugin() + if fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2: + import torch + from torch.distributed.checkpoint.state_dict import set_optimizer_state_dict + + optim_state = None + rank0_only = getattr(fsdp_plugin.optim_state_dict_config, 'rank0_only', False) + if self.accelerator.process_index == 0 or not rank0_only: + optim_state = torch.load(input_path, weights_only=True) + set_optimizer_state_dict(model, optimizer, optim_state, options=self._prepare_fsdp2_sd_options()) + return + + import torch + optimizer.load_state_dict(torch.load(input_path, map_location='cpu', weights_only=False)) + def get_full_state_dict(self, model) -> dict: """Collect full state dict.""" from twinkle.utils import torch_util diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index 48a1da85..ad675006 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -151,6 +151,53 @@ def wrap_model(self, model, optimizer=None): return model, optimizer + def _prepare_optimizer_state_dict_options(self, *, for_load: bool): + from torch.distributed.checkpoint.state_dict import StateDictOptions + + return StateDictOptions( + full_state_dict=True, + cpu_offload=not for_load, + broadcast_from_rank0=for_load, + ) + + def needs_wrapped_optimizer_state(self) -> bool: + return self.device_mesh is not None + + def save_optimizer_checkpoint(self, model, optimizer, output_path: str): + import torch + if not self.needs_wrapped_optimizer_state(): + if Platform.is_master(): + torch.save(optimizer.state_dict(), output_path) + return + + from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict + + optim_state = get_optimizer_state_dict( + model, + optimizer, + options=self._prepare_optimizer_state_dict_options(for_load=False), + ) + if Platform.is_master(): + torch.save(optim_state, output_path) + + def load_optimizer_checkpoint(self, model, optimizer, input_path: str): + import torch + if not self.needs_wrapped_optimizer_state(): + optimizer.load_state_dict(torch.load(input_path, map_location='cpu', weights_only=False)) + return + + from torch.distributed.checkpoint.state_dict import set_optimizer_state_dict + + optim_state = {} + if Platform.is_master(): + optim_state = torch.load(input_path, map_location='cpu', weights_only=True) + set_optimizer_state_dict( + model, + optimizer, + optim_state, + options=self._prepare_optimizer_state_dict_options(for_load=True), + ) + def get_ep_clip_kwargs(self, model) -> Dict[str, Any]: """Return EP-aware kwargs for normalize_and_clip_grad_norm.""" model = self.unwrap_model(model) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 6097ffe2..e2ee5243 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -2,7 +2,9 @@ import asyncio import contextlib import json +import numpy as np import os +import random import re import threading import torch @@ -40,6 +42,9 @@ from twinkle.utils import construct_class, selective_log_softmax, torch_util from twinkle.utils.framework import Torch from twinkle.utils.grad_clip import normalize_and_clip_grad_norm +from twinkle.utils.logger import get_logger + +logger = get_logger() @dataclass @@ -855,21 +860,57 @@ def save(self, name: Optional[str] = None, output_dir: Optional[str] = None, int if kwargs.get('save_optimizer', False): self._save_optimizer(checkpoint_dir, adapter_name=adapter_name) + self._save_training_state( + checkpoint_dir, + adapter_name=adapter_name, + consumed_train_samples=kwargs.get('consumed_train_samples', 0), + ) return checkpoint_dir def _save_optimizer(self, output_dir, **kwargs): adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] + optimizer = optimizer_config.optimizer + lr_scheduler = optimizer_config.lr_scheduler + if optimizer is not None: + optimizer_path = os.path.join(output_dir, 'optimizer.pt') + if hasattr(self.strategy, 'save_optimizer_checkpoint'): + self.strategy.save_optimizer_checkpoint(self.model, optimizer, optimizer_path) + elif Platform.is_master(): + torch.save(optimizer.state_dict(), optimizer_path) if Platform.is_master(): - optimizer = optimizer_config.optimizer - lr_scheduler = optimizer_config.lr_scheduler - if optimizer is not None: - torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt')) if lr_scheduler is not None: torch.save(lr_scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt')) + def _save_training_state(self, output_dir, **kwargs): + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + if not Platform.is_master(): + return + + trainer_state = { + 'checkpoint_version': 1, + 'cur_step': optimizer_config.cur_step, + 'gradient_accumulation_steps': optimizer_config.gradient_accumulation_steps, + 'consumed_train_samples': kwargs.get('consumed_train_samples', 0), + } + with open(os.path.join(output_dir, 'trainer_state.json'), 'w', encoding='utf-8') as f: + json.dump(trainer_state, f) + + if optimizer_config.scaler is not None: + torch.save( + { + 'scaler_state_dict': optimizer_config.scaler.state_dict(), + 'scaler_has_nan': optimizer_config.scaler_has_nan, + }, + os.path.join(output_dir, 'scaler.pt'), + ) + + torch.save(self._get_training_rng_state(), os.path.join(output_dir, 'rng_state.pt')) + def _save_tokenizer(self, output_dir, **kwargs): adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] @@ -913,10 +954,11 @@ def load_peft_weights_for_fsdp2(model, adapter_weights, adapter_name='default'): model_sd = model.state_dict() converted_weights = {} for key, value in adapter_weights.items(): - if f'.{adapter_name}.weight' not in key: - key = key.replace('.weight', f'.{adapter_name}.weight') - if key in model_sd: - param = model_sd[key] + model_key = key + if f'.{adapter_name}.weight' not in model_key: + model_key = model_key.replace('.weight', f'.{adapter_name}.weight') + if model_key in model_sd: + param = model_sd[model_key] if isinstance(param, DTensor) and not isinstance(value, DTensor): value = distribute_tensor(value.to(param.device), param.device_mesh, param.placements) converted_weights[key] = value @@ -935,20 +977,112 @@ def load_peft_weights_for_fsdp2(model, adapter_weights, adapter_name='default'): def _load_optimizer(self, checkpoint_dir, **kwargs): adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + strict = kwargs.pop('strict', False) # assume optimizer and lr_scheduler are created optimizer_config = self.optimizer_group[adapter_name] optimizer_path = os.path.join(checkpoint_dir, 'optimizer.pt') scheduler_path = os.path.join(checkpoint_dir, 'scheduler.pt') + if strict and not os.path.exists(optimizer_path): + raise FileNotFoundError(optimizer_path) + if strict and optimizer_config.lr_scheduler is not None and not os.path.exists(scheduler_path): + logger.warning( + f'Missing scheduler checkpoint {scheduler_path}; resuming without restoring lr scheduler state.', ) + if os.path.exists(optimizer_path) and optimizer_config.optimizer is not None: - state_dict = torch.load(optimizer_path, map_location='cpu') - optimizer_config.optimizer.load_state_dict(state_dict) + if getattr(self.strategy, 'needs_wrapped_optimizer_state', lambda: False)() and not self._model_wrapped: + self._lazy_wrap_model() + if hasattr(self.strategy, 'load_optimizer_checkpoint'): + self.strategy.load_optimizer_checkpoint(self.model, optimizer_config.optimizer, optimizer_path) + else: + state_dict = torch.load(optimizer_path, map_location='cpu', weights_only=True) + optimizer_config.optimizer.load_state_dict(state_dict) if os.path.exists(scheduler_path) and optimizer_config.lr_scheduler is not None: - state_dict = torch.load(scheduler_path, map_location='cpu') + state_dict = torch.load(scheduler_path, map_location='cpu', weights_only=True) optimizer_config.lr_scheduler.load_state_dict(state_dict) + def _load_scaler_state(self, scaler_path, **kwargs): + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + if optimizer_config.scaler is None: + raise ValueError(f'Grad scaler is not configured for adapter {adapter_name!r}') + + scaler_state = torch.load(scaler_path, map_location='cpu', weights_only=True) + optimizer_config.scaler.load_state_dict(scaler_state['scaler_state_dict']) + optimizer_config.scaler_has_nan = scaler_state.get('scaler_has_nan', False) + + def _get_training_rng_state(self): + state = { + 'python_rng_state': random.getstate(), + 'numpy_rng_state': np.random.get_state(), + 'torch_rng_state': torch.get_rng_state(), + } + + device_prefix = Platform.device_prefix() + device_module = getattr(torch, device_prefix, None) + if device_module and hasattr(device_module, 'is_available') and device_module.is_available(): + state['device_type'] = device_prefix + state['device_rng_state'] = device_module.get_rng_state() + else: + state['device_type'] = 'cpu' + state['device_rng_state'] = None + return state + + def _load_rng_state(self, rng_path): + rng_state = torch.load(rng_path, map_location='cpu', weights_only=False) + random.setstate(rng_state['python_rng_state']) + np.random.set_state(rng_state['numpy_rng_state']) + torch.set_rng_state(rng_state['torch_rng_state']) + + device_type = rng_state.get('device_type') + device_rng_state = rng_state.get('device_rng_state') + if device_type != 'cpu' and device_rng_state is not None: + device_module = getattr(torch, device_type, None) + if device_module and hasattr(device_module, 'is_available') and device_module.is_available(): + device_module.set_rng_state(device_rng_state) + + @remote_function() + def read_training_progress(self, checkpoint_dir, **kwargs): + trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') + if not os.path.exists(trainer_state_path): + raise FileNotFoundError(trainer_state_path) + + with open(trainer_state_path, encoding='utf-8') as f: + trainer_state = json.load(f) + + required_keys = {'checkpoint_version', 'cur_step', 'gradient_accumulation_steps', 'consumed_train_samples'} + missing_keys = required_keys - trainer_state.keys() + if missing_keys: + raise ValueError(f'Missing trainer_state keys: {sorted(missing_keys)}') + return trainer_state + + @remote_function() + def load_training_state(self, checkpoint_dir, **kwargs): + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + required_paths = { + 'trainer_state': os.path.join(checkpoint_dir, 'trainer_state.json'), + 'optimizer': os.path.join(checkpoint_dir, 'optimizer.pt'), + 'rng': os.path.join(checkpoint_dir, 'rng_state.pt'), + } + for path in required_paths.values(): + if not os.path.exists(path): + raise FileNotFoundError(path) + + trainer_state = self.read_training_progress(checkpoint_dir) + self._load_optimizer(checkpoint_dir, adapter_name=adapter_name, strict=True) + scaler_path = os.path.join(checkpoint_dir, 'scaler.pt') + if os.path.exists(scaler_path) and optimizer_config.scaler is not None: + self._load_scaler_state(scaler_path, adapter_name=adapter_name) + self._load_rng_state(required_paths['rng']) + + optimizer_config.cur_step = trainer_state['cur_step'] + optimizer_config.gradient_accumulation_steps = trainer_state['gradient_accumulation_steps'] + return trainer_state + @remote_function(collect='first') def get_state_dict(self, **kwargs): return self._get_trainable_parameters(kwargs.pop('adapter_name', self._get_default_group())) diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index 535a3bd7..250fdc57 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -8,9 +8,11 @@ """ from __future__ import annotations +import os import torch import traceback from fastapi import Depends, FastAPI, HTTPException, Request +from pathlib import Path from peft import LoraConfig from typing import TYPE_CHECKING, Any, Callable @@ -347,6 +349,58 @@ async def _task(): await run_task(self.schedule_task_and_wait(_task, task_type='load')) + @app.post('/twinkle/load_training_state', response_model=types.TrainingProgressResponse) + async def load_training_state( + request: Request, + body: types.LoadTrainingStateRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.TrainingProgressResponse: + token = await self._on_request_start(request) + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_resource_exists(adapter_name) + extra_kwargs = body.model_extra or {} + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + resolved = checkpoint_manager.resolve_load_path(body.name) + checkpoint_dir = ( + Path(resolved.checkpoint_dir, resolved.checkpoint_name).as_posix() + if resolved.checkpoint_dir else body.name) + ret = self.model.load_training_state( + checkpoint_dir, + adapter_name=adapter_name, + **extra_kwargs, + ) + return {'result': ret} + + return await run_task(self.schedule_task_and_wait(_task, task_type='load_training_state')) + + @app.post('/twinkle/read_training_progress', response_model=types.TrainingProgressResponse) + async def read_training_progress( + request: Request, + body: types.ReadTrainingProgressRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.TrainingProgressResponse: + token = await self._on_request_start(request) + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_resource_exists(adapter_name) + extra_kwargs = body.model_extra or {} + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + resolved = checkpoint_manager.resolve_load_path(body.name) + checkpoint_dir = ( + Path(resolved.checkpoint_dir, resolved.checkpoint_name).as_posix() + if resolved.checkpoint_dir else body.name) + ret = self.model.read_training_progress( + checkpoint_dir, + adapter_name=adapter_name, + **extra_kwargs, + ) + return {'result': ret} + + return await run_task(self.schedule_task_and_wait(_task, task_type='read_training_progress')) + @app.post('/twinkle/upload_to_hub') async def upload_to_hub( request: Request, diff --git a/src/twinkle_client/dataloader/dataloader.py b/src/twinkle_client/dataloader/dataloader.py index 0a067ddd..f6a24fe4 100644 --- a/src/twinkle_client/dataloader/dataloader.py +++ b/src/twinkle_client/dataloader/dataloader.py @@ -82,4 +82,17 @@ def __next__(self): ) response.raise_for_status() return response.json()["result"] + + + def skip_consumed_samples(self, consumed_train_samples: int): + response = http_post( + url=f'{self.server_url}/call', + json_data={ + 'processor_id': self.processor_id, + 'function': 'skip_consumed_samples', + **{'consumed_train_samples': consumed_train_samples}, + } + ) + response.raise_for_status() + return response.json()["result"] \ No newline at end of file diff --git a/src/twinkle_client/model/multi_lora_transformers.py b/src/twinkle_client/model/multi_lora_transformers.py index 743125d9..37eac765 100644 --- a/src/twinkle_client/model/multi_lora_transformers.py +++ b/src/twinkle_client/model/multi_lora_transformers.py @@ -19,6 +19,7 @@ GetStateDictResponse, GetTrainConfigsResponse, SaveResponse, + TrainingProgressResponse, ) @@ -188,6 +189,24 @@ def load(self, name: str, **kwargs) -> None: ) response.raise_for_status() + def load_training_state(self, name: str, **kwargs) -> Dict[str, Any]: + """Load optimizer, scheduler, scaler, RNG, and progress metadata from a checkpoint.""" + response = http_post( + url=f'{self.server_url}/load_training_state', + json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} + ) + response.raise_for_status() + return TrainingProgressResponse(**response.json()).result + + def read_training_progress(self, name: str, **kwargs) -> Dict[str, Any]: + """Read progress-only checkpoint metadata for resume-only-model flows.""" + response = http_post( + url=f'{self.server_url}/read_training_progress', + json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} + ) + response.raise_for_status() + return TrainingProgressResponse(**response.json()).result + def apply_patch(self, patch_cls: str, **kwargs) -> None: """Apply a patch to the model.""" response = http_post( diff --git a/src/twinkle_client/types/__init__.py b/src/twinkle_client/types/__init__.py index 00b1f967..0e5d37e1 100644 --- a/src/twinkle_client/types/__init__.py +++ b/src/twinkle_client/types/__init__.py @@ -23,11 +23,13 @@ GetStateDictRequest, GetStateDictResponse, GetTrainConfigsResponse, + LoadTrainingStateRequest, LoadRequest, LoadResponse, LrStepResponse, ModelResult, OkResponse, + ReadTrainingProgressRequest, SaveRequest, SaveResponse, SetLossRequest, @@ -41,6 +43,7 @@ SetTemplateRequest, SetTemplateResponse, StepResponse, + TrainingProgressResponse, UploadToHubRequest, UploadToHubResponse, ZeroGradResponse, diff --git a/src/twinkle_client/types/model.py b/src/twinkle_client/types/model.py index e594bae4..0b6d7c08 100644 --- a/src/twinkle_client/types/model.py +++ b/src/twinkle_client/types/model.py @@ -89,6 +89,22 @@ class Config: extra = 'allow' +class LoadTrainingStateRequest(BaseModel): + adapter_name: str + name: str + + class Config: + extra = 'allow' + + +class ReadTrainingProgressRequest(BaseModel): + adapter_name: str + name: str + + class Config: + extra = 'allow' + + class AddAdapterRequest(BaseModel): adapter_name: str config: str @@ -212,6 +228,11 @@ class SaveResponse(BaseModel): checkpoint_dir: Optional[str] = None +class TrainingProgressResponse(BaseModel): + """Response for /read_training_progress endpoint (returns progress metadata).""" + result: Dict[str, Any] + + # --- Void responses (return None → OkResponse) --- class BackwardResponse(OkResponse): diff --git a/tests/dataloader/test_dataloader.py b/tests/dataloader/test_dataloader.py index 79bf78ad..edad0dd3 100644 --- a/tests/dataloader/test_dataloader.py +++ b/tests/dataloader/test_dataloader.py @@ -1,16 +1,31 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import concurrent.futures import numpy as np import os import pytest from pathlib import Path +from torch.utils.data import Dataset as TorchDataset +from torch.utils.data import IterableDataset as TorchIterableDataset import twinkle from twinkle import DeviceMesh from twinkle.data_format import Message from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta +from twinkle.dataset import Dataset, DatasetMeta, IterableDataset from twinkle.processor import InputProcessor + +class _NoOpProcessPoolExecutor: + + def __init__(self, *args, **kwargs): + pass + + def submit(self, fn, *args, **kwargs): + raise RuntimeError('Process pool is disabled in this test environment.') + + +concurrent.futures.ProcessPoolExecutor = _NoOpProcessPoolExecutor + twinkle.initialize(mode='local') TEST_DATA_DIR = Path(__file__).parent.parent / 'dataset' / 'test_data' @@ -22,6 +37,44 @@ def convert_to_messages(example): return {'messages': [Message(role='user', content=text), Message(role='assistant', content='Response')]} +def _build_resume_rows(): + return [ + { + 'text': 'Hello world' + }, + { + 'text': 'Test data' + }, + { + 'text': 'Another example' + }, + { + 'text': 'Sample text' + }, + ] + + +class _InMemoryDataset(TorchDataset): + + def __init__(self, rows): + self.rows = rows + + def __len__(self): + return len(self.rows) + + def __getitem__(self, idx): + return self.rows[idx] + + +class _InMemoryIterableDataset(TorchIterableDataset): + + def __init__(self, rows): + self.rows = rows + + def __iter__(self): + return iter(self.rows) + + class TestDataLoaderBasic: def test_dataloader_basic(self): @@ -157,3 +210,25 @@ def test_retry_sampler_length(self): total_samples = sum(len(batch) for batch in dataloader) assert total_samples == original_len + + +class TestResumeSkip: + + def test_dataloader_skip_consumed_samples_for_map_style_dataset(self): + dataset = _InMemoryDataset(_build_resume_rows()) + dataloader = DataLoader(dataset=dataset, batch_size=2, num_workers=0) + + dataloader.skip_consumed_samples(2) + batches = list(dataloader) + + texts = [item['text'] for batch in batches for item in batch] + assert texts[0] == 'Another example' + + def test_dataloader_warns_when_skip_requested_for_iterable_dataset(self, recwarn): + dataset = _InMemoryIterableDataset(_build_resume_rows()) + dataloader = DataLoader(dataset=dataset, batch_size=2, num_workers=0) + + dataloader.skip_consumed_samples(2) + next(iter(dataloader)) + + assert 'does not support consumed-data skipping' in str(recwarn.pop(UserWarning).message) diff --git a/tests/dataloader/test_sampler.py b/tests/dataloader/test_sampler.py index b8438207..d90c8725 100644 --- a/tests/dataloader/test_sampler.py +++ b/tests/dataloader/test_sampler.py @@ -1,13 +1,29 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import concurrent.futures +import numpy as np import os import pytest from pathlib import Path +from torch.utils.data import Dataset as TorchDataset from torch.utils.data import RandomSampler, SequentialSampler import twinkle +from twinkle import DeviceMesh from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta + +class _NoOpProcessPoolExecutor: + + def __init__(self, *args, **kwargs): + pass + + def submit(self, fn, *args, **kwargs): + raise RuntimeError('Process pool is disabled in this test environment.') + + +concurrent.futures.ProcessPoolExecutor = _NoOpProcessPoolExecutor + twinkle.initialize(mode='local') TEST_DATA_DIR = Path(__file__).parent.parent / 'dataset' / 'test_data' @@ -162,3 +178,42 @@ def test_sequential_vs_random_order(self): different = seq_texts != rand_texts assert different or len(seq_texts) == 1 + + +class TestResumeSkipSamplerOrdering: + + def test_sequential_sampler_skip_happens_before_device_mesh_slice(self): + + class _InMemoryDataset(TorchDataset): + + def __init__(self, rows): + self.rows = rows + + def __len__(self): + return len(self.rows) + + def __getitem__(self, idx): + return self.rows[idx] + + dataset = _InMemoryDataset([ + { + 'text': 'Hello world' + }, + { + 'text': 'Test data' + }, + { + 'text': 'Another example' + }, + { + 'text': 'Sample text' + }, + ]) + sampler = SequentialSampler(dataset) + device_mesh = DeviceMesh(device_type='cpu', mesh=np.array([0, 1]), mesh_dim_names=('dp', )) + dataloader = DataLoader(dataset=dataset, batch_size=4, sampler=sampler, device_mesh=device_mesh, num_workers=0) + + dataloader.skip_consumed_samples(2) + first_batch = list(dataloader)[0] + + assert first_batch[0]['text'] == 'Another example'