Skip to content
6 changes: 3 additions & 3 deletions cookbook/megatron/tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.model import MegatronModel
from twinkle.preprocessor import SelfCognitionProcessor
# Construct a device_mesh, tp=pp=cp=2, dp=1
# Construct a device_mesh, tp=pp=dp=2
device_mesh = DeviceMesh.from_sizes(dp_size=2, tp_size=2, pp_size=2)
# use torchrun mode
twinkle.initialize(mode='local', global_device_mesh=device_mesh)
Expand All @@ -19,7 +19,7 @@
def eval(model):
# 100 Samples
dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100)))
dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B')
dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B')
dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
dataset.encode()
dataloader = DataLoader(dataset=dataset, batch_size=16)
Expand All @@ -33,7 +33,7 @@ def train():
# 1000 samples
dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000)))
# Set template to prepare encoding
dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B')
dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B')
# Preprocess the dataset to standard format
dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
# Encode dataset
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ transformers = [
"torchvision",
]
kernels = ["kernels"]
megatron = ["megatron-core>=0.12.0", "transformer-engine[pytorch]"]
megatron = ["megatron-core>=0.12.0", "transformer-engine[pytorch]", "mcore_bridge"]
vllm = ["vllm>=0.11"]
ray = ["ray[serve]"]
tinker = ["tinker==0.14.0"]
Expand Down
4 changes: 3 additions & 1 deletion src/twinkle/checkpoint_engine/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
# Adapted from https://github.com/volcengine/verl/blob/main/verl/checkpoint_engine/base.py
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, AsyncGenerator, Generator, TypedDict
from typing import TYPE_CHECKING, Any, AsyncGenerator, Generator, Optional, TypedDict

if TYPE_CHECKING:
import torch
Expand Down Expand Up @@ -38,6 +38,8 @@ class CheckpointEngine(ABC):
>>> engine.finalize()
"""

rank: Optional[int] = None

@abstractmethod
def prepare(self) -> dict[str, Any]:
"""Prepare the checkpoint engine before weight synchronization.
Expand Down
2 changes: 1 addition & 1 deletion src/twinkle/model/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import os
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union

from twinkle import Platform, torch_util
from twinkle.data_format import InputFeature, ModelOutput
Expand Down
Loading
Loading