diff --git a/README.md b/README.md index 046425d7..cd7eccfd 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,7 @@ Or use ModelScope's [official image](https://www.modelscope.cn/docs/intro/enviro ## Changelog +- ๐ŸŽ‰2026-03-28 Support DPO training with both Transformers and Megatron backends. See [dpo_full.py](cookbook/rl/dpo_full.py) and [dpo_lora.py](cookbook/rl/dpo_lora.py). - ๐ŸŽ‰2026-03-24 Twinkle Web site is now live at https://modelscope.github.io/twinkle-web/ - ๐ŸŽ‰2026-03-19 Support GKD training ๏ผŒplease refer to this [cookbook](cookbook/rl/gkd_on_policy.py). - ๐ŸŽ‰2026-02-13 Initial version of Twinkleโœจ released, including SFT/PT/RL support for text models. @@ -137,8 +138,8 @@ supported on Twinkleโœจ framework. | qwen3 series | [Qwen/Qwen3-14B-Base](https://modelscope.cn/models/Qwen/Qwen3-14B-Base) | 0.6B/1.7B/4B/8B/14B | transformers>=4.51 | โœ” | [Qwen/Qwen3-14B-Base](https://huggingface.co/Qwen/Qwen3-14B-Base) | | | [Qwen/Qwen3-32B](https://modelscope.cn/models/Qwen/Qwen3-32B) | 0.6B/1.7B/4B/8B/14B/32B | transformers>=4.51 | โœ” | [Qwen/Qwen3-32B](https://huggingface.co/Qwen/Qwen3-32B) | | qwen3_moe series | [Qwen/Qwen3-30B-A3B-Base](https://modelscope.cn/models/Qwen/Qwen3-30B-A3B-Base) | 30B-A3B/A3B-Base,235B-A22B | transformers>=4.51 | โœ” | [Qwen/Qwen3-30B-A3B-Base](https://huggingface.co/Qwen/Qwen3-30B-A3B-Base) | -| qwen3.5 moe series | [Qwen/Qwen3.5-35B-A3B](https://www.modelscope.cn/models/Qwen/Qwen3.5-35B-A3B) | 35B-A3B,122B-A10B, etc. | transformers>=5.20 | โœ” | [Qwen/Qwen3.5-35B-A3B](https://huggingface.co/Qwen/Qwen3.5-35B-A3B) | -| qwen3.5 series | [Qwen/Qwen3.5-9B](https://www.modelscope.cn/models/Qwen/Qwen3.5-9B) | 2B ~ 27B | transformers>=5.20 | โœ” | [Qwen/Qwen3.5-9B](https://huggingface.co/Qwen/Qwen3.5-9B) | +| qwen3.5 moe series | [Qwen/Qwen3.5-35B-A3B](https://www.modelscope.cn/models/Qwen/Qwen3.5-35B-A3B) | 35B-A3B,122B-A10B, etc. | transformers>=5.2.0 | โœ” | [Qwen/Qwen3.5-35B-A3B](https://huggingface.co/Qwen/Qwen3.5-35B-A3B) | +| qwen3.5 series | [Qwen/Qwen3.5-9B](https://www.modelscope.cn/models/Qwen/Qwen3.5-9B) | 2B ~ 27B | transformers>=5.2.0 | โœ” | [Qwen/Qwen3.5-9B](https://huggingface.co/Qwen/Qwen3.5-9B) | | qwen2 series | [Qwen/Qwen2-0.5B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-0.5B-Instruct) | 0.5B/1.5B/7B/72B | transformers>=4.37 | โœ” | [Qwen/Qwen2-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) | | | [Qwen/Qwen2-1.5B](https://modelscope.cn/models/Qwen/Qwen2-1.5B) | 0.5B/1.5B/7B/72B | transformers>=4.37 | โœ” | [Qwen/Qwen2-1.5B](https://huggingface.co/Qwen/Qwen2-1.5B) | | | [Qwen/Qwen2.5-1.5B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-1.5B-Instruct) | 0.5B/1.5B/3B/7B/14B/32B/72B | transformers>=4.37 | โœ” | [Qwen/Qwen2.5-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct) | diff --git a/README_ZH.md b/README_ZH.md index 5b4d1191..e404508a 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -91,6 +91,7 @@ Twinkleโœจๆ”ฏๆŒ็›ธๅŒ็š„็ฎ—ๆณ•ๆŽฅๅฃ่ฟ่กŒๅœจๅ•GPUใ€torchrunๅคšๆœบใ€Rayใ€Cl ## ๆ›ดๆ–ฐๆ—ฅๅฟ— +๐ŸŽ‰2026-03-28 ๆ”ฏๆŒ DPO ่ฎญ็ปƒ๏ผŒๅŒๆ—ถๆ”ฏๆŒ Transformers ๅ’Œ Megatron ๅŽ็ซฏใ€‚ๅ‚่€ƒ [dpo_full.py](cookbook/rl/dpo_full.py) ๅ’Œ [dpo_lora.py](cookbook/rl/dpo_lora.py)ใ€‚ ๐ŸŽ‰2026-03-24 Twinkle ็ซ™็‚นไธŠ็บฟ๏ผŒ่ฎฟ้—ฎๅœฐๅ€ https://modelscope.github.io/twinkle-web/ ๐ŸŽ‰2026-03-19 ๆ”ฏๆŒGKD่’ธ้ฆ่ƒฝๅŠ›๏ผŒๅ‚่€ƒ[cookbook](cookbook/rl/gkd_on_policy.py)ใ€‚ ๐ŸŽ‰2026-02-13 Twinkleโœจ ๅˆๅง‹็‰ˆๆœฌๅ‘ๅธƒ๏ผŒๆ”ฏๆŒๆ–‡ๆœฌๆจกๅž‹็š„SFT/PT/RL่ฎญ็ปƒใ€‚ๆˆ‘ไปฌ่ฟ˜้€š่ฟ‡ๅ…ผๅฎนTinker็š„API๏ผŒๅœจ้ญ”ๆญ็คพๅŒบไธŠๆไพ›ไบ†ๆ— ๆœๅŠกๅ™จ่ฎญ็ปƒๅŠŸ่ƒฝใ€‚ @@ -120,8 +121,8 @@ Twinkleโœจๆ”ฏๆŒ็›ธๅŒ็š„็ฎ—ๆณ•ๆŽฅๅฃ่ฟ่กŒๅœจๅ•GPUใ€torchrunๅคšๆœบใ€Rayใ€Cl | qwen3 ๅ…จ็ณปๅˆ— | [Qwen/Qwen3-14B-Base](https://modelscope.cn/models/Qwen/Qwen3-14B-Base) | 0.6B/1.7B/4B/8B/14B | transformers>=4.51 | โœ” | [Qwen/Qwen3-14B-Base](https://huggingface.co/Qwen/Qwen3-14B-Base) | | | [Qwen/Qwen3-32B](https://modelscope.cn/models/Qwen/Qwen3-32B) | 0.6B/1.7B/4B/8B/14B/32B | transformers>=4.51 | โœ” | [Qwen/Qwen3-32B](https://huggingface.co/Qwen/Qwen3-32B) | | qwen3_moe ๅ…จ็ณปๅˆ— | [Qwen/Qwen3-30B-A3B-Base](https://modelscope.cn/models/Qwen/Qwen3-30B-A3B-Base) | 30B-A3B/A3B-Base๏ผŒ235B-A22B | transformers>=4.51 | โœ” | [Qwen/Qwen3-30B-A3B-Base](https://huggingface.co/Qwen/Qwen3-30B-A3B-Base) | -| qwen3.5 moe ๅ…จ็ณปๅˆ— | [Qwen/Qwen3.5-35B-A3B](https://www.modelscope.cn/models/Qwen/Qwen3.5-35B-A3B) | 35B-A3B,122B-A10B, etc. | transformers>=5.20 | โœ” | [Qwen/Qwen3.5-35B-A3B](https://huggingface.co/Qwen/Qwen3.5-35B-A3B) | -| qwen3.5 ๅ…จ็ณปๅˆ— | [Qwen/Qwen3.5-9B](https://www.modelscope.cn/models/Qwen/Qwen3.5-9B) | 2B ~ 27B | transformers>=5.20 | โœ” | [Qwen/Qwen3.5-9B](https://huggingface.co/Qwen/Qwen3.5-9B) | +| qwen3.5 moe ๅ…จ็ณปๅˆ— | [Qwen/Qwen3.5-35B-A3B](https://www.modelscope.cn/models/Qwen/Qwen3.5-35B-A3B) | 35B-A3B,122B-A10B, etc. | transformers>=5.2.0 | โœ” | [Qwen/Qwen3.5-35B-A3B](https://huggingface.co/Qwen/Qwen3.5-35B-A3B) | +| qwen3.5 ๅ…จ็ณปๅˆ— | [Qwen/Qwen3.5-9B](https://www.modelscope.cn/models/Qwen/Qwen3.5-9B) | 2B ~ 27B | transformers>=5.2.0 | โœ” | [Qwen/Qwen3.5-9B](https://huggingface.co/Qwen/Qwen3.5-9B) | | qwen2 ๅ…จ็ณปๅˆ— | [Qwen/Qwen2-0.5B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-0.5B-Instruct) | 0.5B/1.5B/7B/72B | transformers>=4.37 | โœ” | [Qwen/Qwen2-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) | | | [Qwen/Qwen2-1.5B](https://modelscope.cn/models/Qwen/Qwen2-1.5B) | 0.5B/1.5B/7B/72B | transformers>=4.37 | โœ” | [Qwen/Qwen2-1.5B](https://huggingface.co/Qwen/Qwen2-1.5B) | | | [Qwen/Qwen2.5-1.5B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-1.5B-Instruct) | 0.5B/1.5B/3B/7B/14B/32B/72B | transformers>=4.37 | โœ” | [Qwen/Qwen2.5-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct) | diff --git a/cookbook/rl/dpo_full.py b/cookbook/rl/dpo_full.py new file mode 100644 index 00000000..8c3e5a6f --- /dev/null +++ b/cookbook/rl/dpo_full.py @@ -0,0 +1,258 @@ +"""DPO (Direct Preference Optimization) Full-Parameter Training via Ray. + +Off-policy preference alignment: trains the model to prefer chosen responses +over rejected responses using preference data, without explicit reward modeling. + +Supports both Transformers (FSDP) and Megatron backends via USE_MEGATRON flag. + +Pipeline: + 1. Load preference dataset with chosen/rejected pairs. + 2. Encode positive and negative separately. + 3. Compute reference model log probabilities (frozen). + 4. Train policy model using DPO loss (full-parameter, no LoRA). + +Architecture (Ray): + โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” + โ”‚ Driver (CPU) โ”‚ + โ”‚ dataloader โ”€โ”€โ–บ batched preference pairs โ”‚ + โ”‚ ref_model.forward_only() โ”€โ”€โ–บ reference log probs โ”‚ + โ”‚ policy_model.forward_backward() โ”€โ”€โ–บ DPO loss + gradient โ”‚ + โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + โ”‚ โ”‚ โ”‚ + DataLoader RefModel (frozen) PolicyModel (trainable) + (ref GPUs) (policy GPUs) + +DPO data format (after preprocessing): + - positive: List[Trajectory] - chosen responses + - negative: List[Trajectory] - rejected responses + +For SimPO/ORPO variants that don't require a reference model, +set REF_MODEL_GPUS=0 to skip reference model computation. + +Environment variables (all optional): + USE_MEGATRON โ€“ Use Megatron backend (default: 0, use Transformers) + MODEL_ID โ€“ (default: ms://Qwen/Qwen3-4B) + DATASET_ID โ€“ (default: ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji) + MODEL_GPUS โ€“ GPUs for policy model (default: 4) + REF_MODEL_GPUS โ€“ GPUs for reference model (default: 4, 0 to disable) + BATCH_SIZE โ€“ global batch size (preference pairs) (default: 8) + MAX_STEPS โ€“ total optimization steps (default: 1000) + LR โ€“ learning rate (default: 1e-5) + DPO_BETA โ€“ DPO temperature parameter (default: 0.1) + LOSS_TYPE โ€“ DPO variant (sigmoid/hinge/ipo/simpo/orpo/cpo) (default: sigmoid) + SAVE_STEPS โ€“ checkpoint save interval (default: 100) + MAX_LENGTH โ€“ max sequence length (default: 2048) +""" + +import os +from typing import Any, Dict, List + +import twinkle +from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger +from twinkle.data_format import Trajectory +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.loss import CPOLoss, DPOLoss, ORPOLoss, SimPOLoss +from twinkle.metric import DPOMetric +from twinkle.preprocessor import EmojiDPOProcessor +from twinkle.processor import InputProcessor + +logger = get_logger() + +# โ”€โ”€ Configuration โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ +USE_MEGATRON = int(os.environ.get('USE_MEGATRON', 0)) +MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3-4B') +DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji') + +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) +REF_MODEL_GPUS = int(os.environ.get('REF_MODEL_GPUS', 4)) +NUM_GPUS = MODEL_GPUS + REF_MODEL_GPUS + +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs +GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 2)) +LEARNING_RATE = float(os.environ.get('LR', 1e-5)) +DPO_BETA = float(os.environ.get('DPO_BETA', 0.1)) +SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 1.0)) # SFT loss weight for regularization +LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo, simpo, orpo, cpo +SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 100)) +MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048)) +SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', 'You are a helpful assistant.') + + +def create_dpo_dataset(): + """Create DPO dataset with positive/negative format.""" + dataset = Dataset(DatasetMeta(DATASET_ID, data_slice=range(6000))) + dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH) + dataset.map( + EmojiDPOProcessor, + init_args={ + 'system': SYSTEM_PROMPT, + } + ) + # DPO preprocessor returns {'positive': [...], 'negative': [...]} + # batch_encode handles this format automatically + dataset.encode(load_from_cache_file=True) + return dataset + + +def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Prepare DPO batch: reorganize batch for training with DP-safe interleaving. + + Args: + batch: List of rows, each with 'positive' and 'negative' InputFeatures + and other fields (question, etc.) + + Returns: + List interleaved as [pos_1, neg_1, pos_2, neg_2, ...] to ensure each DP + worker gets complete positive/negative pairs after slicing. + Each item contains all original fields plus the InputFeature fields. + """ + result = [] + + for row in batch: + # Get base fields (excluding positive/negative) + base_fields = {k: v for k, v in row.items() if k not in ('positive', 'negative')} + + # Positive sample: merge base fields with positive InputFeature + pos_sample = {**base_fields, **row['positive']} + # Negative sample: merge base fields with negative InputFeature + neg_sample = {**base_fields, **row['negative']} + + # Interleave: [pos, neg] per pair for DP-safe slicing + result.append(pos_sample) + result.append(neg_sample) + + return result + + +# โ”€โ”€ Loss Factory โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +def create_loss(loss_type: str, beta: float, sft_weight: float = 0.0, reference_free: bool = False): + """Create the appropriate loss function based on configuration.""" + if loss_type == 'simpo': + return SimPOLoss(beta=beta, gamma=0.5) + elif loss_type == 'orpo': + return ORPOLoss(lambda_orpo=beta) + elif loss_type == 'cpo': + return CPOLoss(beta=beta, bc_coef=1.0) + else: + # Standard DPO variants: sigmoid, hinge, ipo + return DPOLoss( + beta=beta, + loss_type=loss_type, + reference_free=reference_free, + sft_weight=sft_weight, + ) + + +# โ”€โ”€ Main Training Loop โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +def main(): + # Set up device groups + device_groups = [ + DeviceGroup(name='policy', ranks=list(range(MODEL_GPUS)), device_type='GPU'), + DeviceGroup(name='reference', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU'), + ] + + # Configure device mesh based on backend + if USE_MEGATRON: + # Megatron: dp=2, pp=2 for each model + from twinkle.model import MegatronModel + policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=2, pp_size=2) + ref_mesh = DeviceMesh.from_sizes(world_size=REF_MODEL_GPUS, dp_size=4) + ModelClass = MegatronModel + else: + # Transformers: fsdp=2, dp=2 for each model + from twinkle.model import TransformersModel + policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, fsdp_size=2, dp_size=2) + ref_mesh = DeviceMesh.from_sizes(world_size=REF_MODEL_GPUS, dp_size=4) + ModelClass = TransformersModel + + twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups) + + # โ”€โ”€ DataLoader Setup โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + dataloader = DataLoader( + dataset=create_dpo_dataset, + batch_size=BATCH_SIZE, + min_batch_size=BATCH_SIZE, + device_mesh=policy_mesh, + ) + + # โ”€โ”€ Policy Model Setup โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + policy_model = ModelClass( + model_id=MODEL_ID, + device_mesh=policy_mesh, + remote_group='policy', + ) + MAX_STEPS = len(dataloader) + + # Determine if we need reference model based on loss type + reference_free = LOSS_TYPE in ['simpo', 'orpo', 'cpo'] + + # Set up loss function and metrics + loss_fn = create_loss(LOSS_TYPE, DPO_BETA, sft_weight=SFT_WEIGHT, reference_free=reference_free) + + # Configure optimizer based on backend (full-parameter training) + if USE_MEGATRON: + policy_model.set_optimizer('default', lr=LEARNING_RATE, weight_decay=0.01) + policy_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS) + else: + policy_model.set_optimizer('AdamW', lr=LEARNING_RATE, weight_decay=0.01) + policy_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=LEARNING_RATE * 0.1) + + policy_model.set_loss(loss_fn) + policy_model.add_metric(DPOMetric, beta=DPO_BETA) + policy_model.set_processor(InputProcessor) + policy_model.set_template('Template', model_id=MODEL_ID) + + # โ”€โ”€ Reference Model Setup โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + ref_model = None + if not reference_free: + ref_model = ModelClass( + model_id=MODEL_ID, + device_mesh=ref_mesh, + remote_group='reference', + ) + ref_model.set_processor(InputProcessor) + ref_model.set_template('Template', model_id=MODEL_ID) + logger.info('Reference model initialized for DPO training') + else: + logger.info(f'Training without reference model (loss_type={LOSS_TYPE})') + + optim_step = 0 + backend_name = 'Megatron' if USE_MEGATRON else 'Transformers' + logger.info(get_device_placement()) + logger.info(f'Starting DPO training ({backend_name}): loss_type={LOSS_TYPE}, beta={DPO_BETA}') + + # โ”€โ”€ Training Loop โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + for batch in dataloader: + # batch is List[Dict] with 'positive' and 'negative' keys + dpo_batch = prepare_dpo_batch(batch) + + # Get reference outputs (lazy - not collected to driver) + ref_outputs = None + if ref_model is not None: + ref_outputs = ref_model.forward_only(inputs=dpo_batch) + + # Forward-backward pass with DPO loss + policy_model.forward_backward(inputs=dpo_batch, ref_outputs=ref_outputs) + policy_model.clip_grad_and_step() + + optim_step += 1 + + # Logging + if optim_step % GRADIENT_ACCUMULATION_STEPS == 0: + metrics = policy_model.calculate_metric(is_training=True) + logger.info(f'[Step {optim_step // GRADIENT_ACCUMULATION_STEPS}/{MAX_STEPS}] {metrics}') + + # Checkpointing + if optim_step % SAVE_STEPS == 0: + policy_model.save(f'dpo-checkpoint-{optim_step}') + + # โ”€โ”€ Save Final Checkpoint โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + logger.info(f'Training completed. Total steps: {optim_step}') + policy_model.save('dpo-final-checkpoint') + + +if __name__ == '__main__': + main() diff --git a/cookbook/rl/dpo_lora.py b/cookbook/rl/dpo_lora.py new file mode 100644 index 00000000..861e72a4 --- /dev/null +++ b/cookbook/rl/dpo_lora.py @@ -0,0 +1,229 @@ +"""DPO (Direct Preference Optimization) Training with LoRA (Single GPU Group). + +LoRA-based DPO training: uses the base model (without LoRA adapter) as reference +model by calling forward_only with disable_lora=True. This eliminates the need for +a separate reference model GPU group. + +Supports both Transformers (FSDP) and Megatron backends via USE_MEGATRON flag. + +Pipeline: + 1. Load preference dataset with chosen/rejected pairs. + 2. Encode positive and negative separately. + 3. Compute reference model log probabilities using base model (disable_lora=True). + 4. Train policy model (with LoRA adapter) using DPO loss. + +Architecture (Ray - Single Group): + โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” + โ”‚ Driver (CPU) โ”‚ + โ”‚ dataloader โ”€โ”€โ–บ batched preference pairs โ”‚ + โ”‚ policy_model.forward_only(disable_lora=True) โ”€โ”€โ–บ ref logps โ”‚ + โ”‚ policy_model.forward_backward() โ”€โ”€โ–บ DPO loss + gradient โ”‚ + โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + โ”‚ + PolicyModel (with LoRA adapter) + - forward_only(disable_lora=True) โ†’ base model inference (reference) + - forward_backward() โ†’ LoRA adapter training (policy) + +DPO data format (after preprocessing): + - positive: List[Trajectory] - chosen responses + - negative: List[Trajectory] - rejected responses + +Environment variables (all optional): + USE_MEGATRON โ€“ Use Megatron backend (default: 0, use Transformers) + MODEL_ID โ€“ (default: ms://Qwen/Qwen3-4B) + DATASET_ID โ€“ (default: ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji) + MODEL_GPUS โ€“ GPUs for policy model (default: 8) + BATCH_SIZE โ€“ global batch size (preference pairs) (default: 8) + MAX_STEPS โ€“ total optimization steps (default: 1000) + LR โ€“ learning rate (default: 1e-4) + DPO_BETA โ€“ DPO temperature parameter (default: 0.1) + LOSS_TYPE โ€“ DPO variant (sigmoid/hinge/ipo) (default: sigmoid) + SAVE_STEPS โ€“ checkpoint save interval (default: 100) + MAX_LENGTH โ€“ max sequence length (default: 2048) +""" + +import os +from typing import Any, Dict, List, Optional + +from peft import LoraConfig + +import twinkle +from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger +from twinkle.data_format import Trajectory +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.loss import DPOLoss +from twinkle.metric import DPOMetric +from twinkle.preprocessor import EmojiDPOProcessor +from twinkle.processor import InputProcessor + +logger = get_logger() + +# โ”€โ”€ Configuration โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ +USE_MEGATRON = int(os.environ.get('USE_MEGATRON', 0)) +MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3-4B') +DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji') + +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 8)) + +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs +GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 2)) +LEARNING_RATE = float(os.environ.get('LR', 1e-4)) # LoRA DPO requires higher LR (1e-4 to 3e-4) +DPO_BETA = float(os.environ.get('DPO_BETA', 0.1)) +SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 1.0)) # SFT loss weight for regularization +LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo +SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 100)) +MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048)) +ADAPTER_NAME = 'default' +SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', 'You are a helpful assistant.') + + +def create_dpo_dataset(): + """Create DPO dataset with positive/negative format.""" + dataset = Dataset(DatasetMeta(DATASET_ID, data_slice=range(6000))) + dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH) + dataset.map( + EmojiDPOProcessor, + init_args={ + 'system': SYSTEM_PROMPT, + } + ) + # DPO preprocessor returns {'positive': [...], 'negative': [...]} + # batch_encode handles this format automatically + dataset.encode(load_from_cache_file=True) + return dataset + + +def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Prepare DPO batch: reorganize batch for training with DP-safe interleaving. + + Args: + batch: List of rows, each with 'positive' and 'negative' InputFeatures + and other fields (question, etc.) + + Returns: + List interleaved as [pos_1, neg_1, pos_2, neg_2, ...] to ensure each DP + worker gets complete positive/negative pairs after slicing. + Each item contains all original fields plus the InputFeature fields. + """ + result = [] + + for row in batch: + # Get base fields (excluding positive/negative) + base_fields = {k: v for k, v in row.items() if k not in ('positive', 'negative')} + + # Positive sample: merge base fields with positive InputFeature + pos_sample = {**base_fields, **row['positive']} + # Negative sample: merge base fields with negative InputFeature + neg_sample = {**base_fields, **row['negative']} + + # Interleave: [pos, neg] per pair for DP-safe slicing + result.append(pos_sample) + result.append(neg_sample) + + return result + + +# โ”€โ”€ Main Training Loop โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +def main(): + # Set up device groups - only one group for LoRA training + device_groups = [ + DeviceGroup(name='policy', ranks=list(range(MODEL_GPUS)), device_type='GPU'), + ] + + # Configure device mesh based on backend + if USE_MEGATRON: + # Megatron: dp=4, pp=2 + from twinkle.model import MegatronModel + policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=4, pp_size=2) + ModelClass = MegatronModel + else: + # Transformers: dp_size=8 + # FSDP2 forward_only & forward has problems with `with unwrapped_model.disable_adapter()` + from twinkle.model import TransformersModel + policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=8) + ModelClass = TransformersModel + + twinkle.initialize(mode='ray', nproc_per_node=MODEL_GPUS, groups=device_groups) + + # โ”€โ”€ DataLoader Setup โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + dataloader = DataLoader( + dataset=create_dpo_dataset, + batch_size=BATCH_SIZE, + min_batch_size=BATCH_SIZE, + device_mesh=policy_mesh, + ) + + # โ”€โ”€ Policy Model Setup with LoRA โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + lora_config = LoraConfig( + target_modules='all-linear', + r=8, + lora_alpha=32, + lora_dropout=0.05, + ) + + policy_model = ModelClass( + model_id=MODEL_ID, + device_mesh=policy_mesh, + remote_group='policy', + ) + MAX_STEPS = len(dataloader) + policy_model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) + + # Configure optimizer based on backend + if USE_MEGATRON: + policy_model.set_optimizer('default', lr=LEARNING_RATE, weight_decay=0.01, adapter_name=ADAPTER_NAME) + policy_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, adapter_name=ADAPTER_NAME) + else: + policy_model.set_optimizer('AdamW', lr=LEARNING_RATE, weight_decay=0.01) + policy_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=LEARNING_RATE * 0.1) + + # Set up loss function and metrics + loss_fn = DPOLoss( + beta=DPO_BETA, + loss_type=LOSS_TYPE, + reference_free=False, # We use base model as reference via disable_lora=True + sft_weight=SFT_WEIGHT, + ) + + policy_model.set_loss(loss_fn) + policy_model.add_metric(DPOMetric, beta=DPO_BETA) + policy_model.set_processor(InputProcessor) + policy_model.set_template('Template', model_id=MODEL_ID) + + optim_step = 0 + backend_name = 'Megatron' if USE_MEGATRON else 'Transformers' + logger.info(get_device_placement()) + logger.info(f'Starting LoRA DPO training ({backend_name}): loss_type={LOSS_TYPE}, beta={DPO_BETA}, lr={LEARNING_RATE}') + logger.info(f'Using base model (disable_lora=True) as reference model') + + # โ”€โ”€ Training Loop โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + for batch in dataloader: + # batch is List[Dict] with 'positive' and 'negative' keys + dpo_batch = prepare_dpo_batch(batch) + + # Get reference outputs using base model (without LoRA adapter) + # disable_lora=True tells the model to skip LoRA and use base weights + ref_outputs = policy_model.forward_only(inputs=dpo_batch, disable_lora=True) + policy_model.forward_backward(inputs=dpo_batch, ref_outputs=ref_outputs) + policy_model.clip_grad_and_step() + + optim_step += 1 + + # Logging + if optim_step % GRADIENT_ACCUMULATION_STEPS == 0: + metrics = policy_model.calculate_metric(is_training=True) + logger.info(f'[Step {optim_step // GRADIENT_ACCUMULATION_STEPS}/{MAX_STEPS}] {metrics}') + + # Checkpointing + if optim_step % SAVE_STEPS == 0: + policy_model.save(f'dpo-lora-checkpoint-{optim_step}') + + # โ”€โ”€ Save Final Checkpoint โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + logger.info(f'Training completed. Total steps: {optim_step}') + policy_model.save('dpo-lora-final-checkpoint') + + +if __name__ == '__main__': + main() diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index 10d75df6..7b6bd2a8 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -1,16 +1,15 @@ -import os from peft import LoraConfig from tqdm import tqdm import twinkle -from twinkle import DeviceMesh, Platform, get_device_placement, get_logger +from twinkle import DeviceMesh, get_device_placement, get_logger from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import TransformersModel from twinkle.preprocessor import SelfCognitionProcessor -# Construct a device_mesh, dp=2 -device_mesh = DeviceMesh.from_sizes(dp_size=2) +# Construct a device_mesh, fsdp_size=2, dp=4 +device_mesh = DeviceMesh.from_sizes(fsdp_size=2, dp_size=4) # use torchrun mode twinkle.initialize(mode='local', global_device_mesh=device_mesh) @@ -20,7 +19,7 @@ def eval(model): # 100 Samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100))) - dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') + dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') dataset.map(SelfCognitionProcessor('twinkleๅคงๆจกๅž‹', 'ModelScope็คพๅŒบ')) dataset.encode() dataloader = DataLoader(dataset=dataset, batch_size=8) @@ -35,7 +34,7 @@ def train(): # 1000 samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) # Set template to prepare encoding - dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') + dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') # Preprocess the dataset to standard format dataset.map(SelfCognitionProcessor('twinkleๅคงๆจกๅž‹', 'ModelScope็คพๅŒบ')) # Encode dataset diff --git a/docs/source_en/Components/Data Format/Trajectory.md b/docs/source_en/Components/Data Format/Trajectory.md index d0c14aec..0efc6b6e 100644 --- a/docs/source_en/Components/Data Format/Trajectory.md +++ b/docs/source_en/Components/Data Format/Trajectory.md @@ -5,12 +5,14 @@ The raw data structure input to Template after dataset ETL is `Trajectory` (traj ```python class Trajectory(TypedDict, total=False): messages: List[Message] - extend_message: List[Tuple[str, List[Message]]] tools: List[Tool] + user_data: List[Tuple[str, Any]] ``` - messages: A list of Message messages, representing the multi-turn conversations actually conducted by the model, usually alternating between `user` and `assistant`. -- extend_message: In training such as DPO and PPO, unusable trajectories or low-score trajectories are usually needed, which will be placed in extend_message - tools: A list of all available tools for the model in this call +- user_data: User-defined data, such as labels in KTO training + +For preference alignment training like DPO, preprocessors return `{'positive': List[Trajectory], 'negative': List[Trajectory]}` format. Trajectory is the standard interface for all dataset preprocessing outputs and template inputs in Twinkle. The format conversion goes from the original dataset to Trajectory, and then to InputFeature. diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Trajectory.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Trajectory.md" index f7ed4f12..5281999c 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Trajectory.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Trajectory.md" @@ -5,12 +5,14 @@ ```python class Trajectory(TypedDict, total=False): messages: List[Message] - extend_message: List[Tuple[str, List[Message]]] tools: List[Tool] + user_data: List[Tuple[str, Any]] ``` - messages: Messageๆถˆๆฏ็š„ๅˆ—่กจ๏ผŒไปฃ่กจๆจกๅž‹ๅฎž้™…่ฟ›่กŒ็š„ๅคš่ฝฎๅฏน่ฏ๏ผŒ้€šๅธธๆ˜ฏ`user`ๅ’Œ`assistant`ไบคๆ›ฟๅ‡บ็Žฐใ€‚ -- extend_message: ๅœจDPOใ€PPO็ญ‰่ฎญ็ปƒไธญ้€šๅธธ้œ€่ฆไธๅฏ็”จ่ฝจ่ฟน๏ผŒๆˆ–ไฝŽๅˆ†่ฝจ่ฟน๏ผŒ่ฏฅ่ฝจ่ฟนไผšๆ”พๅœจextend_messageไธญ - tools: ๆจกๅž‹ๅœจๆœฌๆฌก่ฐƒ็”จไธญ็š„ๆ‰€ๆœ‰ๅฏ็”จๅทฅๅ…ทๅˆ—่กจ +- user_data: ็”จๆˆท่‡ชๅฎšไน‰ๆ•ฐๆฎ๏ผŒๅฆ‚KTO่ฎญ็ปƒไธญ็š„label + +ๅฏนไบŽDPO็ญ‰ๅๅฅฝๅฏน้ฝ่ฎญ็ปƒ๏ผŒ้ข„ๅค„็†ๅ™จ่ฟ”ๅ›ž`{'positive': List[Trajectory], 'negative': List[Trajectory]}`ๆ ผๅผใ€‚ Trajectoryๆ˜ฏtwinkleไธญๆ‰€ๆœ‰ๆ•ฐๆฎ้›†้ข„ๅค„็†่พ“ๅ‡บ๏ผŒๆจกๆฟ่พ“ๅ…ฅ็š„ๆ ‡ๅ‡†ๆŽฅๅฃใ€‚ๆ ผๅผ่ฝฌๆขไธบ็”ฑๅŽŸๅง‹ๆ•ฐๆฎ้›†่ฝฌๆขไธบTrajectory๏ผŒๅ†ๅˆฐInputFeatureใ€‚ diff --git a/src/twinkle/data_format/trajectory.py b/src/twinkle/data_format/trajectory.py index c7742d75..51a21fc5 100644 --- a/src/twinkle/data_format/trajectory.py +++ b/src/twinkle/data_format/trajectory.py @@ -13,6 +13,5 @@ class Trajectory(TypedDict, total=False): messages: List[Message] - extend_message: List[Tuple[str, List[Message]]] tools: List[Tool] user_data: List[Tuple[str, Any]] diff --git a/src/twinkle/dataset/base.py b/src/twinkle/dataset/base.py index 00adc984..6448cc2b 100644 --- a/src/twinkle/dataset/base.py +++ b/src/twinkle/dataset/base.py @@ -86,9 +86,9 @@ def encode(self, add_generation_prompt: bool = False, **kwargs): encode_fn = partial(self.template.batch_encode, add_generation_prompt=add_generation_prompt) with processing_lock('dataset'): # use a default lock because encode is to all datasets - self.dataset = self.dataset.map(encode_fn, - **kwargs).filter(lambda batch: [len(x) > 0 for x in batch['input_ids']], - **kwargs) + self.dataset = self.dataset.map(encode_fn, **kwargs).filter( + lambda batch: [True] * len(next(iter(batch.values()))) + if 'input_ids' not in batch else [len(x) > 0 for x in batch['input_ids']], **kwargs) @remote_function() def check(self, **kwargs): diff --git a/src/twinkle/infra/collectors.py b/src/twinkle/infra/collectors.py index af4d6d6e..aaa60819 100644 --- a/src/twinkle/infra/collectors.py +++ b/src/twinkle/infra/collectors.py @@ -1,10 +1,8 @@ import numpy as np -from typing import TYPE_CHECKING, Any, Dict, List +from typing import Any, Dict, List from twinkle import DeviceMesh - -if TYPE_CHECKING: - import torch +from twinkle.utils import pad_and_stack_tensors def collect_tensor_dict(outputs: List[Dict[str, Any]], device_mesh: DeviceMesh) -> Dict[str, Any]: @@ -19,7 +17,7 @@ def collect_tensor_dict(outputs: List[Dict[str, Any]], device_mesh: DeviceMesh) for d in outputs: all_keys.update(d.keys()) - outputs = [r for i, r in enumerate(outputs) if i in device_mesh.get_pp_last_ranks()] + outputs = [r for i, r in enumerate(outputs) if i in device_mesh.get_collect_ranks()] result = {} for key in all_keys: values = [d[key] for d in outputs if key in d] @@ -39,7 +37,7 @@ def collect_tensor_dict(outputs: List[Dict[str, Any]], device_mesh: DeviceMesh) result[key] = merged elif isinstance(first_value, torch.Tensor): - result[key] = _pad_and_stack_tensors(values) + result[key] = pad_and_stack_tensors(values) elif isinstance(first_value, dict): result[key] = collect_tensor_dict(values) @@ -53,36 +51,3 @@ def collect_tensor_dict(outputs: List[Dict[str, Any]], device_mesh: DeviceMesh) if 'loss' in result and len(result['loss']) > 1: result['loss'] = np.mean(result['loss']) return result - - -def _pad_and_stack_tensors(tensors: List['torch.Tensor'], pad_value: float = -200) -> 'torch.Tensor': - import torch - if not tensors: - raise ValueError('Empty tensor list') - - if len(tensors) == 1: - return tensors[0].unsqueeze(0) - - max_ndim = max(t.ndim for t in tensors) - expanded_tensors = [] - for t in tensors: - while t.ndim < max_ndim: - t = t.unsqueeze(0) - expanded_tensors.append(t) - - max_shape = [] - for dim in range(max_ndim): - max_shape.append(max(t.shape[dim] for t in expanded_tensors)) - - padded_tensors = [] - for t in expanded_tensors: - if list(t.shape) == max_shape: - padded_tensors.append(t) - else: - pad_params = [] - for dim in range(max_ndim - 1, -1, -1): - pad_params.extend([0, max_shape[dim] - t.shape[dim]]) - padded = torch.nn.functional.pad(t, pad_params, value=pad_value) - padded_tensors.append(padded) - - return torch.cat(padded_tensors, dim=0) diff --git a/src/twinkle/loss/__init__.py b/src/twinkle/loss/__init__.py index 7870f5a4..4e4d0e82 100644 --- a/src/twinkle/loss/__init__.py +++ b/src/twinkle/loss/__init__.py @@ -2,6 +2,7 @@ from .base import Loss from .chunked_cross_entropy import ChunkedCrossEntropyLoss from .cross_entropy import CrossEntropyLoss +from .dpo import CPOLoss, DPOLoss, ORPOLoss, SimPOLoss from .gkd import GKDLoss from .grpo import BNPOLoss, CISPOLoss, DRGRPOLoss, GRPOLoss, GSPOLoss, SAPOLoss from .mse import MSELoss @@ -19,4 +20,9 @@ 'cispo': CISPOLoss, 'bnpo': BNPOLoss, 'dr_grpo': DRGRPOLoss, + # DPO family losses + 'dpo': DPOLoss, + 'simpo': SimPOLoss, + 'cpo': CPOLoss, + 'orpo': ORPOLoss, } diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py new file mode 100644 index 00000000..44d81f82 --- /dev/null +++ b/src/twinkle/loss/dpo.py @@ -0,0 +1,549 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +DPO (Direct Preference Optimization) Loss Implementation. + +Reference: + "Direct Preference Optimization: Your Language Model is Secretly a Reward Model" + (https://arxiv.org/abs/2305.18290) +""" +from typing import TYPE_CHECKING, Dict, List, Optional, Union + +from twinkle.data_format import LossOutput +from twinkle.loss.base import Loss +from twinkle.utils.torch_utils import selective_log_softmax + +if TYPE_CHECKING: + import torch + + +class PreferenceLossBase(Loss): + """Base class for preference optimization losses with shared utilities.""" + + def __init__(self, ignore_index: int = -100): + self.ignore_index = ignore_index + + def _compute_logps_from_logits( + self, + logits: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute per-token log probabilities from logits. + + Args: + logits: [batch, seq_len, vocab_size] model logits + labels: [batch, seq_len] target token ids + + Returns: + logps: [batch, seq_len] per-token log probabilities + """ + loss_mask = (labels != self.ignore_index).bool() + masked_labels = labels.clone() + masked_labels[~loss_mask] = 0 + return selective_log_softmax(logits, masked_labels) + + def _compute_sequence_logps( + self, + per_token_logps: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute sequence-level log probabilities by summing valid token logps.""" + loss_mask = (labels != self.ignore_index).float() + return (per_token_logps * loss_mask).sum(dim=-1) + + def _compute_avg_logps( + self, + per_token_logps: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute length-normalized (average) log probabilities.""" + loss_mask = (labels != self.ignore_index).float() + seq_lengths = loss_mask.sum(dim=-1).clamp(min=1) + return (per_token_logps * loss_mask).sum(dim=-1) / seq_lengths + + def _compute_nll_loss( + self, + per_token_logps: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute negative log likelihood loss.""" + loss_mask = (labels != self.ignore_index).float() + return -(per_token_logps * loss_mask).sum() / loss_mask.sum().clamp(min=1) + + def _get_logps_from_outputs( + self, + outputs: Dict, + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Extract or compute log probabilities from model outputs.""" + logps = outputs.get('logps') + if logps is None: + logits = outputs.get('logits') + assert logits is not None, "outputs must contain 'logps' or 'logits'" + if logits.shape[1] != labels.shape[1]: + logits = logits[:, -labels.shape[1]:] + logps = self._compute_logps_from_logits(logits, labels) + return logps + + def _split_chosen_rejected( + self, + tensor: 'torch.Tensor', + ) -> tuple: + """Split interleaved tensor into chosen and rejected. + + Input format: [pos_1, neg_1, pos_2, neg_2, ...] (interleaved for DP-safe slicing) + Output: (chosen [pos_1, pos_2, ...], rejected [neg_1, neg_2, ...]) + """ + # Even indices = chosen (positive), odd indices = rejected (negative) + return tensor[0::2], tensor[1::2] + + +class DPOLoss(PreferenceLossBase): + """Direct Preference Optimization (DPO) Loss. + + DPO directly optimizes the policy using preference data without explicit reward modeling. + The loss function is derived from the Bradley-Terry preference model: + + L_DPO = -log(ฯƒ(ฮฒ * (log ฯ€(y_w|x)/ฯ€_ref(y_w|x) - log ฯ€(y_l|x)/ฯ€_ref(y_l|x)))) + + where: + - y_w is the preferred (chosen) response + - y_l is the dispreferred (rejected) response + - ฮฒ is the temperature parameter controlling deviation from reference + - ฯ€ is the current policy + - ฯ€_ref is the reference policy (frozen) + + Args: + beta: Temperature parameter controlling how much to deviate from ref policy (default: 0.1). + label_smoothing: Label smoothing parameter for soft labels (default: 0.0). + ignore_index: Index to ignore in labels (default: -100). + loss_type: Type of DPO loss variant ('sigmoid', 'hinge', 'ipo', 'kto_pair') (default: 'sigmoid'). + reference_free: Whether to use reference-free DPO (default: False). + sft_weight: Weight for SFT loss on chosen responses to prevent likelihood displacement (default: 0.0). + """ + + def __init__( + self, + beta: float = 0.1, + label_smoothing: float = 0.0, + ignore_index: int = -100, + loss_type: str = 'sigmoid', + reference_free: bool = False, + sft_weight: float = 0.0, + **kwargs, + ): + super().__init__(ignore_index=ignore_index) + self.beta = beta + self.label_smoothing = label_smoothing + self.loss_type = loss_type + self.reference_free = reference_free + self.sft_weight = sft_weight + + def _align_logps( + self, + logps: 'torch.Tensor', + target_shape: tuple, + device: 'torch.device', + dtype: 'torch.dtype', + ) -> 'torch.Tensor': + """Align log probabilities to target shape. + + Args: + logps: Input log probabilities tensor + target_shape: Target (batch, seq_len) shape + device: Target device + dtype: Target dtype + + Returns: + Aligned tensor of shape target_shape + """ + import torch + + if not torch.is_tensor(logps): + raise TypeError(f'Expected torch.Tensor, got {type(logps)}') + + if logps.dim() == 1: + logps = logps.unsqueeze(0) + + if logps.shape == target_shape: + return logps.to(device=device, dtype=dtype) + + # Handle tensor with different sequence length + if logps.dim() == 2 and logps.shape[0] == target_shape[0]: + batch_size, target_seq_len = target_shape + src_seq_len = logps.shape[1] + logps = logps.to(device=device, dtype=dtype) + if src_seq_len > target_seq_len: + # Truncate right (keep left part) - may happen in Ray result merging + return logps[:, :target_seq_len] + else: + raise ValueError(f'ref_logps seq_len ({src_seq_len}) < target seq_len ({target_seq_len}). ' + f'This should not happen when both models process the same batch.') + + raise ValueError(f'Cannot align ref_logps shape {logps.shape} to target shape {target_shape}') + + def _compute_dpo_loss( + self, + policy_chosen_logps: 'torch.Tensor', + policy_rejected_logps: 'torch.Tensor', + reference_chosen_logps: 'torch.Tensor', + reference_rejected_logps: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute the DPO loss. + + Args: + policy_chosen_logps: [batch/2] log probs of chosen under current policy + policy_rejected_logps: [batch/2] log probs of rejected under current policy + reference_chosen_logps: [batch/2] log probs of chosen under reference policy + reference_rejected_logps: [batch/2] log probs of rejected under reference policy + + Returns: + loss: Scalar DPO loss + """ + import torch + import torch.nn.functional as F + + # Compute log ratios + if self.reference_free: + # Reference-free: only use policy log probs + chosen_logratios = policy_chosen_logps + rejected_logratios = policy_rejected_logps + else: + chosen_logratios = policy_chosen_logps - reference_chosen_logps + rejected_logratios = policy_rejected_logps - reference_rejected_logps + + # Compute preference margin + logits = self.beta * (chosen_logratios - rejected_logratios) + + if self.loss_type == 'sigmoid': + # Standard DPO loss: -log(sigmoid(beta * margin)) + losses = -F.logsigmoid(logits) + elif self.loss_type == 'hinge': + # Hinge loss variant + losses = torch.relu(1 - logits) + elif self.loss_type == 'ipo': + # IPO (Identity Preference Optimization) loss + # Reference: "A General Theoretical Paradigm to Understand Learning from Human Feedback" + losses = (logits - 1 / (2 * self.beta))**2 + elif self.loss_type == 'kto_pair': + # KTO pair loss (simplified version) + chosen_logratios_scaled = self.beta * chosen_logratios + rejected_logratios_scaled = self.beta * rejected_logratios + chosen_losses = 1 - F.sigmoid(chosen_logratios_scaled) + rejected_losses = F.sigmoid(rejected_logratios_scaled) + losses = chosen_losses + rejected_losses + else: + raise ValueError(f'Unknown loss_type: {self.loss_type}') + + # Apply label smoothing if specified + if self.label_smoothing > 0: + # Soft labels: (1 - eps) * loss_chosen + eps * loss_rejected + smooth_losses = -F.logsigmoid(-logits) # Loss for flipped preference + losses = (1 - self.label_smoothing) * losses + self.label_smoothing * smooth_losses + + return losses.mean() + + def __call__( + self, + inputs: Dict, + outputs: Dict, + *, + ref_outputs: Optional[Dict] = None, + ref_logps: Optional[Union['torch.Tensor', List[List[float]]]] = None, + ref_chosen_logps: Optional['torch.Tensor'] = None, + ref_rejected_logps: Optional['torch.Tensor'] = None, + **kwargs, + ) -> LossOutput: + """Compute DPO loss. + + The inputs should contain concatenated chosen and rejected examples: + - First half of batch: chosen responses + - Second half of batch: rejected responses + + Args: + inputs: Dict containing 'input_ids' and 'labels' [batch, seq_len]. + Batch should be organized as [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] + outputs: Dict containing either: + - 'logps': [batch, seq_len] pre-computed log probs, OR + - 'logits': [batch, seq_len, vocab] from which logps will be computed + ref_outputs: Dict from reference model forward, containing 'logps'. + ref_logps: [batch, seq_len] or List[List[float]] reference model log probs. + Can also be provided as separate ref_chosen_logps and ref_rejected_logps. + ref_chosen_logps: [batch/2] pre-computed reference log probs for chosen. + ref_rejected_logps: [batch/2] pre-computed reference log probs for rejected. + **kwargs: Additional arguments. + + Returns: + LossOutput with DPO loss and metrics. + """ + import torch + + # Extract ref_logps from ref_outputs if provided + if ref_outputs is not None and ref_logps is None: + ref_logps = ref_outputs.get('logps') + labels = inputs.get('labels') + assert labels is not None, "inputs must contain 'labels'" + if not torch.is_tensor(labels): + labels = torch.as_tensor(labels) + if labels.dim() == 1: + labels = labels.unsqueeze(0) + + batch_size = labels.shape[0] + assert batch_size % 2 == 0, 'Batch size must be even (chosen + rejected pairs)' + + # Get log probabilities from outputs + logps = self._get_logps_from_outputs(outputs, labels) + device = logps.device + dtype = logps.dtype + + # Split into chosen and rejected + chosen_labels, rejected_labels = self._split_chosen_rejected(labels) + chosen_logps, rejected_logps = self._split_chosen_rejected(logps) + + # Compute sequence-level log probs for policy + policy_chosen_logps = self._compute_sequence_logps(chosen_logps, chosen_labels) + policy_rejected_logps = self._compute_sequence_logps(rejected_logps, rejected_labels) + + # Handle reference log probs + if ref_chosen_logps is not None and ref_rejected_logps is not None: + # Pre-computed sequence-level reference log probs provided + reference_chosen_logps = ref_chosen_logps.to(device=device, dtype=dtype) + reference_rejected_logps = ref_rejected_logps.to(device=device, dtype=dtype) + elif ref_logps is not None: + # Per-token reference log probs provided, need to align and sum + ref_logps_aligned = self._align_logps(ref_logps, labels.shape, device, dtype) + ref_chosen, ref_rejected = self._split_chosen_rejected(ref_logps_aligned) + reference_chosen_logps = self._compute_sequence_logps(ref_chosen, chosen_labels) + reference_rejected_logps = self._compute_sequence_logps(ref_rejected, rejected_labels) + elif self.reference_free: + # Reference-free mode: no reference model needed + reference_chosen_logps = torch.zeros_like(policy_chosen_logps) + reference_rejected_logps = torch.zeros_like(policy_rejected_logps) + else: + return LossOutput(loss=torch.tensor(0.0, device=chosen_logps.device), num_tokens=0) + + # Compute DPO loss + dpo_loss = self._compute_dpo_loss( + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + ) + + # Add SFT loss on chosen responses to prevent likelihood displacement + if self.sft_weight > 0: + sft_loss = self._compute_nll_loss(chosen_logps, chosen_labels) + loss = dpo_loss + self.sft_weight * sft_loss + else: + loss = dpo_loss + + # Return 0 to skip gradient normalization by num_tokens + # DPO loss is already per-sample mean, unlike SFT which sums per-token loss + # When num_tokens=0, normalize_and_clip_grad_norm defaults to 1 (no division) + return LossOutput(loss=loss, num_tokens=0) + + +class SimPOLoss(PreferenceLossBase): + """SimPO (Simple Preference Optimization) Loss. + + SimPO is a simpler variant of DPO that doesn't require a reference model. + It uses length-normalized log probabilities. + + Reference: + "SimPO: Simple Preference Optimization with a Reference-Free Reward" + (https://arxiv.org/abs/2405.14734) + + Args: + beta: Temperature parameter (default: 2.5). + gamma: Target reward margin (default: 0.5). + ignore_index: Index to ignore in labels (default: -100). + """ + + def __init__( + self, + beta: float = 2.5, + gamma: float = 0.5, + ignore_index: int = -100, + **kwargs, + ): + super().__init__(ignore_index=ignore_index) + self.beta = beta + self.gamma = gamma + + def __call__( + self, + inputs: Dict, + outputs: Dict, + **kwargs, + ) -> LossOutput: + """Compute SimPO loss.""" + import torch + import torch.nn.functional as F + + labels = inputs.get('labels') + assert labels is not None, "inputs must contain 'labels'" + if not torch.is_tensor(labels): + labels = torch.as_tensor(labels) + if labels.dim() == 1: + labels = labels.unsqueeze(0) + + assert labels.shape[0] % 2 == 0, 'Batch size must be even (chosen + rejected pairs)' + + # Get log probabilities + logps = self._get_logps_from_outputs(outputs, labels) + + # Split into chosen and rejected + chosen_labels, rejected_labels = self._split_chosen_rejected(labels) + chosen_logps, rejected_logps = self._split_chosen_rejected(logps) + + # Compute length-normalized log probs + chosen_rewards = self._compute_avg_logps(chosen_logps, chosen_labels) + rejected_rewards = self._compute_avg_logps(rejected_logps, rejected_labels) + + # SimPO loss: -log(sigmoid(beta * (r_w - r_l) - gamma)) + logits = self.beta * (chosen_rewards - rejected_rewards) - self.gamma + loss = -F.logsigmoid(logits).mean() + + return LossOutput(loss=loss, num_tokens=0) + + +class CPOLoss(PreferenceLossBase): + """CPO (Contrastive Preference Optimization) Loss. + + CPO adds a behavior cloning term to preference optimization. + + Reference: + "Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation" + (https://arxiv.org/abs/2401.08417) + + Args: + beta: Temperature parameter for preference (default: 0.1). + bc_coef: Behavior cloning coefficient (default: 1.0). + ignore_index: Index to ignore in labels (default: -100). + """ + + def __init__( + self, + beta: float = 0.1, + bc_coef: float = 1.0, + ignore_index: int = -100, + **kwargs, + ): + super().__init__(ignore_index=ignore_index) + self.beta = beta + self.bc_coef = bc_coef + + def __call__( + self, + inputs: Dict, + outputs: Dict, + **kwargs, + ) -> LossOutput: + """Compute CPO loss.""" + import torch + import torch.nn.functional as F + + labels = inputs.get('labels') + assert labels is not None, "inputs must contain 'labels'" + if not torch.is_tensor(labels): + labels = torch.as_tensor(labels) + if labels.dim() == 1: + labels = labels.unsqueeze(0) + + assert labels.shape[0] % 2 == 0, 'Batch size must be even' + + # Get log probabilities + logps = self._get_logps_from_outputs(outputs, labels) + + # Split into chosen and rejected + chosen_labels, rejected_labels = self._split_chosen_rejected(labels) + chosen_logps, rejected_logps = self._split_chosen_rejected(logps) + + # Compute sequence-level log probs + chosen_seq_logps = self._compute_sequence_logps(chosen_logps, chosen_labels) + rejected_seq_logps = self._compute_sequence_logps(rejected_logps, rejected_labels) + + # Preference loss (reference-free DPO) + logits = self.beta * (chosen_seq_logps - rejected_seq_logps) + preference_loss = -F.logsigmoid(logits).mean() + + # Behavior cloning loss on chosen + bc_loss = self._compute_nll_loss(chosen_logps, chosen_labels) + + # Combined loss + loss = preference_loss + self.bc_coef * bc_loss + + return LossOutput(loss=loss, num_tokens=0) + + +class ORPOLoss(PreferenceLossBase): + """ORPO (Odds Ratio Preference Optimization) Loss. + + ORPO combines SFT and preference alignment in a single objective using odds ratios. + + Reference: + "ORPO: Monolithic Preference Optimization without Reference Model" + (https://arxiv.org/abs/2403.07691) + + Args: + lambda_orpo: Weight for the odds ratio term (default: 0.1). + ignore_index: Index to ignore in labels (default: -100). + """ + + def __init__( + self, + lambda_orpo: float = 0.1, + ignore_index: int = -100, + **kwargs, + ): + super().__init__(ignore_index=ignore_index) + self.lambda_orpo = lambda_orpo + + def __call__( + self, + inputs: Dict, + outputs: Dict, + **kwargs, + ) -> LossOutput: + """Compute ORPO loss.""" + import torch + import torch.nn.functional as F + + labels = inputs.get('labels') + assert labels is not None, "inputs must contain 'labels'" + if not torch.is_tensor(labels): + labels = torch.as_tensor(labels) + if labels.dim() == 1: + labels = labels.unsqueeze(0) + + assert labels.shape[0] % 2 == 0, 'Batch size must be even' + + # Get log probabilities + logps = self._get_logps_from_outputs(outputs, labels) + + # Split into chosen and rejected + chosen_labels, rejected_labels = self._split_chosen_rejected(labels) + chosen_logps, rejected_logps = self._split_chosen_rejected(logps) + + # SFT loss on chosen + sft_loss = self._compute_nll_loss(chosen_logps, chosen_labels) + + # Compute average log probs for odds ratio + chosen_avg_logps = self._compute_avg_logps(chosen_logps, chosen_labels) + rejected_avg_logps = self._compute_avg_logps(rejected_logps, rejected_labels) + + # Odds ratio: log(odds_chosen / odds_rejected) + # log_odds = log(p/(1-p)) = log(p) - log(1-p) + # Use numerically stable computation + prob_chosen = torch.exp(chosen_avg_logps).clamp(min=1e-7, max=1 - 1e-7) + prob_rejected = torch.exp(rejected_avg_logps).clamp(min=1e-7, max=1 - 1e-7) + log_odds_chosen = torch.log(prob_chosen) - torch.log(1 - prob_chosen) + log_odds_rejected = torch.log(prob_rejected) - torch.log(1 - prob_rejected) + + # ORPO odds ratio loss + odds_ratio = log_odds_chosen - log_odds_rejected + orpo_loss = -F.logsigmoid(odds_ratio).mean() + + # Combined loss + loss = sft_loss + self.lambda_orpo * orpo_loss + + return LossOutput(loss=loss, num_tokens=0) diff --git a/src/twinkle/metric/__init__.py b/src/twinkle/metric/__init__.py index 739c7a0d..59d5bbeb 100644 --- a/src/twinkle/metric/__init__.py +++ b/src/twinkle/metric/__init__.py @@ -2,5 +2,6 @@ from .accuracy import Accuracy from .base import Metric from .completion_and_reward import CompletionRewardMetric +from .dpo import DPOMetric from .loss import LossMetric from .train_metric import TrainMetric diff --git a/src/twinkle/metric/dpo.py b/src/twinkle/metric/dpo.py new file mode 100644 index 00000000..5ce61410 --- /dev/null +++ b/src/twinkle/metric/dpo.py @@ -0,0 +1,209 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""DPO-specific metrics for preference optimization training.""" +from typing import List, Union + +from twinkle.data_format import InputFeature, ModelOutput +from twinkle.utils import pad_and_stack_tensors +from .base import Metric + + +class DPOMetric(Metric): + """Metrics for DPO (Direct Preference Optimization) training. + + Computes TRL-style metrics: + - logps/chosen: Average sequence-level log prob of chosen responses + - logps/rejected: Average sequence-level log prob of rejected responses + - rewards/chosen: ฮฒ * (policy_chosen - ref_chosen) + - rewards/rejected: ฮฒ * (policy_rejected - ref_rejected) + - rewards/margins: chosen_reward - rejected_reward + - rewards/accuracies: Percentage where chosen_reward > rejected_reward + + Args: + device_mesh: The device mesh + process_group: The process group to collect data from + ignore_index: Label index to ignore (default: -100) + beta: DPO beta parameter for reward scaling (default: 0.1) + """ + + def __init__(self, device_mesh, process_group, ignore_index: int = -100, beta: float = 0.1, **kwargs): + super().__init__(device_mesh, process_group, **kwargs) + self.ignore_index = ignore_index + self.beta = beta + self.reset() + + def _compute_sequence_logps(self, per_token_logps, labels): + """Compute sequence-level log probs by summing valid token logps.""" + import torch + loss_mask = (labels != self.ignore_index).float() + return (per_token_logps * loss_mask).sum(dim=-1) + + def _align_logps(self, logps, target_shape, device, dtype): + """Align per-token logps to target shape by padding or truncating. + + Args: + logps: [batch, seq_len] tensor to align + target_shape: Target shape (batch, target_seq_len) + device: Target device + dtype: Target dtype + + Returns: + Aligned tensor with shape matching target_shape + """ + import torch + logps = logps.to(device=device, dtype=dtype) + batch_size, src_len = logps.shape + _, target_len = target_shape + + if src_len == target_len: + return logps + elif src_len < target_len: + raise ValueError(f'ref_logps seq_len ({src_len}) < target seq_len ({target_len}). ' + f'This should not happen when both models process the same batch.') + else: + return logps[:, :target_len] + + def _split_chosen_rejected(self, tensor): + """Split interleaved tensor into chosen and rejected. + + Input format: [pos_1, neg_1, pos_2, neg_2, ...] (interleaved for DP-safe slicing) + Output: (chosen [pos_1, pos_2, ...], rejected [neg_1, neg_2, ...]) + """ + return tensor[0::2], tensor[1::2] + + def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: ModelOutput, **kwargs): + """Accumulate DPO metrics from model outputs. + + Expects: + - outputs['logps']: [batch, seq_len] per-token log probabilities + - inputs['labels']: [batch, seq_len] labels with ignore_index for non-target tokens + - kwargs['ref_outputs']: Optional reference model outputs with 'logps' + """ + import torch + logps = outputs.get('logps') + if logps is None or len(logps) == 0: + return + + if isinstance(logps, list) and logps: + logps = pad_and_stack_tensors(logps) + + # Get labels from inputs + if isinstance(inputs, list): + labels = [input['labels'] for input in inputs] + if len(labels) == 1: + labels = labels[0] + else: + labels = pad_and_stack_tensors(labels) + inputs = {'labels': labels} + + labels = torch.as_tensor(inputs['labels']) + if labels.dim() == 1: + labels = labels.unsqueeze(0) + + # Ensure logps and labels have same device + if logps.device != labels.device: + labels = labels.to(logps.device) + + # Align sequence lengths if needed (truncate right) + if logps.shape[1] != labels.shape[1]: + min_len = min(logps.shape[1], labels.shape[1]) + logps = logps[:, :min_len] + labels = labels[:, :min_len] + + # Compute sequence-level logps + seq_logps = self._compute_sequence_logps(logps, labels) + + # Split into chosen and rejected (interleaved format) + chosen_logps, rejected_logps = self._split_chosen_rejected(seq_logps) + chosen_labels, rejected_labels = self._split_chosen_rejected(labels) + + # Accumulate policy logps + self.total_chosen_logps += chosen_logps.sum().item() + self.total_rejected_logps += rejected_logps.sum().item() + + # Compute rewards if ref_outputs available + ref_outputs = kwargs.get('ref_outputs') + if ref_outputs is not None: + ref_logps = ref_outputs.get('logps') + if ref_logps is not None: + # Align ref_logps to match labels shape (handles different seq lengths) + ref_logps = self._align_logps(ref_logps, labels.shape, labels.device, logps.dtype) + + ref_seq_logps = self._compute_sequence_logps(ref_logps, labels) + ref_chosen_logps, ref_rejected_logps = self._split_chosen_rejected(ref_seq_logps) + + # Accumulate ref logps + self.total_ref_chosen_logps += ref_chosen_logps.sum().item() + self.total_ref_rejected_logps += ref_rejected_logps.sum().item() + + # Compute rewards: ฮฒ * (policy - ref) + chosen_rewards = self.beta * (chosen_logps - ref_chosen_logps) + rejected_rewards = self.beta * (rejected_logps - ref_rejected_logps) + + self.total_chosen_rewards += chosen_rewards.sum().item() + self.total_rejected_rewards += rejected_rewards.sum().item() + margins = chosen_rewards - rejected_rewards + self.total_reward_margin += margins.sum().item() + self.total_reward_correct += (margins > 0).sum().item() + self.has_rewards = True + + self.total_count += chosen_logps.shape[0] + + def reset(self): + """Reset all accumulated values.""" + self.total_chosen_logps = 0.0 + self.total_rejected_logps = 0.0 + self.total_ref_chosen_logps = 0.0 + self.total_ref_rejected_logps = 0.0 + self.total_chosen_rewards = 0.0 + self.total_rejected_rewards = 0.0 + self.total_reward_margin = 0.0 + self.total_reward_correct = 0 + self.total_count = 0 + self.has_rewards = False + + def calculate(self): + """Calculate and return aggregated metrics.""" + local_results = [{ + 'chosen_logps': self.total_chosen_logps, + 'rejected_logps': self.total_rejected_logps, + 'ref_chosen_logps': self.total_ref_chosen_logps, + 'ref_rejected_logps': self.total_ref_rejected_logps, + 'chosen_rewards': self.total_chosen_rewards, + 'rejected_rewards': self.total_rejected_rewards, + 'reward_margin': self.total_reward_margin, + 'reward_correct': self.total_reward_correct, + 'count': self.total_count, + 'has_rewards': self.has_rewards, + }] + all_results = self.gather_results(local_results) + + total_chosen_logps = sum(r['chosen_logps'] for r in all_results) + total_rejected_logps = sum(r['rejected_logps'] for r in all_results) + total_ref_chosen_logps = sum(r['ref_chosen_logps'] for r in all_results) + total_ref_rejected_logps = sum(r['ref_rejected_logps'] for r in all_results) + total_chosen_rewards = sum(r['chosen_rewards'] for r in all_results) + total_rejected_rewards = sum(r['rejected_rewards'] for r in all_results) + total_reward_margin = sum(r['reward_margin'] for r in all_results) + total_reward_correct = sum(r['reward_correct'] for r in all_results) + total_count = sum(r['count'] for r in all_results) + has_rewards = any(r['has_rewards'] for r in all_results) + + self.reset() + + if total_count == 0: + return {} + + results = { + 'logps/chosen': f'{total_chosen_logps / total_count:.2f}', + 'logps/rejected': f'{total_rejected_logps / total_count:.2f}', + } + + if has_rewards: + results['logps/ref_chosen'] = f'{total_ref_chosen_logps / total_count:.2f}' + results['logps/ref_rejected'] = f'{total_ref_rejected_logps / total_count:.2f}' + results['rewards/chosen'] = f'{total_chosen_rewards / total_count:.4f}' + results['rewards/rejected'] = f'{total_rejected_rewards / total_count:.4f}' + results['rewards/margins'] = f'{total_reward_margin / total_count:.4f}' + results['rewards/accuracies'] = f'{total_reward_correct / total_count * 100:.1f}%' + self.reset() + return results diff --git a/src/twinkle/metric/loss.py b/src/twinkle/metric/loss.py index 52f50fdd..8f4ad0c9 100644 --- a/src/twinkle/metric/loss.py +++ b/src/twinkle/metric/loss.py @@ -52,7 +52,6 @@ def calculate(self): 'grad_norm': self.grad_norm, 'num_tokens': self.num_tokens }] - all_results = self.gather_results(local_results) total_loss = sum(r['loss'] for r in all_results) @@ -61,8 +60,10 @@ def calculate(self): num_tokens = sum(r['num_tokens'] for r in all_results) if num_tokens > 0: avg_loss = total_loss / num_tokens - else: + elif total_count > 0: avg_loss = total_loss / total_count + else: + avg_loss = 0.0 self.reset() results = {} if avg_loss is not None: diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 4b37973c..d5923f86 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -30,6 +30,7 @@ from twinkle.loss import CrossEntropyLoss, Loss from twinkle.metric import LossMetric, Metric, TrainMetric from twinkle.model.base import TwinkleModel +from twinkle.model.optimizer_group import BaseOptimizerGroup, TrainStatus from twinkle.patch import Patch, apply_patch from twinkle.processor import InputProcessor from twinkle.template import Template @@ -38,84 +39,38 @@ @dataclass -class MegatronOptimizerGroup: +class MegatronOptimizerGroup(BaseOptimizerGroup): """Optimizer group for Megatron training. Similar to OptimizerGroup but adapted for Megatron's distributed training. """ - adapter_name: str = None - adapter_config: Any = None - optimizer: Optimizer = None - lr_scheduler: LRScheduler = None - inputs: List[InputFeature] = None - outputs: ModelOutput = None - loss_instance: Loss = None - loss_value: Any = None - template: Template = None - processor: InputProcessor = None - gradient_accumulation_steps: int = 1 - cur_step: int = 0 - _dp_group = None - train_metrics: List[Metric] = field(default_factory=list) - eval_metrics: List[Metric] = field(default_factory=list) - _device_mesh: DeviceMesh = None - # Megatron optimizer specific fields - _last_grad_norm: float = 0.0 + # Megatron-specific fields _last_step_success: bool = True - def do_grad_sync(self, gradient_accumulation_steps: Optional[int] = None) -> bool: - if gradient_accumulation_steps is None: - gradient_accumulation_steps = self.gradient_accumulation_steps - else: - self.gradient_accumulation_steps = gradient_accumulation_steps - return (self.cur_step - 1) % gradient_accumulation_steps == 0 and self.cur_step > 1 - def __post_init__(self): if self._device_mesh.data_world_size > 1: self._dp_group = self._device_mesh.create_process_group(['dp', 'fsdp']) - self.train_metrics = [ + train_metrics = [ LossMetric(self._device_mesh, self._dp_group), TrainMetric(self._device_mesh, self._dp_group), ] + self.train_status = TrainStatus(metrics=train_metrics) - self.eval_metrics = [ + eval_metrics = [ LossMetric(self._device_mesh, self._dp_group), TrainMetric(self._device_mesh, self._dp_group), ] + self.eval_status = TrainStatus(metrics=eval_metrics) def _get_lr(self): _lrs = [] + if self.optimizer is None: + return _lrs _default_lr = self.optimizer.chained_optimizers[0].config.lr for param_group in self.optimizer.param_groups: _lrs.append(param_group.get('lr', _default_lr)) return _lrs - def accumulate_metrics(self, is_training): - if is_training: - metrics = self.train_metrics - else: - metrics = self.eval_metrics - if len(metrics) > 0 and self.inputs is not None and self.outputs is not None: - for metric in metrics: - metric.accumulate( - self.inputs, - self.outputs, - lr=self._get_lr(), - step=self.cur_step - 1, - gradient_accumulation_steps=self.gradient_accumulation_steps, - grad_norm=self._last_grad_norm) - - def calculate_metrics(self, is_training): - self.accumulate_metrics(is_training) - if is_training: - metrics = self.train_metrics - else: - metrics = self.eval_metrics - results = {} - for metric in metrics: - results.update(metric.calculate()) - return results - _default_adapter_name = '' @@ -245,7 +200,7 @@ def __init__( def _construct_default_optimizer_group(self): return MegatronOptimizerGroup( - loss_instance=CrossEntropyLoss(), + loss_instance=CrossEntropyLoss(reduction='sum'), template=Template(self.tokenizer_id), processor=InputProcessor(self.device_mesh, framework='megatron'), _device_mesh=self.device_mesh, @@ -297,6 +252,38 @@ def _not_encoded(inputs): assert isinstance(inputs, dict) return 'input_ids' not in inputs and 'input_embedding' not in inputs + @staticmethod + def _slice_value_for_microbatch(value, mb_start: int, mb_end: int, micro_batch_size: int): + """Recursively slice a value for microbatch processing. + + Handles nested dicts (e.g., ref_outputs: {"logps": tensor}) by recursively + slicing internal tensors. + + Args: + value: The value to slice (tensor, ndarray, list, dict, or scalar) + mb_start: Start index of the microbatch + mb_end: End index of the microbatch + micro_batch_size: Size of each microbatch + + Returns: + Sliced value with the same structure + """ + if isinstance(value, torch.Tensor) and value.dim() >= 1 and value.shape[0] > micro_batch_size: + return value[mb_start:mb_end] + elif isinstance(value, np.ndarray) and value.ndim >= 1 and value.shape[0] > micro_batch_size: + return value[mb_start:mb_end] + elif isinstance(value, (list, tuple)) and len(value) > micro_batch_size: + return value[mb_start:mb_end] + elif isinstance(value, dict): + # Recursively slice dict values (e.g., ref_outputs: {"logps": tensor}) + return { + k: MegatronModel._slice_value_for_microbatch(v, mb_start, mb_end, micro_batch_size) + for k, v in value.items() + } + else: + # Scalars, small tensors, or non-sliceable values pass through as-is + return value + def _postprocess_tensor_cp(self, tensor): """All-gather and reconstruct full sequence from CP-split tensor. @@ -405,6 +392,7 @@ def forward_backward(self, from megatron.core.pipeline_parallel import get_forward_backward_func adapter_name = kwargs.pop('adapter_name', self._get_default_group()) + disable_lora = kwargs.pop('disable_lora', False) temperature = float(kwargs.pop('temperature', 1.0)) forward_only = kwargs.pop('forward_only', False) return_logits = kwargs.pop('return_logits', False) @@ -424,7 +412,8 @@ def forward_backward(self, assert isinstance(processor, InputProcessor), 'Set InputProcessor correctly before forwarding' if micro_batch_size is None: - micro_batch_size = 1 + # Compatible with DPO + micro_batch_size = min(2, len(inputs)) inputs = processor(inputs, micro_batch_size=micro_batch_size, variable_seq_lengths=self.variable_seq_lengths) # Get parallelism settings for sequence padding and splitting @@ -455,17 +444,10 @@ def forward_backward(self, for mb_idx in range(num_microbatches): mb_start = mb_idx * micro_batch_size mb_end = mb_start + micro_batch_size - mb_kwargs = {} - for key, value in kwargs.items(): - if isinstance(value, torch.Tensor) and value.dim() >= 1 and value.shape[0] > micro_batch_size: - mb_kwargs[key] = value[mb_start:mb_end] - elif isinstance(value, np.ndarray) and value.ndim >= 1 and value.shape[0] > micro_batch_size: - mb_kwargs[key] = value[mb_start:mb_end] - elif isinstance(value, (list, tuple)) and len(value) > micro_batch_size: - mb_kwargs[key] = value[mb_start:mb_end] - else: - # Scalars, small tensors, or non-sliceable values pass through as-is - mb_kwargs[key] = value + mb_kwargs = { + key: self._slice_value_for_microbatch(value, mb_start, mb_end, micro_batch_size) + for key, value in kwargs.items() + } loss_extra_kwargs_per_mb.append(mb_kwargs) _mb_counter = [0] # mutable counter for closure @@ -479,6 +461,14 @@ def post_loss_function(output_tensor, inputs, logps): losses = result['loss'] counts = result['num_tokens'] if not counts: + # Later will gather this value, so it becomes: + # 1. SUM loss: gather_sum(local_num_tokens) = global_num_tokens + # 2. PER TOKEN MEAN loss: gather_sum(1 * gradient_accumulation_steps ) + # = gradient_accumulation_steps * world_size + # Then, grad will divided by this value: + # 1. SUM loss: (global_sum_grad) / (global_num_tokens) = global_sum_grad/global_num_tokens + # 2. PER TOKEN MEAN loss: (gather_sum(per_token_grad * gradient_accumulation_steps)) + # / (gradient_accumulation_steps * world_size ) = avg_per_token_grad counts = torch.tensor(1, device=losses.device) return self.strategy.reduce_loss(losses, counts, output_tensor, logps) @@ -487,7 +477,13 @@ def post_loss_function(output_tensor, inputs, logps): def forward_step_func(data_iterator, model): batch = next(data_iterator) labels = batch.pop('labels', None) - output_tensor = model(**batch) + # Handle disable_lora for base model inference (e.g., reference in DPO) + unwrapped_model = self.strategy.unwrap_model([model])[0] + if disable_lora and isinstance(unwrapped_model, PeftModel): + with unwrapped_model.disable_adapter(): + output_tensor = model(**batch) + else: + output_tensor = model(**batch) batch['labels'] = labels logps = None if labels is not None and mpu.is_pipeline_last_stage(): @@ -571,7 +567,6 @@ def forward_step_func(data_iterator, model): dp_cp_group = mpu.get_data_parallel_group(with_context_parallel=True) torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG, group=dp_cp_group) - optimizer_config.inputs = inputs if logps and len({_logps.shape[1] for _logps in logps}) == 1: logps = torch.cat(logps, dim=0) if logits and len({_logits.shape[1] for _logits in logits}) == 1: @@ -580,8 +575,14 @@ def forward_step_func(data_iterator, model): loss = loss.detach().cpu().float().numpy() if not return_logits: logits = None - if not forward_only: - optimizer_config.outputs = ModelOutput(logits=logits, loss=loss, logps=logps) + if forward_only: + optimizer_config.eval_status.inputs = inputs + optimizer_config.eval_status.outputs = ModelOutput(logits=logits, loss=loss, logps=logps) + optimizer_config.eval_status.forward_kwargs = kwargs + else: + optimizer_config.train_status.inputs = inputs + optimizer_config.train_status.outputs = ModelOutput(logits=logits, loss=loss, logps=logps) + optimizer_config.train_status.forward_kwargs = kwargs return ModelOutput(logits=logits, loss=loss, logps=logps) @remote_function(dispatch='all') @@ -712,6 +713,7 @@ def set_loss(self, loss_cls: Union[Loss, Type[Loss], str, Callable[[InputFeature optimizer_config = self.optimizer_group[adapter_name] optimizer_config.loss_instance = construct_class(loss_cls, Loss, twinkle.loss, **kwargs) + @remote_function() def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool] = None, **kwargs): """Add an eval metric @@ -727,9 +729,9 @@ def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool] kwargs['device_mesh'] = self.device_mesh kwargs['process_group'] = optimizer_config._dp_group if is_training is None or is_training is True: - optimizer_config.train_metrics.append(construct_class(metric_cls, Metric, twinkle.metric, **kwargs)) + optimizer_config.train_status.metrics.append(construct_class(metric_cls, Metric, twinkle.metric, **kwargs)) if not is_training: - optimizer_config.eval_metrics.append(construct_class(metric_cls, Metric, twinkle.metric, **kwargs)) + optimizer_config.eval_status.metrics.append(construct_class(metric_cls, Metric, twinkle.metric, **kwargs)) @remote_function(dispatch='all') def set_optimizer(self, optimizer_cls: Union[Optimizer, Type[Optimizer], str], **kwargs): @@ -793,16 +795,16 @@ def _create_megatron_optimizer(self, **kwargs): opt_config = OptimizerConfig( optimizer='adam', lr=lr, - min_lr=kwargs.get('min_lr', 0.0), - weight_decay=kwargs.get('weight_decay', 0.01), - adam_beta1=kwargs.get('adam_beta1', 0.9), - adam_beta2=kwargs.get('adam_beta2', 0.999), - adam_eps=kwargs.get('adam_eps', 1e-8), - clip_grad=kwargs.get('clip_grad', 1.0), - bf16=kwargs.get('bf16', True), + min_lr=kwargs.pop('min_lr', 0.0), + weight_decay=kwargs.pop('weight_decay', 0.01), + adam_beta1=kwargs.pop('adam_beta1', 0.9), + adam_beta2=kwargs.pop('adam_beta2', 0.999), + adam_eps=kwargs.pop('adam_eps', 1e-8), + clip_grad=kwargs.pop('clip_grad', 1.0), + bf16=kwargs.pop('bf16', True), use_distributed_optimizer=use_distributed_optimizer, - overlap_param_gather=kwargs.get('overlap_param_gather', False), - log_num_zeros_in_grad=kwargs.get('log_num_zeros_in_grad', False), + overlap_param_gather=kwargs.pop('overlap_param_gather', False), + log_num_zeros_in_grad=kwargs.pop('log_num_zeros_in_grad', False), **kwargs, ) diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index b674ef46..fc35b2b0 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -118,16 +118,18 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T Args: inputs: Model inputs. - **kwargs: Additional arguments. + **kwargs: Additional arguments including disable_lora. Returns: Model outputs. """ - self._check_adapter_valid(kwargs.get('adapter_name')) - with self.multi_adapter.adapter(kwargs.get('adapter_name')): + adapter_name = kwargs.get('adapter_name') + disable_lora = kwargs.get('disable_lora', False) + self._check_adapter_valid(adapter_name) + with self.multi_adapter.adapter(adapter_name, disable_lora=disable_lora): return super().forward_only(inputs=inputs, **kwargs) - @remote_function(dispatch='slice_dp', collect='mean', sync=True) + @remote_function(dispatch='slice_dp', collect=collect_tensor_dict, sync=True) def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], @@ -262,6 +264,7 @@ def set_processor(self, processor_cls: Union[Type[InputProcessor], str, Callable self._check_adapter_valid(kwargs.get('adapter_name')) super().set_processor(processor_cls, **kwargs) + @remote_function() def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool] = None, **kwargs): self._check_adapter_valid(kwargs.get('adapter_name')) super().add_metric(metric_cls, is_training, **kwargs) diff --git a/src/twinkle/model/multi_lora.py b/src/twinkle/model/multi_lora.py index 9780973b..cebde214 100644 --- a/src/twinkle/model/multi_lora.py +++ b/src/twinkle/model/multi_lora.py @@ -45,17 +45,21 @@ def _get_available_lora(self) -> Optional[LoraTenant]: def _count_available_loras(self): return len([_lora for _lora in self.loras if _lora.tenant_adapter_name is None]) - def activate_adapter(self, tenant_adapter_name: str): + def activate_adapter(self, tenant_adapter_name: str, call_enable=False): if not self.has_lora(tenant_adapter_name): raise ValueError(f'Adapter {tenant_adapter_name} does not exist') adapter_name = self.find_lora_by_tenant(tenant_adapter_name).adapter_name if isinstance(self.module, list): for _module in self.module: - # _module.enable_adapter_layers() + if call_enable: + # This will cost time + _module.enable_adapter_layers() if _module.active_adapter != adapter_name: _module.set_adapter(adapter_name) else: - # self.module.enable_adapter_layers() + if call_enable: + # This will cost time + self.module.enable_adapter_layers() if self.module.active_adapter != adapter_name: self.module.set_adapter(adapter_name) @@ -67,10 +71,20 @@ def deactivate_adapter(self): self.module.disable_adapter_layers() @contextmanager - def adapter(self, tenant_adapter_name: str): + def adapter(self, tenant_adapter_name: str, disable_lora: bool = False): self.activate_adapter(tenant_adapter_name) - yield self.find_lora_by_tenant(tenant_adapter_name).adapter_name - # self.deactivate_adapter() + if disable_lora: + # Temporarily disable all adapters while keeping optimizer_group active + with self._disable_lora_context(tenant_adapter_name): + yield self.find_lora_by_tenant(tenant_adapter_name).adapter_name + else: + yield self.find_lora_by_tenant(tenant_adapter_name).adapter_name + + @contextmanager + def _disable_lora_context(self, tenant_adapter_name): + self.deactivate_adapter() + yield + self.activate_adapter(tenant_adapter_name, call_enable=True) @contextmanager def save_context(self, tenant_adapter_name: str): @@ -191,7 +205,7 @@ def _linear_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: lora_A_keys = self.lora_A.keys() for active_adapter in self.active_adapters: - if active_adapter not in lora_A_keys: + if active_adapter not in lora_A_keys or self.disable_adapters: continue _lora = _self.find_lora(active_adapter) target_modules = _lora.tenant_config.target_modules @@ -224,7 +238,7 @@ def _embedding_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: lora_embedding_A_keys = self.lora_embedding_A.keys() for active_adapter in self.active_adapters: - if active_adapter not in lora_embedding_A_keys: + if active_adapter not in lora_embedding_A_keys or self.disable_adapters: continue _lora = self.find_lora(active_adapter) target_modules = _lora.tenant_config.target_modules diff --git a/src/twinkle/model/optimizer_group.py b/src/twinkle/model/optimizer_group.py new file mode 100644 index 00000000..5fdb89f7 --- /dev/null +++ b/src/twinkle/model/optimizer_group.py @@ -0,0 +1,85 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from dataclasses import dataclass, field +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler +from typing import Any, Dict, List, Optional + +from twinkle import DeviceMesh +from twinkle.data_format import InputFeature, ModelOutput +from twinkle.loss import Loss +from twinkle.metric import Metric +from twinkle.processor import InputProcessor +from twinkle.template import Template + + +@dataclass +class TrainStatus: + """Status for training or evaluation. + + Encapsulates inputs, outputs, loss, tokens count and metrics for a training/eval step. + """ + inputs: List[InputFeature] = None + outputs: ModelOutput = None + loss_value: Any = None + num_tokens: int = 0 + metrics: List[Metric] = field(default_factory=list) + forward_kwargs: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class BaseOptimizerGroup: + """Base optimizer group with common fields for training. + + Subclasses: OptimizerGroup (Transformers), MegatronOptimizerGroup (Megatron) + """ + adapter_name: str = None + adapter_config: Any = None + optimizer: Optimizer = None + lr_scheduler: LRScheduler = None + loss_instance: Loss = None + train_status: TrainStatus = None + eval_status: TrainStatus = None + template: Template = None + processor: InputProcessor = None + gradient_accumulation_steps: int = 1 + cur_step: int = 0 + _dp_group: Any = None + _device_mesh: DeviceMesh = None + _last_grad_norm: float = 0.0 + + def do_grad_sync(self, gradient_accumulation_steps: Optional[int] = None) -> bool: + if gradient_accumulation_steps is None: + gradient_accumulation_steps = self.gradient_accumulation_steps + else: + self.gradient_accumulation_steps = gradient_accumulation_steps + return gradient_accumulation_steps == 1 or ((self.cur_step - 1) % gradient_accumulation_steps == 0 + and self.cur_step > 1) + + def _get_lr(self): + """Get learning rates from optimizer. Override in subclass.""" + return [] + + def accumulate_metrics(self, is_training): + """Accumulate metrics for train/eval status. Override in subclass if needed.""" + status = self.train_status if is_training else self.eval_status + if len(status.metrics) > 0 and status.inputs is not None and status.outputs is not None: + for metric in status.metrics: + metric.accumulate( + status.inputs, + status.outputs, + lr=self._get_lr(), + step=self.cur_step - 1, + gradient_accumulation_steps=self.gradient_accumulation_steps, + grad_norm=self._last_grad_norm, + **status.forward_kwargs) + + def calculate_metrics(self, is_training): + """Calculate and return metrics.""" + self.accumulate_metrics(is_training) + status = self.train_status if is_training else self.eval_status + results = {} + for metric in status.metrics: + results.update(metric.calculate()) + status.inputs = None + status.outputs = None + return results diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index 0900f52b..77a25b18 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -10,6 +10,7 @@ from twinkle import DeviceMesh, remote_class, remote_function, template from twinkle.data_format import InputFeature, Trajectory from twinkle.hub import HubOperation +from twinkle.infra import collect_tensor_dict from twinkle.loss import Loss from twinkle.metric import Metric from twinkle.processor import InputProcessor @@ -88,7 +89,7 @@ def unregister_mm_forward_hook(self, optimizer_group: OptimizerGroup): def _lazy_wrap_model(self): pass - @remote_function(dispatch='slice_dp', collect='mean') + @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): self._check_adapter_valid(kwargs.get('adapter_name')) optimizer_config = self.optimizer_group[kwargs.get('adapter_name')] @@ -104,10 +105,12 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, with self.multi_adapter.adapter(kwargs.get('adapter_name')): return super().forward(inputs=inputs, **kwargs) - @remote_function(dispatch='slice_dp', collect='flatten') + @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajectory]], **kwargs): - self._check_adapter_valid(kwargs.get('adapter_name')) - optimizer_config = self.optimizer_group[kwargs.get('adapter_name')] + adapter_name = kwargs.get('adapter_name') + disable_lora = kwargs.get('disable_lora', False) + self._check_adapter_valid(adapter_name) + optimizer_config = self.optimizer_group[adapter_name] if (isinstance(inputs, dict) and self._not_encoded(inputs)) or (isinstance(inputs, list) and self._not_encoded(inputs[0])): # Trajectory or List[Trajectory] @@ -117,7 +120,7 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T inputs = [inputs] inputs = optimizer_config.template.batch_encode(inputs) # noqa self.multi_adapter.check_length(inputs) - with self.multi_adapter.adapter(kwargs.get('adapter_name')): + with self.multi_adapter.adapter(adapter_name, disable_lora=disable_lora): return super().forward_only(inputs=inputs, **kwargs) @remote_function() @@ -244,6 +247,7 @@ def set_grad_scaler(self, **kwargs): self._check_adapter_valid(kwargs.get('adapter_name')) super().set_grad_scaler(**kwargs) + @remote_function() def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool] = None, **kwargs): self._check_adapter_valid(kwargs.get('adapter_name')) super().add_metric(metric_cls, is_training, **kwargs) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 520aaf9f..6097ffe2 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -31,6 +31,7 @@ from twinkle.loss import CrossEntropyLoss, Loss from twinkle.metric import Accuracy, LossMetric, Metric, TrainMetric from twinkle.model.base import TwinkleModel +from twinkle.model.optimizer_group import BaseOptimizerGroup, TrainStatus from twinkle.model.transformers.moe import apply_expert_parallel from twinkle.model.transformers.strategy import AccelerateStrategy, NativeFSDPStrategy from twinkle.patch import Patch, apply_patch @@ -42,53 +43,33 @@ @dataclass -class OptimizerGroup: - adapter_name: str = None +class OptimizerGroup(BaseOptimizerGroup): + """Optimizer group for Transformers training.""" adapter_config: PeftConfig = None - optimizer: Optimizer = None - lr_scheduler: LRScheduler = None - inputs: List[InputFeature] = None - outputs: ModelOutput = None loss_instance: Loss = CrossEntropyLoss - loss_value: Any = None - template: Template = None - processor: InputProcessor = None scaler: GradScaler = None - _last_grad_norm: float = 0.0 scaler_has_nan: bool = False - gradient_accumulation_steps: int = 1 - cur_step: int = 0 - num_tokens: int = 0 - train_metrics: List[Metric] = field(default_factory=list) - eval_metrics: List[Metric] = field(default_factory=list) checkpoint_engine: CheckpointEngine = None - _dp_group = None - _device_mesh: DeviceMesh = None _handler: Any = None - def do_grad_sync(self, gradient_accumulation_steps: Optional[int] = None) -> bool: - if gradient_accumulation_steps is None: - gradient_accumulation_steps = self.gradient_accumulation_steps - else: - self.gradient_accumulation_steps = gradient_accumulation_steps - return (self.cur_step - 1) % gradient_accumulation_steps == 0 and self.cur_step > 1 - def __post_init__(self): self._ensure_dp_group() self._build_metrics() def _build_metrics(self): - self.train_metrics = [ - LossMetric(self._device_mesh, self._dp_group, loss_reduction='sum'), + train_metrics = [ + LossMetric(self._device_mesh, self._dp_group), Accuracy(self._device_mesh, self._dp_group), TrainMetric(self._device_mesh, self._dp_group), ] + self.train_status = TrainStatus(metrics=train_metrics) - self.eval_metrics = [ - LossMetric(self._device_mesh, self._dp_group, loss_reduction='sum'), + eval_metrics = [ + LossMetric(self._device_mesh, self._dp_group), Accuracy(self._device_mesh, self._dp_group), TrainMetric(self._device_mesh, self._dp_group), ] + self.eval_status = TrainStatus(metrics=eval_metrics) def _ensure_dp_group(self): if self._dp_group is not None or self._device_mesh is None: @@ -117,33 +98,18 @@ def _get_lr(self): def accumulate_metrics(self, is_training): self._ensure_dp_group() - if is_training: - metrics = self.train_metrics - else: - metrics = self.eval_metrics - if len(metrics) > 0 and self.inputs is not None and self.outputs is not None: - for metric in metrics: + status = self.train_status if is_training else self.eval_status + if len(status.metrics) > 0 and status.inputs is not None and status.outputs is not None: + for metric in status.metrics: metric.accumulate( - self.inputs, - self.outputs, + status.inputs, + status.outputs, lr=self._get_lr(), step=self.cur_step - 1, gradient_accumulation_steps=self.gradient_accumulation_steps, grad_norm=self._last_grad_norm, - loss_reduction=getattr(self.loss_instance, 'reduction', 'mean')) - - def calculate_metrics(self, is_training): - self.accumulate_metrics(is_training) - if is_training: - metrics = self.train_metrics - else: - metrics = self.eval_metrics - results = {} - for metric in metrics: - results.update(metric.calculate()) - self.inputs = None - self.outputs = None - return results + loss_reduction=getattr(self.loss_instance, 'reduction', 'mean'), + **status.forward_kwargs) _default_adapter_name = '' @@ -278,7 +244,7 @@ def _ensure_sp_strategy(self) -> None: ) def _get_default_group(self): - """Get the only group has optimizer, else return the default one""" + """Get the only group, else return the default one""" if len(self.optimizer_group) == 1: return next(iter(self.optimizer_group)) return self.active_group @@ -403,9 +369,10 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec if self.sp_strategy is not None and labels is None: outputs = self.sp_strategy.postprocess_outputs(outputs) inputs['labels'] = labels - optimizer_config.inputs = inputs - optimizer_config.outputs = outputs - optimizer_config.loss_value = outputs.get('aux_loss', 0) + optimizer_config.train_status.inputs = inputs + optimizer_config.train_status.outputs = outputs + optimizer_config.train_status.forward_kwargs = kwargs + optimizer_config.train_status.loss_value = outputs.get('aux_loss', 0) if labels is not None: loss_mask = (labels != -100).bool() masked_labels = labels.clone() @@ -427,10 +394,12 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T inputs: The model inputs. Can be an encoded batch, or a list of `Trajectory` **kwargs: adapter_name: Lora adapter name. + disable_lora: If True, disable LoRA and use base model for inference. Returns: The output of the model forward. """ adapter_name = kwargs.pop('adapter_name', self._get_default_group()) + disable_lora = kwargs.pop('disable_lora', False) temperature = float(kwargs.pop('temperature', 1.0)) return_logits = kwargs.pop('return_logits', False) optimizer_config = self.optimizer_group[adapter_name] @@ -454,13 +423,19 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T inputs = self.sp_strategy.preprocess_inputs(inputs) labels = inputs.pop('labels', None) optimizer_config.accumulate_metrics(False) - outputs = self.model(**inputs) + unwrapped_model = self.strategy.unwrap_model(self.model) + if disable_lora and isinstance(unwrapped_model, PeftModel): + with unwrapped_model.disable_adapter(): + outputs = self.model(**inputs) + else: + outputs = self.model(**inputs) if self.sp_strategy is not None and labels is None: outputs = self.sp_strategy.postprocess_outputs(outputs) inputs['labels'] = labels - optimizer_config.inputs = inputs - optimizer_config.outputs = outputs - optimizer_config.loss_value = outputs.get('aux_loss', 0) + optimizer_config.eval_status.inputs = inputs + optimizer_config.eval_status.outputs = outputs + optimizer_config.eval_status.forward_kwargs = kwargs + optimizer_config.eval_status.loss_value = outputs.get('aux_loss', 0) if labels is not None: loss_mask = (labels != -100).bool() masked_labels = labels.clone() @@ -489,24 +464,38 @@ def calculate_loss(self, **kwargs): optimizer_config = self.optimizer_group[adapter_name] loss_instance: Loss = optimizer_config.loss_instance assert isinstance(loss_instance, Loss), 'Set a loss_instance before calculating loss' - inputs = optimizer_config.inputs - outputs = optimizer_config.outputs + inputs = optimizer_config.train_status.inputs + outputs = optimizer_config.train_status.outputs assert inputs is not None and outputs is not None, 'Cannot calculate loss of empty inputs and outputs' result = loss_instance(inputs, outputs, **kwargs) loss_value = result['loss'] counts = result['num_tokens'] if not counts: - counts = torch.tensor(0, device=loss_value.device) + counts = torch.tensor(1, device=loss_value.device) + # Later will gather this value, so it becomes: + # 1. SUM loss: gather_sum(local_num_tokens / dp_world_size) = global_num_tokens / dp_world_size + # 2. PER TOKEN MEAN loss: gather_sum(1 * gradient_accumulation_steps / dp_world_size ) + # = gradient_accumulation_steps + # Then, grad will divided by this value: + # 1. SUM loss: gather_mean(local_sum_grad) / (global_num_tokens / dp_world_size) + # = (global_sum_grad / dp_world_size) / (global_num_tokens / dp_world_size) + # = global_sum_grad/global_num_tokens + # 2. PER TOKEN MEAN loss: gather_mean(per_token_grad * gradient_accumulation_steps) + # / gradient_accumulation_steps + # = (global_per_token_grad * gradient_accumulation_steps / dp_world_size ) + # / gradient_accumulation_steps + # = global_per_token_grad / dp_world_size = avg_per_token_grad + counts = counts / self.device_mesh.data_world_size optimizer_config = self.optimizer_group[adapter_name] - optimizer_config.num_tokens += counts.item() + optimizer_config.train_status.num_tokens += counts.item() if self.sp_strategy is not None and 'labels' in inputs: reduction = getattr(loss_instance, 'reduction', None) if reduction is not None: self.sp_strategy.sp_config['loss_reduction'] = str(reduction) loss_value = self.sp_strategy.reduce_loss(loss_value, inputs['labels']) - optimizer_config.loss_value += loss_value - outputs['loss'] = optimizer_config.loss_value - return optimizer_config.loss_value.item() + optimizer_config.train_status.loss_value += loss_value + outputs['loss'] = optimizer_config.train_status.loss_value + return optimizer_config.train_status.loss_value.item() @remote_function() def backward(self, **kwargs): @@ -519,7 +508,7 @@ def backward(self, **kwargs): """ adapter_name = kwargs.pop('adapter_name', self._get_default_group()) optimizer_config = self.optimizer_group[adapter_name] - loss_value = optimizer_config.loss_value + loss_value = optimizer_config.train_status.loss_value assert loss_value is not None, 'Do forwarding and calculating loss before backward' scaler = optimizer_config.scaler if scaler is None and self.mixed_precision == 'fp16': @@ -531,7 +520,7 @@ def backward(self, **kwargs): else: loss_value.backward() optimizer_config.cur_step += 1 - optimizer_config.loss_value = None + optimizer_config.train_status.loss_value = None @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], @@ -582,7 +571,7 @@ def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type=2, **kwargs): scaler.unscale_(optimizer) optimizer_config._ensure_dp_group() - num_tokens = optimizer_config.num_tokens + num_tokens = optimizer_config.train_status.num_tokens num_tokens = torch_util.gather_object([num_tokens], self.device_mesh, optimizer_config._dp_group) num_tokens = sum(num_tokens) parameters = list(self._get_trainable_parameters(adapter_name).values()) @@ -599,7 +588,7 @@ def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type=2, **kwargs): **ep_clip_kwargs, ) optimizer_config._last_grad_norm = grad_norm - optimizer_config.num_tokens = 0 + optimizer_config.train_status.num_tokens = 0 return grad_norm @remote_function(dispatch='all') @@ -1016,8 +1005,7 @@ def _patch_adapter(self, adapter_name: str, config_or_dir: Union[PeftConfig, str else: unwrapped_model.add_adapter(adapter_name, config) - self.optimizer_group[adapter_name] = self.optimizer_group.pop(_default_adapter_name, - self._construct_default_optimizer_group()) + self.optimizer_group[adapter_name] = self._construct_default_optimizer_group() self.optimizer_group[adapter_name].adapter_name = adapter_name self.optimizer_group[adapter_name].adapter_config = config _gas_default = kwargs.get('gradient_accumulation_steps', 1) @@ -1086,6 +1074,7 @@ def set_grad_scaler(self, **kwargs): grad_scaler_config.update(kwargs) optimizer_config.scaler = GradScaler(**grad_scaler_config) + @remote_function() def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool] = None, **kwargs): """Add an eval metric @@ -1101,9 +1090,9 @@ def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool] kwargs['device_mesh'] = self.device_mesh kwargs['process_group'] = optimizer_config._dp_group if is_training is None or is_training is True: - optimizer_config.train_metrics.append(construct_class(metric_cls, Metric, twinkle.metric, **kwargs)) + optimizer_config.train_status.metrics.append(construct_class(metric_cls, Metric, twinkle.metric, **kwargs)) if not is_training: - optimizer_config.eval_metrics.append(construct_class(metric_cls, Metric, twinkle.metric, **kwargs)) + optimizer_config.eval_status.metrics.append(construct_class(metric_cls, Metric, twinkle.metric, **kwargs)) def _get_nb_trainable_parameters(self, adapter_name, model): return PeftModel.get_nb_trainable_parameters(model) diff --git a/src/twinkle/preprocessor/__init__.py b/src/twinkle/preprocessor/__init__.py index 13b52d99..3294e623 100644 --- a/src/twinkle/preprocessor/__init__.py +++ b/src/twinkle/preprocessor/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .base import DataFilter, Preprocessor +from .dpo import EmojiDPOProcessor from .llm import (AlpacaProcessor, CompetitionMathGRPOProcessor, CompetitionMathProcessor, CountdownProcessor, GSM8KProcessor, SelfCognitionProcessor) diff --git a/src/twinkle/preprocessor/dpo.py b/src/twinkle/preprocessor/dpo.py new file mode 100644 index 00000000..c299d74e --- /dev/null +++ b/src/twinkle/preprocessor/dpo.py @@ -0,0 +1,76 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +DPO (Direct Preference Optimization) Data Preprocessors. + +These preprocessors convert various preference dataset formats into the standard +format required by Twinkle for DPO training. + +DPO output format: + - positive: Trajectory - chosen response trajectory + - negative: Trajectory - rejected response trajectory +""" +from typing import Any, Dict, List, Optional, Union + +from twinkle.data_format import Message, Trajectory +from .base import Preprocessor + + +class EmojiDPOProcessor(Preprocessor): + """Preprocessor for shareAI/DPO-zh-en-emoji dataset format. + + Dataset format: + { + 'prompt': str, + 'answer_zh': str, # chosen response (Chinese) + 'answer_en': str, # rejected response (English) + } + + Output format: + - positive: Trajectory with chosen (answer_zh) + - negative: Trajectory with rejected (answer_en) + + Args: + system: Optional system prompt. + chosen_key: Key for chosen response (default: 'answer_zh'). + rejected_key: Key for rejected response (default: 'answer_en'). + prompt_key: Key for prompt (default: 'prompt'). + """ + + def __init__( + self, + system: Optional[str] = None, + chosen_key: str = 'answer_zh', + rejected_key: str = 'answer_en', + prompt_key: str = 'prompt', + ): + self.system = system + self.chosen_key = chosen_key + self.rejected_key = rejected_key + self.prompt_key = prompt_key + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: + """Process a single row.""" + prompt = row.get(self.prompt_key, '') + chosen = row.get(self.chosen_key, '') + rejected = row.get(self.rejected_key, '') + + prompt_messages = [] + if self.system: + prompt_messages.append(Message(role='system', content=self.system)) + prompt_messages.append(Message(role='user', content=prompt)) + + chosen_messages = prompt_messages + [Message(role='assistant', content=chosen)] + rejected_messages = prompt_messages + [Message(role='assistant', content=rejected)] + + return { + 'positive': Trajectory(messages=chosen_messages), + 'negative': Trajectory(messages=rejected_messages), + } + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + results = [self.preprocess(row) for row in rows] + return { + 'positive': [r['positive'] for r in results], + 'negative': [r['negative'] for r in results], + } diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 167a459d..acc17d2e 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -179,9 +179,6 @@ def _add_default_system(self, trajectory: Trajectory) -> List[Trajectory]: if self.use_chat_template and self.default_system: if trajectory['messages'][0]['role'] == 'user': trajectory['messages'].insert(0, Message(role='system', content=self.default_system)) - for (_, messages) in trajectory.get('extend_message', []): - if messages and messages[0]['role'] == 'user': - messages.insert(0, Message(role='system', content=self.default_system)) return [trajectory] def _to_standard_reasoning_content(self, trajectory: Trajectory) -> List[Trajectory]: @@ -206,33 +203,47 @@ def _extract_reasoning_content(messages: list[Message]) -> List[Message]: return result trajectory['messages'] = _extract_reasoning_content(trajectory['messages']) - extra_messages = trajectory.get('extend_message', []) - if extra_messages: - result = [] - for key, extra_message in trajectory.get('extend_message', []): - result.append((key, _extract_reasoning_content(extra_message))) - trajectory['extend_message'] = result return [trajectory] + def _truncate_feature(self, feature: InputFeature, strategy: str) -> InputFeature: + """Truncate input_ids and labels in a single InputFeature.""" + length = len(feature['input_ids']) + if length <= self.max_length: + return feature + if strategy == 'raise': + raise ValueError(f'Input length {length} exceeds max_length {self.max_length}') + result = dict(feature) + if strategy == 'left': + result['input_ids'] = result['input_ids'][-self.max_length:] + if 'labels' in result: + result['labels'] = result['labels'][-self.max_length:] + elif strategy == 'right': + result['input_ids'] = result['input_ids'][:self.max_length] + if 'labels' in result: + result['labels'] = result['labels'][:self.max_length] + return InputFeature(**result) + def _check_max_length(self, input_feature: InputFeature) -> List[InputFeature]: - if self.max_length and len(input_feature['input_ids']) > self.max_length: - if self.truncation_strategy == 'raise': - raise ValueError(f'An input message(length: {len(input_feature["input_ids"])} ' - f'exceeds the maximum length({self.max_length})') - elif self.truncation_strategy == 'left': - return [InputFeature(**{key: value[-self.max_length:] for key, value in input_feature.items()})] - elif self.truncation_strategy == 'right': - return [InputFeature(**{key: value[:self.max_length] for key, value in input_feature.items()})] - else: # split - result = [] - total_length = len(input_feature['input_ids']) - for start in range(0, total_length, self.max_length): - end = min(start + self.max_length, total_length) - result.append(InputFeature(**{key: value[start:end] for key, value in input_feature.items()})) - return result - else: + if not self.max_length: return [input_feature] + strategy = self.truncation_strategy + + # Split strategy + if strategy == 'split': + results = [] + for start in range(0, len(input_feature['input_ids']), self.max_length): + end = min(start + self.max_length, len(input_feature['input_ids'])) + feat = dict(input_feature) + feat['input_ids'] = feat['input_ids'][start:end] + if 'labels' in feat: + feat['labels'] = feat['labels'][start:end] + results.append(InputFeature(**feat)) + return results + + # left/right/raise + return [self._truncate_feature(input_feature, strategy)] + def _add_attention_fields(self, input_feature: InputFeature) -> List[InputFeature]: input_ids = input_feature['input_ids'] input_feature['attention_mask'] = np.ones_like(input_ids) @@ -244,8 +255,8 @@ def _roll_labels(self, input_feature: InputFeature) -> List[InputFeature]: input_feature['labels'] = np.roll(input_feature['labels'], -1, axis=-1) return [input_feature] - def _build_mm_messages(self, trajectory: Trajectory) -> List[Trajectory]: - messages = trajectory['messages'] + def _process_mm_messages(self, messages: List) -> List: + """Process multimodal content in a list of messages.""" new_messages = [] for message in messages: message = copy(message) @@ -265,8 +276,10 @@ def _build_mm_messages(self, trajectory: Trajectory) -> List[Trajectory]: new_messages.append( transfer_to_standard_message(message, self.image_placeholder, self.video_placeholder, self.audio_placeholder, self.is_mm)) + return new_messages - trajectory['messages'] = new_messages + def _build_mm_messages(self, trajectory: Trajectory) -> List[Trajectory]: + trajectory['messages'] = self._process_mm_messages(trajectory['messages']) return [trajectory] def _apply_chat_template(self, trajectory: Trajectory, add_generation_prompt: bool = False, **kwargs): @@ -283,7 +296,8 @@ def _apply_chat_template(self, trajectory: Trajectory, add_generation_prompt: bo **kwargs) return inputs - def encode(self, trajectory: Trajectory, add_generation_prompt: bool = False) -> InputFeature: + def _encode_messages(self, trajectory: Trajectory, add_generation_prompt: bool = False) -> InputFeature: + """Encode a single trajectory's messages into InputFeature.""" if self.use_chat_template: if add_generation_prompt: # For inference: just get input_ids with generation prompt, no labels needed @@ -306,11 +320,17 @@ def encode(self, trajectory: Trajectory, add_generation_prompt: bool = False) -> input_ids = self.tokenizer.encode(text) encoded = {} labels = deepcopy(input_ids) - return InputFeature( + + input_feature = InputFeature( input_ids=np.array(input_ids), labels=np.array(labels), **encoded, ) + trajectory.update(input_feature) + return trajectory + + def encode(self, trajectory: Trajectory, add_generation_prompt: bool = False) -> InputFeature: + return self._encode_messages(trajectory, add_generation_prompt) @staticmethod def map_col_to_row(trajectories: Dict[str, Any]): @@ -338,18 +358,67 @@ def map_row_to_col(rows: List[Union[Dict[str, Any], InputFeature]]) -> Dict[str, return columns - def batch_encode(self, - trajectories: Union[Dict[str, Any], List[Trajectory]], - add_generation_prompt: bool = False) -> List[InputFeature]: - output = [] + def _is_trajectory(self, obj: Any) -> bool: + """Check if an object is a Trajectory (has 'messages' key).""" + return isinstance(obj, Mapping) and 'messages' in obj + + def _get_trajectory_keys(self, columnar: Mapping) -> List[str]: + """Get keys whose values are lists of Trajectories in columnar format.""" + keys = [] + for k, v in columnar.items(): + if isinstance(v, list) and v and self._is_trajectory(v[0]): + keys.append(k) + return keys + + def batch_encode( + self, + trajectories: Union[Dict[str, Any], List[Trajectory]], + add_generation_prompt: bool = False, + ) -> Union[Dict[str, Any], List[InputFeature]]: + """Encode trajectories into InputFeatures. + + Args: + trajectories: Either List[Trajectory] or columnar Dict[str, List]. + For DPO, columnar format with 'positive'/'negative' keys containing + List[Trajectory] is supported. + add_generation_prompt: Whether to add generation prompt. + + Returns: + List[InputFeature] or columnar Dict[str, List[InputFeature]]. + """ _transfer = False + if isinstance(trajectories, Mapping): _transfer = True - trajectories = self.map_col_to_row(trajectories) + # Check if it has trajectory list columns (DPO format) + traj_keys = self._get_trajectory_keys(trajectories) + if traj_keys: + # DPO format: encode each trajectory list separately, keep other columns + result = {} + for key in trajectories: + if key in traj_keys: + # Encode this trajectory list + result[key] = self.batch_encode(trajectories[key], add_generation_prompt=add_generation_prompt) + else: + # Keep non-trajectory columns as-is + result[key] = trajectories[key] + return result + else: + # Standard columnar format + trajectories = self.map_col_to_row(trajectories) + + # Process List[Trajectory] trajectories = self._invoke_pre_pipeline(trajectories) - for trajectory in trajectories: - output.append(self.encode(trajectory, add_generation_prompt=add_generation_prompt)) + + # Use thread pool for parallel encoding + from concurrent.futures import ThreadPoolExecutor + from functools import partial + encode_fn = partial(self.encode, add_generation_prompt=add_generation_prompt) + with ThreadPoolExecutor() as executor: + output = list(executor.map(encode_fn, trajectories)) + output = self._invoke_post_pipeline(output) + if _transfer: output = self.map_row_to_col(output) return output diff --git a/src/twinkle/utils/__init__.py b/src/twinkle/utils/__init__.py index 1b018773..ce91cbf8 100644 --- a/src/twinkle/utils/__init__.py +++ b/src/twinkle/utils/__init__.py @@ -10,7 +10,8 @@ from .parallel import processing_lock from .platforms import GPU, NPU, Platform, ensure_hccl_socket_env, ensure_npu_backend from .safetensors import LazyTensor, SafetensorLazyLoader, StreamingSafetensorSaver -from .torch_utils import pad_sequence_to_length, selective_log_softmax, stateless_init_process_group, to_device +from .torch_utils import (pad_and_stack_tensors, pad_sequence_to_length, selective_log_softmax, + stateless_init_process_group, to_device) from .transformers_utils import find_all_linears, find_layers, get_modules_to_not_convert from .unsafe import check_unsafe, trust_remote_code from .utils import copy_files_by_pattern, deep_getattr diff --git a/src/twinkle/utils/device_mesh.py b/src/twinkle/utils/device_mesh.py index 9f5aa9e7..4393fed7 100644 --- a/src/twinkle/utils/device_mesh.py +++ b/src/twinkle/utils/device_mesh.py @@ -462,6 +462,39 @@ def get_pp_last_ranks(self) -> Optional[list[int]]: pp_world_size = self.pp_world_size or 1 return self.get_pp_stage_ranks(pp_world_size - 1) + def get_collect_ranks(self) -> list[int]: + """Get ranks for collecting slice_dp dispatch results. + + For slice_dp dispatch, data is split by data_rank (DP/FSDP dimensions). + Within each data_rank group, outputs are identical (after all-gather for TP, etc.), + so we only need one representative per data_rank. + + Returns: One rank per data_rank, preferring PP last stage. + """ + data_ws = self.data_world_size + collect_ranks = [] + + # For each data_rank, find a representative global rank + for data_rank in range(data_ws): + # Find all global ranks that map to this data_rank + candidates = [ + r for r in self.mesh.flatten().tolist() if self.get_data_rank_from_global_rank(r) == data_rank + ] + if not candidates: + continue + + # Prefer PP last stage if PP exists + pp_last = self.get_pp_last_ranks() + if pp_last: + pp_candidates = [r for r in candidates if r in pp_last] + if pp_candidates: + candidates = pp_candidates + + # Take the smallest rank as representative + collect_ranks.append(min(candidates)) + + return sorted(collect_ranks) + def has_dim(self, dim_name: str) -> bool: if self.mesh_dim_names is None: return False diff --git a/src/twinkle/utils/torch_utils.py b/src/twinkle/utils/torch_utils.py index 4a721a2a..34506d6f 100644 --- a/src/twinkle/utils/torch_utils.py +++ b/src/twinkle/utils/torch_utils.py @@ -190,3 +190,36 @@ def stateless_init_process_group( communicator = Communicator(pg, device=device) return communicator + + +def pad_and_stack_tensors(tensors: List['torch.Tensor'], pad_value: float = -200) -> 'torch.Tensor': + import torch + if not tensors: + raise ValueError('Empty tensor list') + + if len(tensors) == 1: + return tensors[0] + + max_ndim = max(t.ndim for t in tensors) + expanded_tensors = [] + for t in tensors: + while t.ndim < max_ndim: + t = t.unsqueeze(0) + expanded_tensors.append(t) + + max_shape = [] + for dim in range(max_ndim): + max_shape.append(max(t.shape[dim] for t in expanded_tensors)) + + padded_tensors = [] + for t in expanded_tensors: + if list(t.shape) == max_shape: + padded_tensors.append(t) + else: + pad_params = [] + for dim in range(max_ndim - 1, -1, -1): + pad_params.extend([0, max_shape[dim] - t.shape[dim]]) + padded = torch.nn.functional.pad(t, pad_params, value=pad_value) + padded_tensors.append(padded) + + return torch.cat(padded_tensors, dim=0)