From 5cd3c0fb3496f7a0d15ad765e041a40ac4e85866 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Fri, 27 Mar 2026 12:00:57 +0800 Subject: [PATCH 01/28] docs: add transformers resume design spec --- ...7-transformers-checkpoint-resume-design.md | 353 ++++++++++++++++++ 1 file changed, 353 insertions(+) create mode 100644 docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md diff --git a/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md b/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md new file mode 100644 index 00000000..a821402a --- /dev/null +++ b/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md @@ -0,0 +1,353 @@ +# Transformers Strict Resume Design + +## Summary + +This design adds real checkpoint resumption support for `TransformersModel` without introducing a new trainer class. + +The implementation aligns the resume semantics of `TransformersModel` with the existing `MegatronModel` behavior: + +- normal weight loading remains available +- strict resume restores model weights and training state together +- strict resume does not silently fall back to weight-only loading when required state is missing + +Because Twinkle keeps the training loop explicit in user code, the design extends existing model, dataloader, server, and client interfaces rather than adding a central trainer abstraction. + +## Goals + +- Support true checkpoint resume for `TransformersModel` +- Restore model weights, optimizer state, scheduler state, RNG state, and step counters +- Support dataset progress skipping for map-style datasets +- Expose Swift-like resume controls without adding a new trainer class +- Preserve existing weight-only loading and saving behavior +- Keep backward compatibility for existing checkpoints where possible + +## Non-Goals + +- Do not introduce a new `Trainer` class or resume manager class +- Do not guarantee exact sample-by-sample replay when retry-based sampling changes sample order +- Do not support exact data-progress resume for `IterableDataset` or streaming datasets +- Do not attempt to persist transient runtime state such as in-flight batch tensors, current loss tensors, or metric caches + +## User-Facing Resume Controls + +Resume behavior is controlled by existing training entrypoints through three new parameters: + +- `resume_from_checkpoint: Optional[str] = None` +- `resume_only_model: bool = False` +- `ignore_data_skip: bool = False` + +### Parameter semantics + +#### `resume_from_checkpoint` + +- Specifies the checkpoint directory or checkpoint path to resume from +- When unset, training starts normally from scratch +- When set, the training entrypoint reads the checkpoint and restores model state through existing model APIs + +#### `resume_only_model` + +- Defaults to `False` +- When `False`, resume restores full training state +- When `True`, resume restores only model weights + +#### `ignore_data_skip` + +- Only meaningful when `resume_from_checkpoint` is set and `resume_only_model=True` +- Defaults to `False` +- When `False`, the system still restores training progress metadata needed for data skipping and step/epoch continuation, but does not restore optimizer, scheduler, or RNG +- When `True`, the system restores only model weights and does not restore training progress or skip consumed data + +### Effective behavior matrix + +#### Case 1: `resume_from_checkpoint is None` + +- Start a new training run + +#### Case 2: `resume_from_checkpoint is not None` and `resume_only_model=False` + +- Restore model weights +- Restore optimizer state +- Restore scheduler state +- Restore RNG state +- Restore step counters +- Attempt to skip already consumed training data +- If required model training state is missing, fail without fallback + +#### Case 3: `resume_from_checkpoint is not None` and `resume_only_model=True` and `ignore_data_skip=False` + +- Restore model weights only +- Do not restore optimizer, scheduler, or RNG +- Restore step/progress metadata needed for data skipping +- Attempt to skip already consumed training data + +#### Case 4: `resume_from_checkpoint is not None` and `resume_only_model=True` and `ignore_data_skip=True` + +- Restore model weights only +- Do not restore optimizer, scheduler, RNG, step counters, or data progress +- Restart the training loop from step 0 with no skipping + +## Checkpoint Layout + +Existing checkpoint layout remains valid. New resume metadata is added alongside current files. + +### Existing files preserved + +- model weights saved by `save_pretrained` +- LoRA weights saved as `adapter_model.safetensors` +- tokenizer artifacts +- `optimizer.pt` +- `scheduler.pt` + +### New file + +- `training_state.pt` + +### `training_state.pt` contents + +`training_state.pt` stores a small dictionary with the following fields: + +- `checkpoint_version` +- `cur_step` +- `gradient_accumulation_steps` +- `scaler_state_dict` +- `scaler_has_nan` +- `rng_state` +- `data_progress` + +### `rng_state` contents + +- Python `random` state +- NumPy RNG state +- PyTorch CPU RNG state +- CUDA RNG state + +### `data_progress` contents + +First version stores progress in a compact form: + +- `consumed_train_samples` +- optionally `consumed_batches` when this is easier to compute reliably in a given entrypoint + +The design prefers storing `consumed_train_samples` as the canonical progress value and deriving batch skipping from it where needed. + +## Model Save and Load Semantics + +## `TransformersModel.save` + +`TransformersModel.save(..., save_optimizer=True)` is extended to: + +1. Save weights exactly as today +2. Save tokenizer exactly as today +3. Save `optimizer.pt` and `scheduler.pt` exactly as today +4. Save `training_state.pt` + +When `save_optimizer=False`, save remains weight-only and does not produce strict resume metadata. + +## `TransformersModel.load` + +`TransformersModel.load(..., load_optimizer=False)` keeps current behavior: + +- load model weights only + +`TransformersModel.load(..., load_optimizer=True)` becomes strict model-state resume: + +1. Resolve checkpoint directory +2. Load model weights +3. Load optimizer and scheduler state +4. Load `training_state.pt` +5. Restore scaler state +6. Restore RNG state +7. Restore `cur_step` and `gradient_accumulation_steps` + +### Failure behavior + +When `load_optimizer=True`, missing required model training state is an error: + +- missing `training_state.pt` -> fail +- missing `optimizer.pt` when optimizer restore is required -> fail +- missing `scheduler.pt` when scheduler restore is required -> fail +- malformed required fields in `training_state.pt` -> fail + +This intentionally does not fall back to weight-only loading, to avoid falsely signaling successful strict resume. + +This matches the `MegatronModel` contract more closely than the current `TransformersModel` behavior. + +## Training Progress and Data Skipping + +Twinkle does not currently have a central trainer abstraction. Because of that, data skipping must be driven by existing training entrypoints and dataloader arguments. + +## Dataloader extensions + +Existing dataloader and sampler code is extended rather than replaced: + +- `twinkle.dataloader.DataLoader` +- `twinkle.dataloader.DeviceMeshSampler` +- retry-aware sampler flow + +The dataloader gains resume-oriented arguments: + +- `skip_samples: int = 0` +- optionally `skip_batches: int = 0` + +Map-style datasets use this progress to skip already consumed data before yielding new training batches. + +## Map-style dataset behavior + +For datasets with `__len__`, Twinkle attempts to skip previously consumed data using sampler or batch-sampler level skipping. + +Preferred behavior: + +- preserve existing sharding logic +- apply skip before data is yielded to the training loop +- keep the solution compatible with current `DeviceMeshSampler` wrapping + +## Iterable and streaming behavior + +`IterableDataset` and streaming datasets do not support exact progress skipping in this design. + +Behavior for these datasets: + +- restore model state according to the selected resume mode +- log a clear warning that consumed-data skipping is not supported +- continue training without skipping historical samples + +This is the only fallback allowed in the design. It applies only to dataset progress skipping, not to model-state resume. + +## Entry Point Integration + +No new trainer class is introduced. + +Resume parameters are threaded through existing training entrypoints: + +- direct local training loops using `TwinkleModel` / `TransformersModel` +- current client/server training flows that already support checkpoint save and load + +The practical integration model is: + +1. Parse or receive the three resume parameters +2. If `resume_from_checkpoint` is unset, construct dataloader normally +3. If `resume_only_model=False`, call existing model load with strict restore semantics +4. If `resume_only_model=True`, call weight-only model load +5. If data skipping is enabled, read progress metadata from `training_state.pt` +6. Recreate the dataloader with skip arguments applied + +This keeps the training loop explicit and compatible with current Twinkle examples. + +## Server and Client Behavior + +Server-side checkpoint save/load behavior should preserve current APIs while adding richer metadata. + +### Save path + +When server-side save endpoints request optimizer save: + +- save the model checkpoint as today +- save `optimizer.pt`, `scheduler.pt`, and `training_state.pt` +- persist checkpoint metadata through the existing checkpoint manager + +### Load path + +Current `load_optimizer=True` behavior is retained as the trigger for strict model-state restore. + +The new resume parameters are primarily a training-entrypoint concern. They orchestrate whether to: + +- call strict resume +- call weight-only resume +- request data skipping + +The underlying server model APIs do not need a new trainer object to support this. + +## Compatibility Strategy + +### Existing checkpoints + +Existing checkpoints remain loadable in weight-only mode. + +Examples: + +- `model.load(path, load_optimizer=False)` continues to work +- inference-only consumers remain unaffected + +### Old checkpoints under strict resume + +Old checkpoints that lack `training_state.pt` are not valid for strict `TransformersModel` resume. + +Expected behavior: + +- strict resume fails clearly +- weight-only load continues to work when requested explicitly + +### `resume_only_model=True` + +For `resume_only_model=True`, old checkpoints may still be usable if weight files are present. + +If data skipping is requested but no progress metadata exists, the entrypoint should fail clearly rather than silently train from the beginning while claiming resumed progress. + +## Risks and Constraints + +### RetrySampler interaction + +`RetrySampler` may retry or replace failed samples, including random backfill behavior at the tail of an epoch. + +Because of that: + +- progress skipping can preserve approximate data position +- exact sample-for-sample replay is not guaranteed when retry or backfill paths are exercised + +This limitation should be documented explicitly. + +### Dataset shape changes + +If dataset definition, slicing, filtering, or shuffle configuration changes between save and resume, data skipping semantics may become invalid. + +The user guidance should state that resume should be done with unchanged training parameters and unchanged dataset configuration. + +### Distributed consistency + +Skip logic must be compatible with current device-mesh sharding. The implementation should ensure skip is applied consistently before per-rank slicing causes divergence. + +## Testing Strategy + +Tests should cover: + +### Model-state save/load + +- `training_state.pt` is written when optimizer save is enabled +- scaler, RNG, `cur_step`, and accumulation settings are restored +- strict resume fails when required files are missing + +### Weight-only compatibility + +- legacy checkpoints still load in weight-only mode +- `resume_only_model=True` restores weights without optimizer and RNG + +### Data progress skipping + +- map-style datasets skip consumed data correctly +- skip behavior remains correct with device-mesh sharding +- iterable and streaming datasets emit warnings and continue without skipping + +### Failure cases + +- missing progress metadata when data skipping is requested +- malformed `training_state.pt` +- mismatch between requested strict resume and available checkpoint contents + +## Implementation Outline + +1. Extend `TransformersModel.save/load` to persist and restore `training_state.pt` +2. Add helper methods for RNG save/load and training-state serialization +3. Extend dataloader and sampler stack to support skip arguments for map-style datasets +4. Thread `resume_from_checkpoint`, `resume_only_model`, and `ignore_data_skip` through existing training entrypoints +5. Add warnings for unsupported iterable/streaming data skipping +6. Update docs and examples to prefer trainer-level resume parameters over ad hoc `model.load(..., load_optimizer=True)` logic + +## User Guidance + +Recommended guidance text: + +- To resume training, keep other parameters unchanged and provide `resume_from_checkpoint` +- `resume_only_model=False` performs full resume +- `resume_only_model=True` restores only model weights +- `ignore_data_skip=True` disables progress restore and starts from step 0 +- Iterable and streaming datasets do not support consumed-data skipping and will resume without skipping data From 91eeaebb0077b6f2ea456fef1dc64a653ec8eaa8 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Fri, 27 Mar 2026 15:33:11 +0800 Subject: [PATCH 02/28] docs: refine transformers resume design spec --- ...7-transformers-checkpoint-resume-design.md | 200 +++++++++++------- 1 file changed, 125 insertions(+), 75 deletions(-) diff --git a/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md b/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md index a821402a..9a1d45c2 100644 --- a/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md +++ b/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md @@ -4,22 +4,23 @@ This design adds real checkpoint resumption support for `TransformersModel` without introducing a new trainer class. -The implementation aligns the resume semantics of `TransformersModel` with the existing `MegatronModel` behavior: +The design supports both full-parameter training and LoRA training: -- normal weight loading remains available -- strict resume restores model weights and training state together -- strict resume does not silently fall back to weight-only loading when required state is missing +- full-parameter training restores weights during model initialization +- LoRA training restores adapter weights through the existing load path +- both modes share the same training-state resume contract +- strict model-state resume does not silently fall back to weight-only loading when required state is missing Because Twinkle keeps the training loop explicit in user code, the design extends existing model, dataloader, server, and client interfaces rather than adding a central trainer abstraction. ## Goals - Support true checkpoint resume for `TransformersModel` -- Restore model weights, optimizer state, scheduler state, RNG state, and step counters +- Support both full-parameter and LoRA training resume +- Restore optimizer state, scheduler state, scaler state, RNG state, and step counters - Support dataset progress skipping for map-style datasets - Expose Swift-like resume controls without adding a new trainer class - Preserve existing weight-only loading and saving behavior -- Keep backward compatibility for existing checkpoints where possible ## Non-Goals @@ -54,7 +55,7 @@ Resume behavior is controlled by existing training entrypoints through three new - Only meaningful when `resume_from_checkpoint` is set and `resume_only_model=True` - Defaults to `False` -- When `False`, the system still restores training progress metadata needed for data skipping and step/epoch continuation, but does not restore optimizer, scheduler, or RNG +- When `False`, the system still restores training progress metadata needed for data skipping and step/epoch continuation, but does not restore optimizer, scheduler, scaler, or RNG - When `True`, the system restores only model weights and does not restore training progress or skip consumed data ### Effective behavior matrix @@ -68,6 +69,7 @@ Resume behavior is controlled by existing training entrypoints through three new - Restore model weights - Restore optimizer state - Restore scheduler state +- Restore scaler state - Restore RNG state - Restore step counters - Attempt to skip already consumed training data @@ -76,102 +78,141 @@ Resume behavior is controlled by existing training entrypoints through three new #### Case 3: `resume_from_checkpoint is not None` and `resume_only_model=True` and `ignore_data_skip=False` - Restore model weights only -- Do not restore optimizer, scheduler, or RNG +- Do not restore optimizer, scheduler, scaler, or RNG - Restore step/progress metadata needed for data skipping - Attempt to skip already consumed training data #### Case 4: `resume_from_checkpoint is not None` and `resume_only_model=True` and `ignore_data_skip=True` - Restore model weights only -- Do not restore optimizer, scheduler, RNG, step counters, or data progress +- Do not restore optimizer, scheduler, scaler, RNG, step counters, or data progress - Restart the training loop from step 0 with no skipping ## Checkpoint Layout -Existing checkpoint layout remains valid. New resume metadata is added alongside current files. +Existing weight layouts remain valid. New training-state files are added alongside current checkpoint contents. ### Existing files preserved -- model weights saved by `save_pretrained` +- full-model weights saved by `save_pretrained` - LoRA weights saved as `adapter_model.safetensors` - tokenizer artifacts - `optimizer.pt` - `scheduler.pt` -### New file +### New training-state files -- `training_state.pt` +- `scaler.pt` +- `trainer_state.json` +- `rng_state.pt` -### `training_state.pt` contents +### `trainer_state.json` contents -`training_state.pt` stores a small dictionary with the following fields: +`trainer_state.json` stores lightweight training metadata: - `checkpoint_version` - `cur_step` - `gradient_accumulation_steps` -- `scaler_state_dict` -- `scaler_has_nan` -- `rng_state` -- `data_progress` +- `consumed_train_samples` +- optionally `consumed_batches` + +The design prefers storing `consumed_train_samples` as the canonical progress value and deriving batch skipping from it where needed. + +### `scaler.pt` contents -### `rng_state` contents +- AMP scaler state dict +- optional scaler-related flags such as `scaler_has_nan` + +### `rng_state.pt` contents - Python `random` state - NumPy RNG state - PyTorch CPU RNG state - CUDA RNG state -### `data_progress` contents +## Restore Paths -First version stores progress in a compact form: +## Full-Parameter Training -- `consumed_train_samples` -- optionally `consumed_batches` when this is easier to compute reliably in a given entrypoint +For full-parameter training, model weights are restored during initialization. -The design prefers storing `consumed_train_samples` as the canonical progress value and deriving batch skipping from it where needed. +### Full-parameter restore flow -## Model Save and Load Semantics +1. Construct `TransformersModel(model_id=ckpt_dir, ...)` +2. `__init__` uses `from_pretrained(ckpt_dir, ...)` to restore weights +3. Create optimizer, scheduler, and scaler objects +4. Call `load_training_state(ckpt_dir)` to restore training state +5. If data skipping is enabled, rebuild dataloader with skip arguments derived from `trainer_state.json` + +This means full-parameter resume does not need a separate model-weight loading method after initialization. It only needs explicit training-state restoration. + +## LoRA Training + +For LoRA training, the existing adapter-weight load path remains in place. + +### LoRA restore flow + +1. Construct the model and adapter objects as today +2. Restore adapter weights through the existing `load()` path +3. Create optimizer, scheduler, and scaler objects +4. Call the same `load_training_state(ckpt_dir)` method to restore training state +5. If data skipping is enabled, rebuild dataloader with skip arguments derived from `trainer_state.json` + +## Unified training-state method + +The model layer gains a shared helper such as `load_training_state(ckpt_dir)`. + +This method restores: -## `TransformersModel.save` +- `optimizer.pt` +- `scheduler.pt` +- `scaler.pt` +- `trainer_state.json` +- `rng_state.pt` -`TransformersModel.save(..., save_optimizer=True)` is extended to: +It assumes the corresponding optimizer, scheduler, and scaler objects have already been created before invocation. + +## Model Save and Load Semantics -1. Save weights exactly as today -2. Save tokenizer exactly as today -3. Save `optimizer.pt` and `scheduler.pt` exactly as today -4. Save `training_state.pt` +## Save behavior -When `save_optimizer=False`, save remains weight-only and does not produce strict resume metadata. +When saving with optimizer state enabled, the checkpoint includes: -## `TransformersModel.load` +- weights in the existing full-model or LoRA format +- tokenizer artifacts +- `optimizer.pt` +- `scheduler.pt` +- `scaler.pt` +- `trainer_state.json` +- `rng_state.pt` -`TransformersModel.load(..., load_optimizer=False)` keeps current behavior: +When optimizer save is disabled, save remains weight-only and does not produce strict resume metadata. -- load model weights only +## Strict training-state restore -`TransformersModel.load(..., load_optimizer=True)` becomes strict model-state resume: +Strict model-state resume restores: -1. Resolve checkpoint directory -2. Load model weights -3. Load optimizer and scheduler state -4. Load `training_state.pt` -5. Restore scaler state -6. Restore RNG state -7. Restore `cur_step` and `gradient_accumulation_steps` +- optimizer state +- scheduler state +- scaler state +- RNG state +- `cur_step` +- `gradient_accumulation_steps` +- data-progress metadata ### Failure behavior -When `load_optimizer=True`, missing required model training state is an error: +When strict training-state restore is requested, missing required model training state is an error: -- missing `training_state.pt` -> fail +- missing `trainer_state.json` -> fail - missing `optimizer.pt` when optimizer restore is required -> fail - missing `scheduler.pt` when scheduler restore is required -> fail -- malformed required fields in `training_state.pt` -> fail +- missing `scaler.pt` when scaler restore is required -> fail +- missing `rng_state.pt` when RNG restore is required -> fail +- malformed required fields -> fail This intentionally does not fall back to weight-only loading, to avoid falsely signaling successful strict resume. -This matches the `MegatronModel` contract more closely than the current `TransformersModel` behavior. - ## Training Progress and Data Skipping Twinkle does not currently have a central trainer abstraction. Because of that, data skipping must be driven by existing training entrypoints and dataloader arguments. @@ -226,10 +267,12 @@ The practical integration model is: 1. Parse or receive the three resume parameters 2. If `resume_from_checkpoint` is unset, construct dataloader normally -3. If `resume_only_model=False`, call existing model load with strict restore semantics -4. If `resume_only_model=True`, call weight-only model load -5. If data skipping is enabled, read progress metadata from `training_state.pt` -6. Recreate the dataloader with skip arguments applied +3. Construct model weights through the appropriate path + - full-parameter: restore through `__init__` + - LoRA: restore through existing adapter load logic +4. If `resume_only_model=False`, call `load_training_state(ckpt_dir)` +5. If `resume_only_model=True` and `ignore_data_skip=False`, read `trainer_state.json` for progress only +6. Recreate the dataloader with skip arguments applied when skipping is enabled This keeps the training loop explicit and compatible with current Twinkle examples. @@ -242,17 +285,17 @@ Server-side checkpoint save/load behavior should preserve current APIs while add When server-side save endpoints request optimizer save: - save the model checkpoint as today -- save `optimizer.pt`, `scheduler.pt`, and `training_state.pt` +- save `optimizer.pt`, `scheduler.pt`, `scaler.pt`, `trainer_state.json`, and `rng_state.pt` - persist checkpoint metadata through the existing checkpoint manager ### Load path -Current `load_optimizer=True` behavior is retained as the trigger for strict model-state restore. +Current model load APIs remain the weight-loading trigger. The new resume parameters are primarily a training-entrypoint concern. They orchestrate whether to: -- call strict resume -- call weight-only resume +- restore full training state +- restore weight only - request data skipping The underlying server model APIs do not need a new trainer object to support this. @@ -265,12 +308,13 @@ Existing checkpoints remain loadable in weight-only mode. Examples: -- `model.load(path, load_optimizer=False)` continues to work +- weight-only initialization for full-parameter checkpoints continues to work +- existing LoRA weight loading continues to work - inference-only consumers remain unaffected ### Old checkpoints under strict resume -Old checkpoints that lack `training_state.pt` are not valid for strict `TransformersModel` resume. +Old checkpoints that lack the new training-state files are not valid for strict resume. Expected behavior: @@ -310,16 +354,25 @@ Skip logic must be compatible with current device-mesh sharding. The implementat Tests should cover: -### Model-state save/load +### Full-parameter training resume + +- initializing with `model_id=ckpt_dir` restores weights +- `load_training_state(ckpt_dir)` restores optimizer, scheduler, scaler, RNG, and step metadata + +### LoRA training resume + +- adapter-weight restore continues to work +- `load_training_state(ckpt_dir)` restores shared training state correctly + +### Strict restore failures -- `training_state.pt` is written when optimizer save is enabled -- scaler, RNG, `cur_step`, and accumulation settings are restored - strict resume fails when required files are missing +- malformed state files fail clearly ### Weight-only compatibility - legacy checkpoints still load in weight-only mode -- `resume_only_model=True` restores weights without optimizer and RNG +- `resume_only_model=True` restores weights without optimizer, scheduler, scaler, or RNG ### Data progress skipping @@ -327,20 +380,16 @@ Tests should cover: - skip behavior remains correct with device-mesh sharding - iterable and streaming datasets emit warnings and continue without skipping -### Failure cases - -- missing progress metadata when data skipping is requested -- malformed `training_state.pt` -- mismatch between requested strict resume and available checkpoint contents - ## Implementation Outline -1. Extend `TransformersModel.save/load` to persist and restore `training_state.pt` -2. Add helper methods for RNG save/load and training-state serialization -3. Extend dataloader and sampler stack to support skip arguments for map-style datasets -4. Thread `resume_from_checkpoint`, `resume_only_model`, and `ignore_data_skip` through existing training entrypoints -5. Add warnings for unsupported iterable/streaming data skipping -6. Update docs and examples to prefer trainer-level resume parameters over ad hoc `model.load(..., load_optimizer=True)` logic +1. Add model helpers for saving and loading split training-state files +2. Implement `load_training_state(ckpt_dir)` with shared behavior for full-parameter and LoRA training +3. Keep full-parameter weight restore in `__init__` +4. Keep LoRA weight restore in the existing adapter load path +5. Extend dataloader and sampler stack to support skip arguments for map-style datasets +6. Thread `resume_from_checkpoint`, `resume_only_model`, and `ignore_data_skip` through existing training entrypoints +7. Add warnings for unsupported iterable and streaming data skipping +8. Update docs and examples to show the new resume contract ## User Guidance @@ -350,4 +399,5 @@ Recommended guidance text: - `resume_only_model=False` performs full resume - `resume_only_model=True` restores only model weights - `ignore_data_skip=True` disables progress restore and starts from step 0 +- Full-parameter checkpoints restore weights during model initialization and restore training state afterward - Iterable and streaming datasets do not support consumed-data skipping and will resume without skipping data From 6eebda8d049abc7fc045b1d445502c192a81424b Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Fri, 27 Mar 2026 15:41:55 +0800 Subject: [PATCH 03/28] docs: trim resume state fields --- .../specs/2026-03-27-transformers-checkpoint-resume-design.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md b/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md index 9a1d45c2..3b38a910 100644 --- a/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md +++ b/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md @@ -114,7 +114,6 @@ Existing weight layouts remain valid. New training-state files are added alongsi - `cur_step` - `gradient_accumulation_steps` - `consumed_train_samples` -- optionally `consumed_batches` The design prefers storing `consumed_train_samples` as the canonical progress value and deriving batch skipping from it where needed. From cdd9c1bc44fed6087ab278b04c27d7e06faa7237 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Fri, 27 Mar 2026 15:44:30 +0800 Subject: [PATCH 04/28] docs: add npu resume compatibility requirements --- ...7-transformers-checkpoint-resume-design.md | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md b/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md index 3b38a910..7b90baba 100644 --- a/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md +++ b/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md @@ -20,6 +20,7 @@ Because Twinkle keeps the training loop explicit in user code, the design extend - Restore optimizer state, scheduler state, scaler state, RNG state, and step counters - Support dataset progress skipping for map-style datasets - Expose Swift-like resume controls without adding a new trainer class +- Keep training-state save and load compatible with NPU (Ascend) environments - Preserve existing weight-only loading and saving behavior ## Non-Goals @@ -129,6 +130,35 @@ The design prefers storing `consumed_train_samples` as the canonical progress va - PyTorch CPU RNG state - CUDA RNG state +## Accelerator Compatibility + +Training-state save and load must be accelerator-compatible, including Ascend NPU environments already supported by Twinkle. + +### Device-agnostic serialization + +Training-state files must use device-agnostic serialization: + +- optimizer, scheduler, scaler, and RNG payloads should be serialized in CPU-safe form +- JSON metadata stays in plain text files +- loading should first read state from CPU-safe files and then apply it to objects created on the current runtime device + +This avoids tying resume files to a specific device object layout during save. + +### RNG compatibility requirements + +RNG save and restore must branch by current accelerator backend: + +- CUDA runtime uses `torch.cuda` RNG APIs +- NPU runtime uses `torch.npu` RNG APIs +- CPU RNG and Python/NumPy RNG are always restored + +The implementation must not assume CUDA-only RNG helpers when saving or restoring training state. + +### Scope of compatibility + +The design requires resume support to work correctly in NPU environments. + +The design does not require cross-accelerator resume guarantees such as saving on GPU and resuming on NPU, or saving on NPU and resuming on GPU. The compatibility target is correct save and restore within the active supported accelerator backend. ## Restore Paths ## Full-Parameter Training @@ -400,3 +430,4 @@ Recommended guidance text: - `ignore_data_skip=True` disables progress restore and starts from step 0 - Full-parameter checkpoints restore weights during model initialization and restore training state afterward - Iterable and streaming datasets do not support consumed-data skipping and will resume without skipping data + From 1542492f82db3e2a77abe9e38886e33bd2d2b005 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Fri, 27 Mar 2026 16:17:12 +0800 Subject: [PATCH 05/28] chore: ignore local worktrees --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 58f495d4..afdfcae9 100644 --- a/.gitignore +++ b/.gitignore @@ -143,6 +143,7 @@ images /custom/ megatron_output/ .qoder +.worktrees/ # Pytorch *.pth From 98831180162fab51893e04ff877f27a52ad843d2 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Mon, 30 Mar 2026 08:57:49 +0800 Subject: [PATCH 06/28] wip --- client_tools/client_generator.py | 23 +- ...00\344\275\263\345\256\236\350\267\265.md" | 24 +- src/twinkle/dataloader/dataloader.py | 53 ++- src/twinkle/dataloader/device_mesh_sampler.py | 15 +- src/twinkle/dataloader/retry_sampler.py | 20 +- .../transformers/multi_lora_transformers.py | 2 +- .../model/transformers/transformers.py | 124 +++++++- src/twinkle/server/model/twinkle_handlers.py | 51 +++ .../model/multi_lora_transformers.py | 18 ++ src/twinkle_client/types/__init__.py | 3 + src/twinkle_client/types/model.py | 21 ++ tests/dataloader/test_dataloader.py | 69 +++- tests/dataloader/test_sampler.py | 46 +++ .../transformers/test_checkpoint_resume.py | 301 ++++++++++++++++++ .../model/test_twinkle_resume_routes.py | 180 +++++++++++ 15 files changed, 924 insertions(+), 26 deletions(-) create mode 100644 tests/model/transformers/test_checkpoint_resume.py create mode 100644 tests/server/model/test_twinkle_resume_routes.py diff --git a/client_tools/client_generator.py b/client_tools/client_generator.py index 8927315d..f724c7c1 100644 --- a/client_tools/client_generator.py +++ b/client_tools/client_generator.py @@ -1,4 +1,4 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. +# Copyright (c) ModelScope Contributors. All rights reserved. import ast from pathlib import Path from typing import Dict, List, Set, Tuple @@ -448,6 +448,7 @@ def generate_models(): GetStateDictResponse, GetTrainConfigsResponse, SaveResponse, + TrainingProgressResponse, ) @@ -617,6 +618,23 @@ def load(self, name: str, **kwargs) -> None: ) response.raise_for_status() + def load_training_state(self, name: str, **kwargs) -> None: + """Load optimizer, scheduler, scaler, RNG, and progress metadata from a checkpoint.""" + response = http_post( + url=f'{self.server_url}/load_training_state', + json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} + ) + response.raise_for_status() + + def read_training_progress(self, name: str, **kwargs) -> Dict[str, Any]: + """Read progress-only checkpoint metadata for resume-only-model flows.""" + response = http_post( + url=f'{self.server_url}/read_training_progress', + json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} + ) + response.raise_for_status() + return TrainingProgressResponse(**response.json()).result + def apply_patch(self, patch_cls: str, **kwargs) -> None: """Apply a patch to the model.""" response = http_post( @@ -850,4 +868,5 @@ def apply_patch(self, patch_cls: str, **kwargs) -> None: generate_samplers() print('\n' + '=' * 60) - print('\n✓ All client code generation complete!\n') + print('\nAll client code generation complete!\n') + diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" index ad78e28d..f0112042 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" @@ -410,9 +410,19 @@ def train(): model.set_lr_scheduler('LinearLR') # 恢复训练(如有检查点) - if resume_path: - logger.info(f'Resuming training from {resume_path}') - model.load(resume_path, load_optimizer=True) + resume_from_checkpoint = resume_path + resume_only_model = False + ignore_data_skip = False + if resume_from_checkpoint: + logger.info(f'Resuming training from {resume_from_checkpoint}') + model.load(name=resume_from_checkpoint) + + if not resume_only_model: + trainer_state = model.load_training_state(resume_from_checkpoint) + dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) + elif not ignore_data_skip: + progress = model.read_training_progress(resume_from_checkpoint) + dataloader.skip_consumed_samples(progress['consumed_train_samples']) logger.info(model.get_train_configs()) @@ -445,6 +455,14 @@ if __name__ == '__main__': - 支持断点续训、检查点管理 - 可动态切换 LoRA 适配器、损失函数、优化器等组件 +Resume 模式: + +- `resume_from_checkpoint=None`:开始新的训练任务。 +- `resume_only_model=False`:恢复权重、optimizer、scheduler、scaler、RNG 和进度元数据。 +- `resume_only_model=True` 且 `ignore_data_skip=False`:恢复权重,读取进度元数据,并跳过已消费样本。 +- `resume_only_model=True` 且 `ignore_data_skip=True`:只恢复权重,训练步数和数据进度从 0 开始。 +- `skip_consumed_samples(...)` 不适用于 iterable / streaming dataset。 + ### 3.2 Tinker Client:简洁即用 Tinker 是一个轻量级训练 API。Twinkle 对 Tinker 客户端提供完整支持,几行代码就能拉起训练。已有 Tinker 代码的项目可以直接迎移到 Twinkle 服务端。 diff --git a/src/twinkle/dataloader/dataloader.py b/src/twinkle/dataloader/dataloader.py index b3ce4f0f..e2ef57ce 100644 --- a/src/twinkle/dataloader/dataloader.py +++ b/src/twinkle/dataloader/dataloader.py @@ -1,4 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import copy +import warnings from functools import partial from typing import Callable, Optional, Type, Union @@ -51,6 +53,9 @@ def __init__(self, self.dataloader_params['batch_size'] = batch_size self.device_mesh = device_mesh self.processor: Optional[InputProcessor] = None + self._skip_samples = 0 + self._base_batch_sampler = None + self._base_sampler = None self._set_work_init_fn() def _set_work_init_fn(self): @@ -97,7 +102,9 @@ def _lazy_init_dataloader(self): if not isinstance(self.dataset, IterableDataset): self.dataloader.__initialized = False - self._repeat_sample_and_shard() + self._base_batch_sampler = self.dataloader.batch_sampler + self._base_sampler = self.dataloader.sampler + self._rebuild_sampler_stack() self.dataloader.__initialized = True @remote_function() @@ -116,11 +123,39 @@ def __iter__(self): max_retries=self.max_retries) return _iter - def _repeat_sample_and_shard(self): - if self.dataloader.batch_sampler is not None and hasattr(self.dataloader.batch_sampler, 'sampler'): - self.dataloader.batch_sampler.sampler = RetrySampler( - self.dataloader.batch_sampler.sampler, self.dataset, max_retries=self.max_retries) - self.dataloader.batch_sampler = DeviceMeshSampler(self.dataloader.batch_sampler, self.device_mesh, - self.min_batch_size) - elif self.dataloader.sampler is not None: - self.dataloader.sampler = RetrySampler(self.dataloader.sampler, self.dataset, max_retries=self.max_retries) + @remote_function() + def skip_consumed_samples(self, consumed_train_samples: int) -> None: + from torch.utils.data import IterableDataset + + if isinstance(self.dataset, IterableDataset): + warnings.warn('IterableDataset does not support consumed-data skipping; continuing without skipping.') + self._skip_samples = 0 + return + + self._skip_samples = max(int(consumed_train_samples), 0) + if self.dataloader is not None: + self.dataloader.__initialized = False + self._rebuild_sampler_stack() + self.dataloader.__initialized = True + + def _rebuild_sampler_stack(self): + if self._base_batch_sampler is not None and hasattr(self._base_batch_sampler, 'sampler'): + batch_sampler = copy.copy(self._base_batch_sampler) + batch_sampler.sampler = RetrySampler( + self._base_sampler, + self.dataset, + max_retries=self.max_retries, + ) + self.dataloader.batch_sampler = DeviceMeshSampler( + batch_sampler, + self.device_mesh, + self.min_batch_size, + skip_samples=self._skip_samples, + ) + elif self._base_sampler is not None: + self.dataloader.sampler = RetrySampler( + self._base_sampler, + self.dataset, + max_retries=self.max_retries, + skip_samples=self._skip_samples, + ) diff --git a/src/twinkle/dataloader/device_mesh_sampler.py b/src/twinkle/dataloader/device_mesh_sampler.py index 955b85cd..1f649de3 100644 --- a/src/twinkle/dataloader/device_mesh_sampler.py +++ b/src/twinkle/dataloader/device_mesh_sampler.py @@ -12,15 +12,28 @@ class DeviceMeshSampler(BatchSampler): device_mesh: The device mesh. """ - def __init__(self, original_sampler: BatchSampler, device_mesh: DeviceMesh, min_batch_size: int = None): + def __init__(self, + original_sampler: BatchSampler, + device_mesh: DeviceMesh, + min_batch_size: int = None, + skip_samples: int = 0): self.original_sampler = original_sampler self.device_mesh = device_mesh self.min_batch_size = min_batch_size + self.skip_samples = skip_samples if self.min_batch_size is None and self.device_mesh is not None: self.min_batch_size = self.device_mesh.data_world_size def __iter__(self): + skipped = 0 for batch in self.original_sampler: + if skipped < self.skip_samples: + if skipped + len(batch) <= self.skip_samples: + skipped += len(batch) + continue + batch = batch[self.skip_samples - skipped:] + skipped = self.skip_samples + if not self.device_mesh: yield batch else: diff --git a/src/twinkle/dataloader/retry_sampler.py b/src/twinkle/dataloader/retry_sampler.py index 62f05660..43307b1a 100644 --- a/src/twinkle/dataloader/retry_sampler.py +++ b/src/twinkle/dataloader/retry_sampler.py @@ -14,13 +14,16 @@ class RetrySampler(Sampler): max_retries: The maximum number of retries. """ - def __init__(self, original_sampler: Sampler, dataset: Dataset, max_retries=20): + def __init__(self, original_sampler: Sampler, dataset: Dataset, max_retries=20, skip_samples: int = 0): self.original_sampler = original_sampler self.dataset = dataset self.max_retries = max_retries + self.skip_samples = skip_samples def __iter__(self): - total = 0 + emitted = 0 + seen_valid = 0 + target_total = max(len(self.dataset) - self.skip_samples, 0) for idx in self.original_sampler: for _ in range(self.max_retries): try: @@ -29,8 +32,11 @@ def __iter__(self): data = self.dataset[idx] if not data: continue + seen_valid += 1 + if seen_valid <= self.skip_samples: + break yield idx - total += 1 + emitted += 1 break except Exception: # noqa import traceback @@ -39,12 +45,11 @@ def __iter__(self): else: raise StopIteration(f'Max retries exceeded: {self.max_retries}, no valid data found.') - origin_dataset_len = len(self.dataset) - if total >= origin_dataset_len: + if emitted >= target_total: return for idx in np.random.RandomState().permutation(len(self.dataset)).tolist(): - if total >= origin_dataset_len: + if emitted >= target_total: raise StopIteration for _ in range(self.max_retries): try: @@ -53,7 +58,8 @@ def __iter__(self): if not data: continue yield idx - total += 1 + emitted += 1 + break except Exception: # noqa import traceback traceback.print_exc() diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index 0900f52b..7db01146 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -237,7 +237,7 @@ def load(self, name: Optional[str] = None, output_dir: Optional[str] = None, **k self.multi_adapter.set_state_dict(adapter_name, adapter_weights) if load_optimizer: - self._load_optimizer(checkpoint_dir, adapter_name=adapter_name) + self.load_training_state(checkpoint_dir, adapter_name=adapter_name) @remote_function() def set_grad_scaler(self, **kwargs): diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 520aaf9f..d588b8e7 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -3,8 +3,10 @@ import contextlib import json import os +import random import re import threading +import numpy as np import torch import torch.distributed as dist import transformers @@ -866,6 +868,11 @@ def save(self, name: Optional[str] = None, output_dir: Optional[str] = None, int if kwargs.get('save_optimizer', False): self._save_optimizer(checkpoint_dir, adapter_name=adapter_name) + self._save_training_state( + checkpoint_dir, + adapter_name=adapter_name, + consumed_train_samples=kwargs.get('consumed_train_samples', 0), + ) return checkpoint_dir @@ -881,6 +888,33 @@ def _save_optimizer(self, output_dir, **kwargs): if lr_scheduler is not None: torch.save(lr_scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt')) + def _save_training_state(self, output_dir, **kwargs): + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + if not Platform.is_master(): + return + + trainer_state = { + 'checkpoint_version': 1, + 'cur_step': optimizer_config.cur_step, + 'gradient_accumulation_steps': optimizer_config.gradient_accumulation_steps, + 'consumed_train_samples': kwargs.get('consumed_train_samples', 0), + } + with open(os.path.join(output_dir, 'trainer_state.json'), 'w', encoding='utf-8') as f: + json.dump(trainer_state, f) + + if optimizer_config.scaler is not None: + torch.save( + { + 'scaler_state_dict': optimizer_config.scaler.state_dict(), + 'scaler_has_nan': optimizer_config.scaler_has_nan, + }, + os.path.join(output_dir, 'scaler.pt'), + ) + + torch.save(self._get_training_rng_state(), os.path.join(output_dir, 'rng_state.pt')) + def _save_tokenizer(self, output_dir, **kwargs): adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] @@ -946,20 +980,106 @@ def load_peft_weights_for_fsdp2(model, adapter_weights, adapter_name='default'): def _load_optimizer(self, checkpoint_dir, **kwargs): adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + strict = kwargs.pop('strict', False) # assume optimizer and lr_scheduler are created optimizer_config = self.optimizer_group[adapter_name] optimizer_path = os.path.join(checkpoint_dir, 'optimizer.pt') scheduler_path = os.path.join(checkpoint_dir, 'scheduler.pt') + if strict and not os.path.exists(optimizer_path): + raise FileNotFoundError(optimizer_path) + if strict and not os.path.exists(scheduler_path): + raise FileNotFoundError(scheduler_path) + if os.path.exists(optimizer_path) and optimizer_config.optimizer is not None: - state_dict = torch.load(optimizer_path, map_location='cpu') + state_dict = torch.load(optimizer_path, map_location='cpu', weights_only=False) optimizer_config.optimizer.load_state_dict(state_dict) if os.path.exists(scheduler_path) and optimizer_config.lr_scheduler is not None: - state_dict = torch.load(scheduler_path, map_location='cpu') + state_dict = torch.load(scheduler_path, map_location='cpu', weights_only=False) optimizer_config.lr_scheduler.load_state_dict(state_dict) + def _load_scaler_state(self, scaler_path, **kwargs): + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + if optimizer_config.scaler is None: + raise ValueError(f'Grad scaler is not configured for adapter {adapter_name!r}') + + scaler_state = torch.load(scaler_path, map_location='cpu', weights_only=False) + optimizer_config.scaler.load_state_dict(scaler_state['scaler_state_dict']) + optimizer_config.scaler_has_nan = scaler_state.get('scaler_has_nan', False) + + def _get_training_rng_state(self): + state = { + 'python_rng_state': random.getstate(), + 'numpy_rng_state': np.random.get_state(), + 'torch_rng_state': torch.get_rng_state(), + } + if hasattr(torch, 'npu') and torch.npu.is_available(): + state['device_type'] = 'npu' + state['device_rng_state'] = torch.npu.get_rng_state() + elif torch.cuda.is_available(): + state['device_type'] = 'cuda' + state['device_rng_state'] = torch.cuda.get_rng_state_all() + else: + state['device_type'] = 'cpu' + state['device_rng_state'] = None + return state + + def _load_rng_state(self, rng_path): + rng_state = torch.load(rng_path, map_location='cpu', weights_only=False) + random.setstate(rng_state['python_rng_state']) + np.random.set_state(rng_state['numpy_rng_state']) + torch.set_rng_state(rng_state['torch_rng_state']) + + device_type = rng_state.get('device_type') + device_rng_state = rng_state.get('device_rng_state') + if device_type == 'npu' and hasattr(torch, 'npu') and torch.npu.is_available() and device_rng_state is not None: + torch.npu.set_rng_state(device_rng_state) + elif device_type == 'cuda' and torch.cuda.is_available() and device_rng_state is not None: + torch.cuda.set_rng_state_all(device_rng_state) + + @remote_function() + def read_training_progress(self, checkpoint_dir, **kwargs): + trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') + if not os.path.exists(trainer_state_path): + raise FileNotFoundError(trainer_state_path) + + with open(trainer_state_path, 'r', encoding='utf-8') as f: + trainer_state = json.load(f) + + required_keys = {'checkpoint_version', 'cur_step', 'gradient_accumulation_steps', 'consumed_train_samples'} + missing_keys = required_keys - trainer_state.keys() + if missing_keys: + raise ValueError(f'Missing trainer_state keys: {sorted(missing_keys)}') + return trainer_state + + @remote_function() + def load_training_state(self, checkpoint_dir, **kwargs): + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + required_paths = { + 'trainer_state': os.path.join(checkpoint_dir, 'trainer_state.json'), + 'optimizer': os.path.join(checkpoint_dir, 'optimizer.pt'), + 'scheduler': os.path.join(checkpoint_dir, 'scheduler.pt'), + 'scaler': os.path.join(checkpoint_dir, 'scaler.pt'), + 'rng': os.path.join(checkpoint_dir, 'rng_state.pt'), + } + for path in required_paths.values(): + if not os.path.exists(path): + raise FileNotFoundError(path) + + trainer_state = self.read_training_progress(checkpoint_dir) + self._load_optimizer(checkpoint_dir, adapter_name=adapter_name, strict=True) + self._load_scaler_state(required_paths['scaler'], adapter_name=adapter_name) + self._load_rng_state(required_paths['rng']) + + optimizer_config.cur_step = trainer_state['cur_step'] + optimizer_config.gradient_accumulation_steps = trainer_state['gradient_accumulation_steps'] + return trainer_state + @remote_function(collect='first') def get_state_dict(self, **kwargs): return self._get_trainable_parameters(kwargs.pop('adapter_name', self._get_default_group())) diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index 535a3bd7..259c50c4 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -10,6 +10,8 @@ import torch import traceback +import os +from pathlib import Path from fastapi import Depends, FastAPI, HTTPException, Request from peft import LoraConfig from typing import TYPE_CHECKING, Any, Callable @@ -347,6 +349,55 @@ async def _task(): await run_task(self.schedule_task_and_wait(_task, task_type='load')) + @app.post('/twinkle/load_training_state') + async def load_training_state( + request: Request, + body: types.LoadTrainingStateRequest, + self: ModelManagement = Depends(self_fn), + ) -> None: + token = await self._on_request_start(request) + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_resource_exists(adapter_name) + extra_kwargs = body.model_extra or {} + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + resolved = checkpoint_manager.resolve_load_path(body.name) + checkpoint_dir = (Path(resolved.checkpoint_dir, resolved.checkpoint_name).as_posix() + if resolved.checkpoint_dir else body.name) + self.model.load_training_state( + checkpoint_dir, + adapter_name=adapter_name, + **extra_kwargs, + ) + + await run_task(self.schedule_task_and_wait(_task, task_type='load_training_state')) + + @app.post('/twinkle/read_training_progress', response_model=types.TrainingProgressResponse) + async def read_training_progress( + request: Request, + body: types.ReadTrainingProgressRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.TrainingProgressResponse: + token = await self._on_request_start(request) + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_resource_exists(adapter_name) + extra_kwargs = body.model_extra or {} + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + resolved = checkpoint_manager.resolve_load_path(body.name) + checkpoint_dir = (Path(resolved.checkpoint_dir, resolved.checkpoint_name).as_posix() + if resolved.checkpoint_dir else body.name) + ret = self.model.read_training_progress( + checkpoint_dir, + adapter_name=adapter_name, + **extra_kwargs, + ) + return {'result': ret} + + return await run_task(self.schedule_task_and_wait(_task, task_type='read_training_progress')) + @app.post('/twinkle/upload_to_hub') async def upload_to_hub( request: Request, diff --git a/src/twinkle_client/model/multi_lora_transformers.py b/src/twinkle_client/model/multi_lora_transformers.py index 743125d9..2a8afdce 100644 --- a/src/twinkle_client/model/multi_lora_transformers.py +++ b/src/twinkle_client/model/multi_lora_transformers.py @@ -19,6 +19,7 @@ GetStateDictResponse, GetTrainConfigsResponse, SaveResponse, + TrainingProgressResponse, ) @@ -188,6 +189,23 @@ def load(self, name: str, **kwargs) -> None: ) response.raise_for_status() + def load_training_state(self, name: str, **kwargs) -> None: + """Load optimizer, scheduler, scaler, RNG, and progress metadata from a checkpoint.""" + response = http_post( + url=f'{self.server_url}/load_training_state', + json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} + ) + response.raise_for_status() + + def read_training_progress(self, name: str, **kwargs) -> Dict[str, Any]: + """Read progress-only checkpoint metadata for resume-only-model flows.""" + response = http_post( + url=f'{self.server_url}/read_training_progress', + json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} + ) + response.raise_for_status() + return TrainingProgressResponse(**response.json()).result + def apply_patch(self, patch_cls: str, **kwargs) -> None: """Apply a patch to the model.""" response = http_post( diff --git a/src/twinkle_client/types/__init__.py b/src/twinkle_client/types/__init__.py index 00b1f967..0e5d37e1 100644 --- a/src/twinkle_client/types/__init__.py +++ b/src/twinkle_client/types/__init__.py @@ -23,11 +23,13 @@ GetStateDictRequest, GetStateDictResponse, GetTrainConfigsResponse, + LoadTrainingStateRequest, LoadRequest, LoadResponse, LrStepResponse, ModelResult, OkResponse, + ReadTrainingProgressRequest, SaveRequest, SaveResponse, SetLossRequest, @@ -41,6 +43,7 @@ SetTemplateRequest, SetTemplateResponse, StepResponse, + TrainingProgressResponse, UploadToHubRequest, UploadToHubResponse, ZeroGradResponse, diff --git a/src/twinkle_client/types/model.py b/src/twinkle_client/types/model.py index e594bae4..0b6d7c08 100644 --- a/src/twinkle_client/types/model.py +++ b/src/twinkle_client/types/model.py @@ -89,6 +89,22 @@ class Config: extra = 'allow' +class LoadTrainingStateRequest(BaseModel): + adapter_name: str + name: str + + class Config: + extra = 'allow' + + +class ReadTrainingProgressRequest(BaseModel): + adapter_name: str + name: str + + class Config: + extra = 'allow' + + class AddAdapterRequest(BaseModel): adapter_name: str config: str @@ -212,6 +228,11 @@ class SaveResponse(BaseModel): checkpoint_dir: Optional[str] = None +class TrainingProgressResponse(BaseModel): + """Response for /read_training_progress endpoint (returns progress metadata).""" + result: Dict[str, Any] + + # --- Void responses (return None → OkResponse) --- class BackwardResponse(OkResponse): diff --git a/tests/dataloader/test_dataloader.py b/tests/dataloader/test_dataloader.py index 79bf78ad..232c6fe6 100644 --- a/tests/dataloader/test_dataloader.py +++ b/tests/dataloader/test_dataloader.py @@ -1,14 +1,29 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import concurrent.futures import numpy as np import os import pytest from pathlib import Path +from torch.utils.data import Dataset as TorchDataset +from torch.utils.data import IterableDataset as TorchIterableDataset + + +class _NoOpProcessPoolExecutor: + + def __init__(self, *args, **kwargs): + pass + + def submit(self, fn, *args, **kwargs): + raise RuntimeError('Process pool is disabled in this test environment.') + + +concurrent.futures.ProcessPoolExecutor = _NoOpProcessPoolExecutor import twinkle from twinkle import DeviceMesh from twinkle.data_format import Message from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta +from twinkle.dataset import Dataset, DatasetMeta, IterableDataset from twinkle.processor import InputProcessor twinkle.initialize(mode='local') @@ -22,6 +37,36 @@ def convert_to_messages(example): return {'messages': [Message(role='user', content=text), Message(role='assistant', content='Response')]} +def _build_resume_rows(): + return [ + {'text': 'Hello world'}, + {'text': 'Test data'}, + {'text': 'Another example'}, + {'text': 'Sample text'}, + ] + + +class _InMemoryDataset(TorchDataset): + + def __init__(self, rows): + self.rows = rows + + def __len__(self): + return len(self.rows) + + def __getitem__(self, idx): + return self.rows[idx] + + +class _InMemoryIterableDataset(TorchIterableDataset): + + def __init__(self, rows): + self.rows = rows + + def __iter__(self): + return iter(self.rows) + + class TestDataLoaderBasic: def test_dataloader_basic(self): @@ -157,3 +202,25 @@ def test_retry_sampler_length(self): total_samples = sum(len(batch) for batch in dataloader) assert total_samples == original_len + + +class TestResumeSkip: + + def test_dataloader_skip_consumed_samples_for_map_style_dataset(self): + dataset = _InMemoryDataset(_build_resume_rows()) + dataloader = DataLoader(dataset=dataset, batch_size=2, num_workers=0) + + dataloader.skip_consumed_samples(2) + batches = list(dataloader) + + texts = [item['text'] for batch in batches for item in batch] + assert texts[0] == 'Another example' + + def test_dataloader_warns_when_skip_requested_for_iterable_dataset(self, recwarn): + dataset = _InMemoryIterableDataset(_build_resume_rows()) + dataloader = DataLoader(dataset=dataset, batch_size=2, num_workers=0) + + dataloader.skip_consumed_samples(2) + next(iter(dataloader)) + + assert 'does not support consumed-data skipping' in str(recwarn.pop(UserWarning).message) diff --git a/tests/dataloader/test_sampler.py b/tests/dataloader/test_sampler.py index b8438207..1dd7d7ca 100644 --- a/tests/dataloader/test_sampler.py +++ b/tests/dataloader/test_sampler.py @@ -1,10 +1,26 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import concurrent.futures import os +import numpy as np import pytest from pathlib import Path from torch.utils.data import RandomSampler, SequentialSampler +from torch.utils.data import Dataset as TorchDataset + + +class _NoOpProcessPoolExecutor: + + def __init__(self, *args, **kwargs): + pass + + def submit(self, fn, *args, **kwargs): + raise RuntimeError('Process pool is disabled in this test environment.') + + +concurrent.futures.ProcessPoolExecutor = _NoOpProcessPoolExecutor import twinkle +from twinkle import DeviceMesh from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta @@ -162,3 +178,33 @@ def test_sequential_vs_random_order(self): different = seq_texts != rand_texts assert different or len(seq_texts) == 1 + + +class TestResumeSkipSamplerOrdering: + + def test_sequential_sampler_skip_happens_before_device_mesh_slice(self): + class _InMemoryDataset(TorchDataset): + + def __init__(self, rows): + self.rows = rows + + def __len__(self): + return len(self.rows) + + def __getitem__(self, idx): + return self.rows[idx] + + dataset = _InMemoryDataset([ + {'text': 'Hello world'}, + {'text': 'Test data'}, + {'text': 'Another example'}, + {'text': 'Sample text'}, + ]) + sampler = SequentialSampler(dataset) + device_mesh = DeviceMesh(device_type='cpu', mesh=np.array([0, 1]), mesh_dim_names=('dp', )) + dataloader = DataLoader(dataset=dataset, batch_size=4, sampler=sampler, device_mesh=device_mesh, num_workers=0) + + dataloader.skip_consumed_samples(2) + first_batch = list(dataloader)[0] + + assert first_batch[0]['text'] == 'Another example' diff --git a/tests/model/transformers/test_checkpoint_resume.py b/tests/model/transformers/test_checkpoint_resume.py new file mode 100644 index 00000000..d616156a --- /dev/null +++ b/tests/model/transformers/test_checkpoint_resume.py @@ -0,0 +1,301 @@ +import concurrent.futures +import sys +import types +import uuid +from pathlib import Path +from unittest.mock import Mock + +import pytest +from peft import LoraConfig +from tokenizers import Tokenizer +from tokenizers.models import WordLevel +from tokenizers.pre_tokenizers import Whitespace +from transformers import GPT2Config, GPT2LMHeadModel, PreTrainedTokenizerFast + + +class _NoOpProcessPoolExecutor: + def __init__(self, *args, **kwargs): + pass + + def submit(self, fn, *args, **kwargs): + raise RuntimeError('Process pool is disabled in this test environment.') + + +concurrent.futures.ProcessPoolExecutor = _NoOpProcessPoolExecutor + +if 'zmq' not in sys.modules: + zmq_stub = types.ModuleType('zmq') + + class _ZmqError: + class Again(Exception): + pass + + class _ZmqSocket: + def setsockopt(self, *args, **kwargs): + pass + + def setsockopt_string(self, *args, **kwargs): + pass + + def bind(self, *args, **kwargs): + pass + + def connect(self, *args, **kwargs): + pass + + def send_string(self, *args, **kwargs): + pass + + def send_pyobj(self, *args, **kwargs): + pass + + def recv_string(self, *args, **kwargs): + return '' + + def recv_pyobj(self, *args, **kwargs): + return None + + class _ZmqContext: + def socket(self, *args, **kwargs): + return _ZmqSocket() + + zmq_stub.Context = _ZmqContext + zmq_stub.Socket = _ZmqSocket + zmq_stub.REQ = 0 + zmq_stub.REP = 1 + zmq_stub.PUB = 2 + zmq_stub.SUB = 3 + zmq_stub.SNDMORE = 4 + zmq_stub.IPV6 = 5 + zmq_stub.SUBSCRIBE = 6 + zmq_stub.RCVTIMEO = 7 + zmq_stub.SNDTIMEO = 8 + zmq_stub.LINGER = 9 + zmq_stub.error = _ZmqError + sys.modules['zmq'] = zmq_stub + +from twinkle import DeviceMesh +from twinkle.model.transformers import MultiLoraTransformersModel, TransformersModel + + +def build_tiny_tokenizer(): + vocab = {'[PAD]': 0, '[BOS]': 1, '[EOS]': 2, '[UNK]': 3, 'hello': 4, 'world': 5} + backend = Tokenizer(WordLevel(vocab=vocab, unk_token='[UNK]')) + backend.pre_tokenizer = Whitespace() + tokenizer = PreTrainedTokenizerFast( + tokenizer_object=backend, + pad_token='[PAD]', + bos_token='[BOS]', + eos_token='[EOS]', + unk_token='[UNK]', + ) + return tokenizer + + +@pytest.fixture +def tmp_path(): + root = Path(r'C:\Users\weika\.codex\memories') / 'twinkle-tests' + root.mkdir(exist_ok=True) + path = root / uuid.uuid4().hex + path.mkdir() + return path + + +@pytest.fixture +def tiny_local_model_dir(tmp_path): + model_dir = tmp_path / 'tiny-gpt2' + model_dir.mkdir() + + tokenizer = build_tiny_tokenizer() + tokenizer.save_pretrained(model_dir) + + config = GPT2Config( + vocab_size=tokenizer.vocab_size, + n_layer=1, + n_head=1, + n_embd=16, + n_positions=32, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + GPT2LMHeadModel(config).save_pretrained(model_dir) + return model_dir + + +def build_full_param_model(model_dir): + return TransformersModel( + model_cls='GPT2LMHeadModel', + model_id=str(model_dir), + mixed_precision='no', + grad_scaler_config={}, + ) + + +def build_multi_lora_model(model_dir): + device_mesh = types.SimpleNamespace(fsdp_world_size=0, data_world_size=1) + model = MultiLoraTransformersModel( + model_cls='GPT2LMHeadModel', + model_id=str(model_dir), + mixed_precision='no', + device_mesh=device_mesh, + grad_scaler_config={}, + ) + model.add_adapter_to_model('default', LoraConfig(r=2, lora_alpha=4, target_modules=['c_attn'])) + return model + + +def prepare_full_param_checkpoint(tmp_path, model_dir, cur_step=3, consumed_train_samples=6): + model = build_full_param_model(model_dir) + model.set_optimizer('AdamW', lr=1e-4) + model.set_lr_scheduler('LinearLR') + model.set_grad_scaler(device='cpu') + model.optimizer_group[''].cur_step = cur_step + return Path( + model.save( + name='full-resume', + output_dir=str(tmp_path), + save_optimizer=True, + consumed_train_samples=consumed_train_samples, + )) + + +def prepare_lora_checkpoint(tmp_path, model_dir, cur_step=3, consumed_train_samples=6): + model = build_multi_lora_model(model_dir) + model.set_optimizer('AdamW', adapter_name='default', lr=1e-4) + model.set_lr_scheduler('LinearLR', adapter_name='default') + model.set_grad_scaler(adapter_name='default', device='cpu') + model.optimizer_group['default'].cur_step = cur_step + return Path( + model.save( + name='lora-resume', + output_dir=str(tmp_path), + adapter_name='default', + save_optimizer=True, + consumed_train_samples=consumed_train_samples, + )) + + +def run_resume_only_model_flow(model, dataloader, checkpoint_dir, ignore_data_skip): + if ignore_data_skip: + return None + progress = model.read_training_progress(str(checkpoint_dir)) + dataloader.skip_consumed_samples(progress['consumed_train_samples']) + model.optimizer_group[''].cur_step = progress['cur_step'] + model.optimizer_group[''].gradient_accumulation_steps = progress['gradient_accumulation_steps'] + return progress + + +def _ensure_file_exists(path: Path): + if path.exists(): + return + if path.name == 'trainer_state.json': + path.write_text('{"cur_step": 3}', encoding='utf-8') + else: + path.write_bytes(b'placeholder') + + +def test_save_training_state_writes_split_files(tmp_path, tiny_local_model_dir): + model = build_full_param_model(tiny_local_model_dir) + model.set_optimizer('AdamW', lr=1e-4) + model.set_lr_scheduler('LinearLR') + model.set_grad_scaler(device='cpu') + model.optimizer_group[''].cur_step = 7 + + ckpt_dir = Path(model.save(name='resume-step', output_dir=str(tmp_path), save_optimizer=True)) + + assert (ckpt_dir / 'optimizer.pt').exists() + assert (ckpt_dir / 'scheduler.pt').exists() + assert (ckpt_dir / 'scaler.pt').exists() + assert (ckpt_dir / 'trainer_state.json').exists() + assert (ckpt_dir / 'rng_state.pt').exists() + + +@pytest.mark.parametrize( + 'missing_name, expected_pattern', + [ + ('optimizer.pt', 'optimizer.pt'), + ('scheduler.pt', 'scheduler.pt'), + ('scaler.pt', 'scaler.pt'), + ('rng_state.pt', 'rng_state.pt'), + ('trainer_state.json', 'trainer_state.json'), + ], +) +def test_load_training_state_fails_when_required_file_missing(tmp_path, tiny_local_model_dir, missing_name, + expected_pattern): + ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir) + _ensure_file_exists(ckpt_dir / missing_name) + (ckpt_dir / missing_name).unlink() + + restored = build_full_param_model(ckpt_dir) + restored.set_optimizer('AdamW', lr=1e-4) + restored.set_lr_scheduler('LinearLR') + restored.set_grad_scaler(device='cpu') + + with pytest.raises(FileNotFoundError, match=expected_pattern): + restored.load_training_state(str(ckpt_dir)) + + +def test_load_training_state_fails_for_malformed_trainer_state(tmp_path, tiny_local_model_dir): + ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir) + (ckpt_dir / 'trainer_state.json').write_text('{"cur_step": 3}', encoding='utf-8') + + restored = build_full_param_model(ckpt_dir) + restored.set_optimizer('AdamW', lr=1e-4) + restored.set_lr_scheduler('LinearLR') + restored.set_grad_scaler(device='cpu') + + with pytest.raises((KeyError, ValueError), match='gradient_accumulation_steps|consumed_train_samples'): + restored.load_training_state(str(ckpt_dir)) + + +def test_full_parameter_resume_restores_training_state_after_init(tmp_path, tiny_local_model_dir): + ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir, cur_step=9, consumed_train_samples=18) + restored = build_full_param_model(ckpt_dir) + restored.set_optimizer('AdamW', lr=1e-4) + restored.set_lr_scheduler('LinearLR') + restored.set_grad_scaler(device='cpu') + + trainer_state = restored.load_training_state(str(ckpt_dir)) + + assert trainer_state['cur_step'] == 9 + assert restored.optimizer_group[''].cur_step == 9 + + +def test_lora_resume_keeps_adapter_load_separate_from_training_state(tmp_path, tiny_local_model_dir): + ckpt_dir = prepare_lora_checkpoint(tmp_path, tiny_local_model_dir, cur_step=5, consumed_train_samples=10) + restored = build_multi_lora_model(tiny_local_model_dir) + restored.load(name=ckpt_dir.name, output_dir=str(ckpt_dir.parent), adapter_name='default') + restored.set_optimizer('AdamW', adapter_name='default', lr=1e-4) + restored.set_lr_scheduler('LinearLR', adapter_name='default') + restored.set_grad_scaler(adapter_name='default', device='cpu') + + trainer_state = restored.load_training_state(str(ckpt_dir), adapter_name='default') + + assert trainer_state['cur_step'] == 5 + + +def test_read_training_progress_supports_resume_only_model(tmp_path, tiny_local_model_dir): + ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir, cur_step=6, consumed_train_samples=12) + restored = build_full_param_model(ckpt_dir) + dataloader = Mock() + + progress = run_resume_only_model_flow(restored, dataloader, ckpt_dir, ignore_data_skip=False) + + assert progress['cur_step'] == 6 + assert progress['consumed_train_samples'] == 12 + assert restored.optimizer_group[''].cur_step == 6 + dataloader.skip_consumed_samples.assert_called_once_with(12) + + +def test_resume_only_model_ignore_data_skip_leaves_progress_unrestored(tmp_path, tiny_local_model_dir): + ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir, cur_step=6, consumed_train_samples=12) + restored = build_full_param_model(ckpt_dir) + dataloader = Mock() + restored.read_training_progress = Mock(wraps=restored.read_training_progress) + + progress = run_resume_only_model_flow(restored, dataloader, ckpt_dir, ignore_data_skip=True) + + assert progress is None + assert restored.optimizer_group[''].cur_step == 0 + restored.read_training_progress.assert_not_called() + dataloader.skip_consumed_samples.assert_not_called() diff --git a/tests/server/model/test_twinkle_resume_routes.py b/tests/server/model/test_twinkle_resume_routes.py new file mode 100644 index 00000000..7932ca20 --- /dev/null +++ b/tests/server/model/test_twinkle_resume_routes.py @@ -0,0 +1,180 @@ +import concurrent.futures +import importlib.util +import sys +import types +from pathlib import Path +from unittest.mock import Mock + +from fastapi import FastAPI +from fastapi.testclient import TestClient + + +class _NoOpProcessPoolExecutor: + + def __init__(self, *args, **kwargs): + pass + + def submit(self, fn, *args, **kwargs): + raise RuntimeError('Process pool is disabled in this test environment.') + + +concurrent.futures.ProcessPoolExecutor = _NoOpProcessPoolExecutor + +if 'tinker' not in sys.modules: + tinker_module = types.ModuleType('tinker') + tinker_types_module = types.ModuleType('tinker.types') + + class _TinkerPlaceholder: + pass + + for name in ( + 'CreateModelRequest', + 'TrainingRun', + 'TrainingRunsResponse', + 'Cursor', + 'Checkpoint', + 'CheckpointsListResponse', + 'ParsedCheckpointTinkerPath', + 'WeightsInfoResponse', + ): + setattr(tinker_types_module, name, _TinkerPlaceholder) + tinker_module.types = tinker_types_module + sys.modules['tinker'] = tinker_module + sys.modules['tinker.types'] = tinker_types_module + +if 'twinkle.server.common' not in sys.modules: + common_module = types.ModuleType('twinkle.server.common') + checkpoint_factory_module = types.ModuleType('twinkle.server.common.checkpoint_factory') + checkpoint_factory_module.create_checkpoint_manager = lambda token, client_type='twinkle': None + checkpoint_factory_module.create_training_run_manager = lambda token, client_type='twinkle': None + common_module.checkpoint_factory = checkpoint_factory_module + sys.modules['twinkle.server.common'] = common_module + sys.modules['twinkle.server.common.checkpoint_factory'] = checkpoint_factory_module + +from twinkle_client.types.checkpoint import ResolvedLoadPath + +_HANDLERS_PATH = Path(__file__).resolve().parents[3] / 'src' / 'twinkle' / 'server' / 'model' / 'twinkle_handlers.py' +_HANDLERS_SPEC = importlib.util.spec_from_file_location('twinkle_resume_test_handlers', _HANDLERS_PATH) +handlers = importlib.util.module_from_spec(_HANDLERS_SPEC) +sys.modules[_HANDLERS_SPEC.name] = handlers +_HANDLERS_SPEC.loader.exec_module(handlers) + + +class _FakeCheckpointManager: + + def resolve_load_path(self, path: str) -> ResolvedLoadPath: + return ResolvedLoadPath( + checkpoint_name='ckpt-1', + checkpoint_dir='D:/resolved/weights', + is_twinkle_path=True, + training_run_id='run-1', + checkpoint_id='weights/ckpt-1', + ) + + +class _FakeModelManagement: + + def __init__(self): + self.model = Mock() + + async def _on_request_start(self, request): + request.state.request_id = 'req-1' + return 'token-1' + + def assert_resource_exists(self, adapter_name): + return None + + async def schedule_task_and_wait(self, task, task_type=''): + return await task() + + +def _build_test_client(monkeypatch): + management = _FakeModelManagement() + checkpoint_manager = _FakeCheckpointManager() + monkeypatch.setattr(handlers, 'create_checkpoint_manager', lambda token, client_type='twinkle': checkpoint_manager) + + app = FastAPI() + handlers._register_twinkle_routes(app, lambda: management) + return TestClient(app), management + + +def _run_remote_resume_case(client: TestClient, *, resume_from_checkpoint, resume_only_model, ignore_data_skip): + if resume_from_checkpoint is None: + return None + if not resume_only_model: + return client.post('/twinkle/load_training_state', json={'name': resume_from_checkpoint, 'adapter_name': ''}) + if not ignore_data_skip: + return client.post('/twinkle/read_training_progress', json={'name': resume_from_checkpoint, 'adapter_name': ''}) + return None + + +def test_case_1_no_resume_call_leaves_remote_resume_helpers_unused(monkeypatch): + """Case 1: resume_from_checkpoint is None, so no remote resume helper should be called.""" + client, management = _build_test_client(monkeypatch) + + response = _run_remote_resume_case( + client, + resume_from_checkpoint=None, + resume_only_model=False, + ignore_data_skip=False, + ) + + assert response is None + management.model.load_training_state.assert_not_called() + management.model.read_training_progress.assert_not_called() + + +def test_case_2_resume_only_model_false_uses_load_training_state_route(monkeypatch): + """Case 2: resume_only_model=False should use load_training_state().""" + client, management = _build_test_client(monkeypatch) + + response = _run_remote_resume_case( + client, + resume_from_checkpoint='twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', + resume_only_model=False, + ignore_data_skip=False, + ) + + assert response.status_code == 200 + management.model.load_training_state.assert_called_once_with( + 'D:/resolved/weights/ckpt-1', + adapter_name=None, + ) + management.model.read_training_progress.assert_not_called() + + +def test_case_3_resume_only_model_true_without_ignore_data_skip_reads_progress_only(monkeypatch): + """Case 3: resume_only_model=True and ignore_data_skip=False should use read_training_progress() only.""" + client, management = _build_test_client(monkeypatch) + management.model.read_training_progress.return_value = {'cur_step': 6, 'consumed_train_samples': 12} + + response = _run_remote_resume_case( + client, + resume_from_checkpoint='twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', + resume_only_model=True, + ignore_data_skip=False, + ) + + assert response.status_code == 200 + assert response.json()['result']['consumed_train_samples'] == 12 + management.model.read_training_progress.assert_called_once_with( + 'D:/resolved/weights/ckpt-1', + adapter_name=None, + ) + management.model.load_training_state.assert_not_called() + + +def test_case_4_resume_only_model_true_with_ignore_data_skip_uses_neither_helper(monkeypatch): + """Case 4: resume_only_model=True and ignore_data_skip=True should call neither remote helper.""" + client, management = _build_test_client(monkeypatch) + + response = _run_remote_resume_case( + client, + resume_from_checkpoint='twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', + resume_only_model=True, + ignore_data_skip=True, + ) + + assert response is None + management.model.load_training_state.assert_not_called() + management.model.read_training_progress.assert_not_called() From d41a634f3e02d8cb60837c79b24894f5283ac316 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Mon, 30 Mar 2026 11:51:04 +0800 Subject: [PATCH 07/28] wip --- .gitignore | 1 + .../transformers/test_checkpoint_resume.py | 166 ++++++++---------- .../model/test_twinkle_resume_routes.py | 135 ++++++-------- 3 files changed, 123 insertions(+), 179 deletions(-) diff --git a/.gitignore b/.gitignore index afdfcae9..ae4d2cb3 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,7 @@ wheels/ /temp MANIFEST .locks/ +tmp_test_checkpoints/ # PyInstaller # Usually these files are written by a python script from a template diff --git a/tests/model/transformers/test_checkpoint_resume.py b/tests/model/transformers/test_checkpoint_resume.py index d616156a..f6b5b6a1 100644 --- a/tests/model/transformers/test_checkpoint_resume.py +++ b/tests/model/transformers/test_checkpoint_resume.py @@ -3,14 +3,14 @@ import types import uuid from pathlib import Path -from unittest.mock import Mock +from types import ModuleType import pytest from peft import LoraConfig from tokenizers import Tokenizer from tokenizers.models import WordLevel from tokenizers.pre_tokenizers import Whitespace -from transformers import GPT2Config, GPT2LMHeadModel, PreTrainedTokenizerFast +from transformers import Qwen3Config, Qwen3ForCausalLM, PreTrainedTokenizerFast class _NoOpProcessPoolExecutor: @@ -23,58 +23,21 @@ def submit(self, fn, *args, **kwargs): concurrent.futures.ProcessPoolExecutor = _NoOpProcessPoolExecutor -if 'zmq' not in sys.modules: - zmq_stub = types.ModuleType('zmq') - - class _ZmqError: - class Again(Exception): - pass - - class _ZmqSocket: - def setsockopt(self, *args, **kwargs): - pass - - def setsockopt_string(self, *args, **kwargs): - pass - - def bind(self, *args, **kwargs): - pass - - def connect(self, *args, **kwargs): - pass - - def send_string(self, *args, **kwargs): - pass - - def send_pyobj(self, *args, **kwargs): - pass +ROOT = Path(__file__).resolve().parents[3] +SRC = ROOT / 'src' +if str(SRC) not in sys.path: + sys.path.insert(0, str(SRC)) - def recv_string(self, *args, **kwargs): - return '' - - def recv_pyobj(self, *args, **kwargs): - return None - - class _ZmqContext: - def socket(self, *args, **kwargs): - return _ZmqSocket() - - zmq_stub.Context = _ZmqContext - zmq_stub.Socket = _ZmqSocket - zmq_stub.REQ = 0 - zmq_stub.REP = 1 - zmq_stub.PUB = 2 - zmq_stub.SUB = 3 - zmq_stub.SNDMORE = 4 - zmq_stub.IPV6 = 5 - zmq_stub.SUBSCRIBE = 6 - zmq_stub.RCVTIMEO = 7 - zmq_stub.SNDTIMEO = 8 - zmq_stub.LINGER = 9 - zmq_stub.error = _ZmqError - sys.modules['zmq'] = zmq_stub +if 'zmq' not in sys.modules: + fake_zmq = ModuleType('zmq') + fake_zmq.Socket = object + fake_zmq.Context = object + fake_zmq.REP = 0 + fake_zmq.REQ = 1 + fake_zmq.IPV6 = 2 + fake_zmq.error = types.SimpleNamespace(Again=RuntimeError) + sys.modules['zmq'] = fake_zmq -from twinkle import DeviceMesh from twinkle.model.transformers import MultiLoraTransformersModel, TransformersModel @@ -94,7 +57,7 @@ def build_tiny_tokenizer(): @pytest.fixture def tmp_path(): - root = Path(r'C:\Users\weika\.codex\memories') / 'twinkle-tests' + root = Path(__file__).parent / 'tmp_test_checkpoints' root.mkdir(exist_ok=True) path = root / uuid.uuid4().hex path.mkdir() @@ -103,28 +66,28 @@ def tmp_path(): @pytest.fixture def tiny_local_model_dir(tmp_path): - model_dir = tmp_path / 'tiny-gpt2' + model_dir = tmp_path / 'tiny-qwen3' model_dir.mkdir() tokenizer = build_tiny_tokenizer() tokenizer.save_pretrained(model_dir) - config = GPT2Config( + config = Qwen3Config( vocab_size=tokenizer.vocab_size, - n_layer=1, - n_head=1, - n_embd=16, - n_positions=32, + num_hidden_layers=1, + num_attention_heads=1, + hidden_size=16, + max_position_embeddings=32, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, ) - GPT2LMHeadModel(config).save_pretrained(model_dir) + Qwen3ForCausalLM(config).save_pretrained(model_dir) return model_dir def build_full_param_model(model_dir): return TransformersModel( - model_cls='GPT2LMHeadModel', + model_cls='Qwen3ForCausalLM', model_id=str(model_dir), mixed_precision='no', grad_scaler_config={}, @@ -134,13 +97,13 @@ def build_full_param_model(model_dir): def build_multi_lora_model(model_dir): device_mesh = types.SimpleNamespace(fsdp_world_size=0, data_world_size=1) model = MultiLoraTransformersModel( - model_cls='GPT2LMHeadModel', + model_cls='Qwen3ForCausalLM', model_id=str(model_dir), mixed_precision='no', device_mesh=device_mesh, grad_scaler_config={}, ) - model.add_adapter_to_model('default', LoraConfig(r=2, lora_alpha=4, target_modules=['c_attn'])) + model.add_adapter_to_model('default', LoraConfig(r=2, lora_alpha=4, target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'])) return model @@ -175,16 +138,6 @@ def prepare_lora_checkpoint(tmp_path, model_dir, cur_step=3, consumed_train_samp )) -def run_resume_only_model_flow(model, dataloader, checkpoint_dir, ignore_data_skip): - if ignore_data_skip: - return None - progress = model.read_training_progress(str(checkpoint_dir)) - dataloader.skip_consumed_samples(progress['consumed_train_samples']) - model.optimizer_group[''].cur_step = progress['cur_step'] - model.optimizer_group[''].gradient_accumulation_steps = progress['gradient_accumulation_steps'] - return progress - - def _ensure_file_exists(path: Path): if path.exists(): return @@ -250,6 +203,19 @@ def test_load_training_state_fails_for_malformed_trainer_state(tmp_path, tiny_lo def test_full_parameter_resume_restores_training_state_after_init(tmp_path, tiny_local_model_dir): ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir, cur_step=9, consumed_train_samples=18) + saved = build_full_param_model(tiny_local_model_dir) + saved.set_optimizer('AdamW', lr=1e-4) + saved.set_lr_scheduler('LinearLR') + saved.set_grad_scaler(device='cpu') + saved.optimizer_group[''].cur_step = 9 + saved.optimizer_group[''].gradient_accumulation_steps = 4 + ckpt_dir = Path( + saved.save( + name='full-resume', + output_dir=str(tmp_path), + save_optimizer=True, + consumed_train_samples=18, + )) restored = build_full_param_model(ckpt_dir) restored.set_optimizer('AdamW', lr=1e-4) restored.set_lr_scheduler('LinearLR') @@ -258,13 +224,36 @@ def test_full_parameter_resume_restores_training_state_after_init(tmp_path, tiny trainer_state = restored.load_training_state(str(ckpt_dir)) assert trainer_state['cur_step'] == 9 + assert trainer_state['gradient_accumulation_steps'] == 4 + assert trainer_state['consumed_train_samples'] == 18 assert restored.optimizer_group[''].cur_step == 9 + assert restored.optimizer_group[''].gradient_accumulation_steps == 4 + + +def test_lora_load_does_not_restore_training_state_without_explicit_resume(tmp_path, tiny_local_model_dir): + saved = build_multi_lora_model(tiny_local_model_dir) + saved.set_optimizer('AdamW', adapter_name='default', lr=1e-4) + saved.set_lr_scheduler('LinearLR', adapter_name='default') + saved.set_grad_scaler(adapter_name='default', device='cpu') + saved.optimizer_group['default'].cur_step = 5 + saved.optimizer_group['default'].gradient_accumulation_steps = 3 + ckpt_dir = Path( + saved.save( + name='lora-resume', + output_dir=str(tmp_path), + adapter_name='default', + save_optimizer=True, + consumed_train_samples=10, + )) + restored = build_multi_lora_model(tiny_local_model_dir) + assert restored.optimizer_group['default'].cur_step == 0 + assert restored.optimizer_group['default'].gradient_accumulation_steps == 1 -def test_lora_resume_keeps_adapter_load_separate_from_training_state(tmp_path, tiny_local_model_dir): - ckpt_dir = prepare_lora_checkpoint(tmp_path, tiny_local_model_dir, cur_step=5, consumed_train_samples=10) - restored = build_multi_lora_model(tiny_local_model_dir) restored.load(name=ckpt_dir.name, output_dir=str(ckpt_dir.parent), adapter_name='default') + assert restored.optimizer_group['default'].cur_step == 0 + assert restored.optimizer_group['default'].gradient_accumulation_steps == 1 + restored.set_optimizer('AdamW', adapter_name='default', lr=1e-4) restored.set_lr_scheduler('LinearLR', adapter_name='default') restored.set_grad_scaler(adapter_name='default', device='cpu') @@ -272,30 +261,19 @@ def test_lora_resume_keeps_adapter_load_separate_from_training_state(tmp_path, t trainer_state = restored.load_training_state(str(ckpt_dir), adapter_name='default') assert trainer_state['cur_step'] == 5 + assert trainer_state['gradient_accumulation_steps'] == 3 + assert trainer_state['consumed_train_samples'] == 10 + assert restored.optimizer_group['default'].cur_step == 5 + assert restored.optimizer_group['default'].gradient_accumulation_steps == 3 -def test_read_training_progress_supports_resume_only_model(tmp_path, tiny_local_model_dir): +def test_read_training_progress_returns_metadata_without_mutating_optimizer_state(tmp_path, tiny_local_model_dir): ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir, cur_step=6, consumed_train_samples=12) restored = build_full_param_model(ckpt_dir) - dataloader = Mock() - progress = run_resume_only_model_flow(restored, dataloader, ckpt_dir, ignore_data_skip=False) + progress = restored.read_training_progress(str(ckpt_dir)) assert progress['cur_step'] == 6 assert progress['consumed_train_samples'] == 12 - assert restored.optimizer_group[''].cur_step == 6 - dataloader.skip_consumed_samples.assert_called_once_with(12) - - -def test_resume_only_model_ignore_data_skip_leaves_progress_unrestored(tmp_path, tiny_local_model_dir): - ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir, cur_step=6, consumed_train_samples=12) - restored = build_full_param_model(ckpt_dir) - dataloader = Mock() - restored.read_training_progress = Mock(wraps=restored.read_training_progress) - - progress = run_resume_only_model_flow(restored, dataloader, ckpt_dir, ignore_data_skip=True) - - assert progress is None assert restored.optimizer_group[''].cur_step == 0 - restored.read_training_progress.assert_not_called() - dataloader.skip_consumed_samples.assert_not_called() + assert restored.optimizer_group[''].gradient_accumulation_steps == 1 diff --git a/tests/server/model/test_twinkle_resume_routes.py b/tests/server/model/test_twinkle_resume_routes.py index 7932ca20..f3523913 100644 --- a/tests/server/model/test_twinkle_resume_routes.py +++ b/tests/server/model/test_twinkle_resume_routes.py @@ -1,8 +1,8 @@ import concurrent.futures import importlib.util import sys -import types from pathlib import Path +from types import ModuleType from unittest.mock import Mock from fastapi import FastAPI @@ -20,40 +20,20 @@ def submit(self, fn, *args, **kwargs): concurrent.futures.ProcessPoolExecutor = _NoOpProcessPoolExecutor -if 'tinker' not in sys.modules: - tinker_module = types.ModuleType('tinker') - tinker_types_module = types.ModuleType('tinker.types') +ROOT = Path(__file__).resolve().parents[3] +SRC = ROOT / 'src' +if str(SRC) not in sys.path: + sys.path.insert(0, str(SRC)) - class _TinkerPlaceholder: - pass - - for name in ( - 'CreateModelRequest', - 'TrainingRun', - 'TrainingRunsResponse', - 'Cursor', - 'Checkpoint', - 'CheckpointsListResponse', - 'ParsedCheckpointTinkerPath', - 'WeightsInfoResponse', - ): - setattr(tinker_types_module, name, _TinkerPlaceholder) - tinker_module.types = tinker_types_module - sys.modules['tinker'] = tinker_module - sys.modules['tinker.types'] = tinker_types_module - -if 'twinkle.server.common' not in sys.modules: - common_module = types.ModuleType('twinkle.server.common') - checkpoint_factory_module = types.ModuleType('twinkle.server.common.checkpoint_factory') - checkpoint_factory_module.create_checkpoint_manager = lambda token, client_type='twinkle': None - checkpoint_factory_module.create_training_run_manager = lambda token, client_type='twinkle': None - common_module.checkpoint_factory = checkpoint_factory_module - sys.modules['twinkle.server.common'] = common_module - sys.modules['twinkle.server.common.checkpoint_factory'] = checkpoint_factory_module +if 'twinkle.server.common.checkpoint_factory' not in sys.modules: + fake_checkpoint_factory = ModuleType('twinkle.server.common.checkpoint_factory') + fake_checkpoint_factory.create_checkpoint_manager = lambda *args, **kwargs: None + fake_checkpoint_factory.create_training_run_manager = lambda *args, **kwargs: None + sys.modules['twinkle.server.common.checkpoint_factory'] = fake_checkpoint_factory from twinkle_client.types.checkpoint import ResolvedLoadPath -_HANDLERS_PATH = Path(__file__).resolve().parents[3] / 'src' / 'twinkle' / 'server' / 'model' / 'twinkle_handlers.py' +_HANDLERS_PATH = SRC / 'twinkle' / 'server' / 'model' / 'twinkle_handlers.py' _HANDLERS_SPEC = importlib.util.spec_from_file_location('twinkle_resume_test_handlers', _HANDLERS_PATH) handlers = importlib.util.module_from_spec(_HANDLERS_SPEC) sys.modules[_HANDLERS_SPEC.name] = handlers @@ -62,10 +42,13 @@ class _TinkerPlaceholder: class _FakeCheckpointManager: + def __init__(self, checkpoint_dir='./resolved/weights'): + self._checkpoint_dir = checkpoint_dir + def resolve_load_path(self, path: str) -> ResolvedLoadPath: return ResolvedLoadPath( checkpoint_name='ckpt-1', - checkpoint_dir='D:/resolved/weights', + checkpoint_dir=self._checkpoint_dir, is_twinkle_path=True, training_run_id='run-1', checkpoint_id='weights/ckpt-1', @@ -88,9 +71,9 @@ async def schedule_task_and_wait(self, task, task_type=''): return await task() -def _build_test_client(monkeypatch): +def _build_test_client(monkeypatch, checkpoint_manager=None): management = _FakeModelManagement() - checkpoint_manager = _FakeCheckpointManager() + checkpoint_manager = checkpoint_manager or _FakeCheckpointManager() monkeypatch.setattr(handlers, 'create_checkpoint_manager', lambda token, client_type='twinkle': checkpoint_manager) app = FastAPI() @@ -98,83 +81,65 @@ def _build_test_client(monkeypatch): return TestClient(app), management -def _run_remote_resume_case(client: TestClient, *, resume_from_checkpoint, resume_only_model, ignore_data_skip): - if resume_from_checkpoint is None: - return None - if not resume_only_model: - return client.post('/twinkle/load_training_state', json={'name': resume_from_checkpoint, 'adapter_name': ''}) - if not ignore_data_skip: - return client.post('/twinkle/read_training_progress', json={'name': resume_from_checkpoint, 'adapter_name': ''}) - return None - - -def test_case_1_no_resume_call_leaves_remote_resume_helpers_unused(monkeypatch): - """Case 1: resume_from_checkpoint is None, so no remote resume helper should be called.""" +def test_load_training_state_route_resolves_checkpoint_path_and_calls_model(monkeypatch): client, management = _build_test_client(monkeypatch) - response = _run_remote_resume_case( - client, - resume_from_checkpoint=None, - resume_only_model=False, - ignore_data_skip=False, + response = client.post( + '/twinkle/load_training_state', + json={'name': 'twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', 'adapter_name': ''}, ) - assert response is None - management.model.load_training_state.assert_not_called() + assert response.status_code == 200 + management.model.load_training_state.assert_called_once_with( + 'resolved/weights/ckpt-1', + adapter_name=None, + ) management.model.read_training_progress.assert_not_called() -def test_case_2_resume_only_model_false_uses_load_training_state_route(monkeypatch): - """Case 2: resume_only_model=False should use load_training_state().""" +def test_load_training_state_route_prefixes_non_empty_adapter_name(monkeypatch): client, management = _build_test_client(monkeypatch) - response = _run_remote_resume_case( - client, - resume_from_checkpoint='twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', - resume_only_model=False, - ignore_data_skip=False, + response = client.post( + '/twinkle/load_training_state', + json={'name': 'twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', 'adapter_name': 'adapter-a'}, ) assert response.status_code == 200 management.model.load_training_state.assert_called_once_with( - 'D:/resolved/weights/ckpt-1', - adapter_name=None, + 'resolved/weights/ckpt-1', + adapter_name='req-1-adapter-a', ) - management.model.read_training_progress.assert_not_called() -def test_case_3_resume_only_model_true_without_ignore_data_skip_reads_progress_only(monkeypatch): - """Case 3: resume_only_model=True and ignore_data_skip=False should use read_training_progress() only.""" - client, management = _build_test_client(monkeypatch) - management.model.read_training_progress.return_value = {'cur_step': 6, 'consumed_train_samples': 12} +def test_load_training_state_route_uses_raw_name_when_checkpoint_dir_missing(monkeypatch): + client, management = _build_test_client(monkeypatch, checkpoint_manager=_FakeCheckpointManager(checkpoint_dir=None)) - response = _run_remote_resume_case( - client, - resume_from_checkpoint='twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', - resume_only_model=True, - ignore_data_skip=False, + response = client.post( + '/twinkle/load_training_state', + json={'name': 'local-checkpoint-dir', 'adapter_name': ''}, ) assert response.status_code == 200 - assert response.json()['result']['consumed_train_samples'] == 12 - management.model.read_training_progress.assert_called_once_with( - 'D:/resolved/weights/ckpt-1', + management.model.load_training_state.assert_called_once_with( + 'local-checkpoint-dir', adapter_name=None, ) - management.model.load_training_state.assert_not_called() -def test_case_4_resume_only_model_true_with_ignore_data_skip_uses_neither_helper(monkeypatch): - """Case 4: resume_only_model=True and ignore_data_skip=True should call neither remote helper.""" +def test_read_training_progress_route_returns_progress_and_calls_model(monkeypatch): client, management = _build_test_client(monkeypatch) + management.model.read_training_progress.return_value = {'cur_step': 6, 'consumed_train_samples': 12} - response = _run_remote_resume_case( - client, - resume_from_checkpoint='twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', - resume_only_model=True, - ignore_data_skip=True, + response = client.post( + '/twinkle/read_training_progress', + json={'name': 'twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', 'adapter_name': ''}, ) - assert response is None + assert response.status_code == 200 + assert response.json()['result']['consumed_train_samples'] == 12 + management.model.read_training_progress.assert_called_once_with( + 'resolved/weights/ckpt-1', + adapter_name=None, + ) management.model.load_training_state.assert_not_called() - management.model.read_training_progress.assert_not_called() From 21f9918c0d7764cf091e5db94770f62d534c9f96 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Mon, 30 Mar 2026 15:29:27 +0800 Subject: [PATCH 08/28] wip --- src/twinkle/model/transformers/transformers.py | 5 +++-- .../model/transformers/test_checkpoint_resume.py | 16 +++++++++++++++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index d588b8e7..e2cd9206 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -1064,7 +1064,6 @@ def load_training_state(self, checkpoint_dir, **kwargs): 'trainer_state': os.path.join(checkpoint_dir, 'trainer_state.json'), 'optimizer': os.path.join(checkpoint_dir, 'optimizer.pt'), 'scheduler': os.path.join(checkpoint_dir, 'scheduler.pt'), - 'scaler': os.path.join(checkpoint_dir, 'scaler.pt'), 'rng': os.path.join(checkpoint_dir, 'rng_state.pt'), } for path in required_paths.values(): @@ -1073,7 +1072,9 @@ def load_training_state(self, checkpoint_dir, **kwargs): trainer_state = self.read_training_progress(checkpoint_dir) self._load_optimizer(checkpoint_dir, adapter_name=adapter_name, strict=True) - self._load_scaler_state(required_paths['scaler'], adapter_name=adapter_name) + scaler_path = os.path.join(checkpoint_dir, 'scaler.pt') + if os.path.exists(scaler_path) and optimizer_config.scaler is not None: + self._load_scaler_state(scaler_path, adapter_name=adapter_name) self._load_rng_state(required_paths['rng']) optimizer_config.cur_step = trainer_state['cur_step'] diff --git a/tests/model/transformers/test_checkpoint_resume.py b/tests/model/transformers/test_checkpoint_resume.py index f6b5b6a1..7541cbaa 100644 --- a/tests/model/transformers/test_checkpoint_resume.py +++ b/tests/model/transformers/test_checkpoint_resume.py @@ -168,7 +168,6 @@ def test_save_training_state_writes_split_files(tmp_path, tiny_local_model_dir): [ ('optimizer.pt', 'optimizer.pt'), ('scheduler.pt', 'scheduler.pt'), - ('scaler.pt', 'scaler.pt'), ('rng_state.pt', 'rng_state.pt'), ('trainer_state.json', 'trainer_state.json'), ], @@ -188,6 +187,21 @@ def test_load_training_state_fails_when_required_file_missing(tmp_path, tiny_loc restored.load_training_state(str(ckpt_dir)) +def test_load_training_state_allows_missing_scaler_file(tmp_path, tiny_local_model_dir): + ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir, cur_step=8, consumed_train_samples=16) + (ckpt_dir / 'scaler.pt').unlink() + + restored = build_full_param_model(ckpt_dir) + restored.set_optimizer('AdamW', lr=1e-4) + restored.set_lr_scheduler('LinearLR') + + trainer_state = restored.load_training_state(str(ckpt_dir)) + + assert trainer_state['cur_step'] == 8 + assert trainer_state['consumed_train_samples'] == 16 + assert restored.optimizer_group[''].cur_step == 8 + + def test_load_training_state_fails_for_malformed_trainer_state(tmp_path, tiny_local_model_dir): ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir) (ckpt_dir / 'trainer_state.json').write_text('{"cur_step": 3}', encoding='utf-8') From 1e595317e71baca5e08ef8102c2087ea1d308421 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Mon, 30 Mar 2026 15:51:40 +0800 Subject: [PATCH 09/28] fix --- src/twinkle/model/transformers/transformers.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index e2cd9206..1c4c57fd 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -148,6 +148,19 @@ def calculate_metrics(self, is_training): return results +def _normalize_checkpoint_state(value: Any): + """Convert nested DTensor state into plain CPU tensors for checkpointing.""" + if isinstance(value, dict): + return {k: _normalize_checkpoint_state(v) for k, v in value.items()} + if isinstance(value, list): + return [_normalize_checkpoint_state(v) for v in value] + if isinstance(value, tuple): + return tuple(_normalize_checkpoint_state(v) for v in value) + if torch.is_tensor(value): + return Torch.to_local_tensor(value).cpu() + return value + + _default_adapter_name = '' DEFAULT_LEARNING_RATE = 1e-5 DEFAULT_WEIGHT_DECAY = 0.01 @@ -884,7 +897,8 @@ def _save_optimizer(self, output_dir, **kwargs): optimizer = optimizer_config.optimizer lr_scheduler = optimizer_config.lr_scheduler if optimizer is not None: - torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt')) + state_dict = _normalize_checkpoint_state(optimizer.state_dict()) + torch.save(state_dict, os.path.join(output_dir, 'optimizer.pt')) if lr_scheduler is not None: torch.save(lr_scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt')) From 9bb3f39e0ad7c0cdf61cc1c787336f97d78bccaa Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Mon, 30 Mar 2026 16:13:05 +0800 Subject: [PATCH 10/28] wip --- .../model/transformers/strategy/accelerate.py | 49 +++++++++++++++++++ .../model/transformers/transformers.py | 36 ++++++-------- 2 files changed, 64 insertions(+), 21 deletions(-) diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py index 89b497e2..214623fa 100644 --- a/src/twinkle/model/transformers/strategy/accelerate.py +++ b/src/twinkle/model/transformers/strategy/accelerate.py @@ -124,6 +124,55 @@ def wrap_model(self, model, *args): def unwrap_model(self, model): return self.accelerator.unwrap_model(model, keep_torch_compile=False) + def _prepare_fsdp2_sd_options(self): + fsdp_plugin = self.accelerator.state.fsdp_plugin + if fsdp_plugin is None or fsdp_plugin.fsdp_version != 2: + return None + + from torch.distributed.checkpoint.state_dict import StateDictOptions + from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType + + return StateDictOptions( + full_state_dict=fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT, + cpu_offload=getattr(fsdp_plugin.state_dict_config, 'offload_to_cpu', False), + broadcast_from_rank0=getattr(fsdp_plugin.state_dict_config, 'rank0_only', False), + ) + + def needs_wrapped_optimizer_state(self) -> bool: + fsdp_plugin = self.accelerator.state.fsdp_plugin + return fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2 + + def save_optimizer_checkpoint(self, model, optimizer, output_path: str): + fsdp_plugin = self.accelerator.state.fsdp_plugin + if fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2: + from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict + import torch + + optim_state = get_optimizer_state_dict(model, optimizer, options=self._prepare_fsdp2_sd_options()) + if self.accelerator.process_index == 0: + torch.save(optim_state, output_path) + return + + import torch + if self.accelerator.process_index == 0: + torch.save(optimizer.state_dict(), output_path) + + def load_optimizer_checkpoint(self, model, optimizer, input_path: str): + fsdp_plugin = self.accelerator.state.fsdp_plugin + if fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2: + from torch.distributed.checkpoint.state_dict import set_optimizer_state_dict + import torch + + optim_state = None + rank0_only = getattr(fsdp_plugin.optim_state_dict_config, 'rank0_only', False) + if self.accelerator.process_index == 0 or not rank0_only: + optim_state = torch.load(input_path, weights_only=True) + set_optimizer_state_dict(model, optimizer, optim_state, options=self._prepare_fsdp2_sd_options()) + return + + import torch + optimizer.load_state_dict(torch.load(input_path, map_location='cpu', weights_only=False)) + def get_full_state_dict(self, model) -> dict: """Collect full state dict.""" from twinkle.utils import torch_util diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 1c4c57fd..40473420 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -147,20 +147,6 @@ def calculate_metrics(self, is_training): self.outputs = None return results - -def _normalize_checkpoint_state(value: Any): - """Convert nested DTensor state into plain CPU tensors for checkpointing.""" - if isinstance(value, dict): - return {k: _normalize_checkpoint_state(v) for k, v in value.items()} - if isinstance(value, list): - return [_normalize_checkpoint_state(v) for v in value] - if isinstance(value, tuple): - return tuple(_normalize_checkpoint_state(v) for v in value) - if torch.is_tensor(value): - return Torch.to_local_tensor(value).cpu() - return value - - _default_adapter_name = '' DEFAULT_LEARNING_RATE = 1e-5 DEFAULT_WEIGHT_DECAY = 0.01 @@ -892,13 +878,16 @@ def save(self, name: Optional[str] = None, output_dir: Optional[str] = None, int def _save_optimizer(self, output_dir, **kwargs): adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] + optimizer = optimizer_config.optimizer + lr_scheduler = optimizer_config.lr_scheduler + if optimizer is not None: + optimizer_path = os.path.join(output_dir, 'optimizer.pt') + if hasattr(self.strategy, 'save_optimizer_checkpoint'): + self.strategy.save_optimizer_checkpoint(self.model, optimizer, optimizer_path) + elif Platform.is_master(): + torch.save(optimizer.state_dict(), optimizer_path) if Platform.is_master(): - optimizer = optimizer_config.optimizer - lr_scheduler = optimizer_config.lr_scheduler - if optimizer is not None: - state_dict = _normalize_checkpoint_state(optimizer.state_dict()) - torch.save(state_dict, os.path.join(output_dir, 'optimizer.pt')) if lr_scheduler is not None: torch.save(lr_scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt')) @@ -1007,8 +996,13 @@ def _load_optimizer(self, checkpoint_dir, **kwargs): raise FileNotFoundError(scheduler_path) if os.path.exists(optimizer_path) and optimizer_config.optimizer is not None: - state_dict = torch.load(optimizer_path, map_location='cpu', weights_only=False) - optimizer_config.optimizer.load_state_dict(state_dict) + if getattr(self.strategy, 'needs_wrapped_optimizer_state', lambda: False)() and not self._model_wrapped: + self._lazy_wrap_model() + if hasattr(self.strategy, 'load_optimizer_checkpoint'): + self.strategy.load_optimizer_checkpoint(self.model, optimizer_config.optimizer, optimizer_path) + else: + state_dict = torch.load(optimizer_path, map_location='cpu', weights_only=False) + optimizer_config.optimizer.load_state_dict(state_dict) if os.path.exists(scheduler_path) and optimizer_config.lr_scheduler is not None: state_dict = torch.load(scheduler_path, map_location='cpu', weights_only=False) From fdf1f71942be0627d8f12f09ab22abdb7a98d555 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Mon, 30 Mar 2026 16:54:01 +0800 Subject: [PATCH 11/28] fix --- src/twinkle/model/transformers/transformers.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 40473420..d7e8183a 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -961,10 +961,11 @@ def load_peft_weights_for_fsdp2(model, adapter_weights, adapter_name='default'): model_sd = model.state_dict() converted_weights = {} for key, value in adapter_weights.items(): - if f'.{adapter_name}.weight' not in key: - key = key.replace('.weight', f'.{adapter_name}.weight') - if key in model_sd: - param = model_sd[key] + model_key = key + if f'.{adapter_name}.weight' not in model_key: + model_key = model_key.replace('.weight', f'.{adapter_name}.weight') + if model_key in model_sd: + param = model_sd[model_key] if isinstance(param, DTensor) and not isinstance(value, DTensor): value = distribute_tensor(value.to(param.device), param.device_mesh, param.placements) converted_weights[key] = value From 6cf51606ade61c7a0d4a07190e589d22d8c8fc68 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Mon, 30 Mar 2026 18:34:31 +0800 Subject: [PATCH 12/28] wip --- .../model/transformers/strategy/accelerate.py | 12 +++-- .../transformers/strategy/native_fsdp.py | 47 +++++++++++++++++++ 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py index 214623fa..bcfd5e30 100644 --- a/src/twinkle/model/transformers/strategy/accelerate.py +++ b/src/twinkle/model/transformers/strategy/accelerate.py @@ -124,8 +124,12 @@ def wrap_model(self, model, *args): def unwrap_model(self, model): return self.accelerator.unwrap_model(model, keep_torch_compile=False) + def _get_fsdp_plugin(self): + state = self.accelerator.state + return state.fsdp_plugin if hasattr(state, 'fsdp_plugin') else None + def _prepare_fsdp2_sd_options(self): - fsdp_plugin = self.accelerator.state.fsdp_plugin + fsdp_plugin = self._get_fsdp_plugin() if fsdp_plugin is None or fsdp_plugin.fsdp_version != 2: return None @@ -139,11 +143,11 @@ def _prepare_fsdp2_sd_options(self): ) def needs_wrapped_optimizer_state(self) -> bool: - fsdp_plugin = self.accelerator.state.fsdp_plugin + fsdp_plugin = self._get_fsdp_plugin() return fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2 def save_optimizer_checkpoint(self, model, optimizer, output_path: str): - fsdp_plugin = self.accelerator.state.fsdp_plugin + fsdp_plugin = self._get_fsdp_plugin() if fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2: from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict import torch @@ -158,7 +162,7 @@ def save_optimizer_checkpoint(self, model, optimizer, output_path: str): torch.save(optimizer.state_dict(), output_path) def load_optimizer_checkpoint(self, model, optimizer, input_path: str): - fsdp_plugin = self.accelerator.state.fsdp_plugin + fsdp_plugin = self._get_fsdp_plugin() if fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2: from torch.distributed.checkpoint.state_dict import set_optimizer_state_dict import torch diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index 48a1da85..ad675006 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -151,6 +151,53 @@ def wrap_model(self, model, optimizer=None): return model, optimizer + def _prepare_optimizer_state_dict_options(self, *, for_load: bool): + from torch.distributed.checkpoint.state_dict import StateDictOptions + + return StateDictOptions( + full_state_dict=True, + cpu_offload=not for_load, + broadcast_from_rank0=for_load, + ) + + def needs_wrapped_optimizer_state(self) -> bool: + return self.device_mesh is not None + + def save_optimizer_checkpoint(self, model, optimizer, output_path: str): + import torch + if not self.needs_wrapped_optimizer_state(): + if Platform.is_master(): + torch.save(optimizer.state_dict(), output_path) + return + + from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict + + optim_state = get_optimizer_state_dict( + model, + optimizer, + options=self._prepare_optimizer_state_dict_options(for_load=False), + ) + if Platform.is_master(): + torch.save(optim_state, output_path) + + def load_optimizer_checkpoint(self, model, optimizer, input_path: str): + import torch + if not self.needs_wrapped_optimizer_state(): + optimizer.load_state_dict(torch.load(input_path, map_location='cpu', weights_only=False)) + return + + from torch.distributed.checkpoint.state_dict import set_optimizer_state_dict + + optim_state = {} + if Platform.is_master(): + optim_state = torch.load(input_path, map_location='cpu', weights_only=True) + set_optimizer_state_dict( + model, + optimizer, + optim_state, + options=self._prepare_optimizer_state_dict_options(for_load=True), + ) + def get_ep_clip_kwargs(self, model) -> Dict[str, Any]: """Return EP-aware kwargs for normalize_and_clip_grad_norm.""" model = self.unwrap_model(model) From e21f870c395215b0055fcf20e42645ab95178f46 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 31 Mar 2026 09:52:22 +0800 Subject: [PATCH 13/28] lint --- ...7-transformers-checkpoint-resume-design.md | 1 - .../model/transformers/strategy/accelerate.py | 4 +- .../model/transformers/transformers.py | 5 +- src/twinkle/server/model/twinkle_handlers.py | 14 ++- tests/dataloader/test_dataloader.py | 16 ++- tests/dataloader/test_sampler.py | 21 +++- .../transformers/test_checkpoint_resume.py | 103 +++++++++++++++++- .../model/test_twinkle_resume_routes.py | 25 +++-- 8 files changed, 155 insertions(+), 34 deletions(-) diff --git a/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md b/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md index 7b90baba..4f41c64f 100644 --- a/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md +++ b/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md @@ -430,4 +430,3 @@ Recommended guidance text: - `ignore_data_skip=True` disables progress restore and starts from step 0 - Full-parameter checkpoints restore weights during model initialization and restore training state afterward - Iterable and streaming datasets do not support consumed-data skipping and will resume without skipping data - diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py index bcfd5e30..b9b81f59 100644 --- a/src/twinkle/model/transformers/strategy/accelerate.py +++ b/src/twinkle/model/transformers/strategy/accelerate.py @@ -149,8 +149,8 @@ def needs_wrapped_optimizer_state(self) -> bool: def save_optimizer_checkpoint(self, model, optimizer, output_path: str): fsdp_plugin = self._get_fsdp_plugin() if fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2: - from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict import torch + from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict optim_state = get_optimizer_state_dict(model, optimizer, options=self._prepare_fsdp2_sd_options()) if self.accelerator.process_index == 0: @@ -164,8 +164,8 @@ def save_optimizer_checkpoint(self, model, optimizer, output_path: str): def load_optimizer_checkpoint(self, model, optimizer, input_path: str): fsdp_plugin = self._get_fsdp_plugin() if fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2: - from torch.distributed.checkpoint.state_dict import set_optimizer_state_dict import torch + from torch.distributed.checkpoint.state_dict import set_optimizer_state_dict optim_state = None rank0_only = getattr(fsdp_plugin.optim_state_dict_config, 'rank0_only', False) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index d7e8183a..88366c96 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -2,11 +2,11 @@ import asyncio import contextlib import json +import numpy as np import os import random import re import threading -import numpy as np import torch import torch.distributed as dist import transformers @@ -147,6 +147,7 @@ def calculate_metrics(self, is_training): self.outputs = None return results + _default_adapter_name = '' DEFAULT_LEARNING_RATE = 1e-5 DEFAULT_WEIGHT_DECAY = 0.01 @@ -1055,7 +1056,7 @@ def read_training_progress(self, checkpoint_dir, **kwargs): if not os.path.exists(trainer_state_path): raise FileNotFoundError(trainer_state_path) - with open(trainer_state_path, 'r', encoding='utf-8') as f: + with open(trainer_state_path, encoding='utf-8') as f: trainer_state = json.load(f) required_keys = {'checkpoint_version', 'cur_step', 'gradient_accumulation_steps', 'consumed_train_samples'} diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index 259c50c4..b48b4e12 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -8,11 +8,11 @@ """ from __future__ import annotations +import os import torch import traceback -import os -from pathlib import Path from fastapi import Depends, FastAPI, HTTPException, Request +from pathlib import Path from peft import LoraConfig from typing import TYPE_CHECKING, Any, Callable @@ -363,8 +363,9 @@ async def _task(): extra_kwargs = body.model_extra or {} checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') resolved = checkpoint_manager.resolve_load_path(body.name) - checkpoint_dir = (Path(resolved.checkpoint_dir, resolved.checkpoint_name).as_posix() - if resolved.checkpoint_dir else body.name) + checkpoint_dir = ( + Path(resolved.checkpoint_dir, resolved.checkpoint_name).as_posix() + if resolved.checkpoint_dir else body.name) self.model.load_training_state( checkpoint_dir, adapter_name=adapter_name, @@ -387,8 +388,9 @@ async def _task(): extra_kwargs = body.model_extra or {} checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') resolved = checkpoint_manager.resolve_load_path(body.name) - checkpoint_dir = (Path(resolved.checkpoint_dir, resolved.checkpoint_name).as_posix() - if resolved.checkpoint_dir else body.name) + checkpoint_dir = ( + Path(resolved.checkpoint_dir, resolved.checkpoint_name).as_posix() + if resolved.checkpoint_dir else body.name) ret = self.model.read_training_progress( checkpoint_dir, adapter_name=adapter_name, diff --git a/tests/dataloader/test_dataloader.py b/tests/dataloader/test_dataloader.py index 232c6fe6..82a4f41b 100644 --- a/tests/dataloader/test_dataloader.py +++ b/tests/dataloader/test_dataloader.py @@ -39,10 +39,18 @@ def convert_to_messages(example): def _build_resume_rows(): return [ - {'text': 'Hello world'}, - {'text': 'Test data'}, - {'text': 'Another example'}, - {'text': 'Sample text'}, + { + 'text': 'Hello world' + }, + { + 'text': 'Test data' + }, + { + 'text': 'Another example' + }, + { + 'text': 'Sample text' + }, ] diff --git a/tests/dataloader/test_sampler.py b/tests/dataloader/test_sampler.py index 1dd7d7ca..d5c97dbc 100644 --- a/tests/dataloader/test_sampler.py +++ b/tests/dataloader/test_sampler.py @@ -1,11 +1,11 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import concurrent.futures -import os import numpy as np +import os import pytest from pathlib import Path -from torch.utils.data import RandomSampler, SequentialSampler from torch.utils.data import Dataset as TorchDataset +from torch.utils.data import RandomSampler, SequentialSampler class _NoOpProcessPoolExecutor: @@ -183,6 +183,7 @@ def test_sequential_vs_random_order(self): class TestResumeSkipSamplerOrdering: def test_sequential_sampler_skip_happens_before_device_mesh_slice(self): + class _InMemoryDataset(TorchDataset): def __init__(self, rows): @@ -195,10 +196,18 @@ def __getitem__(self, idx): return self.rows[idx] dataset = _InMemoryDataset([ - {'text': 'Hello world'}, - {'text': 'Test data'}, - {'text': 'Another example'}, - {'text': 'Sample text'}, + { + 'text': 'Hello world' + }, + { + 'text': 'Test data' + }, + { + 'text': 'Another example' + }, + { + 'text': 'Sample text' + }, ]) sampler = SequentialSampler(dataset) device_mesh = DeviceMesh(device_type='cpu', mesh=np.array([0, 1]), mesh_dim_names=('dp', )) diff --git a/tests/model/transformers/test_checkpoint_resume.py b/tests/model/transformers/test_checkpoint_resume.py index 7541cbaa..3e903d82 100644 --- a/tests/model/transformers/test_checkpoint_resume.py +++ b/tests/model/transformers/test_checkpoint_resume.py @@ -1,19 +1,20 @@ import concurrent.futures +import pytest import sys +import torch import types import uuid from pathlib import Path -from types import ModuleType - -import pytest from peft import LoraConfig from tokenizers import Tokenizer from tokenizers.models import WordLevel from tokenizers.pre_tokenizers import Whitespace -from transformers import Qwen3Config, Qwen3ForCausalLM, PreTrainedTokenizerFast +from transformers import PreTrainedTokenizerFast, Qwen3Config, Qwen3ForCausalLM +from types import ModuleType class _NoOpProcessPoolExecutor: + def __init__(self, *args, **kwargs): pass @@ -39,6 +40,7 @@ def submit(self, fn, *args, **kwargs): sys.modules['zmq'] = fake_zmq from twinkle.model.transformers import MultiLoraTransformersModel, TransformersModel +from twinkle.model.transformers.strategy import NativeFSDPStrategy def build_tiny_tokenizer(): @@ -94,6 +96,16 @@ def build_full_param_model(model_dir): ) +def build_native_fsdp_strategy(): + device_mesh = types.SimpleNamespace( + world_size=2, + ep_size=1, + ep_fsdp_size=None, + device_type='cpu', + ) + return NativeFSDPStrategy(device_mesh=device_mesh) + + def build_multi_lora_model(model_dir): device_mesh = types.SimpleNamespace(fsdp_world_size=0, data_world_size=1) model = MultiLoraTransformersModel( @@ -103,7 +115,8 @@ def build_multi_lora_model(model_dir): device_mesh=device_mesh, grad_scaler_config={}, ) - model.add_adapter_to_model('default', LoraConfig(r=2, lora_alpha=4, target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'])) + model.add_adapter_to_model('default', + LoraConfig(r=2, lora_alpha=4, target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'])) return model @@ -173,7 +186,7 @@ def test_save_training_state_writes_split_files(tmp_path, tiny_local_model_dir): ], ) def test_load_training_state_fails_when_required_file_missing(tmp_path, tiny_local_model_dir, missing_name, - expected_pattern): + expected_pattern): ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir) _ensure_file_exists(ckpt_dir / missing_name) (ckpt_dir / missing_name).unlink() @@ -291,3 +304,81 @@ def test_read_training_progress_returns_metadata_without_mutating_optimizer_stat assert progress['consumed_train_samples'] == 12 assert restored.optimizer_group[''].cur_step == 0 assert restored.optimizer_group[''].gradient_accumulation_steps == 1 + + +def test_native_fsdp_strategy_requires_wrapped_optimizer_state(): + strategy = build_native_fsdp_strategy() + + assert strategy.needs_wrapped_optimizer_state() is True + + +def test_native_fsdp_strategy_save_optimizer_checkpoint_uses_full_state_dict(monkeypatch, tmp_path): + strategy = build_native_fsdp_strategy() + model = object() + optimizer = object() + optimizer_path = tmp_path / 'optimizer.pt' + captured = {} + + from torch.distributed.checkpoint import state_dict as checkpoint_state_dict + + def fake_get_optimizer_state_dict(model_arg, optimizer_arg, *, options=None): + captured['model'] = model_arg + captured['optimizer'] = optimizer_arg + captured['options'] = options + return {'state': {'step': 3}} + + def fake_save(obj, path): + captured['saved_obj'] = obj + captured['saved_path'] = path + + monkeypatch.setattr(checkpoint_state_dict, 'get_optimizer_state_dict', fake_get_optimizer_state_dict) + monkeypatch.setattr(torch, 'save', fake_save) + + strategy.save_optimizer_checkpoint(model, optimizer, str(optimizer_path)) + + assert captured['model'] is model + assert captured['optimizer'] is optimizer + assert captured['saved_obj'] == {'state': {'step': 3}} + assert captured['saved_path'] == str(optimizer_path) + assert captured['options'].full_state_dict is True + assert captured['options'].cpu_offload is True + assert captured['options'].broadcast_from_rank0 is False + + +def test_native_fsdp_strategy_load_optimizer_checkpoint_broadcasts_from_rank0(monkeypatch, tmp_path): + strategy = build_native_fsdp_strategy() + model = object() + optimizer = object() + optimizer_path = tmp_path / 'optimizer.pt' + optimizer_path.write_bytes(b'placeholder') + expected_state = {'state': {'step': 7}} + captured = {} + + from torch.distributed.checkpoint import state_dict as checkpoint_state_dict + + def fake_load(path, map_location=None, weights_only=None): + captured['loaded_path'] = path + captured['map_location'] = map_location + captured['weights_only'] = weights_only + return expected_state + + def fake_set_optimizer_state_dict(model_arg, optimizer_arg, optim_state_dict, *, options=None): + captured['model'] = model_arg + captured['optimizer'] = optimizer_arg + captured['optim_state_dict'] = optim_state_dict + captured['options'] = options + + monkeypatch.setattr(torch, 'load', fake_load) + monkeypatch.setattr(checkpoint_state_dict, 'set_optimizer_state_dict', fake_set_optimizer_state_dict) + + strategy.load_optimizer_checkpoint(model, optimizer, str(optimizer_path)) + + assert captured['loaded_path'] == str(optimizer_path) + assert captured['map_location'] == 'cpu' + assert captured['weights_only'] is True + assert captured['model'] is model + assert captured['optimizer'] is optimizer + assert captured['optim_state_dict'] == expected_state + assert captured['options'].full_state_dict is True + assert captured['options'].cpu_offload is False + assert captured['options'].broadcast_from_rank0 is True diff --git a/tests/server/model/test_twinkle_resume_routes.py b/tests/server/model/test_twinkle_resume_routes.py index f3523913..cd0fcfde 100644 --- a/tests/server/model/test_twinkle_resume_routes.py +++ b/tests/server/model/test_twinkle_resume_routes.py @@ -1,13 +1,12 @@ import concurrent.futures import importlib.util import sys +from fastapi import FastAPI +from fastapi.testclient import TestClient from pathlib import Path from types import ModuleType from unittest.mock import Mock -from fastapi import FastAPI -from fastapi.testclient import TestClient - class _NoOpProcessPoolExecutor: @@ -86,7 +85,10 @@ def test_load_training_state_route_resolves_checkpoint_path_and_calls_model(monk response = client.post( '/twinkle/load_training_state', - json={'name': 'twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', 'adapter_name': ''}, + json={ + 'name': 'twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', + 'adapter_name': '' + }, ) assert response.status_code == 200 @@ -102,7 +104,10 @@ def test_load_training_state_route_prefixes_non_empty_adapter_name(monkeypatch): response = client.post( '/twinkle/load_training_state', - json={'name': 'twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', 'adapter_name': 'adapter-a'}, + json={ + 'name': 'twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', + 'adapter_name': 'adapter-a' + }, ) assert response.status_code == 200 @@ -117,7 +122,10 @@ def test_load_training_state_route_uses_raw_name_when_checkpoint_dir_missing(mon response = client.post( '/twinkle/load_training_state', - json={'name': 'local-checkpoint-dir', 'adapter_name': ''}, + json={ + 'name': 'local-checkpoint-dir', + 'adapter_name': '' + }, ) assert response.status_code == 200 @@ -133,7 +141,10 @@ def test_read_training_progress_route_returns_progress_and_calls_model(monkeypat response = client.post( '/twinkle/read_training_progress', - json={'name': 'twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', 'adapter_name': ''}, + json={ + 'name': 'twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', + 'adapter_name': '' + }, ) assert response.status_code == 200 From 70ebe50c4ffc1559bf6f3cef0d71608e3fcc6b3f Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 31 Mar 2026 09:54:01 +0800 Subject: [PATCH 14/28] wip --- ...7-transformers-checkpoint-resume-design.md | 432 ------------------ .../transformers/test_checkpoint_resume.py | 384 ---------------- .../model/test_twinkle_resume_routes.py | 156 ------- 3 files changed, 972 deletions(-) delete mode 100644 docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md delete mode 100644 tests/model/transformers/test_checkpoint_resume.py delete mode 100644 tests/server/model/test_twinkle_resume_routes.py diff --git a/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md b/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md deleted file mode 100644 index 4f41c64f..00000000 --- a/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md +++ /dev/null @@ -1,432 +0,0 @@ -# Transformers Strict Resume Design - -## Summary - -This design adds real checkpoint resumption support for `TransformersModel` without introducing a new trainer class. - -The design supports both full-parameter training and LoRA training: - -- full-parameter training restores weights during model initialization -- LoRA training restores adapter weights through the existing load path -- both modes share the same training-state resume contract -- strict model-state resume does not silently fall back to weight-only loading when required state is missing - -Because Twinkle keeps the training loop explicit in user code, the design extends existing model, dataloader, server, and client interfaces rather than adding a central trainer abstraction. - -## Goals - -- Support true checkpoint resume for `TransformersModel` -- Support both full-parameter and LoRA training resume -- Restore optimizer state, scheduler state, scaler state, RNG state, and step counters -- Support dataset progress skipping for map-style datasets -- Expose Swift-like resume controls without adding a new trainer class -- Keep training-state save and load compatible with NPU (Ascend) environments -- Preserve existing weight-only loading and saving behavior - -## Non-Goals - -- Do not introduce a new `Trainer` class or resume manager class -- Do not guarantee exact sample-by-sample replay when retry-based sampling changes sample order -- Do not support exact data-progress resume for `IterableDataset` or streaming datasets -- Do not attempt to persist transient runtime state such as in-flight batch tensors, current loss tensors, or metric caches - -## User-Facing Resume Controls - -Resume behavior is controlled by existing training entrypoints through three new parameters: - -- `resume_from_checkpoint: Optional[str] = None` -- `resume_only_model: bool = False` -- `ignore_data_skip: bool = False` - -### Parameter semantics - -#### `resume_from_checkpoint` - -- Specifies the checkpoint directory or checkpoint path to resume from -- When unset, training starts normally from scratch -- When set, the training entrypoint reads the checkpoint and restores model state through existing model APIs - -#### `resume_only_model` - -- Defaults to `False` -- When `False`, resume restores full training state -- When `True`, resume restores only model weights - -#### `ignore_data_skip` - -- Only meaningful when `resume_from_checkpoint` is set and `resume_only_model=True` -- Defaults to `False` -- When `False`, the system still restores training progress metadata needed for data skipping and step/epoch continuation, but does not restore optimizer, scheduler, scaler, or RNG -- When `True`, the system restores only model weights and does not restore training progress or skip consumed data - -### Effective behavior matrix - -#### Case 1: `resume_from_checkpoint is None` - -- Start a new training run - -#### Case 2: `resume_from_checkpoint is not None` and `resume_only_model=False` - -- Restore model weights -- Restore optimizer state -- Restore scheduler state -- Restore scaler state -- Restore RNG state -- Restore step counters -- Attempt to skip already consumed training data -- If required model training state is missing, fail without fallback - -#### Case 3: `resume_from_checkpoint is not None` and `resume_only_model=True` and `ignore_data_skip=False` - -- Restore model weights only -- Do not restore optimizer, scheduler, scaler, or RNG -- Restore step/progress metadata needed for data skipping -- Attempt to skip already consumed training data - -#### Case 4: `resume_from_checkpoint is not None` and `resume_only_model=True` and `ignore_data_skip=True` - -- Restore model weights only -- Do not restore optimizer, scheduler, scaler, RNG, step counters, or data progress -- Restart the training loop from step 0 with no skipping - -## Checkpoint Layout - -Existing weight layouts remain valid. New training-state files are added alongside current checkpoint contents. - -### Existing files preserved - -- full-model weights saved by `save_pretrained` -- LoRA weights saved as `adapter_model.safetensors` -- tokenizer artifacts -- `optimizer.pt` -- `scheduler.pt` - -### New training-state files - -- `scaler.pt` -- `trainer_state.json` -- `rng_state.pt` - -### `trainer_state.json` contents - -`trainer_state.json` stores lightweight training metadata: - -- `checkpoint_version` -- `cur_step` -- `gradient_accumulation_steps` -- `consumed_train_samples` - -The design prefers storing `consumed_train_samples` as the canonical progress value and deriving batch skipping from it where needed. - -### `scaler.pt` contents - -- AMP scaler state dict -- optional scaler-related flags such as `scaler_has_nan` - -### `rng_state.pt` contents - -- Python `random` state -- NumPy RNG state -- PyTorch CPU RNG state -- CUDA RNG state - -## Accelerator Compatibility - -Training-state save and load must be accelerator-compatible, including Ascend NPU environments already supported by Twinkle. - -### Device-agnostic serialization - -Training-state files must use device-agnostic serialization: - -- optimizer, scheduler, scaler, and RNG payloads should be serialized in CPU-safe form -- JSON metadata stays in plain text files -- loading should first read state from CPU-safe files and then apply it to objects created on the current runtime device - -This avoids tying resume files to a specific device object layout during save. - -### RNG compatibility requirements - -RNG save and restore must branch by current accelerator backend: - -- CUDA runtime uses `torch.cuda` RNG APIs -- NPU runtime uses `torch.npu` RNG APIs -- CPU RNG and Python/NumPy RNG are always restored - -The implementation must not assume CUDA-only RNG helpers when saving or restoring training state. - -### Scope of compatibility - -The design requires resume support to work correctly in NPU environments. - -The design does not require cross-accelerator resume guarantees such as saving on GPU and resuming on NPU, or saving on NPU and resuming on GPU. The compatibility target is correct save and restore within the active supported accelerator backend. -## Restore Paths - -## Full-Parameter Training - -For full-parameter training, model weights are restored during initialization. - -### Full-parameter restore flow - -1. Construct `TransformersModel(model_id=ckpt_dir, ...)` -2. `__init__` uses `from_pretrained(ckpt_dir, ...)` to restore weights -3. Create optimizer, scheduler, and scaler objects -4. Call `load_training_state(ckpt_dir)` to restore training state -5. If data skipping is enabled, rebuild dataloader with skip arguments derived from `trainer_state.json` - -This means full-parameter resume does not need a separate model-weight loading method after initialization. It only needs explicit training-state restoration. - -## LoRA Training - -For LoRA training, the existing adapter-weight load path remains in place. - -### LoRA restore flow - -1. Construct the model and adapter objects as today -2. Restore adapter weights through the existing `load()` path -3. Create optimizer, scheduler, and scaler objects -4. Call the same `load_training_state(ckpt_dir)` method to restore training state -5. If data skipping is enabled, rebuild dataloader with skip arguments derived from `trainer_state.json` - -## Unified training-state method - -The model layer gains a shared helper such as `load_training_state(ckpt_dir)`. - -This method restores: - -- `optimizer.pt` -- `scheduler.pt` -- `scaler.pt` -- `trainer_state.json` -- `rng_state.pt` - -It assumes the corresponding optimizer, scheduler, and scaler objects have already been created before invocation. - -## Model Save and Load Semantics - -## Save behavior - -When saving with optimizer state enabled, the checkpoint includes: - -- weights in the existing full-model or LoRA format -- tokenizer artifacts -- `optimizer.pt` -- `scheduler.pt` -- `scaler.pt` -- `trainer_state.json` -- `rng_state.pt` - -When optimizer save is disabled, save remains weight-only and does not produce strict resume metadata. - -## Strict training-state restore - -Strict model-state resume restores: - -- optimizer state -- scheduler state -- scaler state -- RNG state -- `cur_step` -- `gradient_accumulation_steps` -- data-progress metadata - -### Failure behavior - -When strict training-state restore is requested, missing required model training state is an error: - -- missing `trainer_state.json` -> fail -- missing `optimizer.pt` when optimizer restore is required -> fail -- missing `scheduler.pt` when scheduler restore is required -> fail -- missing `scaler.pt` when scaler restore is required -> fail -- missing `rng_state.pt` when RNG restore is required -> fail -- malformed required fields -> fail - -This intentionally does not fall back to weight-only loading, to avoid falsely signaling successful strict resume. - -## Training Progress and Data Skipping - -Twinkle does not currently have a central trainer abstraction. Because of that, data skipping must be driven by existing training entrypoints and dataloader arguments. - -## Dataloader extensions - -Existing dataloader and sampler code is extended rather than replaced: - -- `twinkle.dataloader.DataLoader` -- `twinkle.dataloader.DeviceMeshSampler` -- retry-aware sampler flow - -The dataloader gains resume-oriented arguments: - -- `skip_samples: int = 0` -- optionally `skip_batches: int = 0` - -Map-style datasets use this progress to skip already consumed data before yielding new training batches. - -## Map-style dataset behavior - -For datasets with `__len__`, Twinkle attempts to skip previously consumed data using sampler or batch-sampler level skipping. - -Preferred behavior: - -- preserve existing sharding logic -- apply skip before data is yielded to the training loop -- keep the solution compatible with current `DeviceMeshSampler` wrapping - -## Iterable and streaming behavior - -`IterableDataset` and streaming datasets do not support exact progress skipping in this design. - -Behavior for these datasets: - -- restore model state according to the selected resume mode -- log a clear warning that consumed-data skipping is not supported -- continue training without skipping historical samples - -This is the only fallback allowed in the design. It applies only to dataset progress skipping, not to model-state resume. - -## Entry Point Integration - -No new trainer class is introduced. - -Resume parameters are threaded through existing training entrypoints: - -- direct local training loops using `TwinkleModel` / `TransformersModel` -- current client/server training flows that already support checkpoint save and load - -The practical integration model is: - -1. Parse or receive the three resume parameters -2. If `resume_from_checkpoint` is unset, construct dataloader normally -3. Construct model weights through the appropriate path - - full-parameter: restore through `__init__` - - LoRA: restore through existing adapter load logic -4. If `resume_only_model=False`, call `load_training_state(ckpt_dir)` -5. If `resume_only_model=True` and `ignore_data_skip=False`, read `trainer_state.json` for progress only -6. Recreate the dataloader with skip arguments applied when skipping is enabled - -This keeps the training loop explicit and compatible with current Twinkle examples. - -## Server and Client Behavior - -Server-side checkpoint save/load behavior should preserve current APIs while adding richer metadata. - -### Save path - -When server-side save endpoints request optimizer save: - -- save the model checkpoint as today -- save `optimizer.pt`, `scheduler.pt`, `scaler.pt`, `trainer_state.json`, and `rng_state.pt` -- persist checkpoint metadata through the existing checkpoint manager - -### Load path - -Current model load APIs remain the weight-loading trigger. - -The new resume parameters are primarily a training-entrypoint concern. They orchestrate whether to: - -- restore full training state -- restore weight only -- request data skipping - -The underlying server model APIs do not need a new trainer object to support this. - -## Compatibility Strategy - -### Existing checkpoints - -Existing checkpoints remain loadable in weight-only mode. - -Examples: - -- weight-only initialization for full-parameter checkpoints continues to work -- existing LoRA weight loading continues to work -- inference-only consumers remain unaffected - -### Old checkpoints under strict resume - -Old checkpoints that lack the new training-state files are not valid for strict resume. - -Expected behavior: - -- strict resume fails clearly -- weight-only load continues to work when requested explicitly - -### `resume_only_model=True` - -For `resume_only_model=True`, old checkpoints may still be usable if weight files are present. - -If data skipping is requested but no progress metadata exists, the entrypoint should fail clearly rather than silently train from the beginning while claiming resumed progress. - -## Risks and Constraints - -### RetrySampler interaction - -`RetrySampler` may retry or replace failed samples, including random backfill behavior at the tail of an epoch. - -Because of that: - -- progress skipping can preserve approximate data position -- exact sample-for-sample replay is not guaranteed when retry or backfill paths are exercised - -This limitation should be documented explicitly. - -### Dataset shape changes - -If dataset definition, slicing, filtering, or shuffle configuration changes between save and resume, data skipping semantics may become invalid. - -The user guidance should state that resume should be done with unchanged training parameters and unchanged dataset configuration. - -### Distributed consistency - -Skip logic must be compatible with current device-mesh sharding. The implementation should ensure skip is applied consistently before per-rank slicing causes divergence. - -## Testing Strategy - -Tests should cover: - -### Full-parameter training resume - -- initializing with `model_id=ckpt_dir` restores weights -- `load_training_state(ckpt_dir)` restores optimizer, scheduler, scaler, RNG, and step metadata - -### LoRA training resume - -- adapter-weight restore continues to work -- `load_training_state(ckpt_dir)` restores shared training state correctly - -### Strict restore failures - -- strict resume fails when required files are missing -- malformed state files fail clearly - -### Weight-only compatibility - -- legacy checkpoints still load in weight-only mode -- `resume_only_model=True` restores weights without optimizer, scheduler, scaler, or RNG - -### Data progress skipping - -- map-style datasets skip consumed data correctly -- skip behavior remains correct with device-mesh sharding -- iterable and streaming datasets emit warnings and continue without skipping - -## Implementation Outline - -1. Add model helpers for saving and loading split training-state files -2. Implement `load_training_state(ckpt_dir)` with shared behavior for full-parameter and LoRA training -3. Keep full-parameter weight restore in `__init__` -4. Keep LoRA weight restore in the existing adapter load path -5. Extend dataloader and sampler stack to support skip arguments for map-style datasets -6. Thread `resume_from_checkpoint`, `resume_only_model`, and `ignore_data_skip` through existing training entrypoints -7. Add warnings for unsupported iterable and streaming data skipping -8. Update docs and examples to show the new resume contract - -## User Guidance - -Recommended guidance text: - -- To resume training, keep other parameters unchanged and provide `resume_from_checkpoint` -- `resume_only_model=False` performs full resume -- `resume_only_model=True` restores only model weights -- `ignore_data_skip=True` disables progress restore and starts from step 0 -- Full-parameter checkpoints restore weights during model initialization and restore training state afterward -- Iterable and streaming datasets do not support consumed-data skipping and will resume without skipping data diff --git a/tests/model/transformers/test_checkpoint_resume.py b/tests/model/transformers/test_checkpoint_resume.py deleted file mode 100644 index 3e903d82..00000000 --- a/tests/model/transformers/test_checkpoint_resume.py +++ /dev/null @@ -1,384 +0,0 @@ -import concurrent.futures -import pytest -import sys -import torch -import types -import uuid -from pathlib import Path -from peft import LoraConfig -from tokenizers import Tokenizer -from tokenizers.models import WordLevel -from tokenizers.pre_tokenizers import Whitespace -from transformers import PreTrainedTokenizerFast, Qwen3Config, Qwen3ForCausalLM -from types import ModuleType - - -class _NoOpProcessPoolExecutor: - - def __init__(self, *args, **kwargs): - pass - - def submit(self, fn, *args, **kwargs): - raise RuntimeError('Process pool is disabled in this test environment.') - - -concurrent.futures.ProcessPoolExecutor = _NoOpProcessPoolExecutor - -ROOT = Path(__file__).resolve().parents[3] -SRC = ROOT / 'src' -if str(SRC) not in sys.path: - sys.path.insert(0, str(SRC)) - -if 'zmq' not in sys.modules: - fake_zmq = ModuleType('zmq') - fake_zmq.Socket = object - fake_zmq.Context = object - fake_zmq.REP = 0 - fake_zmq.REQ = 1 - fake_zmq.IPV6 = 2 - fake_zmq.error = types.SimpleNamespace(Again=RuntimeError) - sys.modules['zmq'] = fake_zmq - -from twinkle.model.transformers import MultiLoraTransformersModel, TransformersModel -from twinkle.model.transformers.strategy import NativeFSDPStrategy - - -def build_tiny_tokenizer(): - vocab = {'[PAD]': 0, '[BOS]': 1, '[EOS]': 2, '[UNK]': 3, 'hello': 4, 'world': 5} - backend = Tokenizer(WordLevel(vocab=vocab, unk_token='[UNK]')) - backend.pre_tokenizer = Whitespace() - tokenizer = PreTrainedTokenizerFast( - tokenizer_object=backend, - pad_token='[PAD]', - bos_token='[BOS]', - eos_token='[EOS]', - unk_token='[UNK]', - ) - return tokenizer - - -@pytest.fixture -def tmp_path(): - root = Path(__file__).parent / 'tmp_test_checkpoints' - root.mkdir(exist_ok=True) - path = root / uuid.uuid4().hex - path.mkdir() - return path - - -@pytest.fixture -def tiny_local_model_dir(tmp_path): - model_dir = tmp_path / 'tiny-qwen3' - model_dir.mkdir() - - tokenizer = build_tiny_tokenizer() - tokenizer.save_pretrained(model_dir) - - config = Qwen3Config( - vocab_size=tokenizer.vocab_size, - num_hidden_layers=1, - num_attention_heads=1, - hidden_size=16, - max_position_embeddings=32, - bos_token_id=tokenizer.bos_token_id, - eos_token_id=tokenizer.eos_token_id, - ) - Qwen3ForCausalLM(config).save_pretrained(model_dir) - return model_dir - - -def build_full_param_model(model_dir): - return TransformersModel( - model_cls='Qwen3ForCausalLM', - model_id=str(model_dir), - mixed_precision='no', - grad_scaler_config={}, - ) - - -def build_native_fsdp_strategy(): - device_mesh = types.SimpleNamespace( - world_size=2, - ep_size=1, - ep_fsdp_size=None, - device_type='cpu', - ) - return NativeFSDPStrategy(device_mesh=device_mesh) - - -def build_multi_lora_model(model_dir): - device_mesh = types.SimpleNamespace(fsdp_world_size=0, data_world_size=1) - model = MultiLoraTransformersModel( - model_cls='Qwen3ForCausalLM', - model_id=str(model_dir), - mixed_precision='no', - device_mesh=device_mesh, - grad_scaler_config={}, - ) - model.add_adapter_to_model('default', - LoraConfig(r=2, lora_alpha=4, target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'])) - return model - - -def prepare_full_param_checkpoint(tmp_path, model_dir, cur_step=3, consumed_train_samples=6): - model = build_full_param_model(model_dir) - model.set_optimizer('AdamW', lr=1e-4) - model.set_lr_scheduler('LinearLR') - model.set_grad_scaler(device='cpu') - model.optimizer_group[''].cur_step = cur_step - return Path( - model.save( - name='full-resume', - output_dir=str(tmp_path), - save_optimizer=True, - consumed_train_samples=consumed_train_samples, - )) - - -def prepare_lora_checkpoint(tmp_path, model_dir, cur_step=3, consumed_train_samples=6): - model = build_multi_lora_model(model_dir) - model.set_optimizer('AdamW', adapter_name='default', lr=1e-4) - model.set_lr_scheduler('LinearLR', adapter_name='default') - model.set_grad_scaler(adapter_name='default', device='cpu') - model.optimizer_group['default'].cur_step = cur_step - return Path( - model.save( - name='lora-resume', - output_dir=str(tmp_path), - adapter_name='default', - save_optimizer=True, - consumed_train_samples=consumed_train_samples, - )) - - -def _ensure_file_exists(path: Path): - if path.exists(): - return - if path.name == 'trainer_state.json': - path.write_text('{"cur_step": 3}', encoding='utf-8') - else: - path.write_bytes(b'placeholder') - - -def test_save_training_state_writes_split_files(tmp_path, tiny_local_model_dir): - model = build_full_param_model(tiny_local_model_dir) - model.set_optimizer('AdamW', lr=1e-4) - model.set_lr_scheduler('LinearLR') - model.set_grad_scaler(device='cpu') - model.optimizer_group[''].cur_step = 7 - - ckpt_dir = Path(model.save(name='resume-step', output_dir=str(tmp_path), save_optimizer=True)) - - assert (ckpt_dir / 'optimizer.pt').exists() - assert (ckpt_dir / 'scheduler.pt').exists() - assert (ckpt_dir / 'scaler.pt').exists() - assert (ckpt_dir / 'trainer_state.json').exists() - assert (ckpt_dir / 'rng_state.pt').exists() - - -@pytest.mark.parametrize( - 'missing_name, expected_pattern', - [ - ('optimizer.pt', 'optimizer.pt'), - ('scheduler.pt', 'scheduler.pt'), - ('rng_state.pt', 'rng_state.pt'), - ('trainer_state.json', 'trainer_state.json'), - ], -) -def test_load_training_state_fails_when_required_file_missing(tmp_path, tiny_local_model_dir, missing_name, - expected_pattern): - ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir) - _ensure_file_exists(ckpt_dir / missing_name) - (ckpt_dir / missing_name).unlink() - - restored = build_full_param_model(ckpt_dir) - restored.set_optimizer('AdamW', lr=1e-4) - restored.set_lr_scheduler('LinearLR') - restored.set_grad_scaler(device='cpu') - - with pytest.raises(FileNotFoundError, match=expected_pattern): - restored.load_training_state(str(ckpt_dir)) - - -def test_load_training_state_allows_missing_scaler_file(tmp_path, tiny_local_model_dir): - ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir, cur_step=8, consumed_train_samples=16) - (ckpt_dir / 'scaler.pt').unlink() - - restored = build_full_param_model(ckpt_dir) - restored.set_optimizer('AdamW', lr=1e-4) - restored.set_lr_scheduler('LinearLR') - - trainer_state = restored.load_training_state(str(ckpt_dir)) - - assert trainer_state['cur_step'] == 8 - assert trainer_state['consumed_train_samples'] == 16 - assert restored.optimizer_group[''].cur_step == 8 - - -def test_load_training_state_fails_for_malformed_trainer_state(tmp_path, tiny_local_model_dir): - ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir) - (ckpt_dir / 'trainer_state.json').write_text('{"cur_step": 3}', encoding='utf-8') - - restored = build_full_param_model(ckpt_dir) - restored.set_optimizer('AdamW', lr=1e-4) - restored.set_lr_scheduler('LinearLR') - restored.set_grad_scaler(device='cpu') - - with pytest.raises((KeyError, ValueError), match='gradient_accumulation_steps|consumed_train_samples'): - restored.load_training_state(str(ckpt_dir)) - - -def test_full_parameter_resume_restores_training_state_after_init(tmp_path, tiny_local_model_dir): - ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir, cur_step=9, consumed_train_samples=18) - saved = build_full_param_model(tiny_local_model_dir) - saved.set_optimizer('AdamW', lr=1e-4) - saved.set_lr_scheduler('LinearLR') - saved.set_grad_scaler(device='cpu') - saved.optimizer_group[''].cur_step = 9 - saved.optimizer_group[''].gradient_accumulation_steps = 4 - ckpt_dir = Path( - saved.save( - name='full-resume', - output_dir=str(tmp_path), - save_optimizer=True, - consumed_train_samples=18, - )) - restored = build_full_param_model(ckpt_dir) - restored.set_optimizer('AdamW', lr=1e-4) - restored.set_lr_scheduler('LinearLR') - restored.set_grad_scaler(device='cpu') - - trainer_state = restored.load_training_state(str(ckpt_dir)) - - assert trainer_state['cur_step'] == 9 - assert trainer_state['gradient_accumulation_steps'] == 4 - assert trainer_state['consumed_train_samples'] == 18 - assert restored.optimizer_group[''].cur_step == 9 - assert restored.optimizer_group[''].gradient_accumulation_steps == 4 - - -def test_lora_load_does_not_restore_training_state_without_explicit_resume(tmp_path, tiny_local_model_dir): - saved = build_multi_lora_model(tiny_local_model_dir) - saved.set_optimizer('AdamW', adapter_name='default', lr=1e-4) - saved.set_lr_scheduler('LinearLR', adapter_name='default') - saved.set_grad_scaler(adapter_name='default', device='cpu') - saved.optimizer_group['default'].cur_step = 5 - saved.optimizer_group['default'].gradient_accumulation_steps = 3 - ckpt_dir = Path( - saved.save( - name='lora-resume', - output_dir=str(tmp_path), - adapter_name='default', - save_optimizer=True, - consumed_train_samples=10, - )) - restored = build_multi_lora_model(tiny_local_model_dir) - - assert restored.optimizer_group['default'].cur_step == 0 - assert restored.optimizer_group['default'].gradient_accumulation_steps == 1 - - restored.load(name=ckpt_dir.name, output_dir=str(ckpt_dir.parent), adapter_name='default') - assert restored.optimizer_group['default'].cur_step == 0 - assert restored.optimizer_group['default'].gradient_accumulation_steps == 1 - - restored.set_optimizer('AdamW', adapter_name='default', lr=1e-4) - restored.set_lr_scheduler('LinearLR', adapter_name='default') - restored.set_grad_scaler(adapter_name='default', device='cpu') - - trainer_state = restored.load_training_state(str(ckpt_dir), adapter_name='default') - - assert trainer_state['cur_step'] == 5 - assert trainer_state['gradient_accumulation_steps'] == 3 - assert trainer_state['consumed_train_samples'] == 10 - assert restored.optimizer_group['default'].cur_step == 5 - assert restored.optimizer_group['default'].gradient_accumulation_steps == 3 - - -def test_read_training_progress_returns_metadata_without_mutating_optimizer_state(tmp_path, tiny_local_model_dir): - ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir, cur_step=6, consumed_train_samples=12) - restored = build_full_param_model(ckpt_dir) - - progress = restored.read_training_progress(str(ckpt_dir)) - - assert progress['cur_step'] == 6 - assert progress['consumed_train_samples'] == 12 - assert restored.optimizer_group[''].cur_step == 0 - assert restored.optimizer_group[''].gradient_accumulation_steps == 1 - - -def test_native_fsdp_strategy_requires_wrapped_optimizer_state(): - strategy = build_native_fsdp_strategy() - - assert strategy.needs_wrapped_optimizer_state() is True - - -def test_native_fsdp_strategy_save_optimizer_checkpoint_uses_full_state_dict(monkeypatch, tmp_path): - strategy = build_native_fsdp_strategy() - model = object() - optimizer = object() - optimizer_path = tmp_path / 'optimizer.pt' - captured = {} - - from torch.distributed.checkpoint import state_dict as checkpoint_state_dict - - def fake_get_optimizer_state_dict(model_arg, optimizer_arg, *, options=None): - captured['model'] = model_arg - captured['optimizer'] = optimizer_arg - captured['options'] = options - return {'state': {'step': 3}} - - def fake_save(obj, path): - captured['saved_obj'] = obj - captured['saved_path'] = path - - monkeypatch.setattr(checkpoint_state_dict, 'get_optimizer_state_dict', fake_get_optimizer_state_dict) - monkeypatch.setattr(torch, 'save', fake_save) - - strategy.save_optimizer_checkpoint(model, optimizer, str(optimizer_path)) - - assert captured['model'] is model - assert captured['optimizer'] is optimizer - assert captured['saved_obj'] == {'state': {'step': 3}} - assert captured['saved_path'] == str(optimizer_path) - assert captured['options'].full_state_dict is True - assert captured['options'].cpu_offload is True - assert captured['options'].broadcast_from_rank0 is False - - -def test_native_fsdp_strategy_load_optimizer_checkpoint_broadcasts_from_rank0(monkeypatch, tmp_path): - strategy = build_native_fsdp_strategy() - model = object() - optimizer = object() - optimizer_path = tmp_path / 'optimizer.pt' - optimizer_path.write_bytes(b'placeholder') - expected_state = {'state': {'step': 7}} - captured = {} - - from torch.distributed.checkpoint import state_dict as checkpoint_state_dict - - def fake_load(path, map_location=None, weights_only=None): - captured['loaded_path'] = path - captured['map_location'] = map_location - captured['weights_only'] = weights_only - return expected_state - - def fake_set_optimizer_state_dict(model_arg, optimizer_arg, optim_state_dict, *, options=None): - captured['model'] = model_arg - captured['optimizer'] = optimizer_arg - captured['optim_state_dict'] = optim_state_dict - captured['options'] = options - - monkeypatch.setattr(torch, 'load', fake_load) - monkeypatch.setattr(checkpoint_state_dict, 'set_optimizer_state_dict', fake_set_optimizer_state_dict) - - strategy.load_optimizer_checkpoint(model, optimizer, str(optimizer_path)) - - assert captured['loaded_path'] == str(optimizer_path) - assert captured['map_location'] == 'cpu' - assert captured['weights_only'] is True - assert captured['model'] is model - assert captured['optimizer'] is optimizer - assert captured['optim_state_dict'] == expected_state - assert captured['options'].full_state_dict is True - assert captured['options'].cpu_offload is False - assert captured['options'].broadcast_from_rank0 is True diff --git a/tests/server/model/test_twinkle_resume_routes.py b/tests/server/model/test_twinkle_resume_routes.py deleted file mode 100644 index cd0fcfde..00000000 --- a/tests/server/model/test_twinkle_resume_routes.py +++ /dev/null @@ -1,156 +0,0 @@ -import concurrent.futures -import importlib.util -import sys -from fastapi import FastAPI -from fastapi.testclient import TestClient -from pathlib import Path -from types import ModuleType -from unittest.mock import Mock - - -class _NoOpProcessPoolExecutor: - - def __init__(self, *args, **kwargs): - pass - - def submit(self, fn, *args, **kwargs): - raise RuntimeError('Process pool is disabled in this test environment.') - - -concurrent.futures.ProcessPoolExecutor = _NoOpProcessPoolExecutor - -ROOT = Path(__file__).resolve().parents[3] -SRC = ROOT / 'src' -if str(SRC) not in sys.path: - sys.path.insert(0, str(SRC)) - -if 'twinkle.server.common.checkpoint_factory' not in sys.modules: - fake_checkpoint_factory = ModuleType('twinkle.server.common.checkpoint_factory') - fake_checkpoint_factory.create_checkpoint_manager = lambda *args, **kwargs: None - fake_checkpoint_factory.create_training_run_manager = lambda *args, **kwargs: None - sys.modules['twinkle.server.common.checkpoint_factory'] = fake_checkpoint_factory - -from twinkle_client.types.checkpoint import ResolvedLoadPath - -_HANDLERS_PATH = SRC / 'twinkle' / 'server' / 'model' / 'twinkle_handlers.py' -_HANDLERS_SPEC = importlib.util.spec_from_file_location('twinkle_resume_test_handlers', _HANDLERS_PATH) -handlers = importlib.util.module_from_spec(_HANDLERS_SPEC) -sys.modules[_HANDLERS_SPEC.name] = handlers -_HANDLERS_SPEC.loader.exec_module(handlers) - - -class _FakeCheckpointManager: - - def __init__(self, checkpoint_dir='./resolved/weights'): - self._checkpoint_dir = checkpoint_dir - - def resolve_load_path(self, path: str) -> ResolvedLoadPath: - return ResolvedLoadPath( - checkpoint_name='ckpt-1', - checkpoint_dir=self._checkpoint_dir, - is_twinkle_path=True, - training_run_id='run-1', - checkpoint_id='weights/ckpt-1', - ) - - -class _FakeModelManagement: - - def __init__(self): - self.model = Mock() - - async def _on_request_start(self, request): - request.state.request_id = 'req-1' - return 'token-1' - - def assert_resource_exists(self, adapter_name): - return None - - async def schedule_task_and_wait(self, task, task_type=''): - return await task() - - -def _build_test_client(monkeypatch, checkpoint_manager=None): - management = _FakeModelManagement() - checkpoint_manager = checkpoint_manager or _FakeCheckpointManager() - monkeypatch.setattr(handlers, 'create_checkpoint_manager', lambda token, client_type='twinkle': checkpoint_manager) - - app = FastAPI() - handlers._register_twinkle_routes(app, lambda: management) - return TestClient(app), management - - -def test_load_training_state_route_resolves_checkpoint_path_and_calls_model(monkeypatch): - client, management = _build_test_client(monkeypatch) - - response = client.post( - '/twinkle/load_training_state', - json={ - 'name': 'twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', - 'adapter_name': '' - }, - ) - - assert response.status_code == 200 - management.model.load_training_state.assert_called_once_with( - 'resolved/weights/ckpt-1', - adapter_name=None, - ) - management.model.read_training_progress.assert_not_called() - - -def test_load_training_state_route_prefixes_non_empty_adapter_name(monkeypatch): - client, management = _build_test_client(monkeypatch) - - response = client.post( - '/twinkle/load_training_state', - json={ - 'name': 'twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', - 'adapter_name': 'adapter-a' - }, - ) - - assert response.status_code == 200 - management.model.load_training_state.assert_called_once_with( - 'resolved/weights/ckpt-1', - adapter_name='req-1-adapter-a', - ) - - -def test_load_training_state_route_uses_raw_name_when_checkpoint_dir_missing(monkeypatch): - client, management = _build_test_client(monkeypatch, checkpoint_manager=_FakeCheckpointManager(checkpoint_dir=None)) - - response = client.post( - '/twinkle/load_training_state', - json={ - 'name': 'local-checkpoint-dir', - 'adapter_name': '' - }, - ) - - assert response.status_code == 200 - management.model.load_training_state.assert_called_once_with( - 'local-checkpoint-dir', - adapter_name=None, - ) - - -def test_read_training_progress_route_returns_progress_and_calls_model(monkeypatch): - client, management = _build_test_client(monkeypatch) - management.model.read_training_progress.return_value = {'cur_step': 6, 'consumed_train_samples': 12} - - response = client.post( - '/twinkle/read_training_progress', - json={ - 'name': 'twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', - 'adapter_name': '' - }, - ) - - assert response.status_code == 200 - assert response.json()['result']['consumed_train_samples'] == 12 - management.model.read_training_progress.assert_called_once_with( - 'resolved/weights/ckpt-1', - adapter_name=None, - ) - management.model.load_training_state.assert_not_called() From 483778d7482eb506138343f308fe25e1d1c22b04 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 31 Mar 2026 10:29:29 +0800 Subject: [PATCH 15/28] wip --- .gitignore | 2 -- 1 file changed, 2 deletions(-) diff --git a/.gitignore b/.gitignore index ae4d2cb3..58f495d4 100644 --- a/.gitignore +++ b/.gitignore @@ -29,7 +29,6 @@ wheels/ /temp MANIFEST .locks/ -tmp_test_checkpoints/ # PyInstaller # Usually these files are written by a python script from a template @@ -144,7 +143,6 @@ images /custom/ megatron_output/ .qoder -.worktrees/ # Pytorch *.pth From 039789b72f695dd9f20fbf9a396d5bdf44f5e835 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 31 Mar 2026 14:32:32 +0800 Subject: [PATCH 16/28] wip --- client_tools/client_generator.py | 3 ++- .../client/twinkle/self_host/self_congnition.py | 15 ++++++++++++--- src/twinkle/server/model/twinkle_handlers.py | 9 +++++---- src/twinkle_client/dataloader/dataloader.py | 13 +++++++++++++ .../model/multi_lora_transformers.py | 3 ++- 5 files changed, 34 insertions(+), 9 deletions(-) diff --git a/client_tools/client_generator.py b/client_tools/client_generator.py index f724c7c1..3dc99eba 100644 --- a/client_tools/client_generator.py +++ b/client_tools/client_generator.py @@ -618,13 +618,14 @@ def load(self, name: str, **kwargs) -> None: ) response.raise_for_status() - def load_training_state(self, name: str, **kwargs) -> None: + def load_training_state(self, name: str, **kwargs) -> Dict[str, Any]: """Load optimizer, scheduler, scaler, RNG, and progress metadata from a checkpoint.""" response = http_post( url=f'{self.server_url}/load_training_state', json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() + return TrainingProgressResponse(**response.json()).result def read_training_progress(self, name: str, **kwargs) -> Dict[str, Any]: """Read progress-only checkpoint metadata for resume-only-model flows.""" diff --git a/cookbook/client/twinkle/self_host/self_congnition.py b/cookbook/client/twinkle/self_host/self_congnition.py index e31daaba..3975f2a8 100644 --- a/cookbook/client/twinkle/self_host/self_congnition.py +++ b/cookbook/client/twinkle/self_host/self_congnition.py @@ -99,9 +99,13 @@ def train(): # model.set_lr_scheduler('LinearLR') # Step 6: Optionally resume from a previous checkpoint + consumed_train_samples = 0 if resume_path: - logger.info(f'Resuming training from {resume_path}') - model.load(resume_path, load_optimizer=True) + logger.info(f'Resuming model weights from {resume_path}') + model.load(resume_path) + trainer_state = model.load_training_state(resume_path) + dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) + consumed_train_samples = int(trainer_state['consumed_train_samples']) # Step 7: Run the training loop logger.info(model.get_train_configs().model_dump()) @@ -114,6 +118,7 @@ def train(): # Step model.clip_grad_and_step() + consumed_train_samples += len(batch) # Equal to the following steps: # # Clip gradients to prevent exploding gradients (max norm = 1.0) # model.clip_grad_norm(1.0) @@ -131,7 +136,11 @@ def train(): logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric.result}') # Step 8: Save the trained checkpoint - twinkle_path = model.save(name=f'twinkle-epoch-{epoch}', save_optimizer=True) + twinkle_path = model.save( + name=f'twinkle-epoch-{epoch}', + save_optimizer=True, + consumed_train_samples=consumed_train_samples, + ) logger.info(f'Saved checkpoint: {twinkle_path}') # Step 9: Upload the checkpoint to ModelScope Hub diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index b48b4e12..250fdc57 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -349,12 +349,12 @@ async def _task(): await run_task(self.schedule_task_and_wait(_task, task_type='load')) - @app.post('/twinkle/load_training_state') + @app.post('/twinkle/load_training_state', response_model=types.TrainingProgressResponse) async def load_training_state( request: Request, body: types.LoadTrainingStateRequest, self: ModelManagement = Depends(self_fn), - ) -> None: + ) -> types.TrainingProgressResponse: token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) @@ -366,13 +366,14 @@ async def _task(): checkpoint_dir = ( Path(resolved.checkpoint_dir, resolved.checkpoint_name).as_posix() if resolved.checkpoint_dir else body.name) - self.model.load_training_state( + ret = self.model.load_training_state( checkpoint_dir, adapter_name=adapter_name, **extra_kwargs, ) + return {'result': ret} - await run_task(self.schedule_task_and_wait(_task, task_type='load_training_state')) + return await run_task(self.schedule_task_and_wait(_task, task_type='load_training_state')) @app.post('/twinkle/read_training_progress', response_model=types.TrainingProgressResponse) async def read_training_progress( diff --git a/src/twinkle_client/dataloader/dataloader.py b/src/twinkle_client/dataloader/dataloader.py index 0a067ddd..f6a24fe4 100644 --- a/src/twinkle_client/dataloader/dataloader.py +++ b/src/twinkle_client/dataloader/dataloader.py @@ -82,4 +82,17 @@ def __next__(self): ) response.raise_for_status() return response.json()["result"] + + + def skip_consumed_samples(self, consumed_train_samples: int): + response = http_post( + url=f'{self.server_url}/call', + json_data={ + 'processor_id': self.processor_id, + 'function': 'skip_consumed_samples', + **{'consumed_train_samples': consumed_train_samples}, + } + ) + response.raise_for_status() + return response.json()["result"] \ No newline at end of file diff --git a/src/twinkle_client/model/multi_lora_transformers.py b/src/twinkle_client/model/multi_lora_transformers.py index 2a8afdce..37eac765 100644 --- a/src/twinkle_client/model/multi_lora_transformers.py +++ b/src/twinkle_client/model/multi_lora_transformers.py @@ -189,13 +189,14 @@ def load(self, name: str, **kwargs) -> None: ) response.raise_for_status() - def load_training_state(self, name: str, **kwargs) -> None: + def load_training_state(self, name: str, **kwargs) -> Dict[str, Any]: """Load optimizer, scheduler, scaler, RNG, and progress metadata from a checkpoint.""" response = http_post( url=f'{self.server_url}/load_training_state', json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() + return TrainingProgressResponse(**response.json()).result def read_training_progress(self, name: str, **kwargs) -> Dict[str, Any]: """Read progress-only checkpoint metadata for resume-only-model flows.""" From 54de1a44b9f36316f2c6198b3e13f0ae63a7dc7a Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 31 Mar 2026 15:43:57 +0800 Subject: [PATCH 17/28] wip --- src/twinkle/model/transformers/transformers.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index fca3909c..acda5da0 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -39,11 +39,15 @@ from twinkle.patch import Patch, apply_patch from twinkle.processor import InputProcessor from twinkle.template import Template +from twinkle.utils.logger import get_logger from twinkle.utils import construct_class, selective_log_softmax, torch_util from twinkle.utils.framework import Torch from twinkle.utils.grad_clip import normalize_and_clip_grad_norm +logger = get_logger() + + @dataclass class OptimizerGroup(BaseOptimizerGroup): """Optimizer group for Transformers training.""" @@ -983,8 +987,10 @@ def _load_optimizer(self, checkpoint_dir, **kwargs): if strict and not os.path.exists(optimizer_path): raise FileNotFoundError(optimizer_path) - if strict and not os.path.exists(scheduler_path): - raise FileNotFoundError(scheduler_path) + if strict and optimizer_config.lr_scheduler is not None and not os.path.exists(scheduler_path): + logger.warning( + f'Missing scheduler checkpoint {scheduler_path}; resuming without restoring lr scheduler state.', + ) if os.path.exists(optimizer_path) and optimizer_config.optimizer is not None: if getattr(self.strategy, 'needs_wrapped_optimizer_state', lambda: False)() and not self._model_wrapped: @@ -1062,7 +1068,6 @@ def load_training_state(self, checkpoint_dir, **kwargs): required_paths = { 'trainer_state': os.path.join(checkpoint_dir, 'trainer_state.json'), 'optimizer': os.path.join(checkpoint_dir, 'optimizer.pt'), - 'scheduler': os.path.join(checkpoint_dir, 'scheduler.pt'), 'rng': os.path.join(checkpoint_dir, 'rng_state.pt'), } for path in required_paths.values(): From 920ab869a7a96e065af3412fedc408d15835dd24 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 31 Mar 2026 15:45:28 +0800 Subject: [PATCH 18/28] wip --- .../twinkle/self_host/self_congnition.py | 9 +++-- cookbook/transformers/resume_utils.py | 40 +++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) create mode 100644 cookbook/transformers/resume_utils.py diff --git a/cookbook/client/twinkle/self_host/self_congnition.py b/cookbook/client/twinkle/self_host/self_congnition.py index 3975f2a8..967cddbe 100644 --- a/cookbook/client/twinkle/self_host/self_congnition.py +++ b/cookbook/client/twinkle/self_host/self_congnition.py @@ -100,25 +100,28 @@ def train(): # Step 6: Optionally resume from a previous checkpoint consumed_train_samples = 0 + global_step = 0 if resume_path: logger.info(f'Resuming model weights from {resume_path}') model.load(resume_path) trainer_state = model.load_training_state(resume_path) dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) consumed_train_samples = int(trainer_state['consumed_train_samples']) + global_step = int(trainer_state['cur_step']) # Step 7: Run the training loop logger.info(model.get_train_configs().model_dump()) for epoch in range(3): logger.info(f'Starting epoch {epoch}') - for step, batch in enumerate(dataloader): + for _, batch in enumerate(dataloader): # Forward pass + backward pass (computes gradients) model.forward_backward(inputs=batch) # Step model.clip_grad_and_step() consumed_train_samples += len(batch) + global_step += 1 # Equal to the following steps: # # Clip gradients to prevent exploding gradients (max norm = 1.0) # model.clip_grad_norm(1.0) @@ -130,10 +133,10 @@ def train(): # model.lr_step() # Log the loss every 2 steps (aligned with gradient accumulation) - if step % 2 == 0: + if global_step % 2 == 0: # Print metric metric = model.calculate_metric(is_training=True) - logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric.result}') + logger.info(f'Current is step {global_step} of {len(dataloader)}, metric: {metric.result}') # Step 8: Save the trained checkpoint twinkle_path = model.save( diff --git a/cookbook/transformers/resume_utils.py b/cookbook/transformers/resume_utils.py new file mode 100644 index 00000000..86745d0a --- /dev/null +++ b/cookbook/transformers/resume_utils.py @@ -0,0 +1,40 @@ +from pathlib import Path +from typing import Any, Optional +from twinkle import get_logger + + +logger = get_logger() + + +def resume_from_checkpoint( + model: Any, + dataloader: Any, + checkpoint_path: Path, + *, + resume_only_model: bool, + ignore_data_skip: bool, + adapter_name: Optional[str] = None) -> int: + checkpoint_dir = str(checkpoint_path) + model_kwargs = {} + if adapter_name is not None: + model_kwargs['adapter_name'] = adapter_name + + if resume_only_model: + if ignore_data_skip: + logger.info('Resumed weights only and restarted progress from step 0.') + return 0 + progress = model.read_training_progress(checkpoint_dir, **model_kwargs) + dataloader.skip_consumed_samples(progress['consumed_train_samples']) + optimizer_group_name = adapter_name if adapter_name is not None else '' + model.optimizer_group[optimizer_group_name].cur_step = progress['cur_step'] + model.optimizer_group[optimizer_group_name].gradient_accumulation_steps = progress[ + 'gradient_accumulation_steps'] + consumed_train_samples = int(progress['consumed_train_samples']) + logger.info(f'Skipped {consumed_train_samples} consumed samples.') + return consumed_train_samples + + trainer_state = model.load_training_state(checkpoint_dir, **model_kwargs) + dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) + consumed_train_samples = int(trainer_state['consumed_train_samples']) + logger.info(f'Restored full training state from step {trainer_state["cur_step"]}.') + return consumed_train_samples From ffd630484a2e4d196bba9afd39a4a8f1ee10981f Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 31 Mar 2026 15:46:32 +0800 Subject: [PATCH 19/28] lint --- src/twinkle/model/transformers/transformers.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index acda5da0..4e641217 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -39,11 +39,10 @@ from twinkle.patch import Patch, apply_patch from twinkle.processor import InputProcessor from twinkle.template import Template -from twinkle.utils.logger import get_logger from twinkle.utils import construct_class, selective_log_softmax, torch_util from twinkle.utils.framework import Torch from twinkle.utils.grad_clip import normalize_and_clip_grad_norm - +from twinkle.utils.logger import get_logger logger = get_logger() @@ -989,8 +988,7 @@ def _load_optimizer(self, checkpoint_dir, **kwargs): raise FileNotFoundError(optimizer_path) if strict and optimizer_config.lr_scheduler is not None and not os.path.exists(scheduler_path): logger.warning( - f'Missing scheduler checkpoint {scheduler_path}; resuming without restoring lr scheduler state.', - ) + f'Missing scheduler checkpoint {scheduler_path}; resuming without restoring lr scheduler state.', ) if os.path.exists(optimizer_path) and optimizer_config.optimizer is not None: if getattr(self.strategy, 'needs_wrapped_optimizer_state', lambda: False)() and not self._model_wrapped: From 582bd41c5e1a059cad7f96d3d2b99dd87a04d0e9 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 31 Mar 2026 15:53:09 +0800 Subject: [PATCH 20/28] wip --- client_tools/client_generator.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/client_tools/client_generator.py b/client_tools/client_generator.py index 3dc99eba..c9f43fab 100644 --- a/client_tools/client_generator.py +++ b/client_tools/client_generator.py @@ -1,4 +1,4 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. +# Copyright (c) ModelScope Contributors. All rights reserved. import ast from pathlib import Path from typing import Dict, List, Set, Tuple @@ -869,5 +869,4 @@ def apply_patch(self, patch_cls: str, **kwargs) -> None: generate_samplers() print('\n' + '=' * 60) - print('\nAll client code generation complete!\n') - + print('\n✓ All client code generation complete!\n') From 9cb6106b2fc76fccd4c5d6f6cbab7219f948429f Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 1 Apr 2026 09:27:29 +0800 Subject: [PATCH 21/28] wip --- ...00\344\275\263\345\256\236\350\267\265.md" | 24 +++---------------- 1 file changed, 3 insertions(+), 21 deletions(-) diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" index f0112042..ad78e28d 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" @@ -410,19 +410,9 @@ def train(): model.set_lr_scheduler('LinearLR') # 恢复训练(如有检查点) - resume_from_checkpoint = resume_path - resume_only_model = False - ignore_data_skip = False - if resume_from_checkpoint: - logger.info(f'Resuming training from {resume_from_checkpoint}') - model.load(name=resume_from_checkpoint) - - if not resume_only_model: - trainer_state = model.load_training_state(resume_from_checkpoint) - dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) - elif not ignore_data_skip: - progress = model.read_training_progress(resume_from_checkpoint) - dataloader.skip_consumed_samples(progress['consumed_train_samples']) + if resume_path: + logger.info(f'Resuming training from {resume_path}') + model.load(resume_path, load_optimizer=True) logger.info(model.get_train_configs()) @@ -455,14 +445,6 @@ if __name__ == '__main__': - 支持断点续训、检查点管理 - 可动态切换 LoRA 适配器、损失函数、优化器等组件 -Resume 模式: - -- `resume_from_checkpoint=None`:开始新的训练任务。 -- `resume_only_model=False`:恢复权重、optimizer、scheduler、scaler、RNG 和进度元数据。 -- `resume_only_model=True` 且 `ignore_data_skip=False`:恢复权重,读取进度元数据,并跳过已消费样本。 -- `resume_only_model=True` 且 `ignore_data_skip=True`:只恢复权重,训练步数和数据进度从 0 开始。 -- `skip_consumed_samples(...)` 不适用于 iterable / streaming dataset。 - ### 3.2 Tinker Client:简洁即用 Tinker 是一个轻量级训练 API。Twinkle 对 Tinker 客户端提供完整支持,几行代码就能拉起训练。已有 Tinker 代码的项目可以直接迎移到 Twinkle 服务端。 From c0cf72e8089e2c40596b2d28d7ee7ea863ba613e Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 1 Apr 2026 10:27:01 +0800 Subject: [PATCH 22/28] wip --- cookbook/transformers/fsdp2.py | 86 ++++++++++++++++++++------- cookbook/transformers/resume_utils.py | 19 ++++-- 2 files changed, 80 insertions(+), 25 deletions(-) diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index 7b6bd2a8..47512629 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -1,3 +1,5 @@ +from pathlib import Path + from peft import LoraConfig from tqdm import tqdm @@ -8,21 +10,39 @@ from twinkle.model import TransformersModel from twinkle.preprocessor import SelfCognitionProcessor -# Construct a device_mesh, fsdp_size=2, dp=4 -device_mesh = DeviceMesh.from_sizes(fsdp_size=2, dp_size=4) -# use torchrun mode -twinkle.initialize(mode='local', global_device_mesh=device_mesh) +from resume_utils import resume_from_checkpoint logger = get_logger() +MODEL_ID = 'ms://Qwen/Qwen3.5-4B' +DATASET_ID = 'ms://swift/self-cognition' +FSDP_SIZE = 2 +DP_SIZE = 4 +BATCH_SIZE = 8 +LEARNING_RATE = 1e-4 +GRADIENT_ACCUMULATION_STEPS = 2 +LOG_INTERVAL = 20 +EVAL_INTERVAL = 40 + +OUTPUT_DIR = './output/fsdp2' +RESUME_FROM_CHECKPOINT = None +RESUME_ONLY_MODEL = False +IGNORE_DATA_SKIP = False +ADAPTER_NAME = 'default' + +# Construct a device_mesh +device_mesh = DeviceMesh.from_sizes(fsdp_size=FSDP_SIZE, dp_size=DP_SIZE) +# use torchrun mode +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + def eval(model): # 100 Samples - dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100))) - dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(100))) + dataset.set_template('Qwen3_5Template', model_id=MODEL_ID) dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) dataset.encode() - dataloader = DataLoader(dataset=dataset, batch_size=8) + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) for step, batch in tqdm(enumerate(dataloader)): model.forward_only(inputs=batch) model.calculate_loss() @@ -32,29 +52,41 @@ def eval(model): def train(): # 1000 samples - dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(1000))) # Set template to prepare encoding - dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') + dataset.set_template('Qwen3_5Template', model_id=MODEL_ID) # Preprocess the dataset to standard format dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) # Encode dataset dataset.encode() # Global batch size = 8, for GPUs, so 1 sample per GPU - dataloader = DataLoader(dataset=dataset, batch_size=8) + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) # Use a TransformersModel - model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B') + model = TransformersModel(model_id=MODEL_ID) model.model._no_split_modules = {'Qwen3_5DecoderLayer'} lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') # Add a lora to model, with name `default` - # Comment this to use full-parameter training - model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2) + model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) # Add Optimizer for lora `default` - model.set_optimizer(optimizer_cls='AdamW', lr=1e-4) + model.set_optimizer(optimizer_cls='AdamW', lr=LEARNING_RATE) # Add LRScheduler for lora `default` model.set_lr_scheduler( scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader)) + + consumed_train_samples = 0 + if RESUME_FROM_CHECKPOINT: + checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() + consumed_train_samples = resume_from_checkpoint( + model=model, + dataloader=dataloader, + checkpoint_path=checkpoint_path, + resume_only_model=RESUME_ONLY_MODEL, + ignore_data_skip=IGNORE_DATA_SKIP, + adapter_name=ADAPTER_NAME, + ) + logger.info(get_device_placement()) # Print the training config logger.info(model.get_train_configs()) @@ -67,18 +99,32 @@ def train(): model.forward_backward(inputs=batch) # Step model.clip_grad_and_step() - if step % 20 == 0: + consumed_train_samples += BATCH_SIZE + cur_step = model.optimizer_group[ADAPTER_NAME].cur_step + if cur_step % LOG_INTERVAL == 0: # Print metric metric = model.calculate_metric(is_training=True) - logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') - if step > 0 and step % 40 == 0: + logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}') + if cur_step > 0 and cur_step % EVAL_INTERVAL == 0: metrics = eval(model) logger.info(f'Eval metric: {metrics}') - metrics['step'] = step + metrics['step'] = cur_step if loss_metric > float(metrics['loss']): - model.save(f'checkpoint-{step}') + model.save( + f'checkpoint-{cur_step}', + output_dir=OUTPUT_DIR, + adapter_name=ADAPTER_NAME, + save_optimizer=True, + consumed_train_samples=consumed_train_samples, + ) loss_metric = float(metrics['loss']) - model.save(f'last-checkpoint') + model.save( + 'last-checkpoint', + output_dir=OUTPUT_DIR, + adapter_name=ADAPTER_NAME, + save_optimizer=True, + consumed_train_samples=consumed_train_samples, + ) if __name__ == '__main__': diff --git a/cookbook/transformers/resume_utils.py b/cookbook/transformers/resume_utils.py index 86745d0a..3d2d197b 100644 --- a/cookbook/transformers/resume_utils.py +++ b/cookbook/transformers/resume_utils.py @@ -15,24 +15,33 @@ def resume_from_checkpoint( ignore_data_skip: bool, adapter_name: Optional[str] = None) -> int: checkpoint_dir = str(checkpoint_path) + adapter_name = adapter_name or '' model_kwargs = {} - if adapter_name is not None: + if adapter_name != '': + # Load adapter checkpoint. model_kwargs['adapter_name'] = adapter_name + model.load( + name=checkpoint_path.name, + output_dir=str(checkpoint_path.parent), + **model_kwargs, + ) if resume_only_model: + # Only load model weights, optionally skip data. if ignore_data_skip: logger.info('Resumed weights only and restarted progress from step 0.') return 0 progress = model.read_training_progress(checkpoint_dir, **model_kwargs) + # Skip consumed samples in dataloader and move optimizer to the right step. dataloader.skip_consumed_samples(progress['consumed_train_samples']) - optimizer_group_name = adapter_name if adapter_name is not None else '' - model.optimizer_group[optimizer_group_name].cur_step = progress['cur_step'] - model.optimizer_group[optimizer_group_name].gradient_accumulation_steps = progress[ - 'gradient_accumulation_steps'] + model.optimizer_group[adapter_name].cur_step = progress['cur_step'] + model.optimizer_group[adapter_name].gradient_accumulation_steps = progress['gradient_accumulation_steps'] + consumed_train_samples = int(progress['consumed_train_samples']) logger.info(f'Skipped {consumed_train_samples} consumed samples.') return consumed_train_samples + # Load full training state, including model weights, optimizer states, and training progress. trainer_state = model.load_training_state(checkpoint_dir, **model_kwargs) dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) consumed_train_samples = int(trainer_state['consumed_train_samples']) From 505a75cbeb0564a9c914d85749161333405537b3 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 1 Apr 2026 10:41:08 +0800 Subject: [PATCH 23/28] wip --- cookbook/transformers/fsdp2.py | 76 +++++++++++++-------------- cookbook/transformers/resume_utils.py | 24 +++++---- 2 files changed, 53 insertions(+), 47 deletions(-) diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index 47512629..45dd8ac1 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -16,6 +16,9 @@ MODEL_ID = 'ms://Qwen/Qwen3.5-4B' DATASET_ID = 'ms://swift/self-cognition' +TEMPLATE_NAME = 'Qwen3_5Template' +MODEL_NAME = 'twinkle大模型' +MODEL_AUTHOR = 'ModelScope社区' FSDP_SIZE = 2 DP_SIZE = 4 BATCH_SIZE = 8 @@ -23,6 +26,8 @@ GRADIENT_ACCUMULATION_STEPS = 2 LOG_INTERVAL = 20 EVAL_INTERVAL = 40 +EVAL_SAMPLES = 100 +TRAIN_SAMPLES = 1000 OUTPUT_DIR = './output/fsdp2' RESUME_FROM_CHECKPOINT = None @@ -36,29 +41,34 @@ twinkle.initialize(mode='local', global_device_mesh=device_mesh) -def eval(model): - # 100 Samples - dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(100))) - dataset.set_template('Qwen3_5Template', model_id=MODEL_ID) - dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) +def build_dataset(num_samples: int) -> Dataset: + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(num_samples))) + dataset.set_template(TEMPLATE_NAME, model_id=MODEL_ID) + dataset.map(SelfCognitionProcessor(MODEL_NAME, MODEL_AUTHOR)) dataset.encode() - dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) - for step, batch in tqdm(enumerate(dataloader)): + return dataset + + +def save_checkpoint(model: TransformersModel, checkpoint_name: str, consumed_train_samples: int): + model.save( + checkpoint_name, + output_dir=OUTPUT_DIR, + adapter_name=ADAPTER_NAME, + save_optimizer=True, + consumed_train_samples=consumed_train_samples, + ) + + +def evaluate(model): + dataloader = DataLoader(dataset=build_dataset(EVAL_SAMPLES), batch_size=BATCH_SIZE) + for batch in tqdm(dataloader): model.forward_only(inputs=batch) model.calculate_loss() - metrics = model.calculate_metric(is_training=False) - return metrics + return model.calculate_metric(is_training=False) def train(): - # 1000 samples - dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(1000))) - # Set template to prepare encoding - dataset.set_template('Qwen3_5Template', model_id=MODEL_ID) - # Preprocess the dataset to standard format - dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) - # Encode dataset - dataset.encode() + dataset = build_dataset(TRAIN_SAMPLES) # Global batch size = 8, for GPUs, so 1 sample per GPU dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) # Use a TransformersModel @@ -68,7 +78,7 @@ def train(): lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') # Add a lora to model, with name `default` - model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) + model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) # Add Optimizer for lora `default` model.set_optimizer(optimizer_cls='AdamW', lr=LEARNING_RATE) # Add LRScheduler for lora `default` @@ -91,40 +101,30 @@ def train(): # Print the training config logger.info(model.get_train_configs()) logger.info(f'Total steps: {len(dataloader)}') - loss_metric = 99.0 + optimizer_group = model.optimizer_group[ADAPTER_NAME] + best_loss = float('inf') # lora: 8G * 8 # full: 18G * 8 - for step, batch in enumerate(dataloader): + for batch in dataloader: # Do forward and backward model.forward_backward(inputs=batch) # Step model.clip_grad_and_step() consumed_train_samples += BATCH_SIZE - cur_step = model.optimizer_group[ADAPTER_NAME].cur_step + cur_step = optimizer_group.cur_step if cur_step % LOG_INTERVAL == 0: # Print metric metric = model.calculate_metric(is_training=True) logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}') if cur_step > 0 and cur_step % EVAL_INTERVAL == 0: - metrics = eval(model) + metrics = evaluate(model) logger.info(f'Eval metric: {metrics}') metrics['step'] = cur_step - if loss_metric > float(metrics['loss']): - model.save( - f'checkpoint-{cur_step}', - output_dir=OUTPUT_DIR, - adapter_name=ADAPTER_NAME, - save_optimizer=True, - consumed_train_samples=consumed_train_samples, - ) - loss_metric = float(metrics['loss']) - model.save( - 'last-checkpoint', - output_dir=OUTPUT_DIR, - adapter_name=ADAPTER_NAME, - save_optimizer=True, - consumed_train_samples=consumed_train_samples, - ) + current_loss = float(metrics['loss']) + if current_loss < best_loss: + save_checkpoint(model, f'checkpoint-{cur_step}', consumed_train_samples) + best_loss = current_loss + save_checkpoint(model, 'last-checkpoint', consumed_train_samples) if __name__ == '__main__': diff --git a/cookbook/transformers/resume_utils.py b/cookbook/transformers/resume_utils.py index 3d2d197b..fdacf075 100644 --- a/cookbook/transformers/resume_utils.py +++ b/cookbook/transformers/resume_utils.py @@ -1,11 +1,18 @@ from pathlib import Path from typing import Any, Optional + from twinkle import get_logger logger = get_logger() +def _build_model_kwargs(adapter_name: str) -> dict: + if not adapter_name: + return {} + return {'adapter_name': adapter_name} + + def resume_from_checkpoint( model: Any, dataloader: Any, @@ -14,12 +21,11 @@ def resume_from_checkpoint( resume_only_model: bool, ignore_data_skip: bool, adapter_name: Optional[str] = None) -> int: - checkpoint_dir = str(checkpoint_path) adapter_name = adapter_name or '' - model_kwargs = {} - if adapter_name != '': + checkpoint_dir = str(checkpoint_path) + model_kwargs = _build_model_kwargs(adapter_name) + if model_kwargs: # Load adapter checkpoint. - model_kwargs['adapter_name'] = adapter_name model.load( name=checkpoint_path.name, output_dir=str(checkpoint_path.parent), @@ -33,17 +39,17 @@ def resume_from_checkpoint( return 0 progress = model.read_training_progress(checkpoint_dir, **model_kwargs) # Skip consumed samples in dataloader and move optimizer to the right step. - dataloader.skip_consumed_samples(progress['consumed_train_samples']) - model.optimizer_group[adapter_name].cur_step = progress['cur_step'] - model.optimizer_group[adapter_name].gradient_accumulation_steps = progress['gradient_accumulation_steps'] - consumed_train_samples = int(progress['consumed_train_samples']) + dataloader.skip_consumed_samples(consumed_train_samples) + optimizer_group = model.optimizer_group[adapter_name] + optimizer_group.cur_step = progress['cur_step'] + optimizer_group.gradient_accumulation_steps = progress['gradient_accumulation_steps'] logger.info(f'Skipped {consumed_train_samples} consumed samples.') return consumed_train_samples # Load full training state, including model weights, optimizer states, and training progress. trainer_state = model.load_training_state(checkpoint_dir, **model_kwargs) - dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) consumed_train_samples = int(trainer_state['consumed_train_samples']) + dataloader.skip_consumed_samples(consumed_train_samples) logger.info(f'Restored full training state from step {trainer_state["cur_step"]}.') return consumed_train_samples From a222b5b169b67bbad270105ecf8f37fa0e631190 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 1 Apr 2026 10:56:27 +0800 Subject: [PATCH 24/28] fix --- src/twinkle/dataloader/dataloader.py | 15 +++++++++++++++ src/twinkle/dataloader/retry_sampler.py | 12 +++++++++--- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/twinkle/dataloader/dataloader.py b/src/twinkle/dataloader/dataloader.py index e2ef57ce..0725ee47 100644 --- a/src/twinkle/dataloader/dataloader.py +++ b/src/twinkle/dataloader/dataloader.py @@ -1,5 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import copy +import os import warnings from functools import partial from typing import Callable, Optional, Type, Union @@ -56,6 +57,7 @@ def __init__(self, self._skip_samples = 0 self._base_batch_sampler = None self._base_sampler = None + self._retry_sampler_seed = self._resolve_retry_sampler_seed() self._set_work_init_fn() def _set_work_init_fn(self): @@ -65,6 +67,17 @@ def _set_work_init_fn(self): num_workers=num_workers, rank=self.device_mesh.data_rank if self.device_mesh else 0) + @staticmethod + def _resolve_retry_sampler_seed() -> int: + env_seed = os.environ.get('TWINKLE_SEED') + if env_seed is not None: + return int(env_seed) + try: + from twinkle.infra import _seed + return int(_seed) + except Exception: + return 42 + @remote_function() def __len__(self): self._lazy_init_dataloader() @@ -145,6 +158,7 @@ def _rebuild_sampler_stack(self): self._base_sampler, self.dataset, max_retries=self.max_retries, + seed=self._retry_sampler_seed, ) self.dataloader.batch_sampler = DeviceMeshSampler( batch_sampler, @@ -158,4 +172,5 @@ def _rebuild_sampler_stack(self): self.dataset, max_retries=self.max_retries, skip_samples=self._skip_samples, + seed=self._retry_sampler_seed, ) diff --git a/src/twinkle/dataloader/retry_sampler.py b/src/twinkle/dataloader/retry_sampler.py index 43307b1a..27ef3819 100644 --- a/src/twinkle/dataloader/retry_sampler.py +++ b/src/twinkle/dataloader/retry_sampler.py @@ -14,11 +14,17 @@ class RetrySampler(Sampler): max_retries: The maximum number of retries. """ - def __init__(self, original_sampler: Sampler, dataset: Dataset, max_retries=20, skip_samples: int = 0): + def __init__(self, + original_sampler: Sampler, + dataset: Dataset, + max_retries=20, + skip_samples: int = 0, + seed: int = 42): self.original_sampler = original_sampler self.dataset = dataset self.max_retries = max_retries self.skip_samples = skip_samples + self.seed = int(seed) def __iter__(self): emitted = 0 @@ -48,9 +54,9 @@ def __iter__(self): if emitted >= target_total: return - for idx in np.random.RandomState().permutation(len(self.dataset)).tolist(): + for idx in np.random.RandomState(self.seed).permutation(len(self.dataset)).tolist(): if emitted >= target_total: - raise StopIteration + return for _ in range(self.max_retries): try: # Skip None values and raises From 7499e00f500507b9aacd59432aaa36037abf0576 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 1 Apr 2026 11:09:50 +0800 Subject: [PATCH 25/28] wip --- src/twinkle/dataloader/retry_sampler.py | 2 +- src/twinkle/model/transformers/transformers.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/twinkle/dataloader/retry_sampler.py b/src/twinkle/dataloader/retry_sampler.py index 27ef3819..4d8c92e0 100644 --- a/src/twinkle/dataloader/retry_sampler.py +++ b/src/twinkle/dataloader/retry_sampler.py @@ -49,7 +49,7 @@ def __iter__(self): traceback.print_exc() continue else: - raise StopIteration(f'Max retries exceeded: {self.max_retries}, no valid data found.') + raise RuntimeError(f'Max retries exceeded: {self.max_retries}, no valid data found.') if emitted >= target_total: return diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 4e641217..060990ae 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -996,11 +996,11 @@ def _load_optimizer(self, checkpoint_dir, **kwargs): if hasattr(self.strategy, 'load_optimizer_checkpoint'): self.strategy.load_optimizer_checkpoint(self.model, optimizer_config.optimizer, optimizer_path) else: - state_dict = torch.load(optimizer_path, map_location='cpu', weights_only=False) + state_dict = torch.load(optimizer_path, map_location='cpu', weights_only=True) optimizer_config.optimizer.load_state_dict(state_dict) if os.path.exists(scheduler_path) and optimizer_config.lr_scheduler is not None: - state_dict = torch.load(scheduler_path, map_location='cpu', weights_only=False) + state_dict = torch.load(scheduler_path, map_location='cpu', weights_only=True) optimizer_config.lr_scheduler.load_state_dict(state_dict) def _load_scaler_state(self, scaler_path, **kwargs): @@ -1009,7 +1009,7 @@ def _load_scaler_state(self, scaler_path, **kwargs): if optimizer_config.scaler is None: raise ValueError(f'Grad scaler is not configured for adapter {adapter_name!r}') - scaler_state = torch.load(scaler_path, map_location='cpu', weights_only=False) + scaler_state = torch.load(scaler_path, map_location='cpu', weights_only=True) optimizer_config.scaler.load_state_dict(scaler_state['scaler_state_dict']) optimizer_config.scaler_has_nan = scaler_state.get('scaler_has_nan', False) From cd0b09411ed3cf902ca4662b30a204d783709516 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 1 Apr 2026 14:40:17 +0800 Subject: [PATCH 26/28] doc --- .../Components/Model/TransformersModel.md | 13 ++++ docs/source_en/Usage Guide/Quick-Start.md | 67 +++++++++++++++++++ .../Server and Client/Twinkle-Client.md | 49 +++++++++----- ...53\351\200\237\345\274\200\345\247\213.md" | 66 ++++++++++++++++++ ...le\345\256\242\346\210\267\347\253\257.md" | 49 +++++++++----- .../TransformersModel.md" | 13 ++++ 6 files changed, 221 insertions(+), 36 deletions(-) diff --git a/docs/source_en/Components/Model/TransformersModel.md b/docs/source_en/Components/Model/TransformersModel.md index f10b0351..ba4ba1b4 100644 --- a/docs/source_en/Components/Model/TransformersModel.md +++ b/docs/source_en/Components/Model/TransformersModel.md @@ -50,3 +50,16 @@ for data in dataloader: model.forward_backward(...) model.clip_grad_and_step(..., gradient_accumulation_steps=16) ``` + +## Checkpoint and Resume + +`TransformersModel.save()` can save either weights only or a resumable training checkpoint. + +- `model.save(name, save_optimizer=True, consumed_train_samples=...)` saves weights together with optimizer, scheduler, scaler, RNG, and `trainer_state.json`. +- `model.load(name, output_dir=..., adapter_name=...)` restores LoRA / adapter model weights. +- `model.read_training_progress(checkpoint_dir, ...)` reads checkpoint metadata such as `cur_step`, `gradient_accumulation_steps`, and `consumed_train_samples`. +- `model.load_training_state(checkpoint_dir, ...)` restores optimizer-related state and returns the training progress dictionary. + +For full-parameter training, restore model weights by constructing `TransformersModel` with the checkpoint path as `model_id`, for example `TransformersModel(model_id='./output/fsdp2/last-checkpoint')`, and then call `load_training_state(...)` to restore optimizer state and training progress. + +For end-to-end resume logic, including dataloader skipping, refer to `cookbook/transformers/fsdp2.py` and `cookbook/transformers/resume_utils.py`. diff --git a/docs/source_en/Usage Guide/Quick-Start.md b/docs/source_en/Usage Guide/Quick-Start.md index 6a05a53f..d8b72ace 100644 --- a/docs/source_en/Usage Guide/Quick-Start.md +++ b/docs/source_en/Usage Guide/Quick-Start.md @@ -230,6 +230,71 @@ When running, you need to launch training like this: CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 train.py ``` +### Resume from Checkpoint + +The local and `torchrun` training loops above can be extended to support checkpoint resumption. For a complete example, refer to `cookbook/transformers/fsdp2.py` together with `cookbook/transformers/resume_utils.py`. + +When saving a checkpoint intended for resumption, save both model weights and training progress: + +```python +consumed_train_samples = 0 + +def save_checkpoint(model, checkpoint_name): + model.save( + checkpoint_name, + output_dir='./output/fsdp2', + adapter_name='default', + save_optimizer=True, + consumed_train_samples=consumed_train_samples, + ) +``` + +`save_optimizer=True` stores optimizer-related state, and `consumed_train_samples` is written into `trainer_state.json` so the dataloader can skip samples that have already been consumed. + +To resume training, restore the checkpoint before entering the main loop: + +```python +from pathlib import Path + +from resume_utils import resume_from_checkpoint + +RESUME_FROM_CHECKPOINT = './output/fsdp2/last-checkpoint' +RESUME_ONLY_MODEL = False +IGNORE_DATA_SKIP = False + +consumed_train_samples = 0 +if RESUME_FROM_CHECKPOINT: + consumed_train_samples = resume_from_checkpoint( + model=model, + dataloader=dataloader, + checkpoint_path=Path(RESUME_FROM_CHECKPOINT).expanduser().resolve(), + resume_only_model=RESUME_ONLY_MODEL, + ignore_data_skip=IGNORE_DATA_SKIP, + adapter_name='default', + ) +``` + +This helper provides two common resume modes: + +- Full resume: restore weights, optimizer, scheduler, scaler, RNG state, and training progress, then skip consumed samples in the dataloader. +- Weights-only resume: restore only model weights. This is useful when you want to continue with fresh optimizer state or intentionally restart the schedule. + +When `RESUME_ONLY_MODEL=True`, `IGNORE_DATA_SKIP=False` still skips already consumed samples based on `trainer_state.json`. If you want to reload weights but restart the dataset from the beginning, set `IGNORE_DATA_SKIP=True`. + +The flow above is intended for LoRA / adapter training. For full-parameter training, restore model weights by passing the checkpoint path as `model_id` when constructing `TransformersModel`, instead of calling `model.load(...)`. For example: + +```python +resume_path = './output/fsdp2/last-checkpoint' +model = TransformersModel(model_id=resume_path) +trainer_state = model.load_training_state(resume_path) +dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) +``` + +In other words: + +- LoRA / adapter resume: create `TransformersModel` from the original base model, then restore via `model.load(...)` or `resume_from_checkpoint(...)`. +- Full-parameter resume: construct `TransformersModel(...)` with the checkpoint path as `model_id`, then call `load_training_state(...)` to restore optimizer state and training progress. + ### Ray Training [Ray](https://github.com/ray-project/ray) is a commonly used scheduling middleware framework for multi-machine model training and inference scenarios. It provides additional optimizations for multi-model, multi-device execution and resource management, and supports integration with Kubernetes systems for production deployment. These characteristics make it particularly suitable for complex training scenarios such as RL and GKD. @@ -413,6 +478,8 @@ python train.py A major feature of Twinkle is support for multi-tenant mixed training. Specifically, multiple users can use a single base model for LoRA training, which can greatly reduce server-side deployment costs. +Checkpoint resumption is also supported in client-server training. The recommended flow is to restore weights with `model.load(resume_path)`, then restore optimizer and progress metadata with `model.load_training_state(resume_path)`, and finally call `dataloader.skip_consumed_samples(...)`. See `docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md` and `cookbook/client/twinkle/self_host/self_congnition.py`. + Suppose we start a service using eight GPUs. First, we need to start the Ray cluster: ```shell diff --git a/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md b/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md index 66d98eec..e373e30c 100644 --- a/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md +++ b/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md @@ -122,32 +122,36 @@ model.set_optimizer('AdamW', lr=1e-4) model.set_lr_scheduler('LinearLR') # Step 5: Resume training (optional) +consumed_train_samples = 0 +global_step = 0 if resume_path: - logger.info(f'Resuming training from {resume_path}') - model.load(resume_path, load_optimizer=True) + logger.info(f'Resuming model weights from {resume_path}') + model.load(resume_path) + trainer_state = model.load_training_state(resume_path) + dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) + consumed_train_samples = int(trainer_state['consumed_train_samples']) + global_step = int(trainer_state['cur_step']) # Step 6: Training loop -for step, batch in enumerate(dataloader): +for _, batch in enumerate(dataloader): # Forward propagation + backward propagation - output = model.forward_backward(inputs=batch) + model.forward_backward(inputs=batch) - if step % 2 == 0: - logger.info(f'Step {step // 2}, loss: {output}') + # Step + model.clip_grad_and_step() + consumed_train_samples += len(batch) + global_step += 1 - # Gradient clipping - model.clip_grad_norm(1.0) - - # Optimizer update - model.step() - - # Zero gradients - model.zero_grad() - - # Learning rate scheduling - model.lr_step() + if global_step % 2 == 0: + metric = model.calculate_metric(is_training=True) + logger.info(f'Current is step {global_step} of {len(dataloader)}, metric: {metric.result}') # Step 7: Save checkpoint -twinkle_path = model.save(name=f'step-{step}', save_optimizer=True) +twinkle_path = model.save( + name=f'step-{global_step}', + save_optimizer=True, + consumed_train_samples=consumed_train_samples, +) logger.info(f"Saved checkpoint: {twinkle_path}") # Step 8: Upload to ModelScope Hub (optional) @@ -158,6 +162,15 @@ model.upload_to_hub( ) ``` +For checkpoint resumption, the recommended client-side flow is: + +1. Query the server for an existing checkpoint path with `client.list_checkpoints(...)` or `client.get_latest_checkpoint_path(...)`. +2. Call `model.load(resume_path)` to restore adapter weights. +3. Call `model.load_training_state(resume_path)` to restore optimizer, scheduler, RNG, and progress metadata. +4. Call `dataloader.skip_consumed_samples(...)` with `consumed_train_samples` from the returned trainer state. + +This matches the end-to-end example in `cookbook/client/twinkle/self_host/self_congnition.py`. + ## Differences with Megatron Backend When using the Megatron backend, the main differences in client code: diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" index 0b8e386a..a7d98732 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" @@ -231,6 +231,71 @@ if __name__ == '__main__': CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 train.py ``` +### 断点续训 + +上面的本地训练和 `torchrun` 训练循环,都可以扩展为支持断点续训。完整示例可以直接参考 `cookbook/transformers/fsdp2.py` 和 `cookbook/transformers/resume_utils.py`。 + +如果希望保存出来的 checkpoint 可以用于续训,保存时除了模型权重,还需要把训练进度一并落盘: + +```python +consumed_train_samples = 0 + +def save_checkpoint(model, checkpoint_name): + model.save( + checkpoint_name, + output_dir='./output/fsdp2', + adapter_name='default', + save_optimizer=True, + consumed_train_samples=consumed_train_samples, + ) +``` + +其中,`save_optimizer=True` 会保存优化器相关状态,`consumed_train_samples` 会写入 `trainer_state.json`,用于恢复时让 dataloader 跳过已经消费过的数据。 + +恢复训练时,建议在进入主训练循环之前先加载 checkpoint: + +```python +from pathlib import Path + +from resume_utils import resume_from_checkpoint + +RESUME_FROM_CHECKPOINT = './output/fsdp2/last-checkpoint' +RESUME_ONLY_MODEL = False +IGNORE_DATA_SKIP = False + +consumed_train_samples = 0 +if RESUME_FROM_CHECKPOINT: + consumed_train_samples = resume_from_checkpoint( + model=model, + dataloader=dataloader, + checkpoint_path=Path(RESUME_FROM_CHECKPOINT).expanduser().resolve(), + resume_only_model=RESUME_ONLY_MODEL, + ignore_data_skip=IGNORE_DATA_SKIP, + adapter_name='default', + ) +``` + +这个辅助函数覆盖了两种常见恢复模式: + +- 完整续训:恢复权重、优化器、学习率调度器、梯度缩放器、随机数状态和训练进度,并让 dataloader 跳过已消费样本。 +- 仅恢复权重:只加载模型权重,不恢复优化器等训练状态。适合希望沿用参数初始化、但重新开始优化过程的场景。 + +当 `RESUME_ONLY_MODEL=True` 且 `IGNORE_DATA_SKIP=False` 时,仍会根据 `trainer_state.json` 跳过已训练过的数据;如果你只想加载权重、但从数据集开头重新训练,可以把 `IGNORE_DATA_SKIP=True`。 + +上面的恢复流程默认针对 LoRA / adapter 训练。对于全参训练,恢复模型权重时需要在创建 `TransformersModel` 时直接把 `model_id` 设为 checkpoint 路径,而不是再调用 `model.load(...)`。例如: + +```python +resume_path = './output/fsdp2/last-checkpoint' +model = TransformersModel(model_id=resume_path) +trainer_state = model.load_training_state(resume_path) +dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) +``` + +也就是说: + +- LoRA / adapter 续训:先按原始 base model 创建 `TransformersModel`,再通过 `model.load(...)` 或 `resume_from_checkpoint(...)` 恢复。 +- 全参续训:在 `TransformersModel(...)` 初始化时直接传入 checkpoint 路径作为 `model_id`,随后再调用 `load_training_state(...)` 恢复优化器和训练进度。 + ### Ray训练 [Ray](https://github.com/ray-project/ray)是多机模型训练和推理场景中常用的调度中间件框架。它针对多模型、多设备的执行和资源管理进行了额外优化, @@ -412,6 +477,7 @@ python train.py ``` ### 远程训练 +client-server 训练场景同样支持断点续训。推荐流程是先通过 `model.load(resume_path)` 恢复权重,再通过 `model.load_training_state(resume_path)` 恢复优化器和训练进度元数据,最后调用 `dataloader.skip_consumed_samples(...)` 跳过已消费数据。详细示例可参考 `docs/source_zh/使用指引/服务端和客户端/Twinkle客户端.md` 和 `cookbook/client/twinkle/self_host/self_congnition.py`。 Twinkle的一大特色是支持多租户用户混合训练。具体来说,多个用户可以使用一个基模进行lora训练,这样可以极大减小服务端部署成本。 diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" index fd81ac1b..5d4bafe7 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" @@ -122,32 +122,36 @@ model.set_optimizer('AdamW', lr=1e-4) model.set_lr_scheduler('LinearLR') # Step 5: 恢复训练(可选) +consumed_train_samples = 0 +global_step = 0 if resume_path: - logger.info(f'Resuming training from {resume_path}') - model.load(resume_path, load_optimizer=True) + logger.info(f'Resuming model weights from {resume_path}') + model.load(resume_path) + trainer_state = model.load_training_state(resume_path) + dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) + consumed_train_samples = int(trainer_state['consumed_train_samples']) + global_step = int(trainer_state['cur_step']) # Step 6: 训练循环 -for step, batch in enumerate(dataloader): +for _, batch in enumerate(dataloader): # 前向传播 + 反向传播 - output = model.forward_backward(inputs=batch) + model.forward_backward(inputs=batch) - if step % 2 == 0: - logger.info(f'Step {step // 2}, loss: {output}') + # Step + model.clip_grad_and_step() + consumed_train_samples += len(batch) + global_step += 1 - # 梯度裁剪 - model.clip_grad_norm(1.0) - - # 优化器更新 - model.step() - - # 梯度清零 - model.zero_grad() - - # 学习率调度 - model.lr_step() + if global_step % 2 == 0: + metric = model.calculate_metric(is_training=True) + logger.info(f'Current is step {global_step} of {len(dataloader)}, metric: {metric.result}') # Step 7: 保存检查点 -twinkle_path = model.save(name=f'step-{step}', save_optimizer=True) +twinkle_path = model.save( + name=f'step-{global_step}', + save_optimizer=True, + consumed_train_samples=consumed_train_samples, +) logger.info(f"Saved checkpoint: {twinkle_path}") # Step 8: 上传到 ModelScope Hub(可选) @@ -158,6 +162,15 @@ model.upload_to_hub( ) ``` +Twinkle Client 场景下,推荐的断点续训流程是: + +1. 先通过 `client.list_checkpoints(...)` 或 `client.get_latest_checkpoint_path(...)` 获取已有 checkpoint 路径。 +2. 调用 `model.load(resume_path)` 恢复 adapter 权重。 +3. 调用 `model.load_training_state(resume_path)` 恢复优化器、调度器、随机数状态和训练进度元数据。 +4. 使用返回结果中的 `consumed_train_samples` 调用 `dataloader.skip_consumed_samples(...)`,跳过已经训练过的数据。 + +完整示例可直接参考 `cookbook/client/twinkle/self_host/self_congnition.py`。 + ## Megatron 后端的差异 使用 Megatron 后端时,客户端代码的主要差异: diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" index b7b9cf0f..bb494131 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" @@ -50,3 +50,16 @@ for data in dataloader: model.forward_backward(...) model.clip_grad_and_step(..., gradient_accumulation_steps=16) ``` + +## 检查点保存与续训 + +`TransformersModel.save()` 既可以只保存权重,也可以保存可续训的训练检查点。 + +- `model.save(name, save_optimizer=True, consumed_train_samples=...)` 会在保存权重的同时,保存优化器、学习率调度器、梯度缩放器、随机数状态以及 `trainer_state.json`。 +- `model.load(name, output_dir=..., adapter_name=...)` 用于恢复 LoRA / adapter 模型权重。 +- `model.read_training_progress(checkpoint_dir, ...)` 用于读取 checkpoint 中的训练进度元数据,例如 `cur_step`、`gradient_accumulation_steps` 和 `consumed_train_samples`。 +- `model.load_training_state(checkpoint_dir, ...)` 用于恢复优化器等训练状态,并返回训练进度字典。 + +对于全参训练,恢复模型权重时需要在创建 `TransformersModel` 时直接把 checkpoint 路径传给 `model_id`,例如 `TransformersModel(model_id='./output/fsdp2/last-checkpoint')`,随后再调用 `load_training_state(...)` 恢复优化器和训练进度。 + +如果需要完整的断点续训流程,包括 dataloader 跳过已消费数据的逻辑,建议直接参考 `cookbook/transformers/fsdp2.py` 和 `cookbook/transformers/resume_utils.py`。 From abf2c2f6b5f1dd027dc0772b883d4c430cee4370 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 1 Apr 2026 17:08:07 +0800 Subject: [PATCH 27/28] wip --- .../model/transformers/transformers.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 060990ae..5bc2cb9f 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -43,6 +43,7 @@ from twinkle.utils.framework import Torch from twinkle.utils.grad_clip import normalize_and_clip_grad_norm from twinkle.utils.logger import get_logger +from twinkle.utils.platforms import Platform logger = get_logger() @@ -1019,12 +1020,12 @@ def _get_training_rng_state(self): 'numpy_rng_state': np.random.get_state(), 'torch_rng_state': torch.get_rng_state(), } - if hasattr(torch, 'npu') and torch.npu.is_available(): - state['device_type'] = 'npu' - state['device_rng_state'] = torch.npu.get_rng_state() - elif torch.cuda.is_available(): - state['device_type'] = 'cuda' - state['device_rng_state'] = torch.cuda.get_rng_state_all() + + device_prefix = Platform.device_prefix() + device_module = getattr(torch, device_prefix, None) + if device_module and hasattr(device_module, 'is_available') and device_module.is_available(): + state['device_type'] = device_prefix + state['device_rng_state'] = device_module.get_rng_state() else: state['device_type'] = 'cpu' state['device_rng_state'] = None @@ -1038,10 +1039,10 @@ def _load_rng_state(self, rng_path): device_type = rng_state.get('device_type') device_rng_state = rng_state.get('device_rng_state') - if device_type == 'npu' and hasattr(torch, 'npu') and torch.npu.is_available() and device_rng_state is not None: - torch.npu.set_rng_state(device_rng_state) - elif device_type == 'cuda' and torch.cuda.is_available() and device_rng_state is not None: - torch.cuda.set_rng_state_all(device_rng_state) + if device_type != 'cpu' and device_rng_state is not None: + device_module = getattr(torch, device_type, None) + if device_module and hasattr(device_module, 'is_available') and device_module.is_available(): + device_module.set_rng_state(device_rng_state) @remote_function() def read_training_progress(self, checkpoint_dir, **kwargs): From 8bf7a6ad0e975326771664942e280e19a01f5833 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 1 Apr 2026 17:31:12 +0800 Subject: [PATCH 28/28] lint --- src/twinkle/model/transformers/transformers.py | 1 - tests/dataloader/test_dataloader.py | 14 +++++++------- tests/dataloader/test_sampler.py | 10 +++++----- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 5bc2cb9f..e2ee5243 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -43,7 +43,6 @@ from twinkle.utils.framework import Torch from twinkle.utils.grad_clip import normalize_and_clip_grad_norm from twinkle.utils.logger import get_logger -from twinkle.utils.platforms import Platform logger = get_logger() diff --git a/tests/dataloader/test_dataloader.py b/tests/dataloader/test_dataloader.py index 82a4f41b..edad0dd3 100644 --- a/tests/dataloader/test_dataloader.py +++ b/tests/dataloader/test_dataloader.py @@ -7,6 +7,13 @@ from torch.utils.data import Dataset as TorchDataset from torch.utils.data import IterableDataset as TorchIterableDataset +import twinkle +from twinkle import DeviceMesh +from twinkle.data_format import Message +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta, IterableDataset +from twinkle.processor import InputProcessor + class _NoOpProcessPoolExecutor: @@ -19,13 +26,6 @@ def submit(self, fn, *args, **kwargs): concurrent.futures.ProcessPoolExecutor = _NoOpProcessPoolExecutor -import twinkle -from twinkle import DeviceMesh -from twinkle.data_format import Message -from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta, IterableDataset -from twinkle.processor import InputProcessor - twinkle.initialize(mode='local') TEST_DATA_DIR = Path(__file__).parent.parent / 'dataset' / 'test_data' diff --git a/tests/dataloader/test_sampler.py b/tests/dataloader/test_sampler.py index d5c97dbc..d90c8725 100644 --- a/tests/dataloader/test_sampler.py +++ b/tests/dataloader/test_sampler.py @@ -7,6 +7,11 @@ from torch.utils.data import Dataset as TorchDataset from torch.utils.data import RandomSampler, SequentialSampler +import twinkle +from twinkle import DeviceMesh +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta + class _NoOpProcessPoolExecutor: @@ -19,11 +24,6 @@ def submit(self, fn, *args, **kwargs): concurrent.futures.ProcessPoolExecutor = _NoOpProcessPoolExecutor -import twinkle -from twinkle import DeviceMesh -from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta - twinkle.initialize(mode='local') TEST_DATA_DIR = Path(__file__).parent.parent / 'dataset' / 'test_data'