From 563fbddf979faf915f10ed2a4b49cd8aaff0be3b Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Thu, 26 Mar 2026 21:05:38 +0800 Subject: [PATCH 01/31] support dpo --- cookbook/rl/dpo.py | 267 +++++++++++ cookbook/rl/dpo.sh | 84 ++++ src/twinkle/loss/__init__.py | 6 + src/twinkle/loss/dpo.py | 655 +++++++++++++++++++++++++++ src/twinkle/preprocessor/__init__.py | 1 + src/twinkle/preprocessor/dpo.py | 387 ++++++++++++++++ 6 files changed, 1400 insertions(+) create mode 100644 cookbook/rl/dpo.py create mode 100644 cookbook/rl/dpo.sh create mode 100644 src/twinkle/loss/dpo.py create mode 100644 src/twinkle/preprocessor/dpo.py diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo.py new file mode 100644 index 00000000..1e03f84b --- /dev/null +++ b/cookbook/rl/dpo.py @@ -0,0 +1,267 @@ +"""DPO (Direct Preference Optimization) Training via Ray. + +Off-policy preference alignment: trains the model to prefer chosen responses +over rejected responses using preference data, without explicit reward modeling. + +Pipeline: + 1. Load preference dataset with chosen/rejected pairs. + 2. Compute reference model log probabilities (frozen). + 3. Train policy model using DPO loss. + +Architecture (Ray): + ┌─────────────────────────────────────────────────────────────────┐ + │ Driver (CPU) │ + │ dataloader ──► batched preference pairs │ + │ ref_model.forward_only() ──► reference log probs │ + │ policy_model.forward_backward() ──► DPO loss + gradient │ + └─────────────────────────────────────────────────────────────────┘ + │ │ │ + DataLoader RefModel (frozen) PolicyModel (trainable) + (ref GPUs) (policy GPUs) + +For SimPO/ORPO variants that don't require a reference model, +set USE_REFERENCE_MODEL=0 to skip reference model computation. + +Environment variables (all optional): + MODEL_ID – (default: ms://Qwen/Qwen3.5-4B) + DATASET_ID – (default: ms://argilla/ultrafeedback-binarized-preferences-cleaned) + MODEL_GPUS – GPUs for policy model (default: 4) + REF_MODEL_GPUS – GPUs for reference model (default: 4, 0 to disable) + USE_REFERENCE_MODEL – Whether to use reference model (default: 1) + BATCH_SIZE – global batch size (pairs) (default: 8) + MICRO_BATCH_SIZE – per-device micro batch size (default: 2) + MAX_STEPS – total optimization steps (default: 1000) + LR – learning rate (default: 5e-6) + DPO_BETA – DPO temperature parameter (default: 0.1) + LOSS_TYPE – DPO variant (sigmoid/hinge/ipo/simpo/orpo/cpo) (default: sigmoid) + SAVE_STEPS – checkpoint save interval (default: 100) + MAX_LENGTH – max sequence length (default: 2048) +""" + +import os +from typing import Any, Dict, List, Optional + +import torch +from peft import LoraConfig + +import twinkle +from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.loss import CPOLoss, DPOLoss, ORPOLoss, SimPOLoss +from twinkle.model import TransformersModel +from twinkle.preprocessor import DPOProcessor +from twinkle.processor import InputProcessor +from twinkle.template import Template + +logger = get_logger() + +# ── Configuration ───────────────────────────────────────────────────────────── +MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B') +DATASET_ID = os.environ.get('DATASET_ID', 'ms://argilla/ultrafeedback-binarized-preferences-cleaned') + +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) +REF_MODEL_GPUS = int(os.environ.get('REF_MODEL_GPUS', 4)) +USE_REFERENCE_MODEL = bool(int(os.environ.get('USE_REFERENCE_MODEL', 1))) + +# Adjust total GPUs based on whether reference model is used +if USE_REFERENCE_MODEL and REF_MODEL_GPUS > 0: + NUM_GPUS = MODEL_GPUS + REF_MODEL_GPUS +else: + NUM_GPUS = MODEL_GPUS + USE_REFERENCE_MODEL = False + +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) +MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) +GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) +MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) +LEARNING_RATE = float(os.environ.get('LR', 5e-6)) +DPO_BETA = float(os.environ.get('DPO_BETA', 0.1)) +LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo, simpo, orpo, cpo +SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 100)) +MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048)) +ADAPTER_NAME = 'default' + + +# ── Dataset ─────────────────────────────────────────────────────────────────── + +def create_dpo_dataset(): + """Create preference dataset for DPO training. + + The dataset should contain 'chosen' and 'rejected' columns after preprocessing. + Each sample will be duplicated: first the chosen, then the rejected version. + """ + dataset = Dataset(DatasetMeta(DATASET_ID, split='train')) + dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH) + + # Use DPOProcessor to convert dataset to standard format + # Adjust processor based on your dataset format + dataset.map(DPOProcessor( + system='You are a helpful, harmless, and honest assistant.', + chosen_key='chosen', + rejected_key='rejected', + prompt_key='prompt', + )) + + # Encode both chosen and rejected trajectories + dataset.encode() + return dataset + + +def collate_preference_batch(batch: List[Dict[str, Any]]) -> Dict[str, List]: + """Collate preference pairs into DPO batch format. + + DPO loss expects: [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] + """ + chosen_samples = [] + rejected_samples = [] + + for item in batch: + if 'chosen' in item and 'rejected' in item: + chosen_samples.append(item['chosen']) + rejected_samples.append(item['rejected']) + else: + # Assume alternating format if not explicitly separated + chosen_samples.append(item) + + # Concatenate: all chosen first, then all rejected + return chosen_samples + rejected_samples + + +# ── Loss Factory ────────────────────────────────────────────────────────────── + +def create_loss(loss_type: str, beta: float, reference_free: bool = False): + """Create the appropriate loss function based on configuration.""" + if loss_type == 'simpo': + return SimPOLoss(beta=beta, gamma=0.5) + elif loss_type == 'orpo': + return ORPOLoss(lambda_orpo=beta) + elif loss_type == 'cpo': + return CPOLoss(beta=beta, bc_coef=1.0) + else: + # Standard DPO variants: sigmoid, hinge, ipo + return DPOLoss( + beta=beta, + loss_type=loss_type, + reference_free=reference_free, + ) + + +# ── Main Training Loop ──────────────────────────────────────────────────────── + +def main(): + # Set up device groups + if USE_REFERENCE_MODEL: + device_groups = [ + DeviceGroup(name='policy', ranks=list(range(MODEL_GPUS)), device_type='GPU'), + DeviceGroup(name='reference', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU'), + ] + else: + device_groups = [ + DeviceGroup(name='policy', ranks=list(range(MODEL_GPUS)), device_type='GPU'), + ] + + policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) + twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False) + + # ── Policy Model Setup ──────────────────────────────────────────────────── + lora_config = LoraConfig( + target_modules=[ + 'q_proj', 'k_proj', 'v_proj', 'o_proj', + 'gate_proj', 'up_proj', 'down_proj', + ], + r=16, + lora_alpha=32, + lora_dropout=0.05, + ) + + policy_model = TransformersModel( + model_id=MODEL_ID, + device_mesh=policy_mesh, + remote_group='policy', + ) + policy_model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) + policy_model.set_optimizer('AdamW', lr=LEARNING_RATE, weight_decay=0.01) + policy_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=LEARNING_RATE * 0.1) + + # Determine if we need reference model based on loss type + reference_free = LOSS_TYPE in ['simpo', 'orpo', 'cpo'] + + # Set up loss function + loss_fn = create_loss(LOSS_TYPE, DPO_BETA, reference_free=not USE_REFERENCE_MODEL) + policy_model.set_loss(loss_fn) + policy_model.set_processor(InputProcessor) + policy_model.set_template('Template', model_id=MODEL_ID) + + # ── Reference Model Setup (if needed) ───────────────────────────────────── + ref_model = None + if USE_REFERENCE_MODEL and not reference_free: + ref_mesh = DeviceMesh.from_sizes(world_size=REF_MODEL_GPUS, dp_size=REF_MODEL_GPUS) + ref_model = TransformersModel( + model_id=MODEL_ID, + device_mesh=ref_mesh, + remote_group='reference', + ) + ref_model.set_processor(InputProcessor) + ref_model.set_template('Template', model_id=MODEL_ID) + logger.info('Reference model initialized for DPO training') + else: + logger.info(f'Training without reference model (loss_type={LOSS_TYPE})') + + # ── DataLoader Setup ────────────────────────────────────────────────────── + GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS + dataloader = DataLoader( + dataset=create_dpo_dataset, + batch_size=GLOBAL_BATCH_SIZE, + min_batch_size=GLOBAL_BATCH_SIZE, + device_mesh=policy_mesh, + remote_group='policy', + ) + + optim_step = 0 + logger.info(get_device_placement()) + logger.info(f'Starting DPO training: loss_type={LOSS_TYPE}, beta={DPO_BETA}, ' + f'use_ref_model={USE_REFERENCE_MODEL}') + + # ── Training Loop ───────────────────────────────────────────────────────── + for batch in dataloader: + if optim_step >= MAX_STEPS: + break + + # Collate preference pairs: [chosen..., rejected...] + preference_batch = collate_preference_batch(batch if isinstance(batch, list) else [batch]) + + # Compute reference log probabilities if using reference model + ref_logps = None + if ref_model is not None: + with torch.no_grad(): + ref_outputs = ref_model.forward_only(inputs=preference_batch) + ref_logps = ref_outputs.get('logps') + + # Forward-backward pass with DPO loss + policy_model.forward_backward( + inputs=preference_batch, + ref_logps=ref_logps, + micro_batch_size=MICRO_BATCH_SIZE, + ) + + # Gradient clipping and optimizer step + policy_model.clip_grad_and_step() + optim_step += 1 + + # Logging + if optim_step % 10 == 0: + metrics = policy_model.calculate_metric(is_training=True) + logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metrics}') + + # Checkpointing + if optim_step % SAVE_STEPS == 0: + policy_model.save(f'dpo-checkpoint-{optim_step}') + + # ── Save Final Checkpoint ───────────────────────────────────────────────── + logger.info(f'Training completed. Total steps: {optim_step}') + policy_model.save('dpo-final-checkpoint') + + +if __name__ == '__main__': + main() diff --git a/cookbook/rl/dpo.sh b/cookbook/rl/dpo.sh new file mode 100644 index 00000000..65206839 --- /dev/null +++ b/cookbook/rl/dpo.sh @@ -0,0 +1,84 @@ +#!/bin/bash +# DPO Training Script for Ray Mode +# +# This script launches DPO (Direct Preference Optimization) training using Ray +# for distributed training across multiple GPUs. +# +# Usage: +# ./dpo.sh # Default settings (8 GPUs: 4 policy + 4 ref) +# ./dpo.sh simpo # Use SimPO (no reference model needed) +# ./dpo.sh orpo # Use ORPO (no reference model needed) +# +# Environment variables can be set to customize training: +# MODEL_ID - Model to train (default: ms://Qwen/Qwen3.5-4B) +# DATASET_ID - Preference dataset (default: UltraFeedback) +# MODEL_GPUS - GPUs for policy model (default: 4) +# REF_MODEL_GPUS - GPUs for reference model (default: 4) +# USE_REFERENCE_MODEL - Use reference model (default: 1) +# BATCH_SIZE - Global batch size (default: 8) +# MAX_STEPS - Training steps (default: 1000) +# LR - Learning rate (default: 5e-6) +# DPO_BETA - DPO beta parameter (default: 0.1) +# LOSS_TYPE - Loss variant: sigmoid/hinge/ipo/simpo/orpo/cpo (default: sigmoid) + +set -e + +# Parse command line argument for loss type +LOSS_TYPE_ARG=${1:-sigmoid} + +# Set default environment variables if not already set +export MODEL_ID=${MODEL_ID:-"ms://Qwen/Qwen3.5-4B"} +export DATASET_ID=${DATASET_ID:-"ms://argilla/ultrafeedback-binarized-preferences-cleaned"} +export MODEL_GPUS=${MODEL_GPUS:-4} +export BATCH_SIZE=${BATCH_SIZE:-8} +export MICRO_BATCH_SIZE=${MICRO_BATCH_SIZE:-2} +export MAX_STEPS=${MAX_STEPS:-1000} +export LR=${LR:-5e-6} +export DPO_BETA=${DPO_BETA:-0.1} +export SAVE_STEPS=${SAVE_STEPS:-100} +export MAX_LENGTH=${MAX_LENGTH:-2048} + +# Set loss type from argument or environment +export LOSS_TYPE=${LOSS_TYPE:-$LOSS_TYPE_ARG} + +# Reference-free losses don't need reference model +if [[ "$LOSS_TYPE" == "simpo" || "$LOSS_TYPE" == "orpo" || "$LOSS_TYPE" == "cpo" ]]; then + export USE_REFERENCE_MODEL=${USE_REFERENCE_MODEL:-0} + export REF_MODEL_GPUS=${REF_MODEL_GPUS:-0} + echo "Using $LOSS_TYPE loss (reference-free)" +else + export USE_REFERENCE_MODEL=${USE_REFERENCE_MODEL:-1} + export REF_MODEL_GPUS=${REF_MODEL_GPUS:-4} + echo "Using $LOSS_TYPE loss with reference model" +fi + +# Calculate total GPUs +if [[ "$USE_REFERENCE_MODEL" == "1" && "$REF_MODEL_GPUS" -gt 0 ]]; then + TOTAL_GPUS=$((MODEL_GPUS + REF_MODEL_GPUS)) +else + TOTAL_GPUS=$MODEL_GPUS +fi + +echo "==========================================" +echo "DPO Training Configuration" +echo "==========================================" +echo "Model: $MODEL_ID" +echo "Dataset: $DATASET_ID" +echo "Loss Type: $LOSS_TYPE" +echo "DPO Beta: $DPO_BETA" +echo "Policy GPUs: $MODEL_GPUS" +echo "Reference GPUs: $REF_MODEL_GPUS" +echo "Total GPUs: $TOTAL_GPUS" +echo "Batch Size: $BATCH_SIZE" +echo "Micro Batch Size: $MICRO_BATCH_SIZE" +echo "Max Steps: $MAX_STEPS" +echo "Learning Rate: $LR" +echo "Max Length: $MAX_LENGTH" +echo "Save Steps: $SAVE_STEPS" +echo "==========================================" + +# Get script directory +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# Run training +python "$SCRIPT_DIR/dpo.py" diff --git a/src/twinkle/loss/__init__.py b/src/twinkle/loss/__init__.py index 7870f5a4..4e4d0e82 100644 --- a/src/twinkle/loss/__init__.py +++ b/src/twinkle/loss/__init__.py @@ -2,6 +2,7 @@ from .base import Loss from .chunked_cross_entropy import ChunkedCrossEntropyLoss from .cross_entropy import CrossEntropyLoss +from .dpo import CPOLoss, DPOLoss, ORPOLoss, SimPOLoss from .gkd import GKDLoss from .grpo import BNPOLoss, CISPOLoss, DRGRPOLoss, GRPOLoss, GSPOLoss, SAPOLoss from .mse import MSELoss @@ -19,4 +20,9 @@ 'cispo': CISPOLoss, 'bnpo': BNPOLoss, 'dr_grpo': DRGRPOLoss, + # DPO family losses + 'dpo': DPOLoss, + 'simpo': SimPOLoss, + 'cpo': CPOLoss, + 'orpo': ORPOLoss, } diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py new file mode 100644 index 00000000..0f3e364a --- /dev/null +++ b/src/twinkle/loss/dpo.py @@ -0,0 +1,655 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +DPO (Direct Preference Optimization) Loss Implementation. + +Reference: + "Direct Preference Optimization: Your Language Model is Secretly a Reward Model" + (https://arxiv.org/abs/2305.18290) +""" +from typing import TYPE_CHECKING, Dict, List, Optional, Union + +import numpy as np + +from twinkle.data_format import LossOutput +from twinkle.kernel import selective_log_softmax +from twinkle.loss.base import Loss + +if TYPE_CHECKING: + import torch + + +class DPOLoss(Loss): + """Direct Preference Optimization (DPO) Loss. + + DPO directly optimizes the policy using preference data without explicit reward modeling. + The loss function is derived from the Bradley-Terry preference model: + + L_DPO = -log(σ(β * (log π(y_w|x)/π_ref(y_w|x) - log π(y_l|x)/π_ref(y_l|x)))) + + where: + - y_w is the preferred (chosen) response + - y_l is the dispreferred (rejected) response + - β is the temperature parameter controlling deviation from reference + - π is the current policy + - π_ref is the reference policy (frozen) + + Args: + beta: Temperature parameter controlling how much to deviate from ref policy (default: 0.1). + label_smoothing: Label smoothing parameter for soft labels (default: 0.0). + ignore_index: Index to ignore in labels (default: -100). + loss_type: Type of DPO loss variant ('sigmoid', 'hinge', 'ipo', 'kto_pair') (default: 'sigmoid'). + reference_free: Whether to use reference-free DPO (default: False). + """ + + def __init__( + self, + beta: float = 0.1, + label_smoothing: float = 0.0, + ignore_index: int = -100, + loss_type: str = 'sigmoid', + reference_free: bool = False, + **kwargs, + ): + self.beta = beta + self.label_smoothing = label_smoothing + self.ignore_index = ignore_index + self.loss_type = loss_type + self.reference_free = reference_free + + def _compute_logps_from_logits( + self, + logits: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute per-token log probabilities from logits. + + Args: + logits: [batch, seq_len, vocab_size] model logits + labels: [batch, seq_len] target token ids + + Returns: + logps: [batch, seq_len] per-token log probabilities + """ + loss_mask = (labels != self.ignore_index).bool() + masked_labels = labels.clone() + masked_labels[~loss_mask] = 0 + logps = selective_log_softmax(logits, masked_labels) + return logps + + def _compute_sequence_logps( + self, + per_token_logps: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute sequence-level log probabilities by summing valid token logps. + + Args: + per_token_logps: [batch, seq_len] per-token log probabilities + labels: [batch, seq_len] labels for computing mask + + Returns: + seq_logps: [batch] sequence-level log probabilities + """ + loss_mask = (labels != self.ignore_index).float() + return (per_token_logps * loss_mask).sum(dim=-1) + + def _pad_and_align_logps( + self, + logps: Union['torch.Tensor', List[List[float]]], + target_shape: tuple, + loss_mask: 'torch.Tensor', + device: 'torch.device', + dtype: 'torch.dtype', + ) -> 'torch.Tensor': + """Pad and align log probabilities to target shape. + + Args: + logps: Input log probabilities (tensor or ragged list) + target_shape: Target (batch, seq_len) shape + loss_mask: Boolean mask for valid positions + device: Target device + dtype: Target dtype + + Returns: + Aligned tensor of shape target_shape + """ + import torch + + if torch.is_tensor(logps): + if logps.shape == target_shape: + return logps.to(device=device, dtype=dtype) + elif logps.dim() == 1: + logps = logps.unsqueeze(0) + if logps.shape == target_shape: + return logps.to(device=device, dtype=dtype) + + # Handle ragged list input + if isinstance(logps, (list, tuple)): + batch_size, seq_len = target_shape + padded = torch.zeros(target_shape, device=device, dtype=dtype) + for i, row in enumerate(logps): + if row is None: + continue + row_t = torch.as_tensor(row, device=device, dtype=dtype) + valid_positions = loss_mask[i].nonzero(as_tuple=True)[0] + length = min(len(row_t), len(valid_positions)) + if length > 0: + padded[i, valid_positions[:length]] = row_t[:length] + return padded + + return logps.to(device=device, dtype=dtype) + + def _compute_dpo_loss( + self, + policy_chosen_logps: 'torch.Tensor', + policy_rejected_logps: 'torch.Tensor', + reference_chosen_logps: 'torch.Tensor', + reference_rejected_logps: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute the DPO loss. + + Args: + policy_chosen_logps: [batch/2] log probs of chosen under current policy + policy_rejected_logps: [batch/2] log probs of rejected under current policy + reference_chosen_logps: [batch/2] log probs of chosen under reference policy + reference_rejected_logps: [batch/2] log probs of rejected under reference policy + + Returns: + loss: Scalar DPO loss + """ + import torch + import torch.nn.functional as F + + # Compute log ratios + if self.reference_free: + # Reference-free: only use policy log probs + chosen_logratios = policy_chosen_logps + rejected_logratios = policy_rejected_logps + else: + chosen_logratios = policy_chosen_logps - reference_chosen_logps + rejected_logratios = policy_rejected_logps - reference_rejected_logps + + # Compute preference margin + logits = self.beta * (chosen_logratios - rejected_logratios) + + if self.loss_type == 'sigmoid': + # Standard DPO loss: -log(sigmoid(beta * margin)) + losses = -F.logsigmoid(logits) + elif self.loss_type == 'hinge': + # Hinge loss variant + losses = torch.relu(1 - logits) + elif self.loss_type == 'ipo': + # IPO (Identity Preference Optimization) loss + # Reference: "A General Theoretical Paradigm to Understand Learning from Human Feedback" + losses = (logits - 1 / (2 * self.beta)) ** 2 + elif self.loss_type == 'kto_pair': + # KTO pair loss (simplified version) + chosen_logratios_scaled = self.beta * chosen_logratios + rejected_logratios_scaled = self.beta * rejected_logratios + chosen_losses = 1 - F.sigmoid(chosen_logratios_scaled) + rejected_losses = F.sigmoid(rejected_logratios_scaled) + losses = chosen_losses + rejected_losses + else: + raise ValueError(f"Unknown loss_type: {self.loss_type}") + + # Apply label smoothing if specified + if self.label_smoothing > 0: + # Soft labels: (1 - eps) * loss_chosen + eps * loss_rejected + smooth_losses = -F.logsigmoid(-logits) # Loss for flipped preference + losses = (1 - self.label_smoothing) * losses + self.label_smoothing * smooth_losses + + return losses.mean() + + def __call__( + self, + inputs: Dict, + outputs: Dict, + *, + ref_logps: Optional[Union['torch.Tensor', List[List[float]]]] = None, + ref_chosen_logps: Optional['torch.Tensor'] = None, + ref_rejected_logps: Optional['torch.Tensor'] = None, + **kwargs, + ) -> LossOutput: + """Compute DPO loss. + + The inputs should contain concatenated chosen and rejected examples: + - First half of batch: chosen responses + - Second half of batch: rejected responses + + Args: + inputs: Dict containing 'input_ids' and 'labels' [batch, seq_len]. + Batch should be organized as [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] + outputs: Dict containing either: + - 'logps': [batch, seq_len] pre-computed log probs, OR + - 'logits': [batch, seq_len, vocab] from which logps will be computed + ref_logps: [batch, seq_len] or List[List[float]] reference model log probs. + Can also be provided as separate ref_chosen_logps and ref_rejected_logps. + ref_chosen_logps: [batch/2] pre-computed reference log probs for chosen. + ref_rejected_logps: [batch/2] pre-computed reference log probs for rejected. + **kwargs: Additional arguments. + + Returns: + LossOutput with DPO loss and metrics. + """ + import torch + + labels = inputs.get('labels') + assert labels is not None, "inputs must contain 'labels'" + if not torch.is_tensor(labels): + labels = torch.as_tensor(labels) + if labels.dim() == 1: + labels = labels.unsqueeze(0) + + batch_size = labels.shape[0] + assert batch_size % 2 == 0, "Batch size must be even (chosen + rejected pairs)" + half_batch = batch_size // 2 + + # Get log probabilities from outputs + logps = outputs.get('logps') + if logps is None: + logits = outputs.get('logits') + assert logits is not None, "outputs must contain 'logps' or 'logits'" + if logits.shape[1] != labels.shape[1]: + logits = logits[:, -labels.shape[1]:] + logps = self._compute_logps_from_logits(logits, labels) + + device = logps.device + dtype = logps.dtype + + # Split into chosen and rejected + chosen_labels = labels[:half_batch] + rejected_labels = labels[half_batch:] + chosen_logps = logps[:half_batch] + rejected_logps = logps[half_batch:] + + # Compute sequence-level log probs for policy + policy_chosen_logps = self._compute_sequence_logps(chosen_logps, chosen_labels) + policy_rejected_logps = self._compute_sequence_logps(rejected_logps, rejected_labels) + + # Handle reference log probs + if ref_chosen_logps is not None and ref_rejected_logps is not None: + # Pre-computed sequence-level reference log probs provided + reference_chosen_logps = ref_chosen_logps.to(device=device, dtype=dtype) + reference_rejected_logps = ref_rejected_logps.to(device=device, dtype=dtype) + elif ref_logps is not None: + # Per-token reference log probs provided, need to align and sum + loss_mask = (labels != self.ignore_index).bool() + ref_logps_aligned = self._pad_and_align_logps( + ref_logps, labels.shape, loss_mask, device, dtype + ) + ref_chosen = ref_logps_aligned[:half_batch] + ref_rejected = ref_logps_aligned[half_batch:] + reference_chosen_logps = self._compute_sequence_logps(ref_chosen, chosen_labels) + reference_rejected_logps = self._compute_sequence_logps(ref_rejected, rejected_labels) + elif self.reference_free: + # Reference-free mode: no reference model needed + reference_chosen_logps = torch.zeros_like(policy_chosen_logps) + reference_rejected_logps = torch.zeros_like(policy_rejected_logps) + else: + raise ValueError( + "ref_logps or (ref_chosen_logps, ref_rejected_logps) must be provided " + "unless reference_free=True" + ) + + # Compute DPO loss + loss = self._compute_dpo_loss( + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + ) + + return LossOutput(loss=loss, num_tokens=0) + + +class SimPOLoss(Loss): + """SimPO (Simple Preference Optimization) Loss. + + SimPO is a simpler variant of DPO that doesn't require a reference model. + It uses length-normalized log probabilities. + + Reference: + "SimPO: Simple Preference Optimization with a Reference-Free Reward" + (https://arxiv.org/abs/2405.14734) + + Args: + beta: Temperature parameter (default: 2.5). + gamma: Target reward margin (default: 0.5). + ignore_index: Index to ignore in labels (default: -100). + """ + + def __init__( + self, + beta: float = 2.5, + gamma: float = 0.5, + ignore_index: int = -100, + **kwargs, + ): + self.beta = beta + self.gamma = gamma + self.ignore_index = ignore_index + + def _compute_logps_from_logits( + self, + logits: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute per-token log probabilities from logits.""" + loss_mask = (labels != self.ignore_index).bool() + masked_labels = labels.clone() + masked_labels[~loss_mask] = 0 + logps = selective_log_softmax(logits, masked_labels) + return logps + + def _compute_length_normalized_logps( + self, + per_token_logps: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute length-normalized sequence log probabilities. + + Args: + per_token_logps: [batch, seq_len] per-token log probabilities + labels: [batch, seq_len] labels for computing mask + + Returns: + normalized_logps: [batch] length-normalized log probabilities + """ + loss_mask = (labels != self.ignore_index).float() + seq_lengths = loss_mask.sum(dim=-1).clamp(min=1) + seq_logps = (per_token_logps * loss_mask).sum(dim=-1) + return seq_logps / seq_lengths + + def __call__( + self, + inputs: Dict, + outputs: Dict, + **kwargs, + ) -> LossOutput: + """Compute SimPO loss. + + Args: + inputs: Dict containing 'input_ids' and 'labels' [batch, seq_len]. + Batch: [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] + outputs: Dict containing 'logps' or 'logits'. + + Returns: + LossOutput with SimPO loss. + """ + import torch + import torch.nn.functional as F + + labels = inputs.get('labels') + assert labels is not None, "inputs must contain 'labels'" + if not torch.is_tensor(labels): + labels = torch.as_tensor(labels) + if labels.dim() == 1: + labels = labels.unsqueeze(0) + + batch_size = labels.shape[0] + assert batch_size % 2 == 0, "Batch size must be even (chosen + rejected pairs)" + half_batch = batch_size // 2 + + # Get log probabilities + logps = outputs.get('logps') + if logps is None: + logits = outputs.get('logits') + assert logits is not None, "outputs must contain 'logps' or 'logits'" + if logits.shape[1] != labels.shape[1]: + logits = logits[:, -labels.shape[1]:] + logps = self._compute_logps_from_logits(logits, labels) + + # Split into chosen and rejected + chosen_labels = labels[:half_batch] + rejected_labels = labels[half_batch:] + chosen_logps = logps[:half_batch] + rejected_logps = logps[half_batch:] + + # Compute length-normalized log probs + chosen_rewards = self._compute_length_normalized_logps(chosen_logps, chosen_labels) + rejected_rewards = self._compute_length_normalized_logps(rejected_logps, rejected_labels) + + # SimPO loss: -log(sigmoid(beta * (r_w - r_l) - gamma)) + logits = self.beta * (chosen_rewards - rejected_rewards) - self.gamma + loss = -F.logsigmoid(logits).mean() + + return LossOutput(loss=loss, num_tokens=0) + + +class CPOLoss(Loss): + """CPO (Contrastive Preference Optimization) Loss. + + CPO adds a behavior cloning term to preference optimization. + + Reference: + "Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation" + (https://arxiv.org/abs/2401.08417) + + Args: + beta: Temperature parameter for preference (default: 0.1). + bc_coef: Behavior cloning coefficient (default: 1.0). + ignore_index: Index to ignore in labels (default: -100). + """ + + def __init__( + self, + beta: float = 0.1, + bc_coef: float = 1.0, + ignore_index: int = -100, + **kwargs, + ): + self.beta = beta + self.bc_coef = bc_coef + self.ignore_index = ignore_index + + def _compute_logps_from_logits( + self, + logits: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute per-token log probabilities from logits.""" + loss_mask = (labels != self.ignore_index).bool() + masked_labels = labels.clone() + masked_labels[~loss_mask] = 0 + logps = selective_log_softmax(logits, masked_labels) + return logps + + def _compute_sequence_logps( + self, + per_token_logps: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute sequence-level log probabilities.""" + loss_mask = (labels != self.ignore_index).float() + return (per_token_logps * loss_mask).sum(dim=-1) + + def _compute_nll_loss( + self, + per_token_logps: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute negative log likelihood loss for chosen responses.""" + loss_mask = (labels != self.ignore_index).float() + nll = -(per_token_logps * loss_mask).sum() / loss_mask.sum().clamp(min=1) + return nll + + def __call__( + self, + inputs: Dict, + outputs: Dict, + **kwargs, + ) -> LossOutput: + """Compute CPO loss. + + Args: + inputs: Dict containing 'labels' [batch, seq_len]. + outputs: Dict containing 'logps' or 'logits'. + + Returns: + LossOutput with CPO loss. + """ + import torch + import torch.nn.functional as F + + labels = inputs.get('labels') + assert labels is not None, "inputs must contain 'labels'" + if not torch.is_tensor(labels): + labels = torch.as_tensor(labels) + if labels.dim() == 1: + labels = labels.unsqueeze(0) + + batch_size = labels.shape[0] + assert batch_size % 2 == 0, "Batch size must be even" + half_batch = batch_size // 2 + + # Get log probabilities + logps = outputs.get('logps') + if logps is None: + logits = outputs.get('logits') + assert logits is not None, "outputs must contain 'logps' or 'logits'" + if logits.shape[1] != labels.shape[1]: + logits = logits[:, -labels.shape[1]:] + logps = self._compute_logps_from_logits(logits, labels) + + # Split into chosen and rejected + chosen_labels = labels[:half_batch] + rejected_labels = labels[half_batch:] + chosen_logps = logps[:half_batch] + rejected_logps = logps[half_batch:] + + # Compute sequence-level log probs + chosen_seq_logps = self._compute_sequence_logps(chosen_logps, chosen_labels) + rejected_seq_logps = self._compute_sequence_logps(rejected_logps, rejected_labels) + + # Preference loss (reference-free DPO) + logits = self.beta * (chosen_seq_logps - rejected_seq_logps) + preference_loss = -F.logsigmoid(logits).mean() + + # Behavior cloning loss on chosen + bc_loss = self._compute_nll_loss(chosen_logps, chosen_labels) + + # Combined loss + loss = preference_loss + self.bc_coef * bc_loss + + return LossOutput(loss=loss, num_tokens=0) + + +class ORPOLoss(Loss): + """ORPO (Odds Ratio Preference Optimization) Loss. + + ORPO combines SFT and preference alignment in a single objective using odds ratios. + + Reference: + "ORPO: Monolithic Preference Optimization without Reference Model" + (https://arxiv.org/abs/2403.07691) + + Args: + lambda_orpo: Weight for the odds ratio term (default: 0.1). + ignore_index: Index to ignore in labels (default: -100). + """ + + def __init__( + self, + lambda_orpo: float = 0.1, + ignore_index: int = -100, + **kwargs, + ): + self.lambda_orpo = lambda_orpo + self.ignore_index = ignore_index + + def _compute_logps_from_logits( + self, + logits: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute per-token log probabilities from logits.""" + loss_mask = (labels != self.ignore_index).bool() + masked_labels = labels.clone() + masked_labels[~loss_mask] = 0 + logps = selective_log_softmax(logits, masked_labels) + return logps + + def _compute_avg_logps( + self, + per_token_logps: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute average log probabilities over valid tokens.""" + loss_mask = (labels != self.ignore_index).float() + seq_lengths = loss_mask.sum(dim=-1).clamp(min=1) + return (per_token_logps * loss_mask).sum(dim=-1) / seq_lengths + + def _compute_nll_loss( + self, + per_token_logps: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute negative log likelihood loss.""" + loss_mask = (labels != self.ignore_index).float() + return -(per_token_logps * loss_mask).sum() / loss_mask.sum().clamp(min=1) + + def __call__( + self, + inputs: Dict, + outputs: Dict, + **kwargs, + ) -> LossOutput: + """Compute ORPO loss. + + Args: + inputs: Dict containing 'labels' [batch, seq_len]. + outputs: Dict containing 'logps' or 'logits'. + + Returns: + LossOutput with ORPO loss. + """ + import torch + import torch.nn.functional as F + + labels = inputs.get('labels') + assert labels is not None, "inputs must contain 'labels'" + if not torch.is_tensor(labels): + labels = torch.as_tensor(labels) + if labels.dim() == 1: + labels = labels.unsqueeze(0) + + batch_size = labels.shape[0] + assert batch_size % 2 == 0, "Batch size must be even" + half_batch = batch_size // 2 + + # Get log probabilities + logps = outputs.get('logps') + if logps is None: + logits = outputs.get('logits') + assert logits is not None, "outputs must contain 'logps' or 'logits'" + if logits.shape[1] != labels.shape[1]: + logits = logits[:, -labels.shape[1]:] + logps = self._compute_logps_from_logits(logits, labels) + + # Split into chosen and rejected + chosen_labels = labels[:half_batch] + rejected_labels = labels[half_batch:] + chosen_logps = logps[:half_batch] + rejected_logps = logps[half_batch:] + + # SFT loss on chosen + sft_loss = self._compute_nll_loss(chosen_logps, chosen_labels) + + # Compute average log probs for odds ratio + chosen_avg_logps = self._compute_avg_logps(chosen_logps, chosen_labels) + rejected_avg_logps = self._compute_avg_logps(rejected_logps, rejected_labels) + + # Odds ratio: log(odds_chosen / odds_rejected) + # log(p/(1-p)) ≈ log(p) - log(1-p) ≈ log(p) + p for small p + # Simplified: log_odds = avg_logp (since p is small) + log_odds_chosen = chosen_avg_logps - torch.log1p(-torch.exp(chosen_avg_logps).clamp(max=1-1e-7)) + log_odds_rejected = rejected_avg_logps - torch.log1p(-torch.exp(rejected_avg_logps).clamp(max=1-1e-7)) + + # ORPO odds ratio loss + odds_ratio = log_odds_chosen - log_odds_rejected + orpo_loss = -F.logsigmoid(odds_ratio).mean() + + # Combined loss + loss = sft_loss + self.lambda_orpo * orpo_loss + + return LossOutput(loss=loss, num_tokens=0) diff --git a/src/twinkle/preprocessor/__init__.py b/src/twinkle/preprocessor/__init__.py index 13b52d99..18dae667 100644 --- a/src/twinkle/preprocessor/__init__.py +++ b/src/twinkle/preprocessor/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .base import DataFilter, Preprocessor +from .dpo import DPOProcessor, HHRLHFProcessor, IntelOrcaDPOProcessor, ShareGPTDPOProcessor, UltraFeedbackProcessor from .llm import (AlpacaProcessor, CompetitionMathGRPOProcessor, CompetitionMathProcessor, CountdownProcessor, GSM8KProcessor, SelfCognitionProcessor) diff --git a/src/twinkle/preprocessor/dpo.py b/src/twinkle/preprocessor/dpo.py new file mode 100644 index 00000000..c951caab --- /dev/null +++ b/src/twinkle/preprocessor/dpo.py @@ -0,0 +1,387 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +DPO (Direct Preference Optimization) Data Preprocessors. + +These preprocessors convert various preference dataset formats into the standard +Trajectory format required by Twinkle for DPO training. +""" +from typing import Any, Dict, List, Optional, Union + +from twinkle.data_format import Message, Trajectory +from .base import Preprocessor + + +class DPOProcessor(Preprocessor): + """Generic DPO preference data preprocessor. + + Converts preference data with chosen/rejected pairs into Trajectory format. + Supports multiple common dataset formats. + + Expected input format (one of): + 1. {'prompt': str, 'chosen': str, 'rejected': str} + 2. {'prompt': str, 'chosen': List[Message], 'rejected': List[Message]} + 3. {'messages': List[Message], 'chosen': str, 'rejected': str} + 4. {'chosen': List[Message], 'rejected': List[Message]} (full conversations) + + Output: Each sample generates TWO Trajectories: + - First: chosen response trajectory + - Second: rejected response trajectory + The DPO loss expects batch to be [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] + + Args: + system: Optional system prompt to prepend. + chosen_key: Key for chosen response (default: 'chosen'). + rejected_key: Key for rejected response (default: 'rejected'). + prompt_key: Key for prompt/question (default: 'prompt'). + messages_key: Key for conversation messages (default: 'messages'). + """ + + def __init__( + self, + system: Optional[str] = None, + chosen_key: str = 'chosen', + rejected_key: str = 'rejected', + prompt_key: str = 'prompt', + messages_key: str = 'messages', + ): + self.system = system + self.chosen_key = chosen_key + self.rejected_key = rejected_key + self.prompt_key = prompt_key + self.messages_key = messages_key + + def _parse_response(self, response: Union[str, List[Dict], List[Message]]) -> List[Message]: + """Parse response into list of Messages.""" + if isinstance(response, str): + return [Message(role='assistant', content=response)] + elif isinstance(response, list): + messages = [] + for msg in response: + if isinstance(msg, Message): + messages.append(msg) + elif isinstance(msg, dict): + messages.append(Message(role=msg.get('role', 'assistant'), content=msg.get('content', ''))) + return messages + return [Message(role='assistant', content=str(response))] + + def _build_prompt_messages(self, row: Dict[str, Any]) -> List[Message]: + """Build prompt messages from row data.""" + messages = [] + + # Add system message if provided + if self.system: + messages.append(Message(role='system', content=self.system)) + + # Check for messages field (conversation format) + if self.messages_key in row and row[self.messages_key]: + raw_messages = row[self.messages_key] + for msg in raw_messages: + if isinstance(msg, Message): + messages.append(msg) + elif isinstance(msg, dict): + messages.append(Message(role=msg.get('role'), content=msg.get('content', ''))) + return messages + + # Check for prompt field + if self.prompt_key in row and row[self.prompt_key]: + prompt = row[self.prompt_key] + if isinstance(prompt, str): + messages.append(Message(role='user', content=prompt)) + elif isinstance(prompt, list): + for msg in prompt: + if isinstance(msg, Message): + messages.append(msg) + elif isinstance(msg, dict): + messages.append(Message(role=msg.get('role'), content=msg.get('content', ''))) + + return messages + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: + """Process a single row into chosen and rejected trajectories. + + Returns: + Dict with 'chosen' and 'rejected' Trajectory objects. + """ + # Build prompt messages + prompt_messages = self._build_prompt_messages(row) + + # Get chosen response + chosen_raw = row.get(self.chosen_key, '') + chosen_messages = self._parse_response(chosen_raw) + + # Get rejected response + rejected_raw = row.get(self.rejected_key, '') + rejected_messages = self._parse_response(rejected_raw) + + # Build full trajectories + chosen_trajectory = Trajectory(messages=prompt_messages + chosen_messages) + rejected_trajectory = Trajectory(messages=prompt_messages + rejected_messages) + + return { + 'chosen': chosen_trajectory, + 'rejected': rejected_trajectory, + } + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + """Process batched data into paired trajectories. + + Note: Output maintains separate 'chosen' and 'rejected' columns. + The DataLoader/collator should handle pairing them appropriately + for the DPO loss (concatenating chosen batch + rejected batch). + """ + rows = self.map_col_to_row(rows) + processed = [self.preprocess(row) for row in rows] + return self.map_row_to_col(processed) + + +class HHRLHFProcessor(Preprocessor): + """Preprocessor for Anthropic HH-RLHF dataset format. + + HH-RLHF format: + {'chosen': "Human: ... Assistant: ...", 'rejected': "Human: ... Assistant: ..."} + + The conversations use "Human:" and "Assistant:" prefixes. + """ + + def __init__(self, system: Optional[str] = None): + self.system = system + + def _parse_hh_conversation(self, text: str) -> List[Message]: + """Parse HH-RLHF style conversation text into Messages.""" + messages = [] + + if self.system: + messages.append(Message(role='system', content=self.system)) + + # Split by Human/Assistant markers + parts = text.split('\n\nHuman: ') + for i, part in enumerate(parts): + if i == 0 and not part.startswith('Human: '): + if part.strip(): + # Initial text before first Human marker + if part.startswith('Human: '): + part = part[7:] # Remove "Human: " prefix + messages.append(Message(role='user', content=part.strip())) + continue + + # Split Human and Assistant parts + if '\n\nAssistant: ' in part: + human_part, assistant_part = part.split('\n\nAssistant: ', 1) + messages.append(Message(role='user', content=human_part.strip())) + messages.append(Message(role='assistant', content=assistant_part.strip())) + else: + messages.append(Message(role='user', content=part.strip())) + + return messages + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: + """Process HH-RLHF format row.""" + chosen_text = row.get('chosen', '') + rejected_text = row.get('rejected', '') + + chosen_messages = self._parse_hh_conversation(chosen_text) + rejected_messages = self._parse_hh_conversation(rejected_text) + + return { + 'chosen': Trajectory(messages=chosen_messages), + 'rejected': Trajectory(messages=rejected_messages), + } + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + processed = [self.preprocess(row) for row in rows] + return self.map_row_to_col(processed) + + +class UltraFeedbackProcessor(Preprocessor): + """Preprocessor for UltraFeedback dataset format. + + UltraFeedback format: + { + 'instruction': str, + 'completions': [ + {'response': str, 'overall_score': float, ...}, + ... + ] + } + + Selects highest and lowest scored completions as chosen/rejected. + """ + + def __init__( + self, + system: Optional[str] = None, + instruction_key: str = 'instruction', + completions_key: str = 'completions', + response_key: str = 'response', + score_key: str = 'overall_score', + ): + self.system = system + self.instruction_key = instruction_key + self.completions_key = completions_key + self.response_key = response_key + self.score_key = score_key + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Trajectory]]: + """Process UltraFeedback format row.""" + instruction = row.get(self.instruction_key, '') + completions = row.get(self.completions_key, []) + + if len(completions) < 2: + return None # Need at least 2 completions for preference + + # Sort by score + scored_completions = [ + (c.get(self.score_key, 0), c.get(self.response_key, '')) + for c in completions + if c.get(self.response_key) + ] + + if len(scored_completions) < 2: + return None + + scored_completions.sort(key=lambda x: x[0], reverse=True) + chosen_response = scored_completions[0][1] + rejected_response = scored_completions[-1][1] + + # Build messages + messages = [] + if self.system: + messages.append(Message(role='system', content=self.system)) + messages.append(Message(role='user', content=instruction)) + + chosen_trajectory = Trajectory( + messages=messages + [Message(role='assistant', content=chosen_response)] + ) + rejected_trajectory = Trajectory( + messages=messages + [Message(role='assistant', content=rejected_response)] + ) + + return { + 'chosen': chosen_trajectory, + 'rejected': rejected_trajectory, + } + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + processed = [self.preprocess(row) for row in rows if self.preprocess(row) is not None] + if not processed: + return {'chosen': [], 'rejected': []} + return self.map_row_to_col(processed) + + +class ShareGPTDPOProcessor(Preprocessor): + """Preprocessor for ShareGPT-style DPO datasets. + + Expected format: + { + 'conversations': [ + {'from': 'human', 'value': '...'}, + {'from': 'gpt', 'value': '...'}, + ... + ], + 'chosen': {'from': 'gpt', 'value': '...'}, + 'rejected': {'from': 'gpt', 'value': '...'} + } + """ + + ROLE_MAPPING = { + 'human': 'user', + 'gpt': 'assistant', + 'system': 'system', + 'user': 'user', + 'assistant': 'assistant', + } + + def __init__(self, system: Optional[str] = None): + self.system = system + + def _parse_sharegpt_message(self, msg: Dict) -> Message: + """Parse ShareGPT format message.""" + role = self.ROLE_MAPPING.get(msg.get('from', ''), 'user') + content = msg.get('value', '') or msg.get('content', '') + return Message(role=role, content=content) + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: + """Process ShareGPT DPO format row.""" + conversations = row.get('conversations', []) + + # Build prompt messages (excluding last assistant turn if present) + messages = [] + if self.system: + messages.append(Message(role='system', content=self.system)) + + for msg in conversations: + messages.append(self._parse_sharegpt_message(msg)) + + # Remove last message if it's assistant (will be replaced by chosen/rejected) + if messages and messages[-1].role == 'assistant': + messages = messages[:-1] + + # Get chosen and rejected + chosen_msg = row.get('chosen', {}) + rejected_msg = row.get('rejected', {}) + + if isinstance(chosen_msg, dict): + chosen_response = Message( + role='assistant', + content=chosen_msg.get('value', '') or chosen_msg.get('content', '') + ) + else: + chosen_response = Message(role='assistant', content=str(chosen_msg)) + + if isinstance(rejected_msg, dict): + rejected_response = Message( + role='assistant', + content=rejected_msg.get('value', '') or rejected_msg.get('content', '') + ) + else: + rejected_response = Message(role='assistant', content=str(rejected_msg)) + + return { + 'chosen': Trajectory(messages=messages + [chosen_response]), + 'rejected': Trajectory(messages=messages + [rejected_response]), + } + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + processed = [self.preprocess(row) for row in rows] + return self.map_row_to_col(processed) + + +class IntelOrcaDPOProcessor(Preprocessor): + """Preprocessor for Intel ORCA DPO dataset format. + + Expected format: + { + 'system': str, + 'question': str, + 'chosen': str, + 'rejected': str + } + """ + + def __init__(self, default_system: Optional[str] = None): + self.default_system = default_system + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: + """Process Intel ORCA DPO format row.""" + system = row.get('system', self.default_system) + question = row.get('question', '') + chosen = row.get('chosen', '') + rejected = row.get('rejected', '') + + messages = [] + if system: + messages.append(Message(role='system', content=system)) + messages.append(Message(role='user', content=question)) + + return { + 'chosen': Trajectory(messages=messages + [Message(role='assistant', content=chosen)]), + 'rejected': Trajectory(messages=messages + [Message(role='assistant', content=rejected)]), + } + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + processed = [self.preprocess(row) for row in rows] + return self.map_row_to_col(processed) From 52978e9312cc7091aa89265c59ac43e9404814f7 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Thu, 26 Mar 2026 23:03:16 +0800 Subject: [PATCH 02/31] fix --- cookbook/rl/dpo.py | 48 ++++++++++++++++++++------------- src/twinkle/loss/dpo.py | 12 ++++----- src/twinkle/preprocessor/dpo.py | 38 ++++++++++++++++++++------ 3 files changed, 66 insertions(+), 32 deletions(-) diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo.py index 1e03f84b..0a776f29 100644 --- a/cookbook/rl/dpo.py +++ b/cookbook/rl/dpo.py @@ -71,8 +71,8 @@ NUM_GPUS = MODEL_GPUS USE_REFERENCE_MODEL = False -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) -MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs +MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) # Must be even (chosen + rejected) GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) LEARNING_RATE = float(os.environ.get('LR', 5e-6)) @@ -88,41 +88,49 @@ def create_dpo_dataset(): """Create preference dataset for DPO training. - The dataset should contain 'chosen' and 'rejected' columns after preprocessing. - Each sample will be duplicated: first the chosen, then the rejected version. + The dataset will contain interleaved chosen/rejected pairs after preprocessing: + [chosen_1, rejected_1, chosen_2, rejected_2, ...] + + The collate function will reorder to: [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] """ dataset = Dataset(DatasetMeta(DATASET_ID, split='train')) dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH) - # Use DPOProcessor to convert dataset to standard format - # Adjust processor based on your dataset format + # Use DPOProcessor with interleaved output format + # This creates alternating chosen/rejected pairs that can be properly encoded dataset.map(DPOProcessor( system='You are a helpful, harmless, and honest assistant.', chosen_key='chosen', rejected_key='rejected', prompt_key='prompt', + output_format='interleaved', # Output: [chosen_1, rejected_1, chosen_2, ...] )) - # Encode both chosen and rejected trajectories + # Encode the interleaved trajectories dataset.encode() return dataset -def collate_preference_batch(batch: List[Dict[str, Any]]) -> Dict[str, List]: - """Collate preference pairs into DPO batch format. +def collate_preference_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Collate interleaved preference pairs into DPO batch format. + + Input: [chosen_1, rejected_1, chosen_2, rejected_2, ...] (interleaved) + Output: [chosen_1, chosen_2, ..., rejected_1, rejected_2, ...] (grouped) - DPO loss expects: [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] + DPO loss expects: first half chosen, second half rejected. """ + if not batch: + return batch + + # Extract alternating chosen/rejected chosen_samples = [] rejected_samples = [] - for item in batch: - if 'chosen' in item and 'rejected' in item: - chosen_samples.append(item['chosen']) - rejected_samples.append(item['rejected']) - else: - # Assume alternating format if not explicitly separated + for i, item in enumerate(batch): + if i % 2 == 0: # Even indices are chosen chosen_samples.append(item) + else: # Odd indices are rejected + rejected_samples.append(item) # Concatenate: all chosen first, then all rejected return chosen_samples + rejected_samples @@ -209,7 +217,9 @@ def main(): logger.info(f'Training without reference model (loss_type={LOSS_TYPE})') # ── DataLoader Setup ────────────────────────────────────────────────────── - GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS + # Since dataset is interleaved (chosen, rejected, chosen, rejected, ...), + # we need batch_size * 2 samples to get BATCH_SIZE preference pairs + GLOBAL_BATCH_SIZE = BATCH_SIZE * 2 * GRADIENT_ACCUMULATION_STEPS dataloader = DataLoader( dataset=create_dpo_dataset, batch_size=GLOBAL_BATCH_SIZE, @@ -239,10 +249,12 @@ def main(): ref_logps = ref_outputs.get('logps') # Forward-backward pass with DPO loss + # micro_batch_size must be even to maintain chosen/rejected pairing + actual_micro_batch = MICRO_BATCH_SIZE * 2 # Convert pairs to samples policy_model.forward_backward( inputs=preference_batch, ref_logps=ref_logps, - micro_batch_size=MICRO_BATCH_SIZE, + micro_batch_size=actual_micro_batch, ) # Gradient clipping and optimizer step diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py index 0f3e364a..d8f5b207 100644 --- a/src/twinkle/loss/dpo.py +++ b/src/twinkle/loss/dpo.py @@ -8,8 +8,6 @@ """ from typing import TYPE_CHECKING, Dict, List, Optional, Union -import numpy as np - from twinkle.data_format import LossOutput from twinkle.kernel import selective_log_softmax from twinkle.loss.base import Loss @@ -640,10 +638,12 @@ def __call__( rejected_avg_logps = self._compute_avg_logps(rejected_logps, rejected_labels) # Odds ratio: log(odds_chosen / odds_rejected) - # log(p/(1-p)) ≈ log(p) - log(1-p) ≈ log(p) + p for small p - # Simplified: log_odds = avg_logp (since p is small) - log_odds_chosen = chosen_avg_logps - torch.log1p(-torch.exp(chosen_avg_logps).clamp(max=1-1e-7)) - log_odds_rejected = rejected_avg_logps - torch.log1p(-torch.exp(rejected_avg_logps).clamp(max=1-1e-7)) + # log_odds = log(p/(1-p)) = log(p) - log(1-p) + # Use numerically stable computation + prob_chosen = torch.exp(chosen_avg_logps).clamp(min=1e-7, max=1-1e-7) + prob_rejected = torch.exp(rejected_avg_logps).clamp(min=1e-7, max=1-1e-7) + log_odds_chosen = torch.log(prob_chosen) - torch.log(1 - prob_chosen) + log_odds_rejected = torch.log(prob_rejected) - torch.log(1 - prob_rejected) # ORPO odds ratio loss odds_ratio = log_odds_chosen - log_odds_rejected diff --git a/src/twinkle/preprocessor/dpo.py b/src/twinkle/preprocessor/dpo.py index c951caab..0c2ab769 100644 --- a/src/twinkle/preprocessor/dpo.py +++ b/src/twinkle/preprocessor/dpo.py @@ -23,10 +23,12 @@ class DPOProcessor(Preprocessor): 3. {'messages': List[Message], 'chosen': str, 'rejected': str} 4. {'chosen': List[Message], 'rejected': List[Message]} (full conversations) - Output: Each sample generates TWO Trajectories: - - First: chosen response trajectory - - Second: rejected response trajectory - The DPO loss expects batch to be [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] + Output: Each sample is expanded into TWO rows in the dataset: + - Row 2i: chosen response trajectory + - Row 2i+1: rejected response trajectory + + The DPO loss expects batch to be [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n], + which should be handled by a custom collate function or DataLoader. Args: system: Optional system prompt to prepend. @@ -34,6 +36,9 @@ class DPOProcessor(Preprocessor): rejected_key: Key for rejected response (default: 'rejected'). prompt_key: Key for prompt/question (default: 'prompt'). messages_key: Key for conversation messages (default: 'messages'). + output_format: How to structure the output: + - 'interleaved': [chosen_1, rejected_1, chosen_2, rejected_2, ...] + - 'paired': {'chosen': [Traj1, ...], 'rejected': [Traj1, ...]} """ def __init__( @@ -43,12 +48,14 @@ def __init__( rejected_key: str = 'rejected', prompt_key: str = 'prompt', messages_key: str = 'messages', + output_format: str = 'interleaved', ): self.system = system self.chosen_key = chosen_key self.rejected_key = rejected_key self.prompt_key = prompt_key self.messages_key = messages_key + self.output_format = output_format def _parse_response(self, response: Union[str, List[Dict], List[Message]]) -> List[Message]: """Parse response into list of Messages.""" @@ -125,13 +132,28 @@ def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: """Process batched data into paired trajectories. - Note: Output maintains separate 'chosen' and 'rejected' columns. - The DataLoader/collator should handle pairing them appropriately - for the DPO loss (concatenating chosen batch + rejected batch). + Output format depends on self.output_format: + - 'interleaved': Returns standard Trajectory column with alternating + chosen/rejected pairs for proper encoding. The DataLoader should use + a custom collate function to reorder into [chosen..., rejected...]. + - 'paired': Returns separate 'chosen' and 'rejected' columns (requires + special handling in DataLoader). """ rows = self.map_col_to_row(rows) processed = [self.preprocess(row) for row in rows] - return self.map_row_to_col(processed) + + if self.output_format == 'interleaved': + # Flatten to interleaved format: [chosen_1, rejected_1, chosen_2, rejected_2, ...] + # This allows standard Dataset.encode() to work + trajectories = [] + for pair in processed: + trajectories.append(pair['chosen']) + trajectories.append(pair['rejected']) + # Return as standard trajectory column + return {'messages': trajectories} + else: + # Paired format: separate columns + return self.map_row_to_col(processed) class HHRLHFProcessor(Preprocessor): From f5d5961503284f85c7531dabf899a939dac28beb Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Thu, 26 Mar 2026 23:09:43 +0800 Subject: [PATCH 03/31] fix --- src/twinkle/loss/dpo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py index d8f5b207..3cd3d631 100644 --- a/src/twinkle/loss/dpo.py +++ b/src/twinkle/loss/dpo.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union from twinkle.data_format import LossOutput -from twinkle.kernel import selective_log_softmax +from twinkle.utils.torch_utils import selective_log_softmax from twinkle.loss.base import Loss if TYPE_CHECKING: From 3cf03cd3d0f35bd8a7ec9a084b8921ba3e282965 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Thu, 26 Mar 2026 23:31:33 +0800 Subject: [PATCH 04/31] fix --- cookbook/rl/dpo.py | 120 ++++++---- src/twinkle/loss/dpo.py | 331 +++++++++------------------ src/twinkle/preprocessor/__init__.py | 3 +- src/twinkle/preprocessor/dpo.py | 291 ++++++++++++++--------- 4 files changed, 368 insertions(+), 377 deletions(-) diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo.py index 0a776f29..6edb917d 100644 --- a/cookbook/rl/dpo.py +++ b/cookbook/rl/dpo.py @@ -5,8 +5,9 @@ Pipeline: 1. Load preference dataset with chosen/rejected pairs. - 2. Compute reference model log probabilities (frozen). - 3. Train policy model using DPO loss. + 2. Encode chosen and rejected separately. + 3. Compute reference model log probabilities (frozen). + 4. Train policy model using DPO loss. Architecture (Ray): ┌─────────────────────────────────────────────────────────────────┐ @@ -19,16 +20,20 @@ DataLoader RefModel (frozen) PolicyModel (trainable) (ref GPUs) (policy GPUs) +DPO Trajectory format: + - messages: List[Message] - chosen response + - extend_message: [('rejected_messages', List[Message])] - rejected response + For SimPO/ORPO variants that don't require a reference model, set USE_REFERENCE_MODEL=0 to skip reference model computation. Environment variables (all optional): MODEL_ID – (default: ms://Qwen/Qwen3.5-4B) - DATASET_ID – (default: ms://argilla/ultrafeedback-binarized-preferences-cleaned) + DATASET_ID – (default: ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji) MODEL_GPUS – GPUs for policy model (default: 4) REF_MODEL_GPUS – GPUs for reference model (default: 4, 0 to disable) USE_REFERENCE_MODEL – Whether to use reference model (default: 1) - BATCH_SIZE – global batch size (pairs) (default: 8) + BATCH_SIZE – global batch size (preference pairs) (default: 8) MICRO_BATCH_SIZE – per-device micro batch size (default: 2) MAX_STEPS – total optimization steps (default: 1000) LR – learning rate (default: 5e-6) @@ -46,11 +51,12 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger +from twinkle.data_format import Trajectory from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.loss import CPOLoss, DPOLoss, ORPOLoss, SimPOLoss from twinkle.model import TransformersModel -from twinkle.preprocessor import DPOProcessor +from twinkle.preprocessor import EmojiDPOProcessor from twinkle.processor import InputProcessor from twinkle.template import Template @@ -58,7 +64,7 @@ # ── Configuration ───────────────────────────────────────────────────────────── MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B') -DATASET_ID = os.environ.get('DATASET_ID', 'ms://argilla/ultrafeedback-binarized-preferences-cleaned') +DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji') MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) REF_MODEL_GPUS = int(os.environ.get('REF_MODEL_GPUS', 4)) @@ -72,7 +78,7 @@ USE_REFERENCE_MODEL = False BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs -MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) # Must be even (chosen + rejected) +MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) LEARNING_RATE = float(os.environ.get('LR', 5e-6)) @@ -88,51 +94,75 @@ def create_dpo_dataset(): """Create preference dataset for DPO training. - The dataset will contain interleaved chosen/rejected pairs after preprocessing: - [chosen_1, rejected_1, chosen_2, rejected_2, ...] + Uses shareAI/DPO-zh-en-emoji dataset: + - answer_zh: chosen response (Chinese) + - answer_en: rejected response (English) + + Output Trajectory format: + - messages: chosen response + - extend_message: [('rejected_messages', rejected response)] - The collate function will reorder to: [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] + Note: We do NOT call dataset.encode() here. Encoding is done in + prepare_dpo_batch() to properly handle both chosen and rejected. """ dataset = Dataset(DatasetMeta(DATASET_ID, split='train')) dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH) - # Use DPOProcessor with interleaved output format - # This creates alternating chosen/rejected pairs that can be properly encoded - dataset.map(DPOProcessor( - system='You are a helpful, harmless, and honest assistant.', - chosen_key='chosen', - rejected_key='rejected', + # Use EmojiDPOProcessor for shareAI/DPO-zh-en-emoji dataset + # answer_zh -> chosen (messages), answer_en -> rejected (extend_message) + dataset.map(EmojiDPOProcessor( + system='You are a helpful assistant.', + chosen_key='answer_zh', + rejected_key='answer_en', prompt_key='prompt', - output_format='interleaved', # Output: [chosen_1, rejected_1, chosen_2, ...] )) - # Encode the interleaved trajectories - dataset.encode() + # Do NOT encode here - encoding is done in prepare_dpo_batch + # to preserve extend_message for rejected encoding return dataset -def collate_preference_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Collate interleaved preference pairs into DPO batch format. +def prepare_dpo_batch( + batch: List[Dict[str, Any]], + template: Template, +) -> List[Dict[str, Any]]: + """Prepare DPO batch: encode both chosen and rejected. - Input: [chosen_1, rejected_1, chosen_2, rejected_2, ...] (interleaved) - Output: [chosen_1, chosen_2, ..., rejected_1, rejected_2, ...] (grouped) + Args: + batch: List of raw Trajectory dicts with messages (chosen) and + extend_message containing ('rejected_messages', rejected) - DPO loss expects: first half chosen, second half rejected. + Returns: + List organized as [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] + where each item is an encoded InputFeature dict. """ - if not batch: - return batch - - # Extract alternating chosen/rejected chosen_samples = [] rejected_samples = [] - for i, item in enumerate(batch): - if i % 2 == 0: # Even indices are chosen - chosen_samples.append(item) - else: # Odd indices are rejected - rejected_samples.append(item) - - # Concatenate: all chosen first, then all rejected + for item in batch: + # Get messages (chosen) and encode + messages = item.get('messages', []) + chosen_trajectory = Trajectory(messages=messages) + chosen_encoded = template.encode(chosen_trajectory) + chosen_samples.append(dict(chosen_encoded)) + + # Get rejected from extend_message and encode + extend_message = item.get('extend_message', []) + rejected_messages = None + for key, msgs in extend_message: + if key == 'rejected_messages': + rejected_messages = msgs + break + + if rejected_messages: + rejected_trajectory = Trajectory(messages=rejected_messages) + rejected_encoded = template.encode(rejected_trajectory) + rejected_samples.append(dict(rejected_encoded)) + else: + # Fallback: use chosen (should not happen with proper preprocessing) + rejected_samples.append(dict(chosen_encoded)) + + # Return [chosen..., rejected...] return chosen_samples + rejected_samples @@ -201,6 +231,9 @@ def main(): policy_model.set_processor(InputProcessor) policy_model.set_template('Template', model_id=MODEL_ID) + # Get template for encoding rejected messages + template = Template(model_id=MODEL_ID, max_length=MAX_LENGTH) + # ── Reference Model Setup (if needed) ───────────────────────────────────── ref_model = None if USE_REFERENCE_MODEL and not reference_free: @@ -217,9 +250,7 @@ def main(): logger.info(f'Training without reference model (loss_type={LOSS_TYPE})') # ── DataLoader Setup ────────────────────────────────────────────────────── - # Since dataset is interleaved (chosen, rejected, chosen, rejected, ...), - # we need batch_size * 2 samples to get BATCH_SIZE preference pairs - GLOBAL_BATCH_SIZE = BATCH_SIZE * 2 * GRADIENT_ACCUMULATION_STEPS + GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS dataloader = DataLoader( dataset=create_dpo_dataset, batch_size=GLOBAL_BATCH_SIZE, @@ -238,21 +269,22 @@ def main(): if optim_step >= MAX_STEPS: break - # Collate preference pairs: [chosen..., rejected...] - preference_batch = collate_preference_batch(batch if isinstance(batch, list) else [batch]) + # Prepare DPO batch: [chosen..., rejected...] + batch_list = batch if isinstance(batch, list) else [batch] + dpo_batch = prepare_dpo_batch(batch_list, template) # Compute reference log probabilities if using reference model ref_logps = None if ref_model is not None: with torch.no_grad(): - ref_outputs = ref_model.forward_only(inputs=preference_batch) + ref_outputs = ref_model.forward_only(inputs=dpo_batch) ref_logps = ref_outputs.get('logps') # Forward-backward pass with DPO loss - # micro_batch_size must be even to maintain chosen/rejected pairing - actual_micro_batch = MICRO_BATCH_SIZE * 2 # Convert pairs to samples + # micro_batch_size should be even to maintain chosen/rejected pairing + actual_micro_batch = MICRO_BATCH_SIZE * 2 policy_model.forward_backward( - inputs=preference_batch, + inputs=dpo_batch, ref_logps=ref_logps, micro_batch_size=actual_micro_batch, ) diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py index 3cd3d631..e727193a 100644 --- a/src/twinkle/loss/dpo.py +++ b/src/twinkle/loss/dpo.py @@ -16,7 +16,84 @@ import torch -class DPOLoss(Loss): +class PreferenceLossBase(Loss): + """Base class for preference optimization losses with shared utilities.""" + + def __init__(self, ignore_index: int = -100): + self.ignore_index = ignore_index + + def _compute_logps_from_logits( + self, + logits: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute per-token log probabilities from logits. + + Args: + logits: [batch, seq_len, vocab_size] model logits + labels: [batch, seq_len] target token ids + + Returns: + logps: [batch, seq_len] per-token log probabilities + """ + loss_mask = (labels != self.ignore_index).bool() + masked_labels = labels.clone() + masked_labels[~loss_mask] = 0 + return selective_log_softmax(logits, masked_labels) + + def _compute_sequence_logps( + self, + per_token_logps: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute sequence-level log probabilities by summing valid token logps.""" + loss_mask = (labels != self.ignore_index).float() + return (per_token_logps * loss_mask).sum(dim=-1) + + def _compute_avg_logps( + self, + per_token_logps: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute length-normalized (average) log probabilities.""" + loss_mask = (labels != self.ignore_index).float() + seq_lengths = loss_mask.sum(dim=-1).clamp(min=1) + return (per_token_logps * loss_mask).sum(dim=-1) / seq_lengths + + def _compute_nll_loss( + self, + per_token_logps: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute negative log likelihood loss.""" + loss_mask = (labels != self.ignore_index).float() + return -(per_token_logps * loss_mask).sum() / loss_mask.sum().clamp(min=1) + + def _get_logps_from_outputs( + self, + outputs: Dict, + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Extract or compute log probabilities from model outputs.""" + logps = outputs.get('logps') + if logps is None: + logits = outputs.get('logits') + assert logits is not None, "outputs must contain 'logps' or 'logits'" + if logits.shape[1] != labels.shape[1]: + logits = logits[:, -labels.shape[1]:] + logps = self._compute_logps_from_logits(logits, labels) + return logps + + def _split_chosen_rejected( + self, + tensor: 'torch.Tensor', + ) -> tuple: + """Split tensor into chosen (first half) and rejected (second half).""" + half = tensor.shape[0] // 2 + return tensor[:half], tensor[half:] + + +class DPOLoss(PreferenceLossBase): """Direct Preference Optimization (DPO) Loss. DPO directly optimizes the policy using preference data without explicit reward modeling. @@ -48,49 +125,12 @@ def __init__( reference_free: bool = False, **kwargs, ): + super().__init__(ignore_index=ignore_index) self.beta = beta self.label_smoothing = label_smoothing - self.ignore_index = ignore_index self.loss_type = loss_type self.reference_free = reference_free - def _compute_logps_from_logits( - self, - logits: 'torch.Tensor', - labels: 'torch.Tensor', - ) -> 'torch.Tensor': - """Compute per-token log probabilities from logits. - - Args: - logits: [batch, seq_len, vocab_size] model logits - labels: [batch, seq_len] target token ids - - Returns: - logps: [batch, seq_len] per-token log probabilities - """ - loss_mask = (labels != self.ignore_index).bool() - masked_labels = labels.clone() - masked_labels[~loss_mask] = 0 - logps = selective_log_softmax(logits, masked_labels) - return logps - - def _compute_sequence_logps( - self, - per_token_logps: 'torch.Tensor', - labels: 'torch.Tensor', - ) -> 'torch.Tensor': - """Compute sequence-level log probabilities by summing valid token logps. - - Args: - per_token_logps: [batch, seq_len] per-token log probabilities - labels: [batch, seq_len] labels for computing mask - - Returns: - seq_logps: [batch] sequence-level log probabilities - """ - loss_mask = (labels != self.ignore_index).float() - return (per_token_logps * loss_mask).sum(dim=-1) - def _pad_and_align_logps( self, logps: Union['torch.Tensor', List[List[float]]], @@ -240,25 +280,15 @@ def __call__( batch_size = labels.shape[0] assert batch_size % 2 == 0, "Batch size must be even (chosen + rejected pairs)" - half_batch = batch_size // 2 # Get log probabilities from outputs - logps = outputs.get('logps') - if logps is None: - logits = outputs.get('logits') - assert logits is not None, "outputs must contain 'logps' or 'logits'" - if logits.shape[1] != labels.shape[1]: - logits = logits[:, -labels.shape[1]:] - logps = self._compute_logps_from_logits(logits, labels) - + logps = self._get_logps_from_outputs(outputs, labels) device = logps.device dtype = logps.dtype # Split into chosen and rejected - chosen_labels = labels[:half_batch] - rejected_labels = labels[half_batch:] - chosen_logps = logps[:half_batch] - rejected_logps = logps[half_batch:] + chosen_labels, rejected_labels = self._split_chosen_rejected(labels) + chosen_logps, rejected_logps = self._split_chosen_rejected(logps) # Compute sequence-level log probs for policy policy_chosen_logps = self._compute_sequence_logps(chosen_logps, chosen_labels) @@ -275,8 +305,7 @@ def __call__( ref_logps_aligned = self._pad_and_align_logps( ref_logps, labels.shape, loss_mask, device, dtype ) - ref_chosen = ref_logps_aligned[:half_batch] - ref_rejected = ref_logps_aligned[half_batch:] + ref_chosen, ref_rejected = self._split_chosen_rejected(ref_logps_aligned) reference_chosen_logps = self._compute_sequence_logps(ref_chosen, chosen_labels) reference_rejected_logps = self._compute_sequence_logps(ref_rejected, rejected_labels) elif self.reference_free: @@ -300,7 +329,7 @@ def __call__( return LossOutput(loss=loss, num_tokens=0) -class SimPOLoss(Loss): +class SimPOLoss(PreferenceLossBase): """SimPO (Simple Preference Optimization) Loss. SimPO is a simpler variant of DPO that doesn't require a reference model. @@ -323,40 +352,9 @@ def __init__( ignore_index: int = -100, **kwargs, ): + super().__init__(ignore_index=ignore_index) self.beta = beta self.gamma = gamma - self.ignore_index = ignore_index - - def _compute_logps_from_logits( - self, - logits: 'torch.Tensor', - labels: 'torch.Tensor', - ) -> 'torch.Tensor': - """Compute per-token log probabilities from logits.""" - loss_mask = (labels != self.ignore_index).bool() - masked_labels = labels.clone() - masked_labels[~loss_mask] = 0 - logps = selective_log_softmax(logits, masked_labels) - return logps - - def _compute_length_normalized_logps( - self, - per_token_logps: 'torch.Tensor', - labels: 'torch.Tensor', - ) -> 'torch.Tensor': - """Compute length-normalized sequence log probabilities. - - Args: - per_token_logps: [batch, seq_len] per-token log probabilities - labels: [batch, seq_len] labels for computing mask - - Returns: - normalized_logps: [batch] length-normalized log probabilities - """ - loss_mask = (labels != self.ignore_index).float() - seq_lengths = loss_mask.sum(dim=-1).clamp(min=1) - seq_logps = (per_token_logps * loss_mask).sum(dim=-1) - return seq_logps / seq_lengths def __call__( self, @@ -364,16 +362,7 @@ def __call__( outputs: Dict, **kwargs, ) -> LossOutput: - """Compute SimPO loss. - - Args: - inputs: Dict containing 'input_ids' and 'labels' [batch, seq_len]. - Batch: [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] - outputs: Dict containing 'logps' or 'logits'. - - Returns: - LossOutput with SimPO loss. - """ + """Compute SimPO loss.""" import torch import torch.nn.functional as F @@ -384,28 +373,18 @@ def __call__( if labels.dim() == 1: labels = labels.unsqueeze(0) - batch_size = labels.shape[0] - assert batch_size % 2 == 0, "Batch size must be even (chosen + rejected pairs)" - half_batch = batch_size // 2 + assert labels.shape[0] % 2 == 0, "Batch size must be even (chosen + rejected pairs)" # Get log probabilities - logps = outputs.get('logps') - if logps is None: - logits = outputs.get('logits') - assert logits is not None, "outputs must contain 'logps' or 'logits'" - if logits.shape[1] != labels.shape[1]: - logits = logits[:, -labels.shape[1]:] - logps = self._compute_logps_from_logits(logits, labels) + logps = self._get_logps_from_outputs(outputs, labels) # Split into chosen and rejected - chosen_labels = labels[:half_batch] - rejected_labels = labels[half_batch:] - chosen_logps = logps[:half_batch] - rejected_logps = logps[half_batch:] + chosen_labels, rejected_labels = self._split_chosen_rejected(labels) + chosen_logps, rejected_logps = self._split_chosen_rejected(logps) # Compute length-normalized log probs - chosen_rewards = self._compute_length_normalized_logps(chosen_logps, chosen_labels) - rejected_rewards = self._compute_length_normalized_logps(rejected_logps, rejected_labels) + chosen_rewards = self._compute_avg_logps(chosen_logps, chosen_labels) + rejected_rewards = self._compute_avg_logps(rejected_logps, rejected_labels) # SimPO loss: -log(sigmoid(beta * (r_w - r_l) - gamma)) logits = self.beta * (chosen_rewards - rejected_rewards) - self.gamma @@ -414,7 +393,7 @@ def __call__( return LossOutput(loss=loss, num_tokens=0) -class CPOLoss(Loss): +class CPOLoss(PreferenceLossBase): """CPO (Contrastive Preference Optimization) Loss. CPO adds a behavior cloning term to preference optimization. @@ -436,40 +415,9 @@ def __init__( ignore_index: int = -100, **kwargs, ): + super().__init__(ignore_index=ignore_index) self.beta = beta self.bc_coef = bc_coef - self.ignore_index = ignore_index - - def _compute_logps_from_logits( - self, - logits: 'torch.Tensor', - labels: 'torch.Tensor', - ) -> 'torch.Tensor': - """Compute per-token log probabilities from logits.""" - loss_mask = (labels != self.ignore_index).bool() - masked_labels = labels.clone() - masked_labels[~loss_mask] = 0 - logps = selective_log_softmax(logits, masked_labels) - return logps - - def _compute_sequence_logps( - self, - per_token_logps: 'torch.Tensor', - labels: 'torch.Tensor', - ) -> 'torch.Tensor': - """Compute sequence-level log probabilities.""" - loss_mask = (labels != self.ignore_index).float() - return (per_token_logps * loss_mask).sum(dim=-1) - - def _compute_nll_loss( - self, - per_token_logps: 'torch.Tensor', - labels: 'torch.Tensor', - ) -> 'torch.Tensor': - """Compute negative log likelihood loss for chosen responses.""" - loss_mask = (labels != self.ignore_index).float() - nll = -(per_token_logps * loss_mask).sum() / loss_mask.sum().clamp(min=1) - return nll def __call__( self, @@ -477,15 +425,7 @@ def __call__( outputs: Dict, **kwargs, ) -> LossOutput: - """Compute CPO loss. - - Args: - inputs: Dict containing 'labels' [batch, seq_len]. - outputs: Dict containing 'logps' or 'logits'. - - Returns: - LossOutput with CPO loss. - """ + """Compute CPO loss.""" import torch import torch.nn.functional as F @@ -496,24 +436,14 @@ def __call__( if labels.dim() == 1: labels = labels.unsqueeze(0) - batch_size = labels.shape[0] - assert batch_size % 2 == 0, "Batch size must be even" - half_batch = batch_size // 2 + assert labels.shape[0] % 2 == 0, "Batch size must be even" # Get log probabilities - logps = outputs.get('logps') - if logps is None: - logits = outputs.get('logits') - assert logits is not None, "outputs must contain 'logps' or 'logits'" - if logits.shape[1] != labels.shape[1]: - logits = logits[:, -labels.shape[1]:] - logps = self._compute_logps_from_logits(logits, labels) + logps = self._get_logps_from_outputs(outputs, labels) # Split into chosen and rejected - chosen_labels = labels[:half_batch] - rejected_labels = labels[half_batch:] - chosen_logps = logps[:half_batch] - rejected_logps = logps[half_batch:] + chosen_labels, rejected_labels = self._split_chosen_rejected(labels) + chosen_logps, rejected_logps = self._split_chosen_rejected(logps) # Compute sequence-level log probs chosen_seq_logps = self._compute_sequence_logps(chosen_logps, chosen_labels) @@ -532,7 +462,7 @@ def __call__( return LossOutput(loss=loss, num_tokens=0) -class ORPOLoss(Loss): +class ORPOLoss(PreferenceLossBase): """ORPO (Odds Ratio Preference Optimization) Loss. ORPO combines SFT and preference alignment in a single objective using odds ratios. @@ -552,39 +482,8 @@ def __init__( ignore_index: int = -100, **kwargs, ): + super().__init__(ignore_index=ignore_index) self.lambda_orpo = lambda_orpo - self.ignore_index = ignore_index - - def _compute_logps_from_logits( - self, - logits: 'torch.Tensor', - labels: 'torch.Tensor', - ) -> 'torch.Tensor': - """Compute per-token log probabilities from logits.""" - loss_mask = (labels != self.ignore_index).bool() - masked_labels = labels.clone() - masked_labels[~loss_mask] = 0 - logps = selective_log_softmax(logits, masked_labels) - return logps - - def _compute_avg_logps( - self, - per_token_logps: 'torch.Tensor', - labels: 'torch.Tensor', - ) -> 'torch.Tensor': - """Compute average log probabilities over valid tokens.""" - loss_mask = (labels != self.ignore_index).float() - seq_lengths = loss_mask.sum(dim=-1).clamp(min=1) - return (per_token_logps * loss_mask).sum(dim=-1) / seq_lengths - - def _compute_nll_loss( - self, - per_token_logps: 'torch.Tensor', - labels: 'torch.Tensor', - ) -> 'torch.Tensor': - """Compute negative log likelihood loss.""" - loss_mask = (labels != self.ignore_index).float() - return -(per_token_logps * loss_mask).sum() / loss_mask.sum().clamp(min=1) def __call__( self, @@ -592,15 +491,7 @@ def __call__( outputs: Dict, **kwargs, ) -> LossOutput: - """Compute ORPO loss. - - Args: - inputs: Dict containing 'labels' [batch, seq_len]. - outputs: Dict containing 'logps' or 'logits'. - - Returns: - LossOutput with ORPO loss. - """ + """Compute ORPO loss.""" import torch import torch.nn.functional as F @@ -611,24 +502,14 @@ def __call__( if labels.dim() == 1: labels = labels.unsqueeze(0) - batch_size = labels.shape[0] - assert batch_size % 2 == 0, "Batch size must be even" - half_batch = batch_size // 2 + assert labels.shape[0] % 2 == 0, "Batch size must be even" # Get log probabilities - logps = outputs.get('logps') - if logps is None: - logits = outputs.get('logits') - assert logits is not None, "outputs must contain 'logps' or 'logits'" - if logits.shape[1] != labels.shape[1]: - logits = logits[:, -labels.shape[1]:] - logps = self._compute_logps_from_logits(logits, labels) + logps = self._get_logps_from_outputs(outputs, labels) # Split into chosen and rejected - chosen_labels = labels[:half_batch] - rejected_labels = labels[half_batch:] - chosen_logps = logps[:half_batch] - rejected_logps = logps[half_batch:] + chosen_labels, rejected_labels = self._split_chosen_rejected(labels) + chosen_logps, rejected_logps = self._split_chosen_rejected(logps) # SFT loss on chosen sft_loss = self._compute_nll_loss(chosen_logps, chosen_labels) diff --git a/src/twinkle/preprocessor/__init__.py b/src/twinkle/preprocessor/__init__.py index 18dae667..6d9f6dd7 100644 --- a/src/twinkle/preprocessor/__init__.py +++ b/src/twinkle/preprocessor/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .base import DataFilter, Preprocessor -from .dpo import DPOProcessor, HHRLHFProcessor, IntelOrcaDPOProcessor, ShareGPTDPOProcessor, UltraFeedbackProcessor +from .dpo import (DPOProcessor, EmojiDPOProcessor, HHRLHFProcessor, IntelOrcaDPOProcessor, ShareGPTDPOProcessor, + UltraFeedbackKTOProcessor, UltraFeedbackProcessor) from .llm import (AlpacaProcessor, CompetitionMathGRPOProcessor, CompetitionMathProcessor, CountdownProcessor, GSM8KProcessor, SelfCognitionProcessor) diff --git a/src/twinkle/preprocessor/dpo.py b/src/twinkle/preprocessor/dpo.py index 0c2ab769..a81b5f57 100644 --- a/src/twinkle/preprocessor/dpo.py +++ b/src/twinkle/preprocessor/dpo.py @@ -4,6 +4,10 @@ These preprocessors convert various preference dataset formats into the standard Trajectory format required by Twinkle for DPO training. + +DPO Trajectory format: + - messages: List[Message] - chosen response messages + - extend_message: [('rejected_messages', List[Message])] - rejected response messages """ from typing import Any, Dict, List, Optional, Union @@ -23,12 +27,9 @@ class DPOProcessor(Preprocessor): 3. {'messages': List[Message], 'chosen': str, 'rejected': str} 4. {'chosen': List[Message], 'rejected': List[Message]} (full conversations) - Output: Each sample is expanded into TWO rows in the dataset: - - Row 2i: chosen response trajectory - - Row 2i+1: rejected response trajectory - - The DPO loss expects batch to be [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n], - which should be handled by a custom collate function or DataLoader. + Output Trajectory format: + - messages: chosen response (prompt + chosen assistant message) + - extend_message: [('rejected_messages', rejected_messages)] Args: system: Optional system prompt to prepend. @@ -36,9 +37,6 @@ class DPOProcessor(Preprocessor): rejected_key: Key for rejected response (default: 'rejected'). prompt_key: Key for prompt/question (default: 'prompt'). messages_key: Key for conversation messages (default: 'messages'). - output_format: How to structure the output: - - 'interleaved': [chosen_1, rejected_1, chosen_2, rejected_2, ...] - - 'paired': {'chosen': [Traj1, ...], 'rejected': [Traj1, ...]} """ def __init__( @@ -48,14 +46,12 @@ def __init__( rejected_key: str = 'rejected', prompt_key: str = 'prompt', messages_key: str = 'messages', - output_format: str = 'interleaved', ): self.system = system self.chosen_key = chosen_key self.rejected_key = rejected_key self.prompt_key = prompt_key self.messages_key = messages_key - self.output_format = output_format def _parse_response(self, response: Union[str, List[Dict], List[Message]]) -> List[Message]: """Parse response into list of Messages.""" @@ -103,57 +99,38 @@ def _build_prompt_messages(self, row: Dict[str, Any]) -> List[Message]: return messages - def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: - """Process a single row into chosen and rejected trajectories. + def preprocess(self, row: Dict[str, Any]) -> Trajectory: + """Process a single row into a DPO Trajectory. Returns: - Dict with 'chosen' and 'rejected' Trajectory objects. + Trajectory with chosen in messages and rejected in extend_message. """ # Build prompt messages prompt_messages = self._build_prompt_messages(row) # Get chosen response chosen_raw = row.get(self.chosen_key, '') - chosen_messages = self._parse_response(chosen_raw) + chosen_response = self._parse_response(chosen_raw) # Get rejected response rejected_raw = row.get(self.rejected_key, '') - rejected_messages = self._parse_response(rejected_raw) + rejected_response = self._parse_response(rejected_raw) - # Build full trajectories - chosen_trajectory = Trajectory(messages=prompt_messages + chosen_messages) - rejected_trajectory = Trajectory(messages=prompt_messages + rejected_messages) + # Build full message lists + chosen_messages = prompt_messages + chosen_response + rejected_messages = prompt_messages + rejected_response - return { - 'chosen': chosen_trajectory, - 'rejected': rejected_trajectory, - } + # Return Trajectory with rejected in extend_message + return Trajectory( + messages=chosen_messages, + extend_message=[('rejected_messages', rejected_messages)] + ) def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - """Process batched data into paired trajectories. - - Output format depends on self.output_format: - - 'interleaved': Returns standard Trajectory column with alternating - chosen/rejected pairs for proper encoding. The DataLoader should use - a custom collate function to reorder into [chosen..., rejected...]. - - 'paired': Returns separate 'chosen' and 'rejected' columns (requires - special handling in DataLoader). - """ + """Process batched data into DPO trajectories.""" rows = self.map_col_to_row(rows) - processed = [self.preprocess(row) for row in rows] - - if self.output_format == 'interleaved': - # Flatten to interleaved format: [chosen_1, rejected_1, chosen_2, rejected_2, ...] - # This allows standard Dataset.encode() to work - trajectories = [] - for pair in processed: - trajectories.append(pair['chosen']) - trajectories.append(pair['rejected']) - # Return as standard trajectory column - return {'messages': trajectories} - else: - # Paired format: separate columns - return self.map_row_to_col(processed) + trajectories = [self.preprocess(row) for row in rows] + return {'messages': trajectories} class HHRLHFProcessor(Preprocessor): @@ -180,9 +157,8 @@ def _parse_hh_conversation(self, text: str) -> List[Message]: for i, part in enumerate(parts): if i == 0 and not part.startswith('Human: '): if part.strip(): - # Initial text before first Human marker if part.startswith('Human: '): - part = part[7:] # Remove "Human: " prefix + part = part[7:] messages.append(Message(role='user', content=part.strip())) continue @@ -196,7 +172,7 @@ def _parse_hh_conversation(self, text: str) -> List[Message]: return messages - def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: + def preprocess(self, row: Dict[str, Any]) -> Trajectory: """Process HH-RLHF format row.""" chosen_text = row.get('chosen', '') rejected_text = row.get('rejected', '') @@ -204,15 +180,15 @@ def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: chosen_messages = self._parse_hh_conversation(chosen_text) rejected_messages = self._parse_hh_conversation(rejected_text) - return { - 'chosen': Trajectory(messages=chosen_messages), - 'rejected': Trajectory(messages=rejected_messages), - } + return Trajectory( + messages=chosen_messages, + extend_message=[('rejected_messages', rejected_messages)] + ) def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - processed = [self.preprocess(row) for row in rows] - return self.map_row_to_col(processed) + trajectories = [self.preprocess(row) for row in rows] + return {'messages': trajectories} class UltraFeedbackProcessor(Preprocessor): @@ -244,13 +220,13 @@ def __init__( self.response_key = response_key self.score_key = score_key - def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Trajectory]]: + def preprocess(self, row: Dict[str, Any]) -> Optional[Trajectory]: """Process UltraFeedback format row.""" instruction = row.get(self.instruction_key, '') completions = row.get(self.completions_key, []) if len(completions) < 2: - return None # Need at least 2 completions for preference + return None # Sort by score scored_completions = [ @@ -267,29 +243,29 @@ def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Trajectory]]: rejected_response = scored_completions[-1][1] # Build messages - messages = [] + prompt_messages = [] if self.system: - messages.append(Message(role='system', content=self.system)) - messages.append(Message(role='user', content=instruction)) + prompt_messages.append(Message(role='system', content=self.system)) + prompt_messages.append(Message(role='user', content=instruction)) - chosen_trajectory = Trajectory( - messages=messages + [Message(role='assistant', content=chosen_response)] - ) - rejected_trajectory = Trajectory( - messages=messages + [Message(role='assistant', content=rejected_response)] - ) + chosen_messages = prompt_messages + [Message(role='assistant', content=chosen_response)] + rejected_messages = prompt_messages + [Message(role='assistant', content=rejected_response)] - return { - 'chosen': chosen_trajectory, - 'rejected': rejected_trajectory, - } + return Trajectory( + messages=chosen_messages, + extend_message=[('rejected_messages', rejected_messages)] + ) def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - processed = [self.preprocess(row) for row in rows if self.preprocess(row) is not None] - if not processed: - return {'chosen': [], 'rejected': []} - return self.map_row_to_col(processed) + trajectories = [] + for row in rows: + result = self.preprocess(row) + if result is not None: + trajectories.append(result) + if not trajectories: + return {'messages': []} + return {'messages': trajectories} class ShareGPTDPOProcessor(Preprocessor): @@ -324,51 +300,48 @@ def _parse_sharegpt_message(self, msg: Dict) -> Message: content = msg.get('value', '') or msg.get('content', '') return Message(role=role, content=content) - def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: + def preprocess(self, row: Dict[str, Any]) -> Trajectory: """Process ShareGPT DPO format row.""" conversations = row.get('conversations', []) - # Build prompt messages (excluding last assistant turn if present) - messages = [] + # Build prompt messages + prompt_messages = [] if self.system: - messages.append(Message(role='system', content=self.system)) + prompt_messages.append(Message(role='system', content=self.system)) for msg in conversations: - messages.append(self._parse_sharegpt_message(msg)) + prompt_messages.append(self._parse_sharegpt_message(msg)) - # Remove last message if it's assistant (will be replaced by chosen/rejected) - if messages and messages[-1].role == 'assistant': - messages = messages[:-1] + # Remove last message if it's assistant (will be replaced) + if prompt_messages and prompt_messages[-1]['role'] == 'assistant': + prompt_messages = prompt_messages[:-1] # Get chosen and rejected chosen_msg = row.get('chosen', {}) rejected_msg = row.get('rejected', {}) if isinstance(chosen_msg, dict): - chosen_response = Message( - role='assistant', - content=chosen_msg.get('value', '') or chosen_msg.get('content', '') - ) + chosen_content = chosen_msg.get('value', '') or chosen_msg.get('content', '') else: - chosen_response = Message(role='assistant', content=str(chosen_msg)) + chosen_content = str(chosen_msg) if isinstance(rejected_msg, dict): - rejected_response = Message( - role='assistant', - content=rejected_msg.get('value', '') or rejected_msg.get('content', '') - ) + rejected_content = rejected_msg.get('value', '') or rejected_msg.get('content', '') else: - rejected_response = Message(role='assistant', content=str(rejected_msg)) + rejected_content = str(rejected_msg) - return { - 'chosen': Trajectory(messages=messages + [chosen_response]), - 'rejected': Trajectory(messages=messages + [rejected_response]), - } + chosen_messages = prompt_messages + [Message(role='assistant', content=chosen_content)] + rejected_messages = prompt_messages + [Message(role='assistant', content=rejected_content)] + + return Trajectory( + messages=chosen_messages, + extend_message=[('rejected_messages', rejected_messages)] + ) def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - processed = [self.preprocess(row) for row in rows] - return self.map_row_to_col(processed) + trajectories = [self.preprocess(row) for row in rows] + return {'messages': trajectories} class IntelOrcaDPOProcessor(Preprocessor): @@ -386,24 +359,128 @@ class IntelOrcaDPOProcessor(Preprocessor): def __init__(self, default_system: Optional[str] = None): self.default_system = default_system - def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: + def preprocess(self, row: Dict[str, Any]) -> Trajectory: """Process Intel ORCA DPO format row.""" system = row.get('system', self.default_system) question = row.get('question', '') chosen = row.get('chosen', '') rejected = row.get('rejected', '') - messages = [] + prompt_messages = [] if system: - messages.append(Message(role='system', content=system)) - messages.append(Message(role='user', content=question)) + prompt_messages.append(Message(role='system', content=system)) + prompt_messages.append(Message(role='user', content=question)) + + chosen_messages = prompt_messages + [Message(role='assistant', content=chosen)] + rejected_messages = prompt_messages + [Message(role='assistant', content=rejected)] + + return Trajectory( + messages=chosen_messages, + extend_message=[('rejected_messages', rejected_messages)] + ) + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + trajectories = [self.preprocess(row) for row in rows] + return {'messages': trajectories} - return { - 'chosen': Trajectory(messages=messages + [Message(role='assistant', content=chosen)]), - 'rejected': Trajectory(messages=messages + [Message(role='assistant', content=rejected)]), + +class EmojiDPOProcessor(Preprocessor): + """Preprocessor for shareAI/DPO-zh-en-emoji dataset format. + + Dataset format: + { + 'prompt': str, + 'answer_zh': str, # chosen response (Chinese) + 'answer_en': str, # rejected response (English) } + Output Trajectory format: + - messages: prompt + chosen (answer_zh) + - extend_message: [('rejected_messages', prompt + rejected (answer_en))] + + Args: + system: Optional system prompt. + chosen_key: Key for chosen response (default: 'answer_zh'). + rejected_key: Key for rejected response (default: 'answer_en'). + prompt_key: Key for prompt (default: 'prompt'). + """ + + def __init__( + self, + system: Optional[str] = None, + chosen_key: str = 'answer_zh', + rejected_key: str = 'answer_en', + prompt_key: str = 'prompt', + ): + self.system = system + self.chosen_key = chosen_key + self.rejected_key = rejected_key + self.prompt_key = prompt_key + + def preprocess(self, row: Dict[str, Any]) -> Trajectory: + """Process a single row.""" + prompt = row.get(self.prompt_key, '') + chosen = row.get(self.chosen_key, '') + rejected = row.get(self.rejected_key, '') + + prompt_messages = [] + if self.system: + prompt_messages.append(Message(role='system', content=self.system)) + prompt_messages.append(Message(role='user', content=prompt)) + + chosen_messages = prompt_messages + [Message(role='assistant', content=chosen)] + rejected_messages = prompt_messages + [Message(role='assistant', content=rejected)] + + return Trajectory( + messages=chosen_messages, + extend_message=[('rejected_messages', rejected_messages)] + ) + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + trajectories = [self.preprocess(row) for row in rows] + return {'messages': trajectories} + + +class UltraFeedbackKTOProcessor(Preprocessor): + """Preprocessor for ultrafeedback-binarized-preferences-cleaned-kto dataset. + + Dataset format: + { + 'prompt': str, + 'completion': str, + 'label': bool, # True for chosen, False for rejected + } + + For KTO training, we need (prompt, completion, label) format. + The label is stored in user_data. + + Args: + system: Optional system prompt. + """ + + def __init__(self, system: Optional[str] = None): + self.system = system + + def preprocess(self, row: Dict[str, Any]) -> Trajectory: + """Process a single row for KTO.""" + prompt = row.get('prompt', '') + completion = row.get('completion', '') + label = row.get('label', True) + + messages = [] + if self.system: + messages.append(Message(role='system', content=self.system)) + messages.append(Message(role='user', content=prompt)) + messages.append(Message(role='assistant', content=completion)) + + return Trajectory( + messages=messages, + user_data=[('kto_label', label)] + ) + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - processed = [self.preprocess(row) for row in rows] - return self.map_row_to_col(processed) + trajectories = [self.preprocess(row) for row in rows] + return {'messages': trajectories} From bcdad646f7e6fdcb6a8be8d5fe46b00a8504b93b Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Thu, 26 Mar 2026 23:38:27 +0800 Subject: [PATCH 05/31] fix --- cookbook/rl/dpo.py | 79 ++++++++++++++++++++++------------------------ 1 file changed, 38 insertions(+), 41 deletions(-) diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo.py index 6edb917d..1e3d1cd2 100644 --- a/cookbook/rl/dpo.py +++ b/cookbook/rl/dpo.py @@ -41,6 +41,12 @@ LOSS_TYPE – DPO variant (sigmoid/hinge/ipo/simpo/orpo/cpo) (default: sigmoid) SAVE_STEPS – checkpoint save interval (default: 100) MAX_LENGTH – max sequence length (default: 2048) + + Dataset field mapping (for custom datasets): + PROMPT_KEY – key for prompt field (default: 'prompt') + CHOSEN_KEY – key for chosen response (default: 'answer_zh') + REJECTED_KEY – key for rejected response (default: 'answer_en') + SYSTEM_PROMPT – system prompt to prepend (default: 'You are a helpful assistant.') """ import os @@ -51,12 +57,11 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger -from twinkle.data_format import Trajectory +from twinkle.data_format import Message, Trajectory from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.loss import CPOLoss, DPOLoss, ORPOLoss, SimPOLoss from twinkle.model import TransformersModel -from twinkle.preprocessor import EmojiDPOProcessor from twinkle.processor import InputProcessor from twinkle.template import Template @@ -91,6 +96,13 @@ # ── Dataset ─────────────────────────────────────────────────────────────────── +# Dataset field configuration for shareAI/DPO-zh-en-emoji +PROMPT_KEY = os.environ.get('PROMPT_KEY', 'prompt') +CHOSEN_KEY = os.environ.get('CHOSEN_KEY', 'answer_zh') +REJECTED_KEY = os.environ.get('REJECTED_KEY', 'answer_en') +SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', 'You are a helpful assistant.') + + def create_dpo_dataset(): """Create preference dataset for DPO training. @@ -98,27 +110,12 @@ def create_dpo_dataset(): - answer_zh: chosen response (Chinese) - answer_en: rejected response (English) - Output Trajectory format: - - messages: chosen response - - extend_message: [('rejected_messages', rejected response)] - - Note: We do NOT call dataset.encode() here. Encoding is done in - prepare_dpo_batch() to properly handle both chosen and rejected. + Returns raw dataset without preprocessing - preprocessing is done + in prepare_dpo_batch() to avoid PyArrow serialization issues. """ dataset = Dataset(DatasetMeta(DATASET_ID, split='train')) dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH) - - # Use EmojiDPOProcessor for shareAI/DPO-zh-en-emoji dataset - # answer_zh -> chosen (messages), answer_en -> rejected (extend_message) - dataset.map(EmojiDPOProcessor( - system='You are a helpful assistant.', - chosen_key='answer_zh', - rejected_key='answer_en', - prompt_key='prompt', - )) - - # Do NOT encode here - encoding is done in prepare_dpo_batch - # to preserve extend_message for rejected encoding + # Do NOT apply preprocessor here - raw data will be processed in prepare_dpo_batch return dataset @@ -126,11 +123,10 @@ def prepare_dpo_batch( batch: List[Dict[str, Any]], template: Template, ) -> List[Dict[str, Any]]: - """Prepare DPO batch: encode both chosen and rejected. + """Prepare DPO batch: build trajectories and encode both chosen and rejected. Args: - batch: List of raw Trajectory dicts with messages (chosen) and - extend_message containing ('rejected_messages', rejected) + batch: List of raw data dicts from dataset (e.g., {prompt, answer_zh, answer_en}) Returns: List organized as [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] @@ -140,27 +136,28 @@ def prepare_dpo_batch( rejected_samples = [] for item in batch: - # Get messages (chosen) and encode - messages = item.get('messages', []) - chosen_trajectory = Trajectory(messages=messages) + # Build messages from raw data + prompt = item.get(PROMPT_KEY, '') + chosen_response = item.get(CHOSEN_KEY, '') + rejected_response = item.get(REJECTED_KEY, '') + + # Build prompt messages + prompt_messages = [] + if SYSTEM_PROMPT: + prompt_messages.append(Message(role='system', content=SYSTEM_PROMPT)) + prompt_messages.append(Message(role='user', content=prompt)) + + # Build chosen trajectory and encode + chosen_messages = prompt_messages + [Message(role='assistant', content=chosen_response)] + chosen_trajectory = Trajectory(messages=chosen_messages) chosen_encoded = template.encode(chosen_trajectory) chosen_samples.append(dict(chosen_encoded)) - # Get rejected from extend_message and encode - extend_message = item.get('extend_message', []) - rejected_messages = None - for key, msgs in extend_message: - if key == 'rejected_messages': - rejected_messages = msgs - break - - if rejected_messages: - rejected_trajectory = Trajectory(messages=rejected_messages) - rejected_encoded = template.encode(rejected_trajectory) - rejected_samples.append(dict(rejected_encoded)) - else: - # Fallback: use chosen (should not happen with proper preprocessing) - rejected_samples.append(dict(chosen_encoded)) + # Build rejected trajectory and encode + rejected_messages = prompt_messages + [Message(role='assistant', content=rejected_response)] + rejected_trajectory = Trajectory(messages=rejected_messages) + rejected_encoded = template.encode(rejected_trajectory) + rejected_samples.append(dict(rejected_encoded)) # Return [chosen..., rejected...] return chosen_samples + rejected_samples From ee3602cde993cfbfd4758a9530e7619782ef4caa Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Thu, 26 Mar 2026 23:44:37 +0800 Subject: [PATCH 06/31] fix --- cookbook/rl/dpo.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo.py index 1e3d1cd2..6f6db1ac 100644 --- a/cookbook/rl/dpo.py +++ b/cookbook/rl/dpo.py @@ -132,8 +132,8 @@ def prepare_dpo_batch( List organized as [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] where each item is an encoded InputFeature dict. """ - chosen_samples = [] - rejected_samples = [] + chosen_trajectories = [] + rejected_trajectories = [] for item in batch: # Build messages from raw data @@ -147,17 +147,20 @@ def prepare_dpo_batch( prompt_messages.append(Message(role='system', content=SYSTEM_PROMPT)) prompt_messages.append(Message(role='user', content=prompt)) - # Build chosen trajectory and encode + # Build chosen and rejected trajectories chosen_messages = prompt_messages + [Message(role='assistant', content=chosen_response)] - chosen_trajectory = Trajectory(messages=chosen_messages) - chosen_encoded = template.encode(chosen_trajectory) - chosen_samples.append(dict(chosen_encoded)) - - # Build rejected trajectory and encode rejected_messages = prompt_messages + [Message(role='assistant', content=rejected_response)] - rejected_trajectory = Trajectory(messages=rejected_messages) - rejected_encoded = template.encode(rejected_trajectory) - rejected_samples.append(dict(rejected_encoded)) + + chosen_trajectories.append(Trajectory(messages=chosen_messages)) + rejected_trajectories.append(Trajectory(messages=rejected_messages)) + + # Batch encode all trajectories (properly handles multimodal preprocessing) + chosen_encoded = template.batch_encode(chosen_trajectories) + rejected_encoded = template.batch_encode(rejected_trajectories) + + # Convert to list of dicts + chosen_samples = [dict(enc) for enc in chosen_encoded] + rejected_samples = [dict(enc) for enc in rejected_encoded] # Return [chosen..., rejected...] return chosen_samples + rejected_samples From bebbe780c5407d7383fdf53b13952582e6dedc7b Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 27 Mar 2026 10:53:02 +0800 Subject: [PATCH 07/31] wip --- cookbook/rl/dpo.py | 101 ++++++++++++++++++++-------------------- src/twinkle/loss/dpo.py | 17 +++++-- 2 files changed, 65 insertions(+), 53 deletions(-) diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo.py index 6f6db1ac..5850f8c7 100644 --- a/cookbook/rl/dpo.py +++ b/cookbook/rl/dpo.py @@ -25,14 +25,13 @@ - extend_message: [('rejected_messages', List[Message])] - rejected response For SimPO/ORPO variants that don't require a reference model, -set USE_REFERENCE_MODEL=0 to skip reference model computation. +set REF_MODEL_GPUS=0 to skip reference model computation. Environment variables (all optional): MODEL_ID – (default: ms://Qwen/Qwen3.5-4B) DATASET_ID – (default: ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji) MODEL_GPUS – GPUs for policy model (default: 4) REF_MODEL_GPUS – GPUs for reference model (default: 4, 0 to disable) - USE_REFERENCE_MODEL – Whether to use reference model (default: 1) BATCH_SIZE – global batch size (preference pairs) (default: 8) MICRO_BATCH_SIZE – per-device micro batch size (default: 2) MAX_STEPS – total optimization steps (default: 1000) @@ -62,6 +61,7 @@ from twinkle.dataset import Dataset, DatasetMeta from twinkle.loss import CPOLoss, DPOLoss, ORPOLoss, SimPOLoss from twinkle.model import TransformersModel +from twinkle.preprocessor import EmojiDPOProcessor from twinkle.processor import InputProcessor from twinkle.template import Template @@ -73,14 +73,7 @@ MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) REF_MODEL_GPUS = int(os.environ.get('REF_MODEL_GPUS', 4)) -USE_REFERENCE_MODEL = bool(int(os.environ.get('USE_REFERENCE_MODEL', 1))) - -# Adjust total GPUs based on whether reference model is used -if USE_REFERENCE_MODEL and REF_MODEL_GPUS > 0: - NUM_GPUS = MODEL_GPUS + REF_MODEL_GPUS -else: - NUM_GPUS = MODEL_GPUS - USE_REFERENCE_MODEL = False +NUM_GPUS = MODEL_GPUS + REF_MODEL_GPUS BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) @@ -92,30 +85,18 @@ SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 100)) MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048)) ADAPTER_NAME = 'default' - - -# ── Dataset ─────────────────────────────────────────────────────────────────── - -# Dataset field configuration for shareAI/DPO-zh-en-emoji -PROMPT_KEY = os.environ.get('PROMPT_KEY', 'prompt') -CHOSEN_KEY = os.environ.get('CHOSEN_KEY', 'answer_zh') -REJECTED_KEY = os.environ.get('REJECTED_KEY', 'answer_en') SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', 'You are a helpful assistant.') def create_dpo_dataset(): - """Create preference dataset for DPO training. - - Uses shareAI/DPO-zh-en-emoji dataset: - - answer_zh: chosen response (Chinese) - - answer_en: rejected response (English) - - Returns raw dataset without preprocessing - preprocessing is done - in prepare_dpo_batch() to avoid PyArrow serialization issues. - """ - dataset = Dataset(DatasetMeta(DATASET_ID, split='train')) + dataset = Dataset(DatasetMeta(DATASET_ID)) dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH) - # Do NOT apply preprocessor here - raw data will be processed in prepare_dpo_batch + dataset.map( + EmojiDPOProcessor, + init_args={ + 'system': SYSTEM_PROMPT, + } + ) return dataset @@ -130,7 +111,6 @@ def prepare_dpo_batch( Returns: List organized as [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] - where each item is an encoded InputFeature dict. """ chosen_trajectories = [] rejected_trajectories = [] @@ -189,15 +169,10 @@ def create_loss(loss_type: str, beta: float, reference_free: bool = False): def main(): # Set up device groups - if USE_REFERENCE_MODEL: - device_groups = [ - DeviceGroup(name='policy', ranks=list(range(MODEL_GPUS)), device_type='GPU'), - DeviceGroup(name='reference', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU'), - ] - else: - device_groups = [ - DeviceGroup(name='policy', ranks=list(range(MODEL_GPUS)), device_type='GPU'), - ] + device_groups = [ + DeviceGroup(name='policy', ranks=list(range(MODEL_GPUS)), device_type='GPU'), + DeviceGroup(name='reference', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU'), + ] policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False) @@ -226,7 +201,7 @@ def main(): reference_free = LOSS_TYPE in ['simpo', 'orpo', 'cpo'] # Set up loss function - loss_fn = create_loss(LOSS_TYPE, DPO_BETA, reference_free=not USE_REFERENCE_MODEL) + loss_fn = create_loss(LOSS_TYPE, DPO_BETA, reference_free=False) policy_model.set_loss(loss_fn) policy_model.set_processor(InputProcessor) policy_model.set_template('Template', model_id=MODEL_ID) @@ -234,9 +209,9 @@ def main(): # Get template for encoding rejected messages template = Template(model_id=MODEL_ID, max_length=MAX_LENGTH) - # ── Reference Model Setup (if needed) ───────────────────────────────────── + # ── Reference Model Setup ───────────────────────────────────────────────── ref_model = None - if USE_REFERENCE_MODEL and not reference_free: + if not reference_free: ref_mesh = DeviceMesh.from_sizes(world_size=REF_MODEL_GPUS, dp_size=REF_MODEL_GPUS) ref_model = TransformersModel( model_id=MODEL_ID, @@ -261,8 +236,7 @@ def main(): optim_step = 0 logger.info(get_device_placement()) - logger.info(f'Starting DPO training: loss_type={LOSS_TYPE}, beta={DPO_BETA}, ' - f'use_ref_model={USE_REFERENCE_MODEL}') + logger.info(f'Starting DPO training: loss_type={LOSS_TYPE}, beta={DPO_BETA}') # ── Training Loop ───────────────────────────────────────────────────────── for batch in dataloader: @@ -274,19 +248,46 @@ def main(): dpo_batch = prepare_dpo_batch(batch_list, template) # Compute reference log probabilities if using reference model - ref_logps = None + # We compute sequence-level logps here to avoid alignment issues with micro-batching + ref_chosen_logps = None + ref_rejected_logps = None if ref_model is not None: with torch.no_grad(): ref_outputs = ref_model.forward_only(inputs=dpo_batch) - ref_logps = ref_outputs.get('logps') + ref_logps = ref_outputs.get('logps') # [batch, seq_len] + if ref_logps is not None: + # Get labels and pad to same length for stacking + label_tensors = [torch.as_tensor(s['labels']) for s in dpo_batch] + max_len = max(t.shape[0] for t in label_tensors) + # Pad labels with -100 (ignore_index) to max length + padded_labels = [] + for t in label_tensors: + if t.shape[0] < max_len: + pad_size = max_len - t.shape[0] + t = torch.cat([torch.full((pad_size,), -100, dtype=t.dtype), t]) + padded_labels.append(t) + ref_labels = torch.stack(padded_labels) + if ref_labels.device != ref_logps.device: + ref_labels = ref_labels.to(ref_logps.device) + # Align sequence lengths if needed + if ref_logps.shape[1] != ref_labels.shape[1]: + min_len = min(ref_logps.shape[1], ref_labels.shape[1]) + ref_logps = ref_logps[:, -min_len:] + ref_labels = ref_labels[:, -min_len:] + # Compute sequence-level logps (sum of valid token logps) + loss_mask = (ref_labels != -100).float() + seq_logps = (ref_logps * loss_mask).sum(dim=-1) # [batch] + + # Split into chosen and rejected + half = seq_logps.shape[0] // 2 + ref_chosen_logps = seq_logps[:half] + ref_rejected_logps = seq_logps[half:] # Forward-backward pass with DPO loss - # micro_batch_size should be even to maintain chosen/rejected pairing - actual_micro_batch = MICRO_BATCH_SIZE * 2 policy_model.forward_backward( inputs=dpo_batch, - ref_logps=ref_logps, - micro_batch_size=actual_micro_batch, + ref_chosen_logps=ref_chosen_logps, + ref_rejected_logps=ref_rejected_logps, ) # Gradient clipping and optimizer step diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py index e727193a..7f89d446 100644 --- a/src/twinkle/loss/dpo.py +++ b/src/twinkle/loss/dpo.py @@ -154,12 +154,23 @@ def _pad_and_align_logps( import torch if torch.is_tensor(logps): - if logps.shape == target_shape: - return logps.to(device=device, dtype=dtype) - elif logps.dim() == 1: + if logps.dim() == 1: logps = logps.unsqueeze(0) if logps.shape == target_shape: return logps.to(device=device, dtype=dtype) + # Handle tensor with different sequence length - align to target shape + if logps.dim() == 2 and logps.shape[0] == target_shape[0]: + batch_size, target_seq_len = target_shape + src_seq_len = logps.shape[1] + logps = logps.to(device=device, dtype=dtype) + if src_seq_len > target_seq_len: + # Truncate: take the last target_seq_len tokens (response part) + return logps[:, -target_seq_len:] + else: + # Pad: add zeros at the beginning + padded = torch.zeros(target_shape, device=device, dtype=dtype) + padded[:, -src_seq_len:] = logps + return padded # Handle ragged list input if isinstance(logps, (list, tuple)): From 3f8d1a3c4393aa0a5b666dbd6dea8dd6da3c0548 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 27 Mar 2026 13:44:22 +0800 Subject: [PATCH 08/31] wip --- cookbook/rl/dpo.py | 53 +++++------- src/twinkle/data_format/trajectory.py | 2 +- src/twinkle/preprocessor/dpo.py | 47 +++++----- src/twinkle/template/base.py | 120 ++++++++++++++++++++------ 4 files changed, 140 insertions(+), 82 deletions(-) diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo.py index 5850f8c7..e7720714 100644 --- a/cookbook/rl/dpo.py +++ b/cookbook/rl/dpo.py @@ -56,7 +56,7 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger -from twinkle.data_format import Message, Trajectory +from twinkle.data_format import Trajectory from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.loss import CPOLoss, DPOLoss, ORPOLoss, SimPOLoss @@ -101,13 +101,15 @@ def create_dpo_dataset(): def prepare_dpo_batch( - batch: List[Dict[str, Any]], + batch: List[Trajectory], template: Template, ) -> List[Dict[str, Any]]: - """Prepare DPO batch: build trajectories and encode both chosen and rejected. + """Prepare DPO batch: encode both chosen and rejected from preprocessed Trajectories. Args: - batch: List of raw data dicts from dataset (e.g., {prompt, answer_zh, answer_en}) + batch: List of Trajectory objects with: + - messages: chosen response messages + - extend_message: [('rejected_messages', rejected_messages)] Returns: List organized as [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] @@ -115,26 +117,12 @@ def prepare_dpo_batch( chosen_trajectories = [] rejected_trajectories = [] - for item in batch: - # Build messages from raw data - prompt = item.get(PROMPT_KEY, '') - chosen_response = item.get(CHOSEN_KEY, '') - rejected_response = item.get(REJECTED_KEY, '') - - # Build prompt messages - prompt_messages = [] - if SYSTEM_PROMPT: - prompt_messages.append(Message(role='system', content=SYSTEM_PROMPT)) - prompt_messages.append(Message(role='user', content=prompt)) - - # Build chosen and rejected trajectories - chosen_messages = prompt_messages + [Message(role='assistant', content=chosen_response)] - rejected_messages = prompt_messages + [Message(role='assistant', content=rejected_response)] - - chosen_trajectories.append(Trajectory(messages=chosen_messages)) + for traj in batch: + chosen_trajectories.append(Trajectory(messages=traj.messages)) + rejected_messages = [m[1] for m in traj['extend_messages'] if m[0] == 'rejected_messages'][0] rejected_trajectories.append(Trajectory(messages=rejected_messages)) - # Batch encode all trajectories (properly handles multimodal preprocessing) + # Batch encode all trajectories chosen_encoded = template.batch_encode(chosen_trajectories) rejected_encoded = template.batch_encode(rejected_trajectories) @@ -175,7 +163,16 @@ def main(): ] policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) - twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False) + twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups) + + # ── DataLoader Setup ────────────────────────────────────────────────────── + dataloader = DataLoader( + dataset=create_dpo_dataset, + batch_size=BATCH_SIZE, + min_batch_size=BATCH_SIZE, + device_mesh=policy_mesh, + ) + length = len(dataloader) # ── Policy Model Setup ──────────────────────────────────────────────────── lora_config = LoraConfig( @@ -224,16 +221,6 @@ def main(): else: logger.info(f'Training without reference model (loss_type={LOSS_TYPE})') - # ── DataLoader Setup ────────────────────────────────────────────────────── - GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS - dataloader = DataLoader( - dataset=create_dpo_dataset, - batch_size=GLOBAL_BATCH_SIZE, - min_batch_size=GLOBAL_BATCH_SIZE, - device_mesh=policy_mesh, - remote_group='policy', - ) - optim_step = 0 logger.info(get_device_placement()) logger.info(f'Starting DPO training: loss_type={LOSS_TYPE}, beta={DPO_BETA}') diff --git a/src/twinkle/data_format/trajectory.py b/src/twinkle/data_format/trajectory.py index c7742d75..d6910eaf 100644 --- a/src/twinkle/data_format/trajectory.py +++ b/src/twinkle/data_format/trajectory.py @@ -13,6 +13,6 @@ class Trajectory(TypedDict, total=False): messages: List[Message] - extend_message: List[Tuple[str, List[Message]]] + extend_message: List[List[Message]] tools: List[Tool] user_data: List[Tuple[str, Any]] diff --git a/src/twinkle/preprocessor/dpo.py b/src/twinkle/preprocessor/dpo.py index a81b5f57..7c35b676 100644 --- a/src/twinkle/preprocessor/dpo.py +++ b/src/twinkle/preprocessor/dpo.py @@ -123,14 +123,15 @@ def preprocess(self, row: Dict[str, Any]) -> Trajectory: # Return Trajectory with rejected in extend_message return Trajectory( messages=chosen_messages, - extend_message=[('rejected_messages', rejected_messages)] + extend_message=[rejected_messages] ) def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: """Process batched data into DPO trajectories.""" rows = self.map_col_to_row(rows) - trajectories = [self.preprocess(row) for row in rows] - return {'messages': trajectories} + rows = [self.preprocess(row) for row in rows] + rows = self.map_row_to_col(rows) + return rows class HHRLHFProcessor(Preprocessor): @@ -182,13 +183,14 @@ def preprocess(self, row: Dict[str, Any]) -> Trajectory: return Trajectory( messages=chosen_messages, - extend_message=[('rejected_messages', rejected_messages)] + extend_message=[rejected_messages] ) def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - trajectories = [self.preprocess(row) for row in rows] - return {'messages': trajectories} + rows = [self.preprocess(row) for row in rows] + rows = self.map_row_to_col(rows) + return rows class UltraFeedbackProcessor(Preprocessor): @@ -253,7 +255,7 @@ def preprocess(self, row: Dict[str, Any]) -> Optional[Trajectory]: return Trajectory( messages=chosen_messages, - extend_message=[('rejected_messages', rejected_messages)] + extend_message=[rejected_messages] ) def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: @@ -264,8 +266,9 @@ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: if result is not None: trajectories.append(result) if not trajectories: - return {'messages': []} - return {'messages': trajectories} + return {} + rows = self.map_row_to_col(trajectories) + return rows class ShareGPTDPOProcessor(Preprocessor): @@ -335,13 +338,14 @@ def preprocess(self, row: Dict[str, Any]) -> Trajectory: return Trajectory( messages=chosen_messages, - extend_message=[('rejected_messages', rejected_messages)] + extend_message=[rejected_messages] ) def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - trajectories = [self.preprocess(row) for row in rows] - return {'messages': trajectories} + rows = [self.preprocess(row) for row in rows] + rows = self.map_row_to_col(rows) + return rows class IntelOrcaDPOProcessor(Preprocessor): @@ -376,13 +380,14 @@ def preprocess(self, row: Dict[str, Any]) -> Trajectory: return Trajectory( messages=chosen_messages, - extend_message=[('rejected_messages', rejected_messages)] + extend_message=[rejected_messages] ) def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - trajectories = [self.preprocess(row) for row in rows] - return {'messages': trajectories} + rows = [self.preprocess(row) for row in rows] + rows = self.map_row_to_col(rows) + return rows class EmojiDPOProcessor(Preprocessor): @@ -434,13 +439,14 @@ def preprocess(self, row: Dict[str, Any]) -> Trajectory: return Trajectory( messages=chosen_messages, - extend_message=[('rejected_messages', rejected_messages)] + extend_message=[rejected_messages] ) def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - trajectories = [self.preprocess(row) for row in rows] - return {'messages': trajectories} + rows = [self.preprocess(row) for row in rows] + rows = self.map_row_to_col(rows) + return rows class UltraFeedbackKTOProcessor(Preprocessor): @@ -482,5 +488,6 @@ def preprocess(self, row: Dict[str, Any]) -> Trajectory: def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - trajectories = [self.preprocess(row) for row in rows] - return {'messages': trajectories} + rows = [self.preprocess(row) for row in rows] + rows = self.map_row_to_col(rows) + return rows diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 167a459d..cd783a9d 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -179,7 +179,7 @@ def _add_default_system(self, trajectory: Trajectory) -> List[Trajectory]: if self.use_chat_template and self.default_system: if trajectory['messages'][0]['role'] == 'user': trajectory['messages'].insert(0, Message(role='system', content=self.default_system)) - for (_, messages) in trajectory.get('extend_message', []): + for messages in trajectory.get('extend_message', []): if messages and messages[0]['role'] == 'user': messages.insert(0, Message(role='system', content=self.default_system)) return [trajectory] @@ -209,43 +209,79 @@ def _extract_reasoning_content(messages: list[Message]) -> List[Message]: extra_messages = trajectory.get('extend_message', []) if extra_messages: result = [] - for key, extra_message in trajectory.get('extend_message', []): - result.append((key, _extract_reasoning_content(extra_message))) + for extra_message in trajectory.get('extend_message', []): + result.append(_extract_reasoning_content(extra_message)) trajectory['extend_message'] = result return [trajectory] + def _truncate_feature(self, feature: InputFeature, strategy: str) -> InputFeature: + """Truncate input_ids and labels in a single InputFeature.""" + length = len(feature['input_ids']) + if length <= self.max_length: + return feature + if strategy == 'raise': + raise ValueError(f'Input length {length} exceeds max_length {self.max_length}') + result = dict(feature) + if strategy == 'left': + result['input_ids'] = result['input_ids'][-self.max_length:] + if 'labels' in result: + result['labels'] = result['labels'][-self.max_length:] + elif strategy == 'right': + result['input_ids'] = result['input_ids'][:self.max_length] + if 'labels' in result: + result['labels'] = result['labels'][:self.max_length] + return InputFeature(**result) + def _check_max_length(self, input_feature: InputFeature) -> List[InputFeature]: - if self.max_length and len(input_feature['input_ids']) > self.max_length: - if self.truncation_strategy == 'raise': - raise ValueError(f'An input message(length: {len(input_feature["input_ids"])} ' - f'exceeds the maximum length({self.max_length})') - elif self.truncation_strategy == 'left': - return [InputFeature(**{key: value[-self.max_length:] for key, value in input_feature.items()})] - elif self.truncation_strategy == 'right': - return [InputFeature(**{key: value[:self.max_length] for key, value in input_feature.items()})] - else: # split - result = [] - total_length = len(input_feature['input_ids']) - for start in range(0, total_length, self.max_length): - end = min(start + self.max_length, total_length) - result.append(InputFeature(**{key: value[start:end] for key, value in input_feature.items()})) - return result - else: + if not self.max_length: return [input_feature] + strategy = self.truncation_strategy + + # Split strategy: doesn't support extend_message + if strategy == 'split': + if input_feature.get('extend_message'): + raise ValueError('Split strategy does not support extend_message.') + results = [] + for start in range(0, len(input_feature['input_ids']), self.max_length): + end = min(start + self.max_length, len(input_feature['input_ids'])) + feat = dict(input_feature) + feat['input_ids'] = feat['input_ids'][start:end] + if 'labels' in feat: + feat['labels'] = feat['labels'][start:end] + results.append(InputFeature(**feat)) + return results + + # left/right/raise: apply to main and extend_message + result = self._truncate_feature(input_feature, strategy) + if input_feature.get('extend_message'): + result['extend_message'] = [self._truncate_feature(f, strategy) for f in input_feature['extend_message']] + return [result] + + def _add_attention_to_feature(self, feature: InputFeature) -> InputFeature: + """Add attention fields to a single InputFeature.""" + input_ids = feature['input_ids'] + feature['attention_mask'] = np.ones_like(input_ids) + feature['position_ids'] = np.arange(len(input_ids)) + feature['length'] = len(input_ids) + return feature + def _add_attention_fields(self, input_feature: InputFeature) -> List[InputFeature]: - input_ids = input_feature['input_ids'] - input_feature['attention_mask'] = np.ones_like(input_ids) - input_feature['position_ids'] = np.arange(len(input_ids)) - input_feature['length'] = len(input_ids) + self._add_attention_to_feature(input_feature) + if input_feature.get('extend_message'): + for f in input_feature['extend_message']: + self._add_attention_to_feature(f) return [input_feature] def _roll_labels(self, input_feature: InputFeature) -> List[InputFeature]: input_feature['labels'] = np.roll(input_feature['labels'], -1, axis=-1) + if input_feature.get('extend_message'): + for f in input_feature['extend_message']: + f['labels'] = np.roll(f['labels'], -1, axis=-1) return [input_feature] - def _build_mm_messages(self, trajectory: Trajectory) -> List[Trajectory]: - messages = trajectory['messages'] + def _process_mm_messages(self, messages: List) -> List: + """Process multimodal content in a list of messages.""" new_messages = [] for message in messages: message = copy(message) @@ -265,8 +301,16 @@ def _build_mm_messages(self, trajectory: Trajectory) -> List[Trajectory]: new_messages.append( transfer_to_standard_message(message, self.image_placeholder, self.video_placeholder, self.audio_placeholder, self.is_mm)) + return new_messages + + def _build_mm_messages(self, trajectory: Trajectory) -> List[Trajectory]: + trajectory['messages'] = self._process_mm_messages(trajectory['messages']) + if trajectory.get('extend_message'): + new_extend_message = [] + for msgs in trajectory['extend_message']: + new_extend_message.append(self._process_mm_messages(msgs)) + trajectory['extend_message'] = new_extend_message - trajectory['messages'] = new_messages return [trajectory] def _apply_chat_template(self, trajectory: Trajectory, add_generation_prompt: bool = False, **kwargs): @@ -283,7 +327,8 @@ def _apply_chat_template(self, trajectory: Trajectory, add_generation_prompt: bo **kwargs) return inputs - def encode(self, trajectory: Trajectory, add_generation_prompt: bool = False) -> InputFeature: + def _encode_messages(self, trajectory: Trajectory, add_generation_prompt: bool = False) -> InputFeature: + """Encode a single trajectory's messages into InputFeature.""" if self.use_chat_template: if add_generation_prompt: # For inference: just get input_ids with generation prompt, no labels needed @@ -306,11 +351,30 @@ def encode(self, trajectory: Trajectory, add_generation_prompt: bool = False) -> input_ids = self.tokenizer.encode(text) encoded = {} labels = deepcopy(input_ids) - return InputFeature( + + input_feature = InputFeature( input_ids=np.array(input_ids), labels=np.array(labels), **encoded, ) + trajectory.update(input_feature) + return trajectory + + def encode(self, trajectory: Trajectory, add_generation_prompt: bool = False) -> InputFeature: + # Encode main messages + result = self._encode_messages(trajectory, add_generation_prompt) + + # Encode extend_message (e.g., rejected messages in DPO) + if trajectory.get('extend_message'): + encoded_extend = [] + for msgs in trajectory['extend_message']: + # Create a temporary trajectory with the extended messages + ext_trajectory = Trajectory(messages=msgs) + ext_feature = self._encode_messages(ext_trajectory, add_generation_prompt) + encoded_extend.append(ext_feature) + result['extend_message'] = encoded_extend + + return result @staticmethod def map_col_to_row(trajectories: Dict[str, Any]): From d1f223fbc9db55362e3e12363a07139957293f1d Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 27 Mar 2026 15:01:31 +0800 Subject: [PATCH 09/31] wip --- cookbook/rl/dpo.py | 48 +++--- .../Components/Data Format/Trajectory.md | 6 +- .../Trajectory.md" | 6 +- src/twinkle/data_format/trajectory.py | 1 - src/twinkle/dataset/base.py | 2 +- src/twinkle/preprocessor/dpo.py | 147 ++++++++-------- src/twinkle/template/base.py | 157 ++++++++++-------- 7 files changed, 199 insertions(+), 168 deletions(-) diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo.py index e7720714..a9937aa2 100644 --- a/cookbook/rl/dpo.py +++ b/cookbook/rl/dpo.py @@ -5,7 +5,7 @@ Pipeline: 1. Load preference dataset with chosen/rejected pairs. - 2. Encode chosen and rejected separately. + 2. Encode positive and negative separately. 3. Compute reference model log probabilities (frozen). 4. Train policy model using DPO loss. @@ -20,9 +20,9 @@ DataLoader RefModel (frozen) PolicyModel (trainable) (ref GPUs) (policy GPUs) -DPO Trajectory format: - - messages: List[Message] - chosen response - - extend_message: [('rejected_messages', List[Message])] - rejected response +DPO data format (after preprocessing): + - positive: List[Trajectory] - chosen responses + - negative: List[Trajectory] - rejected responses For SimPO/ORPO variants that don't require a reference model, set REF_MODEL_GPUS=0 to skip reference model computation. @@ -89,6 +89,7 @@ def create_dpo_dataset(): + """Create DPO dataset with positive/negative format.""" dataset = Dataset(DatasetMeta(DATASET_ID)) dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH) dataset.map( @@ -97,41 +98,33 @@ def create_dpo_dataset(): 'system': SYSTEM_PROMPT, } ) + # DPO preprocessor returns {'positive': [...], 'negative': [...]} + # batch_encode handles this format automatically + dataset.encode() return dataset def prepare_dpo_batch( - batch: List[Trajectory], + batch: Dict[str, List[Any]], template: Template, ) -> List[Dict[str, Any]]: - """Prepare DPO batch: encode both chosen and rejected from preprocessed Trajectories. + """Prepare DPO batch: convert encoded batch to list format for training. Args: - batch: List of Trajectory objects with: - - messages: chosen response messages - - extend_message: [('rejected_messages', rejected_messages)] + batch: Dict with 'positive' and 'negative' keys, each containing List[InputFeature] Returns: - List organized as [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] + List organized as [positive_1, ..., positive_n, negative_1, ..., negative_n] """ - chosen_trajectories = [] - rejected_trajectories = [] - - for traj in batch: - chosen_trajectories.append(Trajectory(messages=traj.messages)) - rejected_messages = [m[1] for m in traj['extend_messages'] if m[0] == 'rejected_messages'][0] - rejected_trajectories.append(Trajectory(messages=rejected_messages)) - - # Batch encode all trajectories - chosen_encoded = template.batch_encode(chosen_trajectories) - rejected_encoded = template.batch_encode(rejected_trajectories) + positive_features = batch.get('positive', []) + negative_features = batch.get('negative', []) # Convert to list of dicts - chosen_samples = [dict(enc) for enc in chosen_encoded] - rejected_samples = [dict(enc) for enc in rejected_encoded] + positive_samples = [dict(f) for f in positive_features] + negative_samples = [dict(f) for f in negative_features] - # Return [chosen..., rejected...] - return chosen_samples + rejected_samples + # Return [positive..., negative...] + return positive_samples + negative_samples # ── Loss Factory ────────────────────────────────────────────────────────────── @@ -230,9 +223,8 @@ def main(): if optim_step >= MAX_STEPS: break - # Prepare DPO batch: [chosen..., rejected...] - batch_list = batch if isinstance(batch, list) else [batch] - dpo_batch = prepare_dpo_batch(batch_list, template) + # batch is Dict[str, List[Trajectory]] with 'positive' and 'negative' keys + dpo_batch = prepare_dpo_batch(batch, template) # Compute reference log probabilities if using reference model # We compute sequence-level logps here to avoid alignment issues with micro-batching diff --git a/docs/source_en/Components/Data Format/Trajectory.md b/docs/source_en/Components/Data Format/Trajectory.md index d0c14aec..0efc6b6e 100644 --- a/docs/source_en/Components/Data Format/Trajectory.md +++ b/docs/source_en/Components/Data Format/Trajectory.md @@ -5,12 +5,14 @@ The raw data structure input to Template after dataset ETL is `Trajectory` (traj ```python class Trajectory(TypedDict, total=False): messages: List[Message] - extend_message: List[Tuple[str, List[Message]]] tools: List[Tool] + user_data: List[Tuple[str, Any]] ``` - messages: A list of Message messages, representing the multi-turn conversations actually conducted by the model, usually alternating between `user` and `assistant`. -- extend_message: In training such as DPO and PPO, unusable trajectories or low-score trajectories are usually needed, which will be placed in extend_message - tools: A list of all available tools for the model in this call +- user_data: User-defined data, such as labels in KTO training + +For preference alignment training like DPO, preprocessors return `{'positive': List[Trajectory], 'negative': List[Trajectory]}` format. Trajectory is the standard interface for all dataset preprocessing outputs and template inputs in Twinkle. The format conversion goes from the original dataset to Trajectory, and then to InputFeature. diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Trajectory.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Trajectory.md" index f7ed4f12..5281999c 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Trajectory.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Trajectory.md" @@ -5,12 +5,14 @@ ```python class Trajectory(TypedDict, total=False): messages: List[Message] - extend_message: List[Tuple[str, List[Message]]] tools: List[Tool] + user_data: List[Tuple[str, Any]] ``` - messages: Message消息的列表,代表模型实际进行的多轮对话,通常是`user`和`assistant`交替出现。 -- extend_message: 在DPO、PPO等训练中通常需要不可用轨迹,或低分轨迹,该轨迹会放在extend_message中 - tools: 模型在本次调用中的所有可用工具列表 +- user_data: 用户自定义数据,如KTO训练中的label + +对于DPO等偏好对齐训练,预处理器返回`{'positive': List[Trajectory], 'negative': List[Trajectory]}`格式。 Trajectory是twinkle中所有数据集预处理输出,模板输入的标准接口。格式转换为由原始数据集转换为Trajectory,再到InputFeature。 diff --git a/src/twinkle/data_format/trajectory.py b/src/twinkle/data_format/trajectory.py index d6910eaf..51a21fc5 100644 --- a/src/twinkle/data_format/trajectory.py +++ b/src/twinkle/data_format/trajectory.py @@ -13,6 +13,5 @@ class Trajectory(TypedDict, total=False): messages: List[Message] - extend_message: List[List[Message]] tools: List[Tool] user_data: List[Tuple[str, Any]] diff --git a/src/twinkle/dataset/base.py b/src/twinkle/dataset/base.py index 00adc984..fc14a030 100644 --- a/src/twinkle/dataset/base.py +++ b/src/twinkle/dataset/base.py @@ -87,7 +87,7 @@ def encode(self, add_generation_prompt: bool = False, **kwargs): with processing_lock('dataset'): # use a default lock because encode is to all datasets self.dataset = self.dataset.map(encode_fn, - **kwargs).filter(lambda batch: [len(x) > 0 for x in batch['input_ids']], + **kwargs).filter(lambda batch: [True] * len(next(iter(batch.values()))) if 'input_ids' not in batch else [len(x) > 0 for x in batch['input_ids']], **kwargs) @remote_function() diff --git a/src/twinkle/preprocessor/dpo.py b/src/twinkle/preprocessor/dpo.py index 7c35b676..0a03c4ad 100644 --- a/src/twinkle/preprocessor/dpo.py +++ b/src/twinkle/preprocessor/dpo.py @@ -3,11 +3,11 @@ DPO (Direct Preference Optimization) Data Preprocessors. These preprocessors convert various preference dataset formats into the standard -Trajectory format required by Twinkle for DPO training. +format required by Twinkle for DPO training. -DPO Trajectory format: - - messages: List[Message] - chosen response messages - - extend_message: [('rejected_messages', List[Message])] - rejected response messages +DPO output format: + - positive: Trajectory - chosen response trajectory + - negative: Trajectory - rejected response trajectory """ from typing import Any, Dict, List, Optional, Union @@ -18,7 +18,7 @@ class DPOProcessor(Preprocessor): """Generic DPO preference data preprocessor. - Converts preference data with chosen/rejected pairs into Trajectory format. + Converts preference data with chosen/rejected pairs into positive/negative Trajectories. Supports multiple common dataset formats. Expected input format (one of): @@ -27,9 +27,9 @@ class DPOProcessor(Preprocessor): 3. {'messages': List[Message], 'chosen': str, 'rejected': str} 4. {'chosen': List[Message], 'rejected': List[Message]} (full conversations) - Output Trajectory format: - - messages: chosen response (prompt + chosen assistant message) - - extend_message: [('rejected_messages', rejected_messages)] + Output format: + - positive: Trajectory with chosen response + - negative: Trajectory with rejected response Args: system: Optional system prompt to prepend. @@ -99,11 +99,11 @@ def _build_prompt_messages(self, row: Dict[str, Any]) -> List[Message]: return messages - def preprocess(self, row: Dict[str, Any]) -> Trajectory: - """Process a single row into a DPO Trajectory. + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: + """Process a single row into positive/negative Trajectories. Returns: - Trajectory with chosen in messages and rejected in extend_message. + Dict with 'positive' and 'negative' Trajectory. """ # Build prompt messages prompt_messages = self._build_prompt_messages(row) @@ -120,18 +120,22 @@ def preprocess(self, row: Dict[str, Any]) -> Trajectory: chosen_messages = prompt_messages + chosen_response rejected_messages = prompt_messages + rejected_response - # Return Trajectory with rejected in extend_message - return Trajectory( - messages=chosen_messages, - extend_message=[rejected_messages] - ) + return { + 'positive': Trajectory(messages=chosen_messages), + 'negative': Trajectory(messages=rejected_messages), + } def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - """Process batched data into DPO trajectories.""" + """Process batched data into DPO format.""" rows = self.map_col_to_row(rows) - rows = [self.preprocess(row) for row in rows] - rows = self.map_row_to_col(rows) - return rows + results = [self.preprocess(row) for row in rows] + # Collect all positive and negative trajectories + positive_list = [r['positive'] for r in results] + negative_list = [r['negative'] for r in results] + return { + 'positive': positive_list, + 'negative': negative_list, + } class HHRLHFProcessor(Preprocessor): @@ -173,7 +177,7 @@ def _parse_hh_conversation(self, text: str) -> List[Message]: return messages - def preprocess(self, row: Dict[str, Any]) -> Trajectory: + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: """Process HH-RLHF format row.""" chosen_text = row.get('chosen', '') rejected_text = row.get('rejected', '') @@ -181,16 +185,18 @@ def preprocess(self, row: Dict[str, Any]) -> Trajectory: chosen_messages = self._parse_hh_conversation(chosen_text) rejected_messages = self._parse_hh_conversation(rejected_text) - return Trajectory( - messages=chosen_messages, - extend_message=[rejected_messages] - ) + return { + 'positive': Trajectory(messages=chosen_messages), + 'negative': Trajectory(messages=rejected_messages), + } def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - rows = [self.preprocess(row) for row in rows] - rows = self.map_row_to_col(rows) - return rows + results = [self.preprocess(row) for row in rows] + return { + 'positive': [r['positive'] for r in results], + 'negative': [r['negative'] for r in results], + } class UltraFeedbackProcessor(Preprocessor): @@ -222,7 +228,7 @@ def __init__( self.response_key = response_key self.score_key = score_key - def preprocess(self, row: Dict[str, Any]) -> Optional[Trajectory]: + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Trajectory]]: """Process UltraFeedback format row.""" instruction = row.get(self.instruction_key, '') completions = row.get(self.completions_key, []) @@ -253,22 +259,21 @@ def preprocess(self, row: Dict[str, Any]) -> Optional[Trajectory]: chosen_messages = prompt_messages + [Message(role='assistant', content=chosen_response)] rejected_messages = prompt_messages + [Message(role='assistant', content=rejected_response)] - return Trajectory( - messages=chosen_messages, - extend_message=[rejected_messages] - ) + return { + 'positive': Trajectory(messages=chosen_messages), + 'negative': Trajectory(messages=rejected_messages), + } def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - trajectories = [] - for row in rows: - result = self.preprocess(row) - if result is not None: - trajectories.append(result) - if not trajectories: + results = [self.preprocess(row) for row in rows] + results = [r for r in results if r is not None] + if not results: return {} - rows = self.map_row_to_col(trajectories) - return rows + return { + 'positive': [r['positive'] for r in results], + 'negative': [r['negative'] for r in results], + } class ShareGPTDPOProcessor(Preprocessor): @@ -303,7 +308,7 @@ def _parse_sharegpt_message(self, msg: Dict) -> Message: content = msg.get('value', '') or msg.get('content', '') return Message(role=role, content=content) - def preprocess(self, row: Dict[str, Any]) -> Trajectory: + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: """Process ShareGPT DPO format row.""" conversations = row.get('conversations', []) @@ -336,16 +341,18 @@ def preprocess(self, row: Dict[str, Any]) -> Trajectory: chosen_messages = prompt_messages + [Message(role='assistant', content=chosen_content)] rejected_messages = prompt_messages + [Message(role='assistant', content=rejected_content)] - return Trajectory( - messages=chosen_messages, - extend_message=[rejected_messages] - ) + return { + 'positive': Trajectory(messages=chosen_messages), + 'negative': Trajectory(messages=rejected_messages), + } def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - rows = [self.preprocess(row) for row in rows] - rows = self.map_row_to_col(rows) - return rows + results = [self.preprocess(row) for row in rows] + return { + 'positive': [r['positive'] for r in results], + 'negative': [r['negative'] for r in results], + } class IntelOrcaDPOProcessor(Preprocessor): @@ -363,7 +370,7 @@ class IntelOrcaDPOProcessor(Preprocessor): def __init__(self, default_system: Optional[str] = None): self.default_system = default_system - def preprocess(self, row: Dict[str, Any]) -> Trajectory: + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: """Process Intel ORCA DPO format row.""" system = row.get('system', self.default_system) question = row.get('question', '') @@ -378,16 +385,18 @@ def preprocess(self, row: Dict[str, Any]) -> Trajectory: chosen_messages = prompt_messages + [Message(role='assistant', content=chosen)] rejected_messages = prompt_messages + [Message(role='assistant', content=rejected)] - return Trajectory( - messages=chosen_messages, - extend_message=[rejected_messages] - ) + return { + 'positive': Trajectory(messages=chosen_messages), + 'negative': Trajectory(messages=rejected_messages), + } def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - rows = [self.preprocess(row) for row in rows] - rows = self.map_row_to_col(rows) - return rows + results = [self.preprocess(row) for row in rows] + return { + 'positive': [r['positive'] for r in results], + 'negative': [r['negative'] for r in results], + } class EmojiDPOProcessor(Preprocessor): @@ -400,9 +409,9 @@ class EmojiDPOProcessor(Preprocessor): 'answer_en': str, # rejected response (English) } - Output Trajectory format: - - messages: prompt + chosen (answer_zh) - - extend_message: [('rejected_messages', prompt + rejected (answer_en))] + Output format: + - positive: Trajectory with chosen (answer_zh) + - negative: Trajectory with rejected (answer_en) Args: system: Optional system prompt. @@ -423,7 +432,7 @@ def __init__( self.rejected_key = rejected_key self.prompt_key = prompt_key - def preprocess(self, row: Dict[str, Any]) -> Trajectory: + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: """Process a single row.""" prompt = row.get(self.prompt_key, '') chosen = row.get(self.chosen_key, '') @@ -437,16 +446,18 @@ def preprocess(self, row: Dict[str, Any]) -> Trajectory: chosen_messages = prompt_messages + [Message(role='assistant', content=chosen)] rejected_messages = prompt_messages + [Message(role='assistant', content=rejected)] - return Trajectory( - messages=chosen_messages, - extend_message=[rejected_messages] - ) + return { + 'positive': Trajectory(messages=chosen_messages), + 'negative': Trajectory(messages=rejected_messages), + } def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - rows = [self.preprocess(row) for row in rows] - rows = self.map_row_to_col(rows) - return rows + results = [self.preprocess(row) for row in rows] + return { + 'positive': [r['positive'] for r in results], + 'negative': [r['negative'] for r in results], + } class UltraFeedbackKTOProcessor(Preprocessor): diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index cd783a9d..31c04b6a 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -179,9 +179,6 @@ def _add_default_system(self, trajectory: Trajectory) -> List[Trajectory]: if self.use_chat_template and self.default_system: if trajectory['messages'][0]['role'] == 'user': trajectory['messages'].insert(0, Message(role='system', content=self.default_system)) - for messages in trajectory.get('extend_message', []): - if messages and messages[0]['role'] == 'user': - messages.insert(0, Message(role='system', content=self.default_system)) return [trajectory] def _to_standard_reasoning_content(self, trajectory: Trajectory) -> List[Trajectory]: @@ -206,12 +203,6 @@ def _extract_reasoning_content(messages: list[Message]) -> List[Message]: return result trajectory['messages'] = _extract_reasoning_content(trajectory['messages']) - extra_messages = trajectory.get('extend_message', []) - if extra_messages: - result = [] - for extra_message in trajectory.get('extend_message', []): - result.append(_extract_reasoning_content(extra_message)) - trajectory['extend_message'] = result return [trajectory] def _truncate_feature(self, feature: InputFeature, strategy: str) -> InputFeature: @@ -238,10 +229,8 @@ def _check_max_length(self, input_feature: InputFeature) -> List[InputFeature]: strategy = self.truncation_strategy - # Split strategy: doesn't support extend_message + # Split strategy if strategy == 'split': - if input_feature.get('extend_message'): - raise ValueError('Split strategy does not support extend_message.') results = [] for start in range(0, len(input_feature['input_ids']), self.max_length): end = min(start + self.max_length, len(input_feature['input_ids'])) @@ -252,32 +241,18 @@ def _check_max_length(self, input_feature: InputFeature) -> List[InputFeature]: results.append(InputFeature(**feat)) return results - # left/right/raise: apply to main and extend_message - result = self._truncate_feature(input_feature, strategy) - if input_feature.get('extend_message'): - result['extend_message'] = [self._truncate_feature(f, strategy) for f in input_feature['extend_message']] - return [result] - - def _add_attention_to_feature(self, feature: InputFeature) -> InputFeature: - """Add attention fields to a single InputFeature.""" - input_ids = feature['input_ids'] - feature['attention_mask'] = np.ones_like(input_ids) - feature['position_ids'] = np.arange(len(input_ids)) - feature['length'] = len(input_ids) - return feature + # left/right/raise + return [self._truncate_feature(input_feature, strategy)] def _add_attention_fields(self, input_feature: InputFeature) -> List[InputFeature]: - self._add_attention_to_feature(input_feature) - if input_feature.get('extend_message'): - for f in input_feature['extend_message']: - self._add_attention_to_feature(f) + input_ids = input_feature['input_ids'] + input_feature['attention_mask'] = np.ones_like(input_ids) + input_feature['position_ids'] = np.arange(len(input_ids)) + input_feature['length'] = len(input_ids) return [input_feature] def _roll_labels(self, input_feature: InputFeature) -> List[InputFeature]: input_feature['labels'] = np.roll(input_feature['labels'], -1, axis=-1) - if input_feature.get('extend_message'): - for f in input_feature['extend_message']: - f['labels'] = np.roll(f['labels'], -1, axis=-1) return [input_feature] def _process_mm_messages(self, messages: List) -> List: @@ -305,12 +280,6 @@ def _process_mm_messages(self, messages: List) -> List: def _build_mm_messages(self, trajectory: Trajectory) -> List[Trajectory]: trajectory['messages'] = self._process_mm_messages(trajectory['messages']) - if trajectory.get('extend_message'): - new_extend_message = [] - for msgs in trajectory['extend_message']: - new_extend_message.append(self._process_mm_messages(msgs)) - trajectory['extend_message'] = new_extend_message - return [trajectory] def _apply_chat_template(self, trajectory: Trajectory, add_generation_prompt: bool = False, **kwargs): @@ -361,20 +330,7 @@ def _encode_messages(self, trajectory: Trajectory, add_generation_prompt: bool = return trajectory def encode(self, trajectory: Trajectory, add_generation_prompt: bool = False) -> InputFeature: - # Encode main messages - result = self._encode_messages(trajectory, add_generation_prompt) - - # Encode extend_message (e.g., rejected messages in DPO) - if trajectory.get('extend_message'): - encoded_extend = [] - for msgs in trajectory['extend_message']: - # Create a temporary trajectory with the extended messages - ext_trajectory = Trajectory(messages=msgs) - ext_feature = self._encode_messages(ext_trajectory, add_generation_prompt) - encoded_extend.append(ext_feature) - result['extend_message'] = encoded_extend - - return result + return self._encode_messages(trajectory, add_generation_prompt) @staticmethod def map_col_to_row(trajectories: Dict[str, Any]): @@ -402,21 +358,90 @@ def map_row_to_col(rows: List[Union[Dict[str, Any], InputFeature]]) -> Dict[str, return columns - def batch_encode(self, - trajectories: Union[Dict[str, Any], List[Trajectory]], - add_generation_prompt: bool = False) -> List[InputFeature]: - output = [] - _transfer = False + def _is_trajectory(self, obj: Any) -> bool: + """Check if an object is a Trajectory (has 'messages' key).""" + return isinstance(obj, Mapping) and 'messages' in obj + + def _is_trajectory_dict(self, obj: Any) -> bool: + """Check if obj is Dict[str, Trajectory] - all values are Trajectories.""" + if not isinstance(obj, Mapping) or not obj: + return False + return all(self._is_trajectory(v) for v in obj.values()) + + def _get_trajectory_keys(self, obj: Mapping) -> List[str]: + """Get keys in a dict whose values are Trajectories.""" + return [k for k, v in obj.items() if self._is_trajectory(v)] + + def _is_columnar_format(self, obj: Any) -> bool: + """Check if obj is columnar format: Dict[str, List[Any]] but NOT a Trajectory.""" + if not isinstance(obj, Mapping) or not obj: + return False + # Trajectory has 'messages' key with list of Message dicts - not columnar + if self._is_trajectory(obj): + return False + # Dict[str, Trajectory] - not columnar + if self._is_trajectory_dict(obj): + return False + # Check if all values are non-empty lists of same length + first_val = next(iter(obj.values())) + if not isinstance(first_val, list) or len(first_val) == 0: + return False + length = len(first_val) + return all(isinstance(v, list) and len(v) == length for v in obj.values()) + + def batch_encode( + self, + trajectories: Union[Dict[str, Any], List[Trajectory], Trajectory], + add_generation_prompt: bool = False, + ) -> Union[Dict[str, Any], List[InputFeature], InputFeature]: + """Encode trajectories into InputFeatures. + + Supports three input formats: + 1. Trajectory -> InputFeature + 2. List[Trajectory] -> List[InputFeature] + 3. Dict containing Trajectories -> Dict with Trajectories encoded + + Also handles columnar format (Dict[str, List]) by converting to rows first. + """ + # Handle columnar format: convert to rows first + if self._is_columnar_format(trajectories): + rows = self.map_col_to_row(trajectories) + encoded = self.batch_encode(rows, add_generation_prompt=add_generation_prompt) + if isinstance(encoded, list) and encoded: + return self.map_row_to_col(encoded) + return encoded + + # Case 1: Single Trajectory + if self._is_trajectory(trajectories) and not self._is_trajectory_dict(trajectories): + processed = self._invoke_pre_pipeline([trajectories]) + output = [self.encode(t, add_generation_prompt=add_generation_prompt) for t in processed] + output = self._invoke_post_pipeline(output) + return output[0] if len(output) == 1 else output + + # Case 2: List (Trajectory or Dict containing Trajectories) + if isinstance(trajectories, list): + if not trajectories: + return [] + first = trajectories[0] + if isinstance(first, Mapping) and self._get_trajectory_keys(first): + # List of dicts containing Trajectories + return [self.batch_encode(row, add_generation_prompt=add_generation_prompt) for row in trajectories] + else: + # List[Trajectory] + processed = self._invoke_pre_pipeline(trajectories) + output = [self.encode(t, add_generation_prompt=add_generation_prompt) for t in processed] + return self._invoke_post_pipeline(output) + + # Case 3: Dict containing Trajectories (encode only Trajectory values) if isinstance(trajectories, Mapping): - _transfer = True - trajectories = self.map_col_to_row(trajectories) - trajectories = self._invoke_pre_pipeline(trajectories) - for trajectory in trajectories: - output.append(self.encode(trajectory, add_generation_prompt=add_generation_prompt)) - output = self._invoke_post_pipeline(output) - if _transfer: - output = self.map_row_to_col(output) - return output + traj_keys = self._get_trajectory_keys(trajectories) + if traj_keys: + result = dict(trajectories) # Copy non-trajectory keys + for key in traj_keys: + result[key] = self.batch_encode(trajectories[key], add_generation_prompt=add_generation_prompt) + return result + + raise ValueError(f'Unsupported input type: {type(trajectories)}') def check(self, trajectory: Trajectory) -> Optional[Trajectory]: encoded = None From 3a25caa7ad20d55066f0e77a098cd3704c0beabc Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 27 Mar 2026 17:01:22 +0800 Subject: [PATCH 10/31] wip --- cookbook/rl/dpo.py | 91 ++++++++++------------------ src/twinkle/loss/dpo.py | 91 ++++++++++++++-------------- src/twinkle/metric/loss.py | 1 - src/twinkle/template/base.py | 114 +++++++++++++++-------------------- 4 files changed, 124 insertions(+), 173 deletions(-) diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo.py index a9937aa2..00311c63 100644 --- a/cookbook/rl/dpo.py +++ b/cookbook/rl/dpo.py @@ -51,7 +51,6 @@ import os from typing import Any, Dict, List, Optional -import torch from peft import LoraConfig import twinkle @@ -63,7 +62,6 @@ from twinkle.model import TransformersModel from twinkle.preprocessor import EmojiDPOProcessor from twinkle.processor import InputProcessor -from twinkle.template import Template logger = get_logger() @@ -75,8 +73,8 @@ REF_MODEL_GPUS = int(os.environ.get('REF_MODEL_GPUS', 4)) NUM_GPUS = MODEL_GPUS + REF_MODEL_GPUS -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs -MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) # Number of preference pairs +MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 4)) GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) LEARNING_RATE = float(os.environ.get('LR', 5e-6)) @@ -100,31 +98,38 @@ def create_dpo_dataset(): ) # DPO preprocessor returns {'positive': [...], 'negative': [...]} # batch_encode handles this format automatically - dataset.encode() + dataset.encode(load_from_cache_file=True) return dataset -def prepare_dpo_batch( - batch: Dict[str, List[Any]], - template: Template, -) -> List[Dict[str, Any]]: - """Prepare DPO batch: convert encoded batch to list format for training. +def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Prepare DPO batch: reorganize batch for training with DP-safe interleaving. Args: - batch: Dict with 'positive' and 'negative' keys, each containing List[InputFeature] + batch: List of rows, each with 'positive' and 'negative' InputFeatures + and other fields (question, etc.) Returns: - List organized as [positive_1, ..., positive_n, negative_1, ..., negative_n] + List interleaved as [pos_1, neg_1, pos_2, neg_2, ...] to ensure each DP + worker gets complete positive/negative pairs after slicing. + Each item contains all original fields plus the InputFeature fields. """ - positive_features = batch.get('positive', []) - negative_features = batch.get('negative', []) + result = [] - # Convert to list of dicts - positive_samples = [dict(f) for f in positive_features] - negative_samples = [dict(f) for f in negative_features] + for row in batch: + # Get base fields (excluding positive/negative) + base_fields = {k: v for k, v in row.items() if k not in ('positive', 'negative')} - # Return [positive..., negative...] - return positive_samples + negative_samples + # Positive sample: merge base fields with positive InputFeature + pos_sample = {**base_fields, **row['positive']} + # Negative sample: merge base fields with negative InputFeature + neg_sample = {**base_fields, **row['negative']} + + # Interleave: [pos, neg] per pair for DP-safe slicing + result.append(pos_sample) + result.append(neg_sample) + + return result # ── Loss Factory ────────────────────────────────────────────────────────────── @@ -196,9 +201,6 @@ def main(): policy_model.set_processor(InputProcessor) policy_model.set_template('Template', model_id=MODEL_ID) - # Get template for encoding rejected messages - template = Template(model_id=MODEL_ID, max_length=MAX_LENGTH) - # ── Reference Model Setup ───────────────────────────────────────────────── ref_model = None if not reference_free: @@ -223,50 +225,19 @@ def main(): if optim_step >= MAX_STEPS: break - # batch is Dict[str, List[Trajectory]] with 'positive' and 'negative' keys - dpo_batch = prepare_dpo_batch(batch, template) + # batch is List[Dict] with 'positive' and 'negative' keys + dpo_batch = prepare_dpo_batch(batch) - # Compute reference log probabilities if using reference model - # We compute sequence-level logps here to avoid alignment issues with micro-batching - ref_chosen_logps = None - ref_rejected_logps = None + # Get reference outputs (lazy - not collected to driver) + ref_outputs = None if ref_model is not None: - with torch.no_grad(): - ref_outputs = ref_model.forward_only(inputs=dpo_batch) - ref_logps = ref_outputs.get('logps') # [batch, seq_len] - if ref_logps is not None: - # Get labels and pad to same length for stacking - label_tensors = [torch.as_tensor(s['labels']) for s in dpo_batch] - max_len = max(t.shape[0] for t in label_tensors) - # Pad labels with -100 (ignore_index) to max length - padded_labels = [] - for t in label_tensors: - if t.shape[0] < max_len: - pad_size = max_len - t.shape[0] - t = torch.cat([torch.full((pad_size,), -100, dtype=t.dtype), t]) - padded_labels.append(t) - ref_labels = torch.stack(padded_labels) - if ref_labels.device != ref_logps.device: - ref_labels = ref_labels.to(ref_logps.device) - # Align sequence lengths if needed - if ref_logps.shape[1] != ref_labels.shape[1]: - min_len = min(ref_logps.shape[1], ref_labels.shape[1]) - ref_logps = ref_logps[:, -min_len:] - ref_labels = ref_labels[:, -min_len:] - # Compute sequence-level logps (sum of valid token logps) - loss_mask = (ref_labels != -100).float() - seq_logps = (ref_logps * loss_mask).sum(dim=-1) # [batch] - - # Split into chosen and rejected - half = seq_logps.shape[0] // 2 - ref_chosen_logps = seq_logps[:half] - ref_rejected_logps = seq_logps[half:] + ref_outputs = ref_model.forward_only(inputs=dpo_batch) # Forward-backward pass with DPO loss + # ref_outputs is passed to loss which extracts logps internally policy_model.forward_backward( inputs=dpo_batch, - ref_chosen_logps=ref_chosen_logps, - ref_rejected_logps=ref_rejected_logps, + ref_outputs=ref_outputs, ) # Gradient clipping and optimizer step diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py index 7f89d446..52fc28b0 100644 --- a/src/twinkle/loss/dpo.py +++ b/src/twinkle/loss/dpo.py @@ -88,9 +88,13 @@ def _split_chosen_rejected( self, tensor: 'torch.Tensor', ) -> tuple: - """Split tensor into chosen (first half) and rejected (second half).""" - half = tensor.shape[0] // 2 - return tensor[:half], tensor[half:] + """Split interleaved tensor into chosen and rejected. + + Input format: [pos_1, neg_1, pos_2, neg_2, ...] (interleaved for DP-safe slicing) + Output: (chosen [pos_1, pos_2, ...], rejected [neg_1, neg_2, ...]) + """ + # Even indices = chosen (positive), odd indices = rejected (negative) + return tensor[0::2], tensor[1::2] class DPOLoss(PreferenceLossBase): @@ -131,20 +135,18 @@ def __init__( self.loss_type = loss_type self.reference_free = reference_free - def _pad_and_align_logps( + def _align_logps( self, - logps: Union['torch.Tensor', List[List[float]]], + logps: 'torch.Tensor', target_shape: tuple, - loss_mask: 'torch.Tensor', device: 'torch.device', dtype: 'torch.dtype', ) -> 'torch.Tensor': - """Pad and align log probabilities to target shape. + """Align log probabilities to target shape. Args: - logps: Input log probabilities (tensor or ragged list) + logps: Input log probabilities tensor target_shape: Target (batch, seq_len) shape - loss_mask: Boolean mask for valid positions device: Target device dtype: Target dtype @@ -153,40 +155,32 @@ def _pad_and_align_logps( """ import torch - if torch.is_tensor(logps): - if logps.dim() == 1: - logps = logps.unsqueeze(0) - if logps.shape == target_shape: - return logps.to(device=device, dtype=dtype) - # Handle tensor with different sequence length - align to target shape - if logps.dim() == 2 and logps.shape[0] == target_shape[0]: - batch_size, target_seq_len = target_shape - src_seq_len = logps.shape[1] - logps = logps.to(device=device, dtype=dtype) - if src_seq_len > target_seq_len: - # Truncate: take the last target_seq_len tokens (response part) - return logps[:, -target_seq_len:] - else: - # Pad: add zeros at the beginning - padded = torch.zeros(target_shape, device=device, dtype=dtype) - padded[:, -src_seq_len:] = logps - return padded - - # Handle ragged list input - if isinstance(logps, (list, tuple)): - batch_size, seq_len = target_shape - padded = torch.zeros(target_shape, device=device, dtype=dtype) - for i, row in enumerate(logps): - if row is None: - continue - row_t = torch.as_tensor(row, device=device, dtype=dtype) - valid_positions = loss_mask[i].nonzero(as_tuple=True)[0] - length = min(len(row_t), len(valid_positions)) - if length > 0: - padded[i, valid_positions[:length]] = row_t[:length] - return padded - - return logps.to(device=device, dtype=dtype) + if not torch.is_tensor(logps): + raise TypeError(f'Expected torch.Tensor, got {type(logps)}') + + if logps.dim() == 1: + logps = logps.unsqueeze(0) + + if logps.shape == target_shape: + return logps.to(device=device, dtype=dtype) + + # Handle tensor with different sequence length + if logps.dim() == 2 and logps.shape[0] == target_shape[0]: + batch_size, target_seq_len = target_shape + src_seq_len = logps.shape[1] + logps = logps.to(device=device, dtype=dtype) + if src_seq_len > target_seq_len: + # Truncate right (keep left part) - may happen in Ray result merging + return logps[:, :target_seq_len] + else: + raise ValueError( + f'ref_logps seq_len ({src_seq_len}) < target seq_len ({target_seq_len}). ' + f'This should not happen when both models process the same batch.' + ) + + raise ValueError( + f'Cannot align ref_logps shape {logps.shape} to target shape {target_shape}' + ) def _compute_dpo_loss( self, @@ -254,6 +248,7 @@ def __call__( inputs: Dict, outputs: Dict, *, + ref_outputs: Optional[Dict] = None, ref_logps: Optional[Union['torch.Tensor', List[List[float]]]] = None, ref_chosen_logps: Optional['torch.Tensor'] = None, ref_rejected_logps: Optional['torch.Tensor'] = None, @@ -271,6 +266,7 @@ def __call__( outputs: Dict containing either: - 'logps': [batch, seq_len] pre-computed log probs, OR - 'logits': [batch, seq_len, vocab] from which logps will be computed + ref_outputs: Dict from reference model forward, containing 'logps'. ref_logps: [batch, seq_len] or List[List[float]] reference model log probs. Can also be provided as separate ref_chosen_logps and ref_rejected_logps. ref_chosen_logps: [batch/2] pre-computed reference log probs for chosen. @@ -282,6 +278,10 @@ def __call__( """ import torch + # Extract ref_logps from ref_outputs if provided + if ref_outputs is not None and ref_logps is None: + ref_logps = ref_outputs.get('logps') + labels = inputs.get('labels') assert labels is not None, "inputs must contain 'labels'" if not torch.is_tensor(labels): @@ -312,9 +312,8 @@ def __call__( reference_rejected_logps = ref_rejected_logps.to(device=device, dtype=dtype) elif ref_logps is not None: # Per-token reference log probs provided, need to align and sum - loss_mask = (labels != self.ignore_index).bool() - ref_logps_aligned = self._pad_and_align_logps( - ref_logps, labels.shape, loss_mask, device, dtype + ref_logps_aligned = self._align_logps( + ref_logps, labels.shape, device, dtype ) ref_chosen, ref_rejected = self._split_chosen_rejected(ref_logps_aligned) reference_chosen_logps = self._compute_sequence_logps(ref_chosen, chosen_labels) diff --git a/src/twinkle/metric/loss.py b/src/twinkle/metric/loss.py index 52f50fdd..7c4a8b93 100644 --- a/src/twinkle/metric/loss.py +++ b/src/twinkle/metric/loss.py @@ -52,7 +52,6 @@ def calculate(self): 'grad_norm': self.grad_norm, 'num_tokens': self.num_tokens }] - all_results = self.gather_results(local_results) total_loss = sum(r['loss'] for r in all_results) diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 31c04b6a..a3e53056 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -362,86 +362,68 @@ def _is_trajectory(self, obj: Any) -> bool: """Check if an object is a Trajectory (has 'messages' key).""" return isinstance(obj, Mapping) and 'messages' in obj - def _is_trajectory_dict(self, obj: Any) -> bool: - """Check if obj is Dict[str, Trajectory] - all values are Trajectories.""" - if not isinstance(obj, Mapping) or not obj: - return False - return all(self._is_trajectory(v) for v in obj.values()) - - def _get_trajectory_keys(self, obj: Mapping) -> List[str]: - """Get keys in a dict whose values are Trajectories.""" - return [k for k, v in obj.items() if self._is_trajectory(v)] - - def _is_columnar_format(self, obj: Any) -> bool: - """Check if obj is columnar format: Dict[str, List[Any]] but NOT a Trajectory.""" - if not isinstance(obj, Mapping) or not obj: - return False - # Trajectory has 'messages' key with list of Message dicts - not columnar - if self._is_trajectory(obj): - return False - # Dict[str, Trajectory] - not columnar - if self._is_trajectory_dict(obj): - return False - # Check if all values are non-empty lists of same length - first_val = next(iter(obj.values())) - if not isinstance(first_val, list) or len(first_val) == 0: - return False - length = len(first_val) - return all(isinstance(v, list) and len(v) == length for v in obj.values()) + def _get_trajectory_keys(self, columnar: Mapping) -> List[str]: + """Get keys whose values are lists of Trajectories in columnar format.""" + keys = [] + for k, v in columnar.items(): + if isinstance(v, list) and v and self._is_trajectory(v[0]): + keys.append(k) + return keys def batch_encode( self, - trajectories: Union[Dict[str, Any], List[Trajectory], Trajectory], + trajectories: Union[Dict[str, Any], List[Trajectory]], add_generation_prompt: bool = False, - ) -> Union[Dict[str, Any], List[InputFeature], InputFeature]: + ) -> Union[Dict[str, Any], List[InputFeature]]: """Encode trajectories into InputFeatures. - Supports three input formats: - 1. Trajectory -> InputFeature - 2. List[Trajectory] -> List[InputFeature] - 3. Dict containing Trajectories -> Dict with Trajectories encoded + Args: + trajectories: Either List[Trajectory] or columnar Dict[str, List]. + For DPO, columnar format with 'positive'/'negative' keys containing + List[Trajectory] is supported. + add_generation_prompt: Whether to add generation prompt. - Also handles columnar format (Dict[str, List]) by converting to rows first. + Returns: + List[InputFeature] or columnar Dict[str, List[InputFeature]]. """ - # Handle columnar format: convert to rows first - if self._is_columnar_format(trajectories): - rows = self.map_col_to_row(trajectories) - encoded = self.batch_encode(rows, add_generation_prompt=add_generation_prompt) - if isinstance(encoded, list) and encoded: - return self.map_row_to_col(encoded) - return encoded - - # Case 1: Single Trajectory - if self._is_trajectory(trajectories) and not self._is_trajectory_dict(trajectories): - processed = self._invoke_pre_pipeline([trajectories]) - output = [self.encode(t, add_generation_prompt=add_generation_prompt) for t in processed] - output = self._invoke_post_pipeline(output) - return output[0] if len(output) == 1 else output - - # Case 2: List (Trajectory or Dict containing Trajectories) - if isinstance(trajectories, list): - if not trajectories: - return [] - first = trajectories[0] - if isinstance(first, Mapping) and self._get_trajectory_keys(first): - # List of dicts containing Trajectories - return [self.batch_encode(row, add_generation_prompt=add_generation_prompt) for row in trajectories] - else: - # List[Trajectory] - processed = self._invoke_pre_pipeline(trajectories) - output = [self.encode(t, add_generation_prompt=add_generation_prompt) for t in processed] - return self._invoke_post_pipeline(output) + _transfer = False - # Case 3: Dict containing Trajectories (encode only Trajectory values) if isinstance(trajectories, Mapping): + _transfer = True + # Check if it has trajectory list columns (DPO format) traj_keys = self._get_trajectory_keys(trajectories) if traj_keys: - result = dict(trajectories) # Copy non-trajectory keys - for key in traj_keys: - result[key] = self.batch_encode(trajectories[key], add_generation_prompt=add_generation_prompt) + # DPO format: encode each trajectory list separately, keep other columns + result = {} + for key in trajectories: + if key in traj_keys: + # Encode this trajectory list + result[key] = self.batch_encode( + trajectories[key], add_generation_prompt=add_generation_prompt + ) + else: + # Keep non-trajectory columns as-is + result[key] = trajectories[key] return result + else: + # Standard columnar format + trajectories = self.map_col_to_row(trajectories) + + # Process List[Trajectory] + trajectories = self._invoke_pre_pipeline(trajectories) - raise ValueError(f'Unsupported input type: {type(trajectories)}') + # Use thread pool for parallel encoding + from concurrent.futures import ThreadPoolExecutor + from functools import partial + encode_fn = partial(self.encode, add_generation_prompt=add_generation_prompt) + with ThreadPoolExecutor() as executor: + output = list(executor.map(encode_fn, trajectories)) + + output = self._invoke_post_pipeline(output) + + if _transfer: + output = self.map_row_to_col(output) + return output def check(self, trajectory: Trajectory) -> Optional[Trajectory]: encoded = None From 0cf1ac38d5aebba51d7a0f4753e814037432f048 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 27 Mar 2026 20:08:33 +0800 Subject: [PATCH 11/31] wip --- cookbook/rl/dpo.py | 25 +-- src/twinkle/loss/dpo.py | 19 +- src/twinkle/metric/__init__.py | 1 + src/twinkle/metric/dpo.py | 177 ++++++++++++++++++ src/twinkle/metric/loss.py | 4 +- .../model/transformers/transformers.py | 7 +- 6 files changed, 217 insertions(+), 16 deletions(-) create mode 100644 src/twinkle/metric/dpo.py diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo.py index 00311c63..7c5ffe8a 100644 --- a/cookbook/rl/dpo.py +++ b/cookbook/rl/dpo.py @@ -59,6 +59,7 @@ from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.loss import CPOLoss, DPOLoss, ORPOLoss, SimPOLoss +from twinkle.metric import DPOMetric from twinkle.model import TransformersModel from twinkle.preprocessor import EmojiDPOProcessor from twinkle.processor import InputProcessor @@ -66,7 +67,7 @@ logger = get_logger() # ── Configuration ───────────────────────────────────────────────────────────── -MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B') +MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen2.5-7B-Instruct') DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji') MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) @@ -75,12 +76,13 @@ BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) # Number of preference pairs MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 4)) -GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) +GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 8)) MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) -LEARNING_RATE = float(os.environ.get('LR', 5e-6)) +LEARNING_RATE = float(os.environ.get('LR', 5e-5)) DPO_BETA = float(os.environ.get('DPO_BETA', 0.1)) +SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 0.1)) # SFT loss weight for regularization LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo, simpo, orpo, cpo -SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 100)) +SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 200)) MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048)) ADAPTER_NAME = 'default' SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', 'You are a helpful assistant.') @@ -88,7 +90,7 @@ def create_dpo_dataset(): """Create DPO dataset with positive/negative format.""" - dataset = Dataset(DatasetMeta(DATASET_ID)) + dataset = Dataset(DatasetMeta(DATASET_ID, data_slice=range(15000))) dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH) dataset.map( EmojiDPOProcessor, @@ -134,7 +136,7 @@ def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: # ── Loss Factory ────────────────────────────────────────────────────────────── -def create_loss(loss_type: str, beta: float, reference_free: bool = False): +def create_loss(loss_type: str, beta: float, sft_weight: float = 0.0, reference_free: bool = False): """Create the appropriate loss function based on configuration.""" if loss_type == 'simpo': return SimPOLoss(beta=beta, gamma=0.5) @@ -148,6 +150,7 @@ def create_loss(loss_type: str, beta: float, reference_free: bool = False): beta=beta, loss_type=loss_type, reference_free=reference_free, + sft_weight=sft_weight, ) @@ -174,10 +177,7 @@ def main(): # ── Policy Model Setup ──────────────────────────────────────────────────── lora_config = LoraConfig( - target_modules=[ - 'q_proj', 'k_proj', 'v_proj', 'o_proj', - 'gate_proj', 'up_proj', 'down_proj', - ], + target_modules='all-linear', r=16, lora_alpha=32, lora_dropout=0.05, @@ -195,9 +195,10 @@ def main(): # Determine if we need reference model based on loss type reference_free = LOSS_TYPE in ['simpo', 'orpo', 'cpo'] - # Set up loss function - loss_fn = create_loss(LOSS_TYPE, DPO_BETA, reference_free=False) + # Set up loss function and metrics + loss_fn = create_loss(LOSS_TYPE, DPO_BETA, sft_weight=SFT_WEIGHT, reference_free=False) policy_model.set_loss(loss_fn) + policy_model.add_metric(DPOMetric, beta=DPO_BETA) policy_model.set_processor(InputProcessor) policy_model.set_template('Template', model_id=MODEL_ID) diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py index 52fc28b0..0ad69ed9 100644 --- a/src/twinkle/loss/dpo.py +++ b/src/twinkle/loss/dpo.py @@ -118,6 +118,7 @@ class DPOLoss(PreferenceLossBase): ignore_index: Index to ignore in labels (default: -100). loss_type: Type of DPO loss variant ('sigmoid', 'hinge', 'ipo', 'kto_pair') (default: 'sigmoid'). reference_free: Whether to use reference-free DPO (default: False). + sft_weight: Weight for SFT loss on chosen responses to prevent likelihood displacement (default: 0.0). """ def __init__( @@ -127,6 +128,7 @@ def __init__( ignore_index: int = -100, loss_type: str = 'sigmoid', reference_free: bool = False, + sft_weight: float = 0.0, **kwargs, ): super().__init__(ignore_index=ignore_index) @@ -134,6 +136,7 @@ def __init__( self.label_smoothing = label_smoothing self.loss_type = loss_type self.reference_free = reference_free + self.sft_weight = sft_weight def _align_logps( self, @@ -329,14 +332,26 @@ def __call__( ) # Compute DPO loss - loss = self._compute_dpo_loss( + dpo_loss = self._compute_dpo_loss( policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, ) - return LossOutput(loss=loss, num_tokens=0) + # Add SFT loss on chosen responses to prevent likelihood displacement + if self.sft_weight > 0: + sft_loss = self._compute_nll_loss(chosen_logps, chosen_labels) + loss = dpo_loss + self.sft_weight * sft_loss + else: + loss = dpo_loss + + # Return sample count for gradient normalization (not token count) + # DPO loss is already per-sample mean, so we just count samples for accumulation + import torch + num_samples = torch.tensor(chosen_labels.shape[0], device=loss.device) + + return LossOutput(loss=loss, num_tokens=num_samples) class SimPOLoss(PreferenceLossBase): diff --git a/src/twinkle/metric/__init__.py b/src/twinkle/metric/__init__.py index 739c7a0d..59d5bbeb 100644 --- a/src/twinkle/metric/__init__.py +++ b/src/twinkle/metric/__init__.py @@ -2,5 +2,6 @@ from .accuracy import Accuracy from .base import Metric from .completion_and_reward import CompletionRewardMetric +from .dpo import DPOMetric from .loss import LossMetric from .train_metric import TrainMetric diff --git a/src/twinkle/metric/dpo.py b/src/twinkle/metric/dpo.py new file mode 100644 index 00000000..c3f3e6cf --- /dev/null +++ b/src/twinkle/metric/dpo.py @@ -0,0 +1,177 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""DPO-specific metrics for preference optimization training.""" +from typing import List, Union + +from twinkle.data_format import InputFeature, ModelOutput +from .base import Metric + + +class DPOMetric(Metric): + """Metrics for DPO (Direct Preference Optimization) training. + + Computes TRL-style metrics: + - logps/chosen: Average sequence-level log prob of chosen responses + - logps/rejected: Average sequence-level log prob of rejected responses + - rewards/chosen: β * (policy_chosen - ref_chosen) + - rewards/rejected: β * (policy_rejected - ref_rejected) + - rewards/margins: chosen_reward - rejected_reward + - rewards/accuracies: Percentage where chosen_reward > rejected_reward + + Args: + device_mesh: The device mesh + process_group: The process group to collect data from + ignore_index: Label index to ignore (default: -100) + beta: DPO beta parameter for reward scaling (default: 0.1) + """ + + def __init__(self, device_mesh, process_group, ignore_index: int = -100, beta: float = 0.1, **kwargs): + super().__init__(device_mesh, process_group, **kwargs) + self.ignore_index = ignore_index + self.beta = beta + self.reset() + + def _compute_sequence_logps(self, per_token_logps, labels): + """Compute sequence-level log probs by summing valid token logps.""" + import torch + loss_mask = (labels != self.ignore_index).float() + return (per_token_logps * loss_mask).sum(dim=-1) + + def _split_chosen_rejected(self, tensor): + """Split interleaved tensor into chosen and rejected. + + Input format: [pos_1, neg_1, pos_2, neg_2, ...] (interleaved for DP-safe slicing) + Output: (chosen [pos_1, pos_2, ...], rejected [neg_1, neg_2, ...]) + """ + return tensor[0::2], tensor[1::2] + + def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: ModelOutput, **kwargs): + """Accumulate DPO metrics from model outputs. + + Expects: + - outputs['logps']: [batch, seq_len] per-token log probabilities + - inputs['labels']: [batch, seq_len] labels with ignore_index for non-target tokens + - kwargs['ref_outputs']: Optional reference model outputs with 'logps' + """ + import torch + + logps = outputs.get('logps') + if logps is None: + return + + # Get labels from inputs + if isinstance(inputs, list): + # Stack labels from list of inputs + labels_list = [torch.as_tensor(inp['labels']) for inp in inputs] + max_len = max(l.shape[0] for l in labels_list) + padded = [] + for l in labels_list: + if l.shape[0] < max_len: + pad = torch.full((max_len - l.shape[0],), self.ignore_index, dtype=l.dtype) + l = torch.cat([pad, l]) + padded.append(l) + labels = torch.stack(padded) + else: + labels = torch.as_tensor(inputs['labels']) + if labels.dim() == 1: + labels = labels.unsqueeze(0) + + # Ensure logps and labels have same device + if logps.device != labels.device: + labels = labels.to(logps.device) + + # Align sequence lengths if needed (truncate right) + if logps.shape[1] != labels.shape[1]: + min_len = min(logps.shape[1], labels.shape[1]) + logps = logps[:, :min_len] + labels = labels[:, :min_len] + + # Compute sequence-level logps + seq_logps = self._compute_sequence_logps(logps, labels) + + # Split into chosen and rejected (interleaved format) + chosen_logps, rejected_logps = self._split_chosen_rejected(seq_logps) + chosen_labels, rejected_labels = self._split_chosen_rejected(labels) + + # Accumulate policy logps + self.total_chosen_logps += chosen_logps.sum().item() + self.total_rejected_logps += rejected_logps.sum().item() + + # Compute rewards if ref_outputs available + ref_outputs = kwargs.get('ref_outputs') + if ref_outputs is not None: + ref_logps = ref_outputs.get('logps') + if ref_logps is not None: + # Align ref_logps + if ref_logps.device != labels.device: + ref_logps = ref_logps.to(labels.device) + if ref_logps.shape[1] != labels.shape[1]: + min_len = min(ref_logps.shape[1], labels.shape[1]) + ref_logps = ref_logps[:, :min_len] + + ref_seq_logps = self._compute_sequence_logps(ref_logps, labels) + ref_chosen_logps, ref_rejected_logps = self._split_chosen_rejected(ref_seq_logps) + + # Compute rewards: β * (policy - ref) + chosen_rewards = self.beta * (chosen_logps - ref_chosen_logps) + rejected_rewards = self.beta * (rejected_logps - ref_rejected_logps) + + self.total_chosen_rewards += chosen_rewards.sum().item() + self.total_rejected_rewards += rejected_rewards.sum().item() + margins = chosen_rewards - rejected_rewards + self.total_reward_margin += margins.sum().item() + self.total_reward_correct += (margins > 0).sum().item() + self.has_rewards = True + + self.total_count += chosen_logps.shape[0] + + def reset(self): + """Reset all accumulated values.""" + self.total_chosen_logps = 0.0 + self.total_rejected_logps = 0.0 + self.total_chosen_rewards = 0.0 + self.total_rejected_rewards = 0.0 + self.total_reward_margin = 0.0 + self.total_reward_correct = 0 + self.total_count = 0 + self.has_rewards = False + + def calculate(self): + """Calculate and return aggregated metrics.""" + local_results = [{ + 'chosen_logps': self.total_chosen_logps, + 'rejected_logps': self.total_rejected_logps, + 'chosen_rewards': self.total_chosen_rewards, + 'rejected_rewards': self.total_rejected_rewards, + 'reward_margin': self.total_reward_margin, + 'reward_correct': self.total_reward_correct, + 'count': self.total_count, + 'has_rewards': self.has_rewards, + }] + all_results = self.gather_results(local_results) + + total_chosen_logps = sum(r['chosen_logps'] for r in all_results) + total_rejected_logps = sum(r['rejected_logps'] for r in all_results) + total_chosen_rewards = sum(r['chosen_rewards'] for r in all_results) + total_rejected_rewards = sum(r['rejected_rewards'] for r in all_results) + total_reward_margin = sum(r['reward_margin'] for r in all_results) + total_reward_correct = sum(r['reward_correct'] for r in all_results) + total_count = sum(r['count'] for r in all_results) + has_rewards = any(r['has_rewards'] for r in all_results) + + self.reset() + + if total_count == 0: + return {} + + results = { + 'logps/chosen': f'{total_chosen_logps / total_count:.2f}', + 'logps/rejected': f'{total_rejected_logps / total_count:.2f}', + } + + if has_rewards: + results['rewards/chosen'] = f'{total_chosen_rewards / total_count:.4f}' + results['rewards/rejected'] = f'{total_rejected_rewards / total_count:.4f}' + results['rewards/margins'] = f'{total_reward_margin / total_count:.4f}' + results['rewards/accuracies'] = f'{total_reward_correct / total_count * 100:.1f}%' + + return results diff --git a/src/twinkle/metric/loss.py b/src/twinkle/metric/loss.py index 7c4a8b93..8f4ad0c9 100644 --- a/src/twinkle/metric/loss.py +++ b/src/twinkle/metric/loss.py @@ -60,8 +60,10 @@ def calculate(self): num_tokens = sum(r['num_tokens'] for r in all_results) if num_tokens > 0: avg_loss = total_loss / num_tokens - else: + elif total_count > 0: avg_loss = total_loss / total_count + else: + avg_loss = 0.0 self.reset() results = {} if avg_loss is not None: diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 520aaf9f..cdafbf59 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -121,6 +121,8 @@ def accumulate_metrics(self, is_training): metrics = self.train_metrics else: metrics = self.eval_metrics + # Get stored forward_kwargs from previous forward + forward_kwargs = getattr(self, 'forward_kwargs', None) or {} if len(metrics) > 0 and self.inputs is not None and self.outputs is not None: for metric in metrics: metric.accumulate( @@ -130,7 +132,8 @@ def accumulate_metrics(self, is_training): step=self.cur_step - 1, gradient_accumulation_steps=self.gradient_accumulation_steps, grad_norm=self._last_grad_norm, - loss_reduction=getattr(self.loss_instance, 'reduction', 'mean')) + loss_reduction=getattr(self.loss_instance, 'reduction', 'mean'), + **forward_kwargs) def calculate_metrics(self, is_training): self.accumulate_metrics(is_training) @@ -405,6 +408,7 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec inputs['labels'] = labels optimizer_config.inputs = inputs optimizer_config.outputs = outputs + optimizer_config.forward_kwargs = kwargs # Store for next metric accumulation optimizer_config.loss_value = outputs.get('aux_loss', 0) if labels is not None: loss_mask = (labels != -100).bool() @@ -1086,6 +1090,7 @@ def set_grad_scaler(self, **kwargs): grad_scaler_config.update(kwargs) optimizer_config.scaler = GradScaler(**grad_scaler_config) + @remote_function() def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool] = None, **kwargs): """Add an eval metric From 8c662f047519342f0c8eafed205557b97bd0c69a Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 27 Mar 2026 22:03:07 +0800 Subject: [PATCH 12/31] wip --- cookbook/rl/dpo.py | 10 +++++----- src/twinkle/loss/dpo.py | 10 ++++------ 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo.py index 7c5ffe8a..f3454c2f 100644 --- a/cookbook/rl/dpo.py +++ b/cookbook/rl/dpo.py @@ -76,11 +76,10 @@ BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) # Number of preference pairs MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 4)) -GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 8)) -MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) -LEARNING_RATE = float(os.environ.get('LR', 5e-5)) +GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 4)) +LEARNING_RATE = float(os.environ.get('LR', 5e-6)) # TRL default for DPO is 5e-7 to 5e-6 DPO_BETA = float(os.environ.get('DPO_BETA', 0.1)) -SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 0.1)) # SFT loss weight for regularization +SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 1.0)) # SFT loss weight for regularization LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo, simpo, orpo, cpo SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 200)) MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048)) @@ -90,7 +89,7 @@ def create_dpo_dataset(): """Create DPO dataset with positive/negative format.""" - dataset = Dataset(DatasetMeta(DATASET_ID, data_slice=range(15000))) + dataset = Dataset(DatasetMeta(DATASET_ID)) dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH) dataset.map( EmojiDPOProcessor, @@ -188,6 +187,7 @@ def main(): device_mesh=policy_mesh, remote_group='policy', ) + MAX_STEPS = len(dataloader) policy_model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) policy_model.set_optimizer('AdamW', lr=LEARNING_RATE, weight_decay=0.01) policy_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=LEARNING_RATE * 0.1) diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py index 0ad69ed9..ce533053 100644 --- a/src/twinkle/loss/dpo.py +++ b/src/twinkle/loss/dpo.py @@ -346,12 +346,10 @@ def __call__( else: loss = dpo_loss - # Return sample count for gradient normalization (not token count) - # DPO loss is already per-sample mean, so we just count samples for accumulation - import torch - num_samples = torch.tensor(chosen_labels.shape[0], device=loss.device) - - return LossOutput(loss=loss, num_tokens=num_samples) + # Return 0 to skip gradient normalization by num_tokens + # DPO loss is already per-sample mean, unlike SFT which sums per-token loss + # When num_tokens=0, normalize_and_clip_grad_norm defaults to 1 (no division) + return LossOutput(loss=loss, num_tokens=0) class SimPOLoss(PreferenceLossBase): From aa860993c493b80e49d3339d32057b91151d8a3f Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 27 Mar 2026 23:38:48 +0800 Subject: [PATCH 13/31] wip --- cookbook/transformers/fsdp2.py | 2 +- src/twinkle/model/megatron/megatron.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index 10d75df6..a8c8bb87 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -10,7 +10,7 @@ from twinkle.preprocessor import SelfCognitionProcessor # Construct a device_mesh, dp=2 -device_mesh = DeviceMesh.from_sizes(dp_size=2) +device_mesh = DeviceMesh.from_sizes(dp_size=8) # use torchrun mode twinkle.initialize(mode='local', global_device_mesh=device_mesh) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 4b37973c..ae2ae59f 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -245,7 +245,7 @@ def __init__( def _construct_default_optimizer_group(self): return MegatronOptimizerGroup( - loss_instance=CrossEntropyLoss(), + loss_instance=CrossEntropyLoss(reduction='sum'), template=Template(self.tokenizer_id), processor=InputProcessor(self.device_mesh, framework='megatron'), _device_mesh=self.device_mesh, From 6bdaaca88c924b564ce12317e136207ecea6e12c Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 28 Mar 2026 00:11:18 +0800 Subject: [PATCH 14/31] wip --- src/twinkle/model/megatron/megatron.py | 6 ++++++ src/twinkle/model/transformers/transformers.py | 13 ++++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index ae2ae59f..7c131210 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -479,6 +479,12 @@ def post_loss_function(output_tensor, inputs, logps): losses = result['loss'] counts = result['num_tokens'] if not counts: + # Later will gather this value, so it becomes: + # 1. SUM loss: gather_sum(local_num_tokens) = global_num_tokens + # 2. PER TOKEN MEAN loss: gather_sum(1 * gradient_accumulation_steps ) = gradient_accumulation_steps * world_size + # Then, grad will divided by this value: + # 1. SUM loss: (global_sum_grad) / (global_num_tokens) = global_sum_grad/global_num_tokens + # 2. PER TOKEN MEAN loss: (gather_sum(per_token_grad * gradient_accumulation_steps)) / (gradient_accumulation_steps * world_size ) = avg_per_token_grad counts = torch.tensor(1, device=losses.device) return self.strategy.reduce_loss(losses, counts, output_tensor, logps) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index cdafbf59..a6be74e8 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -500,7 +500,18 @@ def calculate_loss(self, **kwargs): loss_value = result['loss'] counts = result['num_tokens'] if not counts: - counts = torch.tensor(0, device=loss_value.device) + counts = torch.tensor(1, device=loss_value.device) + # Later will gather this value, so it becomes: + # 1. SUM loss: gather_sum(local_num_tokens / dp_world_size) = global_num_tokens / dp_world_size + # 2. PER TOKEN MEAN loss: gather_sum(1 * gradient_accumulation_steps / dp_world_size ) = gradient_accumulation_steps + # Then, grad will divided by this value: + # 1. SUM loss: gather_mean(local_sum_grad) / (global_num_tokens / dp_world_size) + # = (global_sum_grad / dp_world_size) / (global_num_tokens / dp_world_size) + # = global_sum_grad/global_num_tokens + # 2. PER TOKEN MEAN loss: gather_mean(per_token_grad * gradient_accumulation_steps) / gradient_accumulation_steps + # = (global_per_token_grad * gradient_accumulation_steps / dp_world_size ) / gradient_accumulation_steps + # = global_per_token_grad / dp_world_size = avg_per_token_grad + counts = counts / self.device_mesh.data_world_size optimizer_config = self.optimizer_group[adapter_name] optimizer_config.num_tokens += counts.item() if self.sp_strategy is not None and 'labels' in inputs: From c9b4f2841a2a8bd1067fdb552ed4f2f7203671b7 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 29 Mar 2026 10:20:40 +0800 Subject: [PATCH 15/31] wip --- cookbook/rl/dpo.py | 17 +- cookbook/rl/dpo_lora.py | 222 ++++++++++++++++++ cookbook/transformers/fsdp2.py | 56 ++++- src/twinkle/metric/dpo.py | 39 ++- src/twinkle/model/megatron/megatron.py | 59 +++-- .../model/megatron/multi_lora_megatron.py | 8 +- src/twinkle/model/multi_lora.py | 26 +- .../transformers/multi_lora_transformers.py | 8 +- .../model/transformers/transformers.py | 50 ++-- 9 files changed, 413 insertions(+), 72 deletions(-) create mode 100644 cookbook/rl/dpo_lora.py diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo.py index f3454c2f..cd86653b 100644 --- a/cookbook/rl/dpo.py +++ b/cookbook/rl/dpo.py @@ -74,14 +74,14 @@ REF_MODEL_GPUS = int(os.environ.get('REF_MODEL_GPUS', 4)) NUM_GPUS = MODEL_GPUS + REF_MODEL_GPUS -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) # Number of preference pairs -MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 4)) -GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 4)) -LEARNING_RATE = float(os.environ.get('LR', 5e-6)) # TRL default for DPO is 5e-7 to 5e-6 +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs +MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 8)) +GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 2)) +LEARNING_RATE = float(os.environ.get('LR', 1e-4)) # TRL default for DPO is 5e-7 to 5e-6 DPO_BETA = float(os.environ.get('DPO_BETA', 0.1)) SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 1.0)) # SFT loss weight for regularization LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo, simpo, orpo, cpo -SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 200)) +SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 100)) MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048)) ADAPTER_NAME = 'default' SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', 'You are a helpful assistant.') @@ -89,7 +89,7 @@ def create_dpo_dataset(): """Create DPO dataset with positive/negative format.""" - dataset = Dataset(DatasetMeta(DATASET_ID)) + dataset = Dataset(DatasetMeta(DATASET_ID, data_slice=range(30000))) dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH) dataset.map( EmojiDPOProcessor, @@ -223,9 +223,6 @@ def main(): # ── Training Loop ───────────────────────────────────────────────────────── for batch in dataloader: - if optim_step >= MAX_STEPS: - break - # batch is List[Dict] with 'positive' and 'negative' keys dpo_batch = prepare_dpo_batch(batch) @@ -246,7 +243,7 @@ def main(): optim_step += 1 # Logging - if optim_step % 10 == 0: + if optim_step % 1 == 0: metrics = policy_model.calculate_metric(is_training=True) logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metrics}') diff --git a/cookbook/rl/dpo_lora.py b/cookbook/rl/dpo_lora.py new file mode 100644 index 00000000..47f80e83 --- /dev/null +++ b/cookbook/rl/dpo_lora.py @@ -0,0 +1,222 @@ +"""DPO (Direct Preference Optimization) Training with LoRA (Single GPU Group). + +LoRA-based DPO training: uses the base model (without LoRA adapter) as reference +model by calling forward_only with adapter_name=''. This eliminates the need for +a separate reference model GPU group. + +Pipeline: + 1. Load preference dataset with chosen/rejected pairs. + 2. Encode positive and negative separately. + 3. Compute reference model log probabilities using base model (adapter_name=''). + 4. Train policy model (with LoRA adapter) using DPO loss. + +Architecture (Ray - Single Group): + ┌─────────────────────────────────────────────────────────────────┐ + │ Driver (CPU) │ + │ dataloader ──► batched preference pairs │ + │ policy_model.forward_only(adapter_name='') ──► reference logps│ + │ policy_model.forward_backward() ──► DPO loss + gradient │ + └─────────────────────────────────────────────────────────────────┘ + │ + PolicyModel (with LoRA adapter) + - forward_only(adapter_name='') → base model inference (reference) + - forward_backward() → LoRA adapter training (policy) + +DPO data format (after preprocessing): + - positive: List[Trajectory] - chosen responses + - negative: List[Trajectory] - rejected responses + +Environment variables (all optional): + MODEL_ID – (default: ms://Qwen/Qwen2.5-7B-Instruct) + DATASET_ID – (default: ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji) + MODEL_GPUS – GPUs for policy model (default: 4) + BATCH_SIZE – global batch size (preference pairs) (default: 8) + MICRO_BATCH_SIZE – per-device micro batch size (default: 2) + MAX_STEPS – total optimization steps (default: 1000) + LR – learning rate (default: 1e-4) + DPO_BETA – DPO temperature parameter (default: 0.1) + LOSS_TYPE – DPO variant (sigmoid/hinge/ipo) (default: sigmoid) + SAVE_STEPS – checkpoint save interval (default: 100) + MAX_LENGTH – max sequence length (default: 2048) + + Dataset field mapping (for custom datasets): + PROMPT_KEY – key for prompt field (default: 'prompt') + CHOSEN_KEY – key for chosen response (default: 'answer_zh') + REJECTED_KEY – key for rejected response (default: 'answer_en') + SYSTEM_PROMPT – system prompt to prepend (default: 'You are a helpful assistant.') +""" + +import os +from typing import Any, Dict, List, Optional + +from peft import LoraConfig + +import twinkle +from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger +from twinkle.data_format import Trajectory +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.loss import DPOLoss +from twinkle.metric import DPOMetric +from twinkle.model import TransformersModel +from twinkle.preprocessor import EmojiDPOProcessor +from twinkle.processor import InputProcessor + +logger = get_logger() + +# ── Configuration ───────────────────────────────────────────────────────────── +MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen2.5-7B-Instruct') +DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji') + +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) + +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs +MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 8)) +GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 2)) +LEARNING_RATE = float(os.environ.get('LR', 1e-4)) # LoRA DPO requires higher LR (1e-4 to 3e-4) +DPO_BETA = float(os.environ.get('DPO_BETA', 0.1)) +SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 1.0)) # SFT loss weight for regularization +LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo +SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 100)) +MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048)) +ADAPTER_NAME = 'default' +SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', 'You are a helpful assistant.') + + +def create_dpo_dataset(): + """Create DPO dataset with positive/negative format.""" + dataset = Dataset(DatasetMeta(DATASET_ID, data_slice=range(30000))) + dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH) + dataset.map( + EmojiDPOProcessor, + init_args={ + 'system': SYSTEM_PROMPT, + } + ) + # DPO preprocessor returns {'positive': [...], 'negative': [...]} + # batch_encode handles this format automatically + dataset.encode(load_from_cache_file=True) + return dataset + + +def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Prepare DPO batch: reorganize batch for training with DP-safe interleaving. + + Args: + batch: List of rows, each with 'positive' and 'negative' InputFeatures + and other fields (question, etc.) + + Returns: + List interleaved as [pos_1, neg_1, pos_2, neg_2, ...] to ensure each DP + worker gets complete positive/negative pairs after slicing. + Each item contains all original fields plus the InputFeature fields. + """ + result = [] + + for row in batch: + # Get base fields (excluding positive/negative) + base_fields = {k: v for k, v in row.items() if k not in ('positive', 'negative')} + + # Positive sample: merge base fields with positive InputFeature + pos_sample = {**base_fields, **row['positive']} + # Negative sample: merge base fields with negative InputFeature + neg_sample = {**base_fields, **row['negative']} + + # Interleave: [pos, neg] per pair for DP-safe slicing + result.append(pos_sample) + result.append(neg_sample) + + return result + + +# ── Main Training Loop ──────────────────────────────────────────────────────── + +def main(): + # Set up device groups - only one group for LoRA training + device_groups = [ + DeviceGroup(name='policy', ranks=list(range(MODEL_GPUS)), device_type='GPU'), + ] + + policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) + twinkle.initialize(mode='ray', nproc_per_node=8, groups=device_groups) + + # ── DataLoader Setup ────────────────────────────────────────────────────── + dataloader = DataLoader( + dataset=create_dpo_dataset, + batch_size=BATCH_SIZE, + min_batch_size=BATCH_SIZE, + device_mesh=policy_mesh, + ) + length = len(dataloader) + + # ── Policy Model Setup with LoRA ────────────────────────────────────────── + lora_config = LoraConfig( + target_modules='all-linear', + r=16, + lora_alpha=32, + lora_dropout=0.05, + ) + + policy_model = TransformersModel( + model_id=MODEL_ID, + device_mesh=policy_mesh, + remote_group='policy', + ) + MAX_STEPS = len(dataloader) + policy_model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) + policy_model.set_optimizer('AdamW', lr=LEARNING_RATE, weight_decay=0.01) + policy_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=LEARNING_RATE * 0.1) + + # Set up loss function and metrics + loss_fn = DPOLoss( + beta=DPO_BETA, + loss_type=LOSS_TYPE, + reference_free=False, # We use base model as reference via disable_lora=True + sft_weight=SFT_WEIGHT, + ) + policy_model.set_loss(loss_fn) + policy_model.add_metric(DPOMetric, beta=DPO_BETA) + policy_model.set_processor(InputProcessor) + policy_model.set_template('Template', model_id=MODEL_ID) + + optim_step = 0 + logger.info(get_device_placement()) + logger.info(f'Starting LoRA DPO training: loss_type={LOSS_TYPE}, beta={DPO_BETA}, lr={LEARNING_RATE}') + logger.info(f'Using base model (disable_lora=True) as reference model') + + # ── Training Loop ───────────────────────────────────────────────────────── + for batch in dataloader: + # batch is List[Dict] with 'positive' and 'negative' keys + dpo_batch = prepare_dpo_batch(batch) + + # Get reference outputs using base model (without LoRA adapter) + # disable_lora=True tells the model to skip LoRA and use base weights + ref_outputs = policy_model.forward_only(inputs=dpo_batch, disable_lora=True) + + # Forward-backward pass with DPO loss (using LoRA adapter) + # ref_outputs is passed to loss which extracts logps internally + policy_model.forward_backward( + inputs=dpo_batch, + ref_outputs=ref_outputs, + ) + + # Gradient clipping and optimizer step + policy_model.clip_grad_and_step() + optim_step += 1 + + # Logging + if optim_step % 1 == 0: + metrics = policy_model.calculate_metric(is_training=True) + logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metrics}') + + # Checkpointing + if optim_step % SAVE_STEPS == 0: + policy_model.save(f'dpo-lora-checkpoint-{optim_step}') + + # ── Save Final Checkpoint ───────────────────────────────────────────────── + logger.info(f'Training completed. Total steps: {optim_step}') + policy_model.save('dpo-lora-final-checkpoint') + + +if __name__ == '__main__': + main() diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index a8c8bb87..9aa94137 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -7,7 +7,8 @@ from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import TransformersModel -from twinkle.preprocessor import SelfCognitionProcessor +from twinkle.data_format import Message, Trajectory +from twinkle.preprocessor import SelfCognitionProcessor, Preprocessor # Construct a device_mesh, dp=2 device_mesh = DeviceMesh.from_sizes(dp_size=8) @@ -20,7 +21,7 @@ def eval(model): # 100 Samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100))) - dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') + dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct') dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) dataset.encode() dataloader = DataLoader(dataset=dataset, batch_size=8) @@ -31,19 +32,55 @@ def eval(model): return metrics +class EmojiDPOProcessor(Preprocessor): + def __init__( + self, + system = 'You are a helpful assistant.', + chosen_key: str = 'answer_zh', + rejected_key: str = 'answer_en', + prompt_key: str = 'prompt', + ): + self.system = system + self.chosen_key = chosen_key + self.rejected_key = rejected_key + self.prompt_key = prompt_key + + def __call__(self, rows): + rows = self.map_col_to_row(rows) + rows = [self.preprocess(row) for row in rows] + rows = self.map_row_to_col(rows) + return rows + + def preprocess(self, row): + """Process a single row.""" + prompt = row.get(self.prompt_key, '') + chosen = row.get(self.chosen_key, '') + rejected = row.get(self.rejected_key, '') + + prompt_messages = [] + if self.system: + prompt_messages.append(Message(role='system', content=self.system)) + prompt_messages.append(Message(role='user', content=prompt)) + + chosen_messages = prompt_messages + [Message(role='assistant', content=chosen)] + rejected_messages = prompt_messages + [Message(role='assistant', content=rejected)] + + return Trajectory(messages=chosen_messages) + + def train(): # 1000 samples - dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) + dataset = Dataset(dataset_meta=DatasetMeta('ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji')) # Set template to prepare encoding - dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') + dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct') # Preprocess the dataset to standard format - dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) + dataset.map(EmojiDPOProcessor) # Encode dataset dataset.encode() # Global batch size = 8, for GPUs, so 1 sample per GPU dataloader = DataLoader(dataset=dataset, batch_size=8) # Use a TransformersModel - model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B') + model = TransformersModel(model_id='ms://Qwen/Qwen2.5-7B-Instruct') model.model._no_split_modules = {'Qwen3_5DecoderLayer'} lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') @@ -72,13 +109,6 @@ def train(): # Print metric metric = model.calculate_metric(is_training=True) logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') - if step > 0 and step % 40 == 0: - metrics = eval(model) - logger.info(f'Eval metric: {metrics}') - metrics['step'] = step - if loss_metric > float(metrics['loss']): - model.save(f'checkpoint-{step}') - loss_metric = float(metrics['loss']) model.save(f'last-checkpoint') diff --git a/src/twinkle/metric/dpo.py b/src/twinkle/metric/dpo.py index c3f3e6cf..ceb3c651 100644 --- a/src/twinkle/metric/dpo.py +++ b/src/twinkle/metric/dpo.py @@ -36,6 +36,34 @@ def _compute_sequence_logps(self, per_token_logps, labels): loss_mask = (labels != self.ignore_index).float() return (per_token_logps * loss_mask).sum(dim=-1) + def _align_logps(self, logps, target_shape, device, dtype): + """Align per-token logps to target shape by padding or truncating. + + Args: + logps: [batch, seq_len] tensor to align + target_shape: Target shape (batch, target_seq_len) + device: Target device + dtype: Target dtype + + Returns: + Aligned tensor with shape matching target_shape + """ + import torch + logps = logps.to(device=device, dtype=dtype) + batch_size, src_len = logps.shape + _, target_len = target_shape + + if src_len == target_len: + return logps + elif src_len < target_len: + breakpoint() + raise ValueError( + f'ref_logps seq_len ({src_len}) < target seq_len ({target_len}). ' + f'This should not happen when both models process the same batch.' + ) + else: + return logps[:, :target_len] + def _split_chosen_rejected(self, tensor): """Split interleaved tensor into chosen and rejected. @@ -101,12 +129,11 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M if ref_outputs is not None: ref_logps = ref_outputs.get('logps') if ref_logps is not None: - # Align ref_logps - if ref_logps.device != labels.device: - ref_logps = ref_logps.to(labels.device) - if ref_logps.shape[1] != labels.shape[1]: - min_len = min(ref_logps.shape[1], labels.shape[1]) - ref_logps = ref_logps[:, :min_len] + # Align ref_logps to match labels shape (handles different seq lengths) + # breakpoint() + ref_logps = self._align_logps( + ref_logps, labels.shape, labels.device, logps.dtype + ) ref_seq_logps = self._compute_sequence_logps(ref_logps, labels) ref_chosen_logps, ref_rejected_logps = self._split_chosen_rejected(ref_seq_logps) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 7c131210..bc195e59 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -51,6 +51,9 @@ class MegatronOptimizerGroup: outputs: ModelOutput = None loss_instance: Loss = None loss_value: Any = None + eval_inputs: List[InputFeature] = None + eval_outputs: ModelOutput = None + eval_loss_value: Any = None template: Template = None processor: InputProcessor = None gradient_accumulation_steps: int = 1 @@ -93,13 +96,17 @@ def _get_lr(self): def accumulate_metrics(self, is_training): if is_training: metrics = self.train_metrics + inputs = self.inputs + outputs = self.outputs else: metrics = self.eval_metrics - if len(metrics) > 0 and self.inputs is not None and self.outputs is not None: + inputs = self.eval_inputs + outputs = self.eval_outputs + if len(metrics) > 0 and inputs is not None and outputs is not None: for metric in metrics: metric.accumulate( - self.inputs, - self.outputs, + inputs, + outputs, lr=self._get_lr(), step=self.cur_step - 1, gradient_accumulation_steps=self.gradient_accumulation_steps, @@ -405,6 +412,7 @@ def forward_backward(self, from megatron.core.pipeline_parallel import get_forward_backward_func adapter_name = kwargs.pop('adapter_name', self._get_default_group()) + disable_lora = kwargs.pop('disable_lora', False) temperature = float(kwargs.pop('temperature', 1.0)) forward_only = kwargs.pop('forward_only', False) return_logits = kwargs.pop('return_logits', False) @@ -521,17 +529,33 @@ def forward_step_func(data_iterator, model): self._accumulate_metric(optimizer_config, is_training=not forward_only) - # Run forward-backward with Megatron's scheduler - # Megatron handles all communication internally using proper process groups - losses = forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=data_iter, - model=self.model, - num_microbatches=len(inputs), - seq_length=seq_length, - micro_batch_size=micro_batch_size, - forward_only=forward_only, - ) + # Handle disable_lora for base model inference (e.g., reference in DPO) + def _set_disable_adapters(model, value: bool): + if isinstance(model, list): + for m in model: + if isinstance(m, PeftModel): + m.disable_adapters = value + elif isinstance(model, PeftModel): + model.disable_adapters = value + + if disable_lora: + _set_disable_adapters(self.model, True) + + try: + # Run forward-backward with Megatron's scheduler + # Megatron handles all communication internally using proper process groups + losses = forward_backward_func( + forward_step_func=forward_step_func, + data_iterator=data_iter, + model=self.model, + num_microbatches=len(inputs), + seq_length=seq_length, + micro_batch_size=micro_batch_size, + forward_only=forward_only, + ) + finally: + if disable_lora: + _set_disable_adapters(self.model, False) # Extract loss from results (only last PP stage returns non-empty) loss = torch.tensor(0.0).to(Platform.get_local_device()) @@ -577,7 +601,6 @@ def forward_step_func(data_iterator, model): dp_cp_group = mpu.get_data_parallel_group(with_context_parallel=True) torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG, group=dp_cp_group) - optimizer_config.inputs = inputs if logps and len({_logps.shape[1] for _logps in logps}) == 1: logps = torch.cat(logps, dim=0) if logits and len({_logits.shape[1] for _logits in logits}) == 1: @@ -586,7 +609,11 @@ def forward_step_func(data_iterator, model): loss = loss.detach().cpu().float().numpy() if not return_logits: logits = None - if not forward_only: + if forward_only: + optimizer_config.eval_inputs = inputs + optimizer_config.eval_outputs = ModelOutput(logits=logits, loss=loss, logps=logps) + else: + optimizer_config.inputs = inputs optimizer_config.outputs = ModelOutput(logits=logits, loss=loss, logps=logps) return ModelOutput(logits=logits, loss=loss, logps=logps) diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index b674ef46..50afbdfd 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -118,13 +118,15 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T Args: inputs: Model inputs. - **kwargs: Additional arguments. + **kwargs: Additional arguments including disable_lora. Returns: Model outputs. """ - self._check_adapter_valid(kwargs.get('adapter_name')) - with self.multi_adapter.adapter(kwargs.get('adapter_name')): + adapter_name = kwargs.get('adapter_name') + disable_lora = kwargs.get('disable_lora', False) + self._check_adapter_valid(adapter_name) + with self.multi_adapter.adapter(adapter_name, disable_lora=disable_lora): return super().forward_only(inputs=inputs, **kwargs) @remote_function(dispatch='slice_dp', collect='mean', sync=True) diff --git a/src/twinkle/model/multi_lora.py b/src/twinkle/model/multi_lora.py index 9780973b..75f74cc1 100644 --- a/src/twinkle/model/multi_lora.py +++ b/src/twinkle/model/multi_lora.py @@ -45,17 +45,21 @@ def _get_available_lora(self) -> Optional[LoraTenant]: def _count_available_loras(self): return len([_lora for _lora in self.loras if _lora.tenant_adapter_name is None]) - def activate_adapter(self, tenant_adapter_name: str): + def activate_adapter(self, tenant_adapter_name: str, call_enable=False): if not self.has_lora(tenant_adapter_name): raise ValueError(f'Adapter {tenant_adapter_name} does not exist') adapter_name = self.find_lora_by_tenant(tenant_adapter_name).adapter_name if isinstance(self.module, list): for _module in self.module: - # _module.enable_adapter_layers() + if call_enable: + # This will cost time + _module.enable_adapter_layers() if _module.active_adapter != adapter_name: _module.set_adapter(adapter_name) else: - # self.module.enable_adapter_layers() + if call_enable: + # This will cost time + self.module.enable_adapter_layers() if self.module.active_adapter != adapter_name: self.module.set_adapter(adapter_name) @@ -67,10 +71,20 @@ def deactivate_adapter(self): self.module.disable_adapter_layers() @contextmanager - def adapter(self, tenant_adapter_name: str): + def adapter(self, tenant_adapter_name: str, disable_lora: bool = False): self.activate_adapter(tenant_adapter_name) - yield self.find_lora_by_tenant(tenant_adapter_name).adapter_name - # self.deactivate_adapter() + if disable_lora: + # Temporarily disable all adapters while keeping optimizer_group active + with self._disable_lora_context(tenant_adapter_name): + yield self.find_lora_by_tenant(tenant_adapter_name).adapter_name + else: + yield self.find_lora_by_tenant(tenant_adapter_name).adapter_name + + @contextmanager + def _disable_lora_context(self, tenant_adapter_name): + self.deactivate_adapter() + yield + self.activate_adapter(tenant_adapter_name, call_enable=True) @contextmanager def save_context(self, tenant_adapter_name: str): diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index 0900f52b..34e81e83 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -106,8 +106,10 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, @remote_function(dispatch='slice_dp', collect='flatten') def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajectory]], **kwargs): - self._check_adapter_valid(kwargs.get('adapter_name')) - optimizer_config = self.optimizer_group[kwargs.get('adapter_name')] + adapter_name = kwargs.get('adapter_name') + disable_lora = kwargs.get('disable_lora', False) + self._check_adapter_valid(adapter_name) + optimizer_config = self.optimizer_group[adapter_name] if (isinstance(inputs, dict) and self._not_encoded(inputs)) or (isinstance(inputs, list) and self._not_encoded(inputs[0])): # Trajectory or List[Trajectory] @@ -117,7 +119,7 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T inputs = [inputs] inputs = optimizer_config.template.batch_encode(inputs) # noqa self.multi_adapter.check_length(inputs) - with self.multi_adapter.adapter(kwargs.get('adapter_name')): + with self.multi_adapter.adapter(adapter_name, disable_lora=disable_lora): return super().forward_only(inputs=inputs, **kwargs) @remote_function() diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index a6be74e8..bc919fbd 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -41,16 +41,26 @@ from twinkle.utils.grad_clip import normalize_and_clip_grad_norm +@dataclass +class TrainStatus: + inputs: List[InputFeature] = None + outputs: ModelOutput = None + loss_value: Any = None + num_tokens: int = 0 + metrics: List[Metric] = field(default_factory=list) + + @dataclass class OptimizerGroup: adapter_name: str = None adapter_config: PeftConfig = None optimizer: Optimizer = None lr_scheduler: LRScheduler = None - inputs: List[InputFeature] = None - outputs: ModelOutput = None + + eval_inputs: List[InputFeature] = None + eval_outputs: ModelOutput = None + eval_loss_value: Any = None loss_instance: Loss = CrossEntropyLoss - loss_value: Any = None template: Template = None processor: InputProcessor = None scaler: GradScaler = None @@ -58,8 +68,8 @@ class OptimizerGroup: scaler_has_nan: bool = False gradient_accumulation_steps: int = 1 cur_step: int = 0 - num_tokens: int = 0 - train_metrics: List[Metric] = field(default_factory=list) + + train_ eval_metrics: List[Metric] = field(default_factory=list) checkpoint_engine: CheckpointEngine = None _dp_group = None @@ -119,15 +129,19 @@ def accumulate_metrics(self, is_training): self._ensure_dp_group() if is_training: metrics = self.train_metrics + inputs = self.inputs + outputs = self.outputs else: metrics = self.eval_metrics + inputs = self.eval_inputs + outputs = self.eval_outputs # Get stored forward_kwargs from previous forward forward_kwargs = getattr(self, 'forward_kwargs', None) or {} - if len(metrics) > 0 and self.inputs is not None and self.outputs is not None: + if len(metrics) > 0 and inputs is not None and outputs is not None: for metric in metrics: metric.accumulate( - self.inputs, - self.outputs, + inputs, + outputs, lr=self._get_lr(), step=self.cur_step - 1, gradient_accumulation_steps=self.gradient_accumulation_steps, @@ -281,7 +295,7 @@ def _ensure_sp_strategy(self) -> None: ) def _get_default_group(self): - """Get the only group has optimizer, else return the default one""" + """Get the only group, else return the default one""" if len(self.optimizer_group) == 1: return next(iter(self.optimizer_group)) return self.active_group @@ -431,10 +445,12 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T inputs: The model inputs. Can be an encoded batch, or a list of `Trajectory` **kwargs: adapter_name: Lora adapter name. + disable_lora: If True, disable LoRA and use base model for inference. Returns: The output of the model forward. """ adapter_name = kwargs.pop('adapter_name', self._get_default_group()) + disable_lora = kwargs.pop('disable_lora', False) temperature = float(kwargs.pop('temperature', 1.0)) return_logits = kwargs.pop('return_logits', False) optimizer_config = self.optimizer_group[adapter_name] @@ -450,6 +466,7 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T if isinstance(inputs, dict): inputs = [inputs] inputs = optimizer_config.template.batch_encode(inputs) # noqa + breakpoint() with torch.no_grad(): processor: InputProcessor = optimizer_config.processor assert isinstance(processor, InputProcessor), 'Set InputProcessor correctly before forwarding' @@ -458,13 +475,17 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T inputs = self.sp_strategy.preprocess_inputs(inputs) labels = inputs.pop('labels', None) optimizer_config.accumulate_metrics(False) - outputs = self.model(**inputs) + if disable_lora and isinstance(self.model, PeftModel): + with self.model.disable_adapter(): + outputs = self.model(**inputs) + else: + outputs = self.model(**inputs) if self.sp_strategy is not None and labels is None: outputs = self.sp_strategy.postprocess_outputs(outputs) inputs['labels'] = labels - optimizer_config.inputs = inputs - optimizer_config.outputs = outputs - optimizer_config.loss_value = outputs.get('aux_loss', 0) + optimizer_config.eval_inputs = inputs + optimizer_config.eval_outputs = outputs + optimizer_config.eval_loss_value = outputs.get('aux_loss', 0) if labels is not None: loss_mask = (labels != -100).bool() masked_labels = labels.clone() @@ -1031,8 +1052,7 @@ def _patch_adapter(self, adapter_name: str, config_or_dir: Union[PeftConfig, str else: unwrapped_model.add_adapter(adapter_name, config) - self.optimizer_group[adapter_name] = self.optimizer_group.pop(_default_adapter_name, - self._construct_default_optimizer_group()) + self.optimizer_group[adapter_name] = self._construct_default_optimizer_group() self.optimizer_group[adapter_name].adapter_name = adapter_name self.optimizer_group[adapter_name].adapter_config = config _gas_default = kwargs.get('gradient_accumulation_steps', 1) From c7a1465e9a35d611f5ced8287528bae4885120b9 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 29 Mar 2026 10:45:47 +0800 Subject: [PATCH 16/31] wip --- src/twinkle/model/megatron/megatron.py | 80 ++---------- src/twinkle/model/optimizer_group.py | 86 +++++++++++++ .../model/transformers/transformers.py | 120 +++++------------- 3 files changed, 134 insertions(+), 152 deletions(-) create mode 100644 src/twinkle/model/optimizer_group.py diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index bc195e59..15a2326e 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -25,6 +25,7 @@ from twinkle import DeviceMesh, Platform, remote_class, remote_function, requires, torch_util from twinkle.checkpoint_engine.mixin import CheckpointEngineMixin from twinkle.data_format import InputFeature, ModelOutput, Trajectory +from twinkle.model.optimizer_group import BaseOptimizerGroup, TrainStatus from twinkle.hub import HubOperation from twinkle.infra import collect_tensor_dict from twinkle.loss import CrossEntropyLoss, Loss @@ -38,53 +39,28 @@ @dataclass -class MegatronOptimizerGroup: +class MegatronOptimizerGroup(BaseOptimizerGroup): """Optimizer group for Megatron training. Similar to OptimizerGroup but adapted for Megatron's distributed training. """ - adapter_name: str = None - adapter_config: Any = None - optimizer: Optimizer = None - lr_scheduler: LRScheduler = None - inputs: List[InputFeature] = None - outputs: ModelOutput = None - loss_instance: Loss = None - loss_value: Any = None - eval_inputs: List[InputFeature] = None - eval_outputs: ModelOutput = None - eval_loss_value: Any = None - template: Template = None - processor: InputProcessor = None - gradient_accumulation_steps: int = 1 - cur_step: int = 0 - _dp_group = None - train_metrics: List[Metric] = field(default_factory=list) - eval_metrics: List[Metric] = field(default_factory=list) - _device_mesh: DeviceMesh = None - # Megatron optimizer specific fields - _last_grad_norm: float = 0.0 + # Megatron-specific fields _last_step_success: bool = True - def do_grad_sync(self, gradient_accumulation_steps: Optional[int] = None) -> bool: - if gradient_accumulation_steps is None: - gradient_accumulation_steps = self.gradient_accumulation_steps - else: - self.gradient_accumulation_steps = gradient_accumulation_steps - return (self.cur_step - 1) % gradient_accumulation_steps == 0 and self.cur_step > 1 - def __post_init__(self): if self._device_mesh.data_world_size > 1: self._dp_group = self._device_mesh.create_process_group(['dp', 'fsdp']) - self.train_metrics = [ + train_metrics = [ LossMetric(self._device_mesh, self._dp_group), TrainMetric(self._device_mesh, self._dp_group), ] + self.train_status = TrainStatus(metrics=train_metrics) - self.eval_metrics = [ + eval_metrics = [ LossMetric(self._device_mesh, self._dp_group), TrainMetric(self._device_mesh, self._dp_group), ] + self.eval_status = TrainStatus(metrics=eval_metrics) def _get_lr(self): _lrs = [] @@ -93,36 +69,6 @@ def _get_lr(self): _lrs.append(param_group.get('lr', _default_lr)) return _lrs - def accumulate_metrics(self, is_training): - if is_training: - metrics = self.train_metrics - inputs = self.inputs - outputs = self.outputs - else: - metrics = self.eval_metrics - inputs = self.eval_inputs - outputs = self.eval_outputs - if len(metrics) > 0 and inputs is not None and outputs is not None: - for metric in metrics: - metric.accumulate( - inputs, - outputs, - lr=self._get_lr(), - step=self.cur_step - 1, - gradient_accumulation_steps=self.gradient_accumulation_steps, - grad_norm=self._last_grad_norm) - - def calculate_metrics(self, is_training): - self.accumulate_metrics(is_training) - if is_training: - metrics = self.train_metrics - else: - metrics = self.eval_metrics - results = {} - for metric in metrics: - results.update(metric.calculate()) - return results - _default_adapter_name = '' @@ -610,11 +556,11 @@ def _set_disable_adapters(model, value: bool): if not return_logits: logits = None if forward_only: - optimizer_config.eval_inputs = inputs - optimizer_config.eval_outputs = ModelOutput(logits=logits, loss=loss, logps=logps) + optimizer_config.eval_status.inputs = inputs + optimizer_config.eval_status.outputs = ModelOutput(logits=logits, loss=loss, logps=logps) else: - optimizer_config.inputs = inputs - optimizer_config.outputs = ModelOutput(logits=logits, loss=loss, logps=logps) + optimizer_config.train_status.inputs = inputs + optimizer_config.train_status.outputs = ModelOutput(logits=logits, loss=loss, logps=logps) return ModelOutput(logits=logits, loss=loss, logps=logps) @remote_function(dispatch='all') @@ -760,9 +706,9 @@ def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool] kwargs['device_mesh'] = self.device_mesh kwargs['process_group'] = optimizer_config._dp_group if is_training is None or is_training is True: - optimizer_config.train_metrics.append(construct_class(metric_cls, Metric, twinkle.metric, **kwargs)) + optimizer_config.train_status.metrics.append(construct_class(metric_cls, Metric, twinkle.metric, **kwargs)) if not is_training: - optimizer_config.eval_metrics.append(construct_class(metric_cls, Metric, twinkle.metric, **kwargs)) + optimizer_config.eval_status.metrics.append(construct_class(metric_cls, Metric, twinkle.metric, **kwargs)) @remote_function(dispatch='all') def set_optimizer(self, optimizer_cls: Union[Optimizer, Type[Optimizer], str], **kwargs): diff --git a/src/twinkle/model/optimizer_group.py b/src/twinkle/model/optimizer_group.py new file mode 100644 index 00000000..1a23d568 --- /dev/null +++ b/src/twinkle/model/optimizer_group.py @@ -0,0 +1,86 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler + +from twinkle import DeviceMesh +from twinkle.data_format import InputFeature, ModelOutput +from twinkle.loss import Loss +from twinkle.metric import Metric +from twinkle.processor import InputProcessor +from twinkle.template import Template + + +@dataclass +class TrainStatus: + """Status for training or evaluation. + + Encapsulates inputs, outputs, loss, tokens count and metrics for a training/eval step. + """ + inputs: List[InputFeature] = None + outputs: ModelOutput = None + loss_value: Any = None + num_tokens: int = 0 + metrics: List[Metric] = field(default_factory=list) + forward_kwargs: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class BaseOptimizerGroup: + """Base optimizer group with common fields for training. + + Subclasses: OptimizerGroup (Transformers), MegatronOptimizerGroup (Megatron) + """ + adapter_name: str = None + adapter_config: Any = None + optimizer: Optimizer = None + lr_scheduler: LRScheduler = None + loss_instance: Loss = None + train_status: TrainStatus = None + eval_status: TrainStatus = None + template: Template = None + processor: InputProcessor = None + gradient_accumulation_steps: int = 1 + cur_step: int = 0 + _dp_group: Any = None + _device_mesh: DeviceMesh = None + _last_grad_norm: float = 0.0 + + def do_grad_sync(self, gradient_accumulation_steps: Optional[int] = None) -> bool: + if gradient_accumulation_steps is None: + gradient_accumulation_steps = self.gradient_accumulation_steps + else: + self.gradient_accumulation_steps = gradient_accumulation_steps + return (self.cur_step - 1) % gradient_accumulation_steps == 0 and self.cur_step > 1 + + def _get_lr(self): + """Get learning rates from optimizer. Override in subclass.""" + return [] + + def accumulate_metrics(self, is_training): + """Accumulate metrics for train/eval status. Override in subclass if needed.""" + status = self.train_status if is_training else self.eval_status + if len(status.metrics) > 0 and status.inputs is not None and status.outputs is not None: + for metric in status.metrics: + metric.accumulate( + status.inputs, + status.outputs, + lr=self._get_lr(), + step=self.cur_step - 1, + gradient_accumulation_steps=self.gradient_accumulation_steps, + grad_norm=self._last_grad_norm, + **status.forward_kwargs) + + def calculate_metrics(self, is_training): + """Calculate and return metrics.""" + self.accumulate_metrics(is_training) + status = self.train_status if is_training else self.eval_status + results = {} + for metric in status.metrics: + results.update(metric.calculate()) + status.inputs = None + status.outputs = None + return results + diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index bc919fbd..5c1f6498 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -31,6 +31,7 @@ from twinkle.loss import CrossEntropyLoss, Loss from twinkle.metric import Accuracy, LossMetric, Metric, TrainMetric from twinkle.model.base import TwinkleModel +from twinkle.model.optimizer_group import BaseOptimizerGroup, TrainStatus from twinkle.model.transformers.moe import apply_expert_parallel from twinkle.model.transformers.strategy import AccelerateStrategy, NativeFSDPStrategy from twinkle.patch import Patch, apply_patch @@ -42,63 +43,33 @@ @dataclass -class TrainStatus: - inputs: List[InputFeature] = None - outputs: ModelOutput = None - loss_value: Any = None - num_tokens: int = 0 - metrics: List[Metric] = field(default_factory=list) - - -@dataclass -class OptimizerGroup: - adapter_name: str = None +class OptimizerGroup(BaseOptimizerGroup): + """Optimizer group for Transformers training.""" adapter_config: PeftConfig = None - optimizer: Optimizer = None - lr_scheduler: LRScheduler = None - - eval_inputs: List[InputFeature] = None - eval_outputs: ModelOutput = None - eval_loss_value: Any = None loss_instance: Loss = CrossEntropyLoss - template: Template = None - processor: InputProcessor = None scaler: GradScaler = None - _last_grad_norm: float = 0.0 scaler_has_nan: bool = False - gradient_accumulation_steps: int = 1 - cur_step: int = 0 - - train_ - eval_metrics: List[Metric] = field(default_factory=list) checkpoint_engine: CheckpointEngine = None - _dp_group = None - _device_mesh: DeviceMesh = None _handler: Any = None - def do_grad_sync(self, gradient_accumulation_steps: Optional[int] = None) -> bool: - if gradient_accumulation_steps is None: - gradient_accumulation_steps = self.gradient_accumulation_steps - else: - self.gradient_accumulation_steps = gradient_accumulation_steps - return (self.cur_step - 1) % gradient_accumulation_steps == 0 and self.cur_step > 1 - def __post_init__(self): self._ensure_dp_group() self._build_metrics() def _build_metrics(self): - self.train_metrics = [ - LossMetric(self._device_mesh, self._dp_group, loss_reduction='sum'), + train_metrics = [ + LossMetric(self._device_mesh, self._dp_group), Accuracy(self._device_mesh, self._dp_group), TrainMetric(self._device_mesh, self._dp_group), ] + self.train_status = TrainStatus(metrics=train_metrics) - self.eval_metrics = [ - LossMetric(self._device_mesh, self._dp_group, loss_reduction='sum'), + eval_metrics = [ + LossMetric(self._device_mesh, self._dp_group), Accuracy(self._device_mesh, self._dp_group), TrainMetric(self._device_mesh, self._dp_group), ] + self.eval_status = TrainStatus(metrics=eval_metrics) def _ensure_dp_group(self): if self._dp_group is not None or self._device_mesh is None: @@ -127,40 +98,18 @@ def _get_lr(self): def accumulate_metrics(self, is_training): self._ensure_dp_group() - if is_training: - metrics = self.train_metrics - inputs = self.inputs - outputs = self.outputs - else: - metrics = self.eval_metrics - inputs = self.eval_inputs - outputs = self.eval_outputs - # Get stored forward_kwargs from previous forward - forward_kwargs = getattr(self, 'forward_kwargs', None) or {} - if len(metrics) > 0 and inputs is not None and outputs is not None: - for metric in metrics: + status = self.train_status if is_training else self.eval_status + if len(status.metrics) > 0 and status.inputs is not None and status.outputs is not None: + for metric in status.metrics: metric.accumulate( - inputs, - outputs, + status.inputs, + status.outputs, lr=self._get_lr(), step=self.cur_step - 1, gradient_accumulation_steps=self.gradient_accumulation_steps, grad_norm=self._last_grad_norm, loss_reduction=getattr(self.loss_instance, 'reduction', 'mean'), - **forward_kwargs) - - def calculate_metrics(self, is_training): - self.accumulate_metrics(is_training) - if is_training: - metrics = self.train_metrics - else: - metrics = self.eval_metrics - results = {} - for metric in metrics: - results.update(metric.calculate()) - self.inputs = None - self.outputs = None - return results + **status.forward_kwargs) _default_adapter_name = '' @@ -420,10 +369,10 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec if self.sp_strategy is not None and labels is None: outputs = self.sp_strategy.postprocess_outputs(outputs) inputs['labels'] = labels - optimizer_config.inputs = inputs - optimizer_config.outputs = outputs - optimizer_config.forward_kwargs = kwargs # Store for next metric accumulation - optimizer_config.loss_value = outputs.get('aux_loss', 0) + optimizer_config.train_status.inputs = inputs + optimizer_config.train_status.outputs = outputs + optimizer_config.train_status.forward_kwargs = kwargs + optimizer_config.train_status.loss_value = outputs.get('aux_loss', 0) if labels is not None: loss_mask = (labels != -100).bool() masked_labels = labels.clone() @@ -483,9 +432,10 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T if self.sp_strategy is not None and labels is None: outputs = self.sp_strategy.postprocess_outputs(outputs) inputs['labels'] = labels - optimizer_config.eval_inputs = inputs - optimizer_config.eval_outputs = outputs - optimizer_config.eval_loss_value = outputs.get('aux_loss', 0) + optimizer_config.eval_status.inputs = inputs + optimizer_config.eval_status.outputs = outputs + optimizer_config.eval_status.forward_kwargs = kwargs + optimizer_config.eval_status.loss_value = outputs.get('aux_loss', 0) if labels is not None: loss_mask = (labels != -100).bool() masked_labels = labels.clone() @@ -514,8 +464,8 @@ def calculate_loss(self, **kwargs): optimizer_config = self.optimizer_group[adapter_name] loss_instance: Loss = optimizer_config.loss_instance assert isinstance(loss_instance, Loss), 'Set a loss_instance before calculating loss' - inputs = optimizer_config.inputs - outputs = optimizer_config.outputs + inputs = optimizer_config.train_status.inputs + outputs = optimizer_config.train_status.outputs assert inputs is not None and outputs is not None, 'Cannot calculate loss of empty inputs and outputs' result = loss_instance(inputs, outputs, **kwargs) loss_value = result['loss'] @@ -534,15 +484,15 @@ def calculate_loss(self, **kwargs): # = global_per_token_grad / dp_world_size = avg_per_token_grad counts = counts / self.device_mesh.data_world_size optimizer_config = self.optimizer_group[adapter_name] - optimizer_config.num_tokens += counts.item() + optimizer_config.train_status.num_tokens += counts.item() if self.sp_strategy is not None and 'labels' in inputs: reduction = getattr(loss_instance, 'reduction', None) if reduction is not None: self.sp_strategy.sp_config['loss_reduction'] = str(reduction) loss_value = self.sp_strategy.reduce_loss(loss_value, inputs['labels']) - optimizer_config.loss_value += loss_value - outputs['loss'] = optimizer_config.loss_value - return optimizer_config.loss_value.item() + optimizer_config.train_status.loss_value += loss_value + outputs['loss'] = optimizer_config.train_status.loss_value + return optimizer_config.train_status.loss_value.item() @remote_function() def backward(self, **kwargs): @@ -555,7 +505,7 @@ def backward(self, **kwargs): """ adapter_name = kwargs.pop('adapter_name', self._get_default_group()) optimizer_config = self.optimizer_group[adapter_name] - loss_value = optimizer_config.loss_value + loss_value = optimizer_config.train_status.loss_value assert loss_value is not None, 'Do forwarding and calculating loss before backward' scaler = optimizer_config.scaler if scaler is None and self.mixed_precision == 'fp16': @@ -567,7 +517,7 @@ def backward(self, **kwargs): else: loss_value.backward() optimizer_config.cur_step += 1 - optimizer_config.loss_value = None + optimizer_config.train_status.loss_value = None @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], @@ -618,7 +568,7 @@ def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type=2, **kwargs): scaler.unscale_(optimizer) optimizer_config._ensure_dp_group() - num_tokens = optimizer_config.num_tokens + num_tokens = optimizer_config.train_status.num_tokens num_tokens = torch_util.gather_object([num_tokens], self.device_mesh, optimizer_config._dp_group) num_tokens = sum(num_tokens) parameters = list(self._get_trainable_parameters(adapter_name).values()) @@ -635,7 +585,7 @@ def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type=2, **kwargs): **ep_clip_kwargs, ) optimizer_config._last_grad_norm = grad_norm - optimizer_config.num_tokens = 0 + optimizer_config.train_status.num_tokens = 0 return grad_norm @remote_function(dispatch='all') @@ -1137,9 +1087,9 @@ def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool] kwargs['device_mesh'] = self.device_mesh kwargs['process_group'] = optimizer_config._dp_group if is_training is None or is_training is True: - optimizer_config.train_metrics.append(construct_class(metric_cls, Metric, twinkle.metric, **kwargs)) + optimizer_config.train_status.metrics.append(construct_class(metric_cls, Metric, twinkle.metric, **kwargs)) if not is_training: - optimizer_config.eval_metrics.append(construct_class(metric_cls, Metric, twinkle.metric, **kwargs)) + optimizer_config.eval_status.metrics.append(construct_class(metric_cls, Metric, twinkle.metric, **kwargs)) def _get_nb_trainable_parameters(self, adapter_name, model): return PeftModel.get_nb_trainable_parameters(model) From 1c8bcd238c4abd81e64fe738b559d13747cdd69d Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 29 Mar 2026 13:04:02 +0800 Subject: [PATCH 17/31] wip --- cookbook/rl/dpo.py | 4 +-- cookbook/rl/dpo_lora.py | 6 ++-- src/twinkle/infra/collectors.py | 2 +- src/twinkle/metric/dpo.py | 1 - .../model/transformers/transformers.py | 6 +++- src/twinkle/utils/device_mesh.py | 34 +++++++++++++++++++ 6 files changed, 45 insertions(+), 8 deletions(-) diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo.py index cd86653b..93f3c3b2 100644 --- a/cookbook/rl/dpo.py +++ b/cookbook/rl/dpo.py @@ -70,8 +70,8 @@ MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen2.5-7B-Instruct') DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji') -MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) -REF_MODEL_GPUS = int(os.environ.get('REF_MODEL_GPUS', 4)) +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 2)) +REF_MODEL_GPUS = int(os.environ.get('REF_MODEL_GPUS', 2)) NUM_GPUS = MODEL_GPUS + REF_MODEL_GPUS BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs diff --git a/cookbook/rl/dpo_lora.py b/cookbook/rl/dpo_lora.py index 47f80e83..be61957b 100644 --- a/cookbook/rl/dpo_lora.py +++ b/cookbook/rl/dpo_lora.py @@ -68,10 +68,10 @@ MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen2.5-7B-Instruct') DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji') -MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 2)) -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs -MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 8)) +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 2)) # Number of preference pairs +MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 2)) LEARNING_RATE = float(os.environ.get('LR', 1e-4)) # LoRA DPO requires higher LR (1e-4 to 3e-4) DPO_BETA = float(os.environ.get('DPO_BETA', 0.1)) diff --git a/src/twinkle/infra/collectors.py b/src/twinkle/infra/collectors.py index af4d6d6e..55ae73fc 100644 --- a/src/twinkle/infra/collectors.py +++ b/src/twinkle/infra/collectors.py @@ -19,7 +19,7 @@ def collect_tensor_dict(outputs: List[Dict[str, Any]], device_mesh: DeviceMesh) for d in outputs: all_keys.update(d.keys()) - outputs = [r for i, r in enumerate(outputs) if i in device_mesh.get_pp_last_ranks()] + outputs = [r for i, r in enumerate(outputs) if i in device_mesh.get_collect_ranks()] result = {} for key in all_keys: values = [d[key] for d in outputs if key in d] diff --git a/src/twinkle/metric/dpo.py b/src/twinkle/metric/dpo.py index ceb3c651..4816d4e9 100644 --- a/src/twinkle/metric/dpo.py +++ b/src/twinkle/metric/dpo.py @@ -56,7 +56,6 @@ def _align_logps(self, logps, target_shape, device, dtype): if src_len == target_len: return logps elif src_len < target_len: - breakpoint() raise ValueError( f'ref_logps seq_len ({src_len}) < target seq_len ({target_len}). ' f'This should not happen when both models process the same batch.' diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 5c1f6498..90654f23 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -40,6 +40,9 @@ from twinkle.utils import construct_class, selective_log_softmax, torch_util from twinkle.utils.framework import Torch from twinkle.utils.grad_clip import normalize_and_clip_grad_norm +from twinkle.utils import get_logger + +logger = get_logger() @dataclass @@ -188,6 +191,7 @@ def __init__( } self.optimizer_group[_default_adapter_name].adapter_name = _default_adapter_name self.active_group = _default_adapter_name + # breakpoint() def _decide_strategy(self, strategy: Literal['accelerate', 'native_fsdp']): self._expert_parallel_config = self._fsdp_config.pop('expert_parallel', None) @@ -415,7 +419,6 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T if isinstance(inputs, dict): inputs = [inputs] inputs = optimizer_config.template.batch_encode(inputs) # noqa - breakpoint() with torch.no_grad(): processor: InputProcessor = optimizer_config.processor assert isinstance(processor, InputProcessor), 'Set InputProcessor correctly before forwarding' @@ -443,6 +446,7 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T logits = outputs['logits'] logits.div_(temperature) outputs['logps'] = selective_log_softmax(logits, masked_labels) + logger.info(f'logps: {outputs["logps"].sum()}') outputs = copy(outputs) outputs['past_key_values'] = None if not return_logits: diff --git a/src/twinkle/utils/device_mesh.py b/src/twinkle/utils/device_mesh.py index 9f5aa9e7..d95d5da8 100644 --- a/src/twinkle/utils/device_mesh.py +++ b/src/twinkle/utils/device_mesh.py @@ -462,6 +462,40 @@ def get_pp_last_ranks(self) -> Optional[list[int]]: pp_world_size = self.pp_world_size or 1 return self.get_pp_stage_ranks(pp_world_size - 1) + def get_collect_ranks(self) -> list[int]: + """Get ranks for collecting slice_dp dispatch results. + + For slice_dp dispatch, data is split by data_rank (DP/FSDP dimensions). + Within each data_rank group, outputs are identical (after all-gather for TP, etc.), + so we only need one representative per data_rank. + + Returns: One rank per data_rank, preferring PP last stage. + """ + data_ws = self.data_world_size + collect_ranks = [] + + # For each data_rank, find a representative global rank + for data_rank in range(data_ws): + # Find all global ranks that map to this data_rank + candidates = [ + r for r in self.mesh.flatten().tolist() + if self.get_data_rank_from_global_rank(r) == data_rank + ] + if not candidates: + continue + + # Prefer PP last stage if PP exists + pp_last = self.get_pp_last_ranks() + if pp_last: + pp_candidates = [r for r in candidates if r in pp_last] + if pp_candidates: + candidates = pp_candidates + + # Take the smallest rank as representative + collect_ranks.append(min(candidates)) + + return sorted(collect_ranks) + def has_dim(self, dim_name: str) -> bool: if self.mesh_dim_names is None: return False From f4fe545348a8a18a7786c0c04101d4b327f2af65 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 29 Mar 2026 14:02:15 +0800 Subject: [PATCH 18/31] wip --- cookbook/rl/dpo.py | 4 ++-- src/twinkle/metric/dpo.py | 12 ++++++++++++ src/twinkle/model/megatron/megatron.py | 1 + src/twinkle/model/transformers/transformers.py | 9 +++------ 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo.py index 93f3c3b2..6fa90e91 100644 --- a/cookbook/rl/dpo.py +++ b/cookbook/rl/dpo.py @@ -74,8 +74,8 @@ REF_MODEL_GPUS = int(os.environ.get('REF_MODEL_GPUS', 2)) NUM_GPUS = MODEL_GPUS + REF_MODEL_GPUS -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs -MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 8)) +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 2)) # Number of preference pairs +MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 2)) LEARNING_RATE = float(os.environ.get('LR', 1e-4)) # TRL default for DPO is 5e-7 to 5e-6 DPO_BETA = float(os.environ.get('DPO_BETA', 0.1)) diff --git a/src/twinkle/metric/dpo.py b/src/twinkle/metric/dpo.py index 4816d4e9..7f21e47f 100644 --- a/src/twinkle/metric/dpo.py +++ b/src/twinkle/metric/dpo.py @@ -137,6 +137,10 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M ref_seq_logps = self._compute_sequence_logps(ref_logps, labels) ref_chosen_logps, ref_rejected_logps = self._split_chosen_rejected(ref_seq_logps) + # Accumulate ref logps + self.total_ref_chosen_logps += ref_chosen_logps.sum().item() + self.total_ref_rejected_logps += ref_rejected_logps.sum().item() + # Compute rewards: β * (policy - ref) chosen_rewards = self.beta * (chosen_logps - ref_chosen_logps) rejected_rewards = self.beta * (rejected_logps - ref_rejected_logps) @@ -154,6 +158,8 @@ def reset(self): """Reset all accumulated values.""" self.total_chosen_logps = 0.0 self.total_rejected_logps = 0.0 + self.total_ref_chosen_logps = 0.0 + self.total_ref_rejected_logps = 0.0 self.total_chosen_rewards = 0.0 self.total_rejected_rewards = 0.0 self.total_reward_margin = 0.0 @@ -166,6 +172,8 @@ def calculate(self): local_results = [{ 'chosen_logps': self.total_chosen_logps, 'rejected_logps': self.total_rejected_logps, + 'ref_chosen_logps': self.total_ref_chosen_logps, + 'ref_rejected_logps': self.total_ref_rejected_logps, 'chosen_rewards': self.total_chosen_rewards, 'rejected_rewards': self.total_rejected_rewards, 'reward_margin': self.total_reward_margin, @@ -177,6 +185,8 @@ def calculate(self): total_chosen_logps = sum(r['chosen_logps'] for r in all_results) total_rejected_logps = sum(r['rejected_logps'] for r in all_results) + total_ref_chosen_logps = sum(r['ref_chosen_logps'] for r in all_results) + total_ref_rejected_logps = sum(r['ref_rejected_logps'] for r in all_results) total_chosen_rewards = sum(r['chosen_rewards'] for r in all_results) total_rejected_rewards = sum(r['rejected_rewards'] for r in all_results) total_reward_margin = sum(r['reward_margin'] for r in all_results) @@ -195,6 +205,8 @@ def calculate(self): } if has_rewards: + results['logps/ref_chosen'] = f'{total_ref_chosen_logps / total_count:.2f}' + results['logps/ref_rejected'] = f'{total_ref_rejected_logps / total_count:.2f}' results['rewards/chosen'] = f'{total_chosen_rewards / total_count:.4f}' results['rewards/rejected'] = f'{total_rejected_rewards / total_count:.4f}' results['rewards/margins'] = f'{total_reward_margin / total_count:.4f}' diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 15a2326e..b120fc2b 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -477,6 +477,7 @@ def forward_step_func(data_iterator, model): # Handle disable_lora for base model inference (e.g., reference in DPO) def _set_disable_adapters(model, value: bool): + model = self.strategy.unwrap_model(model) if isinstance(model, list): for m in model: if isinstance(m, PeftModel): diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 90654f23..767cda93 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -40,9 +40,6 @@ from twinkle.utils import construct_class, selective_log_softmax, torch_util from twinkle.utils.framework import Torch from twinkle.utils.grad_clip import normalize_and_clip_grad_norm -from twinkle.utils import get_logger - -logger = get_logger() @dataclass @@ -427,8 +424,9 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T inputs = self.sp_strategy.preprocess_inputs(inputs) labels = inputs.pop('labels', None) optimizer_config.accumulate_metrics(False) - if disable_lora and isinstance(self.model, PeftModel): - with self.model.disable_adapter(): + unwrapped_model = self.strategy.unwrap_model(self.model) + if disable_lora and isinstance(unwrapped_model, PeftModel): + with unwrapped_model.disable_adapter(): outputs = self.model(**inputs) else: outputs = self.model(**inputs) @@ -446,7 +444,6 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T logits = outputs['logits'] logits.div_(temperature) outputs['logps'] = selective_log_softmax(logits, masked_labels) - logger.info(f'logps: {outputs["logps"].sum()}') outputs = copy(outputs) outputs['past_key_values'] = None if not return_logits: From ed00c1b94ea9f84be4d82565520aab8620c22165 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 29 Mar 2026 15:21:25 +0800 Subject: [PATCH 19/31] wip --- cookbook/rl/dpo_lora.py | 11 +-- src/twinkle/loss/dpo.py | 5 +- src/twinkle/metric/dpo.py | 22 ++---- src/twinkle/model/megatron/megatron.py | 68 ++++++++----------- .../model/megatron/multi_lora_megatron.py | 1 + .../model/transformers/transformers.py | 1 - 6 files changed, 44 insertions(+), 64 deletions(-) diff --git a/cookbook/rl/dpo_lora.py b/cookbook/rl/dpo_lora.py index be61957b..765b3ded 100644 --- a/cookbook/rl/dpo_lora.py +++ b/cookbook/rl/dpo_lora.py @@ -58,7 +58,7 @@ from twinkle.dataset import Dataset, DatasetMeta from twinkle.loss import DPOLoss from twinkle.metric import DPOMetric -from twinkle.model import TransformersModel +from twinkle.model import MegatronModel from twinkle.preprocessor import EmojiDPOProcessor from twinkle.processor import InputProcessor @@ -157,15 +157,15 @@ def main(): lora_dropout=0.05, ) - policy_model = TransformersModel( + policy_model = MegatronModel( model_id=MODEL_ID, device_mesh=policy_mesh, remote_group='policy', ) MAX_STEPS = len(dataloader) policy_model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) - policy_model.set_optimizer('AdamW', lr=LEARNING_RATE, weight_decay=0.01) - policy_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=LEARNING_RATE * 0.1) + policy_model.set_optimizer('default', lr=LEARNING_RATE, weight_decay=0.01) + policy_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS) # Set up loss function and metrics loss_fn = DPOLoss( @@ -191,13 +191,14 @@ def main(): # Get reference outputs using base model (without LoRA adapter) # disable_lora=True tells the model to skip LoRA and use base weights - ref_outputs = policy_model.forward_only(inputs=dpo_batch, disable_lora=True) + ref_outputs = policy_model.forward_only(inputs=dpo_batch, micro_batch_size=2, disable_lora=True) # Forward-backward pass with DPO loss (using LoRA adapter) # ref_outputs is passed to loss which extracts logps internally policy_model.forward_backward( inputs=dpo_batch, ref_outputs=ref_outputs, + micro_batch_size=2, ) # Gradient clipping and optimizer step diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py index ce533053..610beb13 100644 --- a/src/twinkle/loss/dpo.py +++ b/src/twinkle/loss/dpo.py @@ -326,10 +326,7 @@ def __call__( reference_chosen_logps = torch.zeros_like(policy_chosen_logps) reference_rejected_logps = torch.zeros_like(policy_rejected_logps) else: - raise ValueError( - "ref_logps or (ref_chosen_logps, ref_rejected_logps) must be provided " - "unless reference_free=True" - ) + return LossOutput(loss=torch.tensor(0.0, device=chosen_logps.device), num_tokens=0) # Compute DPO loss dpo_loss = self._compute_dpo_loss( diff --git a/src/twinkle/metric/dpo.py b/src/twinkle/metric/dpo.py index 7f21e47f..404ff706 100644 --- a/src/twinkle/metric/dpo.py +++ b/src/twinkle/metric/dpo.py @@ -80,27 +80,18 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M - kwargs['ref_outputs']: Optional reference model outputs with 'logps' """ import torch - logps = outputs.get('logps') if logps is None: return # Get labels from inputs if isinstance(inputs, list): - # Stack labels from list of inputs - labels_list = [torch.as_tensor(inp['labels']) for inp in inputs] - max_len = max(l.shape[0] for l in labels_list) - padded = [] - for l in labels_list: - if l.shape[0] < max_len: - pad = torch.full((max_len - l.shape[0],), self.ignore_index, dtype=l.dtype) - l = torch.cat([pad, l]) - padded.append(l) - labels = torch.stack(padded) - else: - labels = torch.as_tensor(inputs['labels']) - if labels.dim() == 1: - labels = labels.unsqueeze(0) + assert len(inputs) == 1 + inputs = inputs[0] + + labels = torch.as_tensor(inputs['labels']) + if labels.dim() == 1: + labels = labels.unsqueeze(0) # Ensure logps and labels have same device if logps.device != labels.device: @@ -129,7 +120,6 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M ref_logps = ref_outputs.get('logps') if ref_logps is not None: # Align ref_logps to match labels shape (handles different seq lengths) - # breakpoint() ref_logps = self._align_logps( ref_logps, labels.shape, labels.device, logps.dtype ) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index b120fc2b..4c44243c 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -447,7 +447,13 @@ def post_loss_function(output_tensor, inputs, logps): def forward_step_func(data_iterator, model): batch = next(data_iterator) labels = batch.pop('labels', None) - output_tensor = model(**batch) + # Handle disable_lora for base model inference (e.g., reference in DPO) + unwrapped_model = self.strategy.unwrap_model([model])[0] + if disable_lora and isinstance(unwrapped_model, PeftModel): + with unwrapped_model.disable_adapter(): + output_tensor = model(**batch) + else: + output_tensor = model(**batch) batch['labels'] = labels logps = None if labels is not None and mpu.is_pipeline_last_stage(): @@ -475,34 +481,17 @@ def forward_step_func(data_iterator, model): self._accumulate_metric(optimizer_config, is_training=not forward_only) - # Handle disable_lora for base model inference (e.g., reference in DPO) - def _set_disable_adapters(model, value: bool): - model = self.strategy.unwrap_model(model) - if isinstance(model, list): - for m in model: - if isinstance(m, PeftModel): - m.disable_adapters = value - elif isinstance(model, PeftModel): - model.disable_adapters = value - - if disable_lora: - _set_disable_adapters(self.model, True) - - try: - # Run forward-backward with Megatron's scheduler - # Megatron handles all communication internally using proper process groups - losses = forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=data_iter, - model=self.model, - num_microbatches=len(inputs), - seq_length=seq_length, - micro_batch_size=micro_batch_size, - forward_only=forward_only, - ) - finally: - if disable_lora: - _set_disable_adapters(self.model, False) + # Run forward-backward with Megatron's scheduler + # Megatron handles all communication internally using proper process groups + losses = forward_backward_func( + forward_step_func=forward_step_func, + data_iterator=data_iter, + model=self.model, + num_microbatches=len(inputs), + seq_length=seq_length, + micro_batch_size=micro_batch_size, + forward_only=forward_only, + ) # Extract loss from results (only last PP stage returns non-empty) loss = torch.tensor(0.0).to(Platform.get_local_device()) @@ -559,9 +548,11 @@ def _set_disable_adapters(model, value: bool): if forward_only: optimizer_config.eval_status.inputs = inputs optimizer_config.eval_status.outputs = ModelOutput(logits=logits, loss=loss, logps=logps) + optimizer_config.eval_status.forward_kwargs = kwargs else: optimizer_config.train_status.inputs = inputs optimizer_config.train_status.outputs = ModelOutput(logits=logits, loss=loss, logps=logps) + optimizer_config.train_status.forward_kwargs = kwargs return ModelOutput(logits=logits, loss=loss, logps=logps) @remote_function(dispatch='all') @@ -692,6 +683,7 @@ def set_loss(self, loss_cls: Union[Loss, Type[Loss], str, Callable[[InputFeature optimizer_config = self.optimizer_group[adapter_name] optimizer_config.loss_instance = construct_class(loss_cls, Loss, twinkle.loss, **kwargs) + @remote_function() def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool] = None, **kwargs): """Add an eval metric @@ -773,16 +765,16 @@ def _create_megatron_optimizer(self, **kwargs): opt_config = OptimizerConfig( optimizer='adam', lr=lr, - min_lr=kwargs.get('min_lr', 0.0), - weight_decay=kwargs.get('weight_decay', 0.01), - adam_beta1=kwargs.get('adam_beta1', 0.9), - adam_beta2=kwargs.get('adam_beta2', 0.999), - adam_eps=kwargs.get('adam_eps', 1e-8), - clip_grad=kwargs.get('clip_grad', 1.0), - bf16=kwargs.get('bf16', True), + min_lr=kwargs.pop('min_lr', 0.0), + weight_decay=kwargs.pop('weight_decay', 0.01), + adam_beta1=kwargs.pop('adam_beta1', 0.9), + adam_beta2=kwargs.pop('adam_beta2', 0.999), + adam_eps=kwargs.pop('adam_eps', 1e-8), + clip_grad=kwargs.pop('clip_grad', 1.0), + bf16=kwargs.pop('bf16', True), use_distributed_optimizer=use_distributed_optimizer, - overlap_param_gather=kwargs.get('overlap_param_gather', False), - log_num_zeros_in_grad=kwargs.get('log_num_zeros_in_grad', False), + overlap_param_gather=kwargs.pop('overlap_param_gather', False), + log_num_zeros_in_grad=kwargs.pop('log_num_zeros_in_grad', False), **kwargs, ) diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index 50afbdfd..627eaa90 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -264,6 +264,7 @@ def set_processor(self, processor_cls: Union[Type[InputProcessor], str, Callable self._check_adapter_valid(kwargs.get('adapter_name')) super().set_processor(processor_cls, **kwargs) + @remote_function() def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool] = None, **kwargs): self._check_adapter_valid(kwargs.get('adapter_name')) super().add_metric(metric_cls, is_training, **kwargs) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 767cda93..22e3538d 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -188,7 +188,6 @@ def __init__( } self.optimizer_group[_default_adapter_name].adapter_name = _default_adapter_name self.active_group = _default_adapter_name - # breakpoint() def _decide_strategy(self, strategy: Literal['accelerate', 'native_fsdp']): self._expert_parallel_config = self._fsdp_config.pop('expert_parallel', None) From bebe60ea4a0ffa2b466af345efc1b6e76215b8a9 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 29 Mar 2026 16:52:26 +0800 Subject: [PATCH 20/31] wip --- cookbook/rl/dpo_lora.py | 34 ++++++++++--------- src/twinkle/infra/collectors.py | 2 +- src/twinkle/model/megatron/megatron.py | 4 ++- .../model/megatron/multi_lora_megatron.py | 2 +- src/twinkle/model/multi_lora.py | 4 +-- .../transformers/multi_lora_transformers.py | 6 ++-- 6 files changed, 29 insertions(+), 23 deletions(-) diff --git a/cookbook/rl/dpo_lora.py b/cookbook/rl/dpo_lora.py index 765b3ded..914e91ea 100644 --- a/cookbook/rl/dpo_lora.py +++ b/cookbook/rl/dpo_lora.py @@ -58,7 +58,7 @@ from twinkle.dataset import Dataset, DatasetMeta from twinkle.loss import DPOLoss from twinkle.metric import DPOMetric -from twinkle.model import MegatronModel +from twinkle.model import MultiLoraMegatronModel from twinkle.preprocessor import EmojiDPOProcessor from twinkle.processor import InputProcessor @@ -68,7 +68,7 @@ MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen2.5-7B-Instruct') DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji') -MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 2)) +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 8)) BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 2)) # Number of preference pairs MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) @@ -137,7 +137,7 @@ def main(): DeviceGroup(name='policy', ranks=list(range(MODEL_GPUS)), device_type='GPU'), ] - policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) + policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=1, pp_size=2, cp_size=2, tp_size=2) twinkle.initialize(mode='ray', nproc_per_node=8, groups=device_groups) # ── DataLoader Setup ────────────────────────────────────────────────────── @@ -157,15 +157,17 @@ def main(): lora_dropout=0.05, ) - policy_model = MegatronModel( + policy_model = MultiLoraMegatronModel( model_id=MODEL_ID, device_mesh=policy_mesh, remote_group='policy', ) MAX_STEPS = len(dataloader) policy_model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) - policy_model.set_optimizer('default', lr=LEARNING_RATE, weight_decay=0.01) - policy_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS) + # policy_model.set_optimizer('AdamW', lr=LEARNING_RATE, weight_decay=0.01, adapter_name=ADAPTER_NAME) + # policy_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, adapter_name=ADAPTER_NAME) + policy_model.set_optimizer('default', lr=LEARNING_RATE, weight_decay=0.01, adapter_name=ADAPTER_NAME) + policy_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, adapter_name=ADAPTER_NAME) # Set up loss function and metrics loss_fn = DPOLoss( @@ -174,10 +176,10 @@ def main(): reference_free=False, # We use base model as reference via disable_lora=True sft_weight=SFT_WEIGHT, ) - policy_model.set_loss(loss_fn) - policy_model.add_metric(DPOMetric, beta=DPO_BETA) - policy_model.set_processor(InputProcessor) - policy_model.set_template('Template', model_id=MODEL_ID) + policy_model.set_loss(loss_fn, adapter_name=ADAPTER_NAME) + policy_model.add_metric(DPOMetric, beta=DPO_BETA, adapter_name=ADAPTER_NAME) + policy_model.set_processor(InputProcessor, adapter_name=ADAPTER_NAME) + policy_model.set_template('Template', model_id=MODEL_ID, adapter_name=ADAPTER_NAME) optim_step = 0 logger.info(get_device_placement()) @@ -191,32 +193,32 @@ def main(): # Get reference outputs using base model (without LoRA adapter) # disable_lora=True tells the model to skip LoRA and use base weights - ref_outputs = policy_model.forward_only(inputs=dpo_batch, micro_batch_size=2, disable_lora=True) - + ref_outputs = policy_model.forward_only(inputs=dpo_batch, micro_batch_size=2, disable_lora=True, adapter_name=ADAPTER_NAME) # Forward-backward pass with DPO loss (using LoRA adapter) # ref_outputs is passed to loss which extracts logps internally policy_model.forward_backward( inputs=dpo_batch, ref_outputs=ref_outputs, micro_batch_size=2, + adapter_name=ADAPTER_NAME ) # Gradient clipping and optimizer step - policy_model.clip_grad_and_step() + policy_model.clip_grad_and_step(adapter_name=ADAPTER_NAME) optim_step += 1 # Logging if optim_step % 1 == 0: - metrics = policy_model.calculate_metric(is_training=True) + metrics = policy_model.calculate_metric(is_training=True, adapter_name=ADAPTER_NAME) logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metrics}') # Checkpointing if optim_step % SAVE_STEPS == 0: - policy_model.save(f'dpo-lora-checkpoint-{optim_step}') + policy_model.save(f'dpo-lora-checkpoint-{optim_step}', adapter_name=ADAPTER_NAME) # ── Save Final Checkpoint ───────────────────────────────────────────────── logger.info(f'Training completed. Total steps: {optim_step}') - policy_model.save('dpo-lora-final-checkpoint') + policy_model.save('dpo-lora-final-checkpoint', adapter_name=ADAPTER_NAME) if __name__ == '__main__': diff --git a/src/twinkle/infra/collectors.py b/src/twinkle/infra/collectors.py index 55ae73fc..0e4e6c35 100644 --- a/src/twinkle/infra/collectors.py +++ b/src/twinkle/infra/collectors.py @@ -61,7 +61,7 @@ def _pad_and_stack_tensors(tensors: List['torch.Tensor'], pad_value: float = -20 raise ValueError('Empty tensor list') if len(tensors) == 1: - return tensors[0].unsqueeze(0) + return tensors[0] max_ndim = max(t.ndim for t in tensors) expanded_tensors = [] diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 4c44243c..d6bcebc4 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -400,7 +400,9 @@ def forward_backward(self, seq_length = original_seq_length + (divisor - original_seq_length % divisor) else: seq_length = original_seq_length - + + if 'ref_outputs' in kwargs: + breakpoint() num_microbatches = len(inputs) loss_extra_kwargs_per_mb = [] if num_microbatches <= 1: diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index 627eaa90..fc35b2b0 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -129,7 +129,7 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T with self.multi_adapter.adapter(adapter_name, disable_lora=disable_lora): return super().forward_only(inputs=inputs, **kwargs) - @remote_function(dispatch='slice_dp', collect='mean', sync=True) + @remote_function(dispatch='slice_dp', collect=collect_tensor_dict, sync=True) def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], diff --git a/src/twinkle/model/multi_lora.py b/src/twinkle/model/multi_lora.py index 75f74cc1..cebde214 100644 --- a/src/twinkle/model/multi_lora.py +++ b/src/twinkle/model/multi_lora.py @@ -205,7 +205,7 @@ def _linear_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: lora_A_keys = self.lora_A.keys() for active_adapter in self.active_adapters: - if active_adapter not in lora_A_keys: + if active_adapter not in lora_A_keys or self.disable_adapters: continue _lora = _self.find_lora(active_adapter) target_modules = _lora.tenant_config.target_modules @@ -238,7 +238,7 @@ def _embedding_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: lora_embedding_A_keys = self.lora_embedding_A.keys() for active_adapter in self.active_adapters: - if active_adapter not in lora_embedding_A_keys: + if active_adapter not in lora_embedding_A_keys or self.disable_adapters: continue _lora = self.find_lora(active_adapter) target_modules = _lora.tenant_config.target_modules diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index 34e81e83..9fe52c8d 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union from twinkle import DeviceMesh, remote_class, remote_function, template +from twinkle.infra import collect_tensor_dict from twinkle.data_format import InputFeature, Trajectory from twinkle.hub import HubOperation from twinkle.loss import Loss @@ -88,7 +89,7 @@ def unregister_mm_forward_hook(self, optimizer_group: OptimizerGroup): def _lazy_wrap_model(self): pass - @remote_function(dispatch='slice_dp', collect='mean') + @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): self._check_adapter_valid(kwargs.get('adapter_name')) optimizer_config = self.optimizer_group[kwargs.get('adapter_name')] @@ -104,7 +105,7 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, with self.multi_adapter.adapter(kwargs.get('adapter_name')): return super().forward(inputs=inputs, **kwargs) - @remote_function(dispatch='slice_dp', collect='flatten') + @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajectory]], **kwargs): adapter_name = kwargs.get('adapter_name') disable_lora = kwargs.get('disable_lora', False) @@ -246,6 +247,7 @@ def set_grad_scaler(self, **kwargs): self._check_adapter_valid(kwargs.get('adapter_name')) super().set_grad_scaler(**kwargs) + @remote_function() def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool] = None, **kwargs): self._check_adapter_valid(kwargs.get('adapter_name')) super().add_metric(metric_cls, is_training, **kwargs) From 8fc2bb71208065de672dfb840ea1bd67f585cd4d Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 29 Mar 2026 17:28:08 +0800 Subject: [PATCH 21/31] wip --- src/twinkle/infra/collectors.py | 36 ++----------------- src/twinkle/loss/dpo.py | 1 - src/twinkle/metric/dpo.py | 14 ++++++-- src/twinkle/model/megatron/megatron.py | 49 +++++++++++++++++++------- src/twinkle/utils/__init__.py | 2 +- src/twinkle/utils/torch_utils.py | 33 +++++++++++++++++ 6 files changed, 83 insertions(+), 52 deletions(-) diff --git a/src/twinkle/infra/collectors.py b/src/twinkle/infra/collectors.py index 0e4e6c35..fbe4f17f 100644 --- a/src/twinkle/infra/collectors.py +++ b/src/twinkle/infra/collectors.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Any, Dict, List from twinkle import DeviceMesh +from twinkle.utils import pad_and_stack_tensors if TYPE_CHECKING: import torch @@ -39,7 +40,7 @@ def collect_tensor_dict(outputs: List[Dict[str, Any]], device_mesh: DeviceMesh) result[key] = merged elif isinstance(first_value, torch.Tensor): - result[key] = _pad_and_stack_tensors(values) + result[key] = pad_and_stack_tensors(values) elif isinstance(first_value, dict): result[key] = collect_tensor_dict(values) @@ -53,36 +54,3 @@ def collect_tensor_dict(outputs: List[Dict[str, Any]], device_mesh: DeviceMesh) if 'loss' in result and len(result['loss']) > 1: result['loss'] = np.mean(result['loss']) return result - - -def _pad_and_stack_tensors(tensors: List['torch.Tensor'], pad_value: float = -200) -> 'torch.Tensor': - import torch - if not tensors: - raise ValueError('Empty tensor list') - - if len(tensors) == 1: - return tensors[0] - - max_ndim = max(t.ndim for t in tensors) - expanded_tensors = [] - for t in tensors: - while t.ndim < max_ndim: - t = t.unsqueeze(0) - expanded_tensors.append(t) - - max_shape = [] - for dim in range(max_ndim): - max_shape.append(max(t.shape[dim] for t in expanded_tensors)) - - padded_tensors = [] - for t in expanded_tensors: - if list(t.shape) == max_shape: - padded_tensors.append(t) - else: - pad_params = [] - for dim in range(max_ndim - 1, -1, -1): - pad_params.extend([0, max_shape[dim] - t.shape[dim]]) - padded = torch.nn.functional.pad(t, pad_params, value=pad_value) - padded_tensors.append(padded) - - return torch.cat(padded_tensors, dim=0) diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py index 610beb13..ada318c6 100644 --- a/src/twinkle/loss/dpo.py +++ b/src/twinkle/loss/dpo.py @@ -284,7 +284,6 @@ def __call__( # Extract ref_logps from ref_outputs if provided if ref_outputs is not None and ref_logps is None: ref_logps = ref_outputs.get('logps') - labels = inputs.get('labels') assert labels is not None, "inputs must contain 'labels'" if not torch.is_tensor(labels): diff --git a/src/twinkle/metric/dpo.py b/src/twinkle/metric/dpo.py index 404ff706..bd6bd5d0 100644 --- a/src/twinkle/metric/dpo.py +++ b/src/twinkle/metric/dpo.py @@ -3,6 +3,7 @@ from typing import List, Union from twinkle.data_format import InputFeature, ModelOutput +from twinkle.utils import pad_and_stack_tensors from .base import Metric @@ -81,13 +82,20 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M """ import torch logps = outputs.get('logps') - if logps is None: + if logps is None or len(logps) == 0: return + + if isinstance(logps, list) and logps: + logps = pad_and_stack_tensors(logps) # Get labels from inputs if isinstance(inputs, list): - assert len(inputs) == 1 - inputs = inputs[0] + labels = [input['labels'] for input in inputs] + if len(labels) == 1: + labels = labels[0] + else: + labels = pad_and_stack_tensors(labels) + inputs = {'labels': labels} labels = torch.as_tensor(inputs['labels']) if labels.dim() == 1: diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index d6bcebc4..cd6c9f28 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -250,6 +250,38 @@ def _not_encoded(inputs): assert isinstance(inputs, dict) return 'input_ids' not in inputs and 'input_embedding' not in inputs + @staticmethod + def _slice_value_for_microbatch(value, mb_start: int, mb_end: int, micro_batch_size: int): + """Recursively slice a value for microbatch processing. + + Handles nested dicts (e.g., ref_outputs: {"logps": tensor}) by recursively + slicing internal tensors. + + Args: + value: The value to slice (tensor, ndarray, list, dict, or scalar) + mb_start: Start index of the microbatch + mb_end: End index of the microbatch + micro_batch_size: Size of each microbatch + + Returns: + Sliced value with the same structure + """ + if isinstance(value, torch.Tensor) and value.dim() >= 1 and value.shape[0] > micro_batch_size: + return value[mb_start:mb_end] + elif isinstance(value, np.ndarray) and value.ndim >= 1 and value.shape[0] > micro_batch_size: + return value[mb_start:mb_end] + elif isinstance(value, (list, tuple)) and len(value) > micro_batch_size: + return value[mb_start:mb_end] + elif isinstance(value, dict): + # Recursively slice dict values (e.g., ref_outputs: {"logps": tensor}) + return { + k: MegatronModel._slice_value_for_microbatch(v, mb_start, mb_end, micro_batch_size) + for k, v in value.items() + } + else: + # Scalars, small tensors, or non-sliceable values pass through as-is + return value + def _postprocess_tensor_cp(self, tensor): """All-gather and reconstruct full sequence from CP-split tensor. @@ -401,8 +433,6 @@ def forward_backward(self, else: seq_length = original_seq_length - if 'ref_outputs' in kwargs: - breakpoint() num_microbatches = len(inputs) loss_extra_kwargs_per_mb = [] if num_microbatches <= 1: @@ -411,17 +441,10 @@ def forward_backward(self, for mb_idx in range(num_microbatches): mb_start = mb_idx * micro_batch_size mb_end = mb_start + micro_batch_size - mb_kwargs = {} - for key, value in kwargs.items(): - if isinstance(value, torch.Tensor) and value.dim() >= 1 and value.shape[0] > micro_batch_size: - mb_kwargs[key] = value[mb_start:mb_end] - elif isinstance(value, np.ndarray) and value.ndim >= 1 and value.shape[0] > micro_batch_size: - mb_kwargs[key] = value[mb_start:mb_end] - elif isinstance(value, (list, tuple)) and len(value) > micro_batch_size: - mb_kwargs[key] = value[mb_start:mb_end] - else: - # Scalars, small tensors, or non-sliceable values pass through as-is - mb_kwargs[key] = value + mb_kwargs = { + key: self._slice_value_for_microbatch(value, mb_start, mb_end, micro_batch_size) + for key, value in kwargs.items() + } loss_extra_kwargs_per_mb.append(mb_kwargs) _mb_counter = [0] # mutable counter for closure diff --git a/src/twinkle/utils/__init__.py b/src/twinkle/utils/__init__.py index 1b018773..5d543371 100644 --- a/src/twinkle/utils/__init__.py +++ b/src/twinkle/utils/__init__.py @@ -10,7 +10,7 @@ from .parallel import processing_lock from .platforms import GPU, NPU, Platform, ensure_hccl_socket_env, ensure_npu_backend from .safetensors import LazyTensor, SafetensorLazyLoader, StreamingSafetensorSaver -from .torch_utils import pad_sequence_to_length, selective_log_softmax, stateless_init_process_group, to_device +from .torch_utils import pad_sequence_to_length, selective_log_softmax, stateless_init_process_group, to_device, pad_and_stack_tensors from .transformers_utils import find_all_linears, find_layers, get_modules_to_not_convert from .unsafe import check_unsafe, trust_remote_code from .utils import copy_files_by_pattern, deep_getattr diff --git a/src/twinkle/utils/torch_utils.py b/src/twinkle/utils/torch_utils.py index 4a721a2a..1c4d8462 100644 --- a/src/twinkle/utils/torch_utils.py +++ b/src/twinkle/utils/torch_utils.py @@ -190,3 +190,36 @@ def stateless_init_process_group( communicator = Communicator(pg, device=device) return communicator + + +def pad_and_stack_tensors(tensors: List['torch.Tensor'], pad_value: float = -200) -> 'torch.Tensor': + import torch + if not tensors: + raise ValueError('Empty tensor list') + + if len(tensors) == 1: + return tensors[0] + + max_ndim = max(t.ndim for t in tensors) + expanded_tensors = [] + for t in tensors: + while t.ndim < max_ndim: + t = t.unsqueeze(0) + expanded_tensors.append(t) + + max_shape = [] + for dim in range(max_ndim): + max_shape.append(max(t.shape[dim] for t in expanded_tensors)) + + padded_tensors = [] + for t in expanded_tensors: + if list(t.shape) == max_shape: + padded_tensors.append(t) + else: + pad_params = [] + for dim in range(max_ndim - 1, -1, -1): + pad_params.extend([0, max_shape[dim] - t.shape[dim]]) + padded = torch.nn.functional.pad(t, pad_params, value=pad_value) + padded_tensors.append(padded) + + return torch.cat(padded_tensors, dim=0) \ No newline at end of file From 20fde359a47ba4cfe2e4633f16a44e186f479fcb Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 29 Mar 2026 18:11:03 +0800 Subject: [PATCH 22/31] wip --- cookbook/rl/dpo.sh | 84 ------- cookbook/rl/dpo_lora.py | 25 +- cookbook/transformers/fsdp2.py | 58 ++--- src/twinkle/metric/dpo.py | 2 +- src/twinkle/preprocessor/dpo.py | 428 -------------------------------- 5 files changed, 27 insertions(+), 570 deletions(-) delete mode 100644 cookbook/rl/dpo.sh diff --git a/cookbook/rl/dpo.sh b/cookbook/rl/dpo.sh deleted file mode 100644 index 65206839..00000000 --- a/cookbook/rl/dpo.sh +++ /dev/null @@ -1,84 +0,0 @@ -#!/bin/bash -# DPO Training Script for Ray Mode -# -# This script launches DPO (Direct Preference Optimization) training using Ray -# for distributed training across multiple GPUs. -# -# Usage: -# ./dpo.sh # Default settings (8 GPUs: 4 policy + 4 ref) -# ./dpo.sh simpo # Use SimPO (no reference model needed) -# ./dpo.sh orpo # Use ORPO (no reference model needed) -# -# Environment variables can be set to customize training: -# MODEL_ID - Model to train (default: ms://Qwen/Qwen3.5-4B) -# DATASET_ID - Preference dataset (default: UltraFeedback) -# MODEL_GPUS - GPUs for policy model (default: 4) -# REF_MODEL_GPUS - GPUs for reference model (default: 4) -# USE_REFERENCE_MODEL - Use reference model (default: 1) -# BATCH_SIZE - Global batch size (default: 8) -# MAX_STEPS - Training steps (default: 1000) -# LR - Learning rate (default: 5e-6) -# DPO_BETA - DPO beta parameter (default: 0.1) -# LOSS_TYPE - Loss variant: sigmoid/hinge/ipo/simpo/orpo/cpo (default: sigmoid) - -set -e - -# Parse command line argument for loss type -LOSS_TYPE_ARG=${1:-sigmoid} - -# Set default environment variables if not already set -export MODEL_ID=${MODEL_ID:-"ms://Qwen/Qwen3.5-4B"} -export DATASET_ID=${DATASET_ID:-"ms://argilla/ultrafeedback-binarized-preferences-cleaned"} -export MODEL_GPUS=${MODEL_GPUS:-4} -export BATCH_SIZE=${BATCH_SIZE:-8} -export MICRO_BATCH_SIZE=${MICRO_BATCH_SIZE:-2} -export MAX_STEPS=${MAX_STEPS:-1000} -export LR=${LR:-5e-6} -export DPO_BETA=${DPO_BETA:-0.1} -export SAVE_STEPS=${SAVE_STEPS:-100} -export MAX_LENGTH=${MAX_LENGTH:-2048} - -# Set loss type from argument or environment -export LOSS_TYPE=${LOSS_TYPE:-$LOSS_TYPE_ARG} - -# Reference-free losses don't need reference model -if [[ "$LOSS_TYPE" == "simpo" || "$LOSS_TYPE" == "orpo" || "$LOSS_TYPE" == "cpo" ]]; then - export USE_REFERENCE_MODEL=${USE_REFERENCE_MODEL:-0} - export REF_MODEL_GPUS=${REF_MODEL_GPUS:-0} - echo "Using $LOSS_TYPE loss (reference-free)" -else - export USE_REFERENCE_MODEL=${USE_REFERENCE_MODEL:-1} - export REF_MODEL_GPUS=${REF_MODEL_GPUS:-4} - echo "Using $LOSS_TYPE loss with reference model" -fi - -# Calculate total GPUs -if [[ "$USE_REFERENCE_MODEL" == "1" && "$REF_MODEL_GPUS" -gt 0 ]]; then - TOTAL_GPUS=$((MODEL_GPUS + REF_MODEL_GPUS)) -else - TOTAL_GPUS=$MODEL_GPUS -fi - -echo "==========================================" -echo "DPO Training Configuration" -echo "==========================================" -echo "Model: $MODEL_ID" -echo "Dataset: $DATASET_ID" -echo "Loss Type: $LOSS_TYPE" -echo "DPO Beta: $DPO_BETA" -echo "Policy GPUs: $MODEL_GPUS" -echo "Reference GPUs: $REF_MODEL_GPUS" -echo "Total GPUs: $TOTAL_GPUS" -echo "Batch Size: $BATCH_SIZE" -echo "Micro Batch Size: $MICRO_BATCH_SIZE" -echo "Max Steps: $MAX_STEPS" -echo "Learning Rate: $LR" -echo "Max Length: $MAX_LENGTH" -echo "Save Steps: $SAVE_STEPS" -echo "==========================================" - -# Get script directory -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" - -# Run training -python "$SCRIPT_DIR/dpo.py" diff --git a/cookbook/rl/dpo_lora.py b/cookbook/rl/dpo_lora.py index 914e91ea..513f7093 100644 --- a/cookbook/rl/dpo_lora.py +++ b/cookbook/rl/dpo_lora.py @@ -58,20 +58,20 @@ from twinkle.dataset import Dataset, DatasetMeta from twinkle.loss import DPOLoss from twinkle.metric import DPOMetric -from twinkle.model import MultiLoraMegatronModel +from twinkle.model import MegatronModel from twinkle.preprocessor import EmojiDPOProcessor from twinkle.processor import InputProcessor logger = get_logger() # ── Configuration ───────────────────────────────────────────────────────────── -MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen2.5-7B-Instruct') +MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3-4B') DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji') MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 8)) -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 2)) # Number of preference pairs -MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs +MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 8)) GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 2)) LEARNING_RATE = float(os.environ.get('LR', 1e-4)) # LoRA DPO requires higher LR (1e-4 to 3e-4) DPO_BETA = float(os.environ.get('DPO_BETA', 0.1)) @@ -85,7 +85,7 @@ def create_dpo_dataset(): """Create DPO dataset with positive/negative format.""" - dataset = Dataset(DatasetMeta(DATASET_ID, data_slice=range(30000))) + dataset = Dataset(DatasetMeta(DATASET_ID, data_slice=range(6000))) dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH) dataset.map( EmojiDPOProcessor, @@ -137,7 +137,7 @@ def main(): DeviceGroup(name='policy', ranks=list(range(MODEL_GPUS)), device_type='GPU'), ] - policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=1, pp_size=2, cp_size=2, tp_size=2) + policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=4, pp_size=2) twinkle.initialize(mode='ray', nproc_per_node=8, groups=device_groups) # ── DataLoader Setup ────────────────────────────────────────────────────── @@ -152,20 +152,18 @@ def main(): # ── Policy Model Setup with LoRA ────────────────────────────────────────── lora_config = LoraConfig( target_modules='all-linear', - r=16, + r=8, lora_alpha=32, lora_dropout=0.05, ) - policy_model = MultiLoraMegatronModel( + policy_model = MegatronModel( model_id=MODEL_ID, device_mesh=policy_mesh, remote_group='policy', ) MAX_STEPS = len(dataloader) policy_model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) - # policy_model.set_optimizer('AdamW', lr=LEARNING_RATE, weight_decay=0.01, adapter_name=ADAPTER_NAME) - # policy_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, adapter_name=ADAPTER_NAME) policy_model.set_optimizer('default', lr=LEARNING_RATE, weight_decay=0.01, adapter_name=ADAPTER_NAME) policy_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, adapter_name=ADAPTER_NAME) @@ -205,16 +203,17 @@ def main(): # Gradient clipping and optimizer step policy_model.clip_grad_and_step(adapter_name=ADAPTER_NAME) - optim_step += 1 # Logging - if optim_step % 1 == 0: + if optim_step % 16 == 0: metrics = policy_model.calculate_metric(is_training=True, adapter_name=ADAPTER_NAME) - logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metrics}') + logger.info(f'[Step {optim_step // GRADIENT_ACCUMULATION_STEPS}/{MAX_STEPS}] {metrics}') # Checkpointing if optim_step % SAVE_STEPS == 0: policy_model.save(f'dpo-lora-checkpoint-{optim_step}', adapter_name=ADAPTER_NAME) + + optim_step += 1 # ── Save Final Checkpoint ───────────────────────────────────────────────── logger.info(f'Training completed. Total steps: {optim_step}') diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index 9aa94137..10d75df6 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -7,11 +7,10 @@ from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import TransformersModel -from twinkle.data_format import Message, Trajectory -from twinkle.preprocessor import SelfCognitionProcessor, Preprocessor +from twinkle.preprocessor import SelfCognitionProcessor # Construct a device_mesh, dp=2 -device_mesh = DeviceMesh.from_sizes(dp_size=8) +device_mesh = DeviceMesh.from_sizes(dp_size=2) # use torchrun mode twinkle.initialize(mode='local', global_device_mesh=device_mesh) @@ -21,7 +20,7 @@ def eval(model): # 100 Samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100))) - dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct') + dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) dataset.encode() dataloader = DataLoader(dataset=dataset, batch_size=8) @@ -32,55 +31,19 @@ def eval(model): return metrics -class EmojiDPOProcessor(Preprocessor): - def __init__( - self, - system = 'You are a helpful assistant.', - chosen_key: str = 'answer_zh', - rejected_key: str = 'answer_en', - prompt_key: str = 'prompt', - ): - self.system = system - self.chosen_key = chosen_key - self.rejected_key = rejected_key - self.prompt_key = prompt_key - - def __call__(self, rows): - rows = self.map_col_to_row(rows) - rows = [self.preprocess(row) for row in rows] - rows = self.map_row_to_col(rows) - return rows - - def preprocess(self, row): - """Process a single row.""" - prompt = row.get(self.prompt_key, '') - chosen = row.get(self.chosen_key, '') - rejected = row.get(self.rejected_key, '') - - prompt_messages = [] - if self.system: - prompt_messages.append(Message(role='system', content=self.system)) - prompt_messages.append(Message(role='user', content=prompt)) - - chosen_messages = prompt_messages + [Message(role='assistant', content=chosen)] - rejected_messages = prompt_messages + [Message(role='assistant', content=rejected)] - - return Trajectory(messages=chosen_messages) - - def train(): # 1000 samples - dataset = Dataset(dataset_meta=DatasetMeta('ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji')) + dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) # Set template to prepare encoding - dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct') + dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') # Preprocess the dataset to standard format - dataset.map(EmojiDPOProcessor) + dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) # Encode dataset dataset.encode() # Global batch size = 8, for GPUs, so 1 sample per GPU dataloader = DataLoader(dataset=dataset, batch_size=8) # Use a TransformersModel - model = TransformersModel(model_id='ms://Qwen/Qwen2.5-7B-Instruct') + model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B') model.model._no_split_modules = {'Qwen3_5DecoderLayer'} lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') @@ -109,6 +72,13 @@ def train(): # Print metric metric = model.calculate_metric(is_training=True) logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') + if step > 0 and step % 40 == 0: + metrics = eval(model) + logger.info(f'Eval metric: {metrics}') + metrics['step'] = step + if loss_metric > float(metrics['loss']): + model.save(f'checkpoint-{step}') + loss_metric = float(metrics['loss']) model.save(f'last-checkpoint') diff --git a/src/twinkle/metric/dpo.py b/src/twinkle/metric/dpo.py index bd6bd5d0..4d8513d0 100644 --- a/src/twinkle/metric/dpo.py +++ b/src/twinkle/metric/dpo.py @@ -209,5 +209,5 @@ def calculate(self): results['rewards/rejected'] = f'{total_rejected_rewards / total_count:.4f}' results['rewards/margins'] = f'{total_reward_margin / total_count:.4f}' results['rewards/accuracies'] = f'{total_reward_correct / total_count * 100:.1f}%' - + self.reset() return results diff --git a/src/twinkle/preprocessor/dpo.py b/src/twinkle/preprocessor/dpo.py index 0a03c4ad..c299d74e 100644 --- a/src/twinkle/preprocessor/dpo.py +++ b/src/twinkle/preprocessor/dpo.py @@ -15,390 +15,6 @@ from .base import Preprocessor -class DPOProcessor(Preprocessor): - """Generic DPO preference data preprocessor. - - Converts preference data with chosen/rejected pairs into positive/negative Trajectories. - Supports multiple common dataset formats. - - Expected input format (one of): - 1. {'prompt': str, 'chosen': str, 'rejected': str} - 2. {'prompt': str, 'chosen': List[Message], 'rejected': List[Message]} - 3. {'messages': List[Message], 'chosen': str, 'rejected': str} - 4. {'chosen': List[Message], 'rejected': List[Message]} (full conversations) - - Output format: - - positive: Trajectory with chosen response - - negative: Trajectory with rejected response - - Args: - system: Optional system prompt to prepend. - chosen_key: Key for chosen response (default: 'chosen'). - rejected_key: Key for rejected response (default: 'rejected'). - prompt_key: Key for prompt/question (default: 'prompt'). - messages_key: Key for conversation messages (default: 'messages'). - """ - - def __init__( - self, - system: Optional[str] = None, - chosen_key: str = 'chosen', - rejected_key: str = 'rejected', - prompt_key: str = 'prompt', - messages_key: str = 'messages', - ): - self.system = system - self.chosen_key = chosen_key - self.rejected_key = rejected_key - self.prompt_key = prompt_key - self.messages_key = messages_key - - def _parse_response(self, response: Union[str, List[Dict], List[Message]]) -> List[Message]: - """Parse response into list of Messages.""" - if isinstance(response, str): - return [Message(role='assistant', content=response)] - elif isinstance(response, list): - messages = [] - for msg in response: - if isinstance(msg, Message): - messages.append(msg) - elif isinstance(msg, dict): - messages.append(Message(role=msg.get('role', 'assistant'), content=msg.get('content', ''))) - return messages - return [Message(role='assistant', content=str(response))] - - def _build_prompt_messages(self, row: Dict[str, Any]) -> List[Message]: - """Build prompt messages from row data.""" - messages = [] - - # Add system message if provided - if self.system: - messages.append(Message(role='system', content=self.system)) - - # Check for messages field (conversation format) - if self.messages_key in row and row[self.messages_key]: - raw_messages = row[self.messages_key] - for msg in raw_messages: - if isinstance(msg, Message): - messages.append(msg) - elif isinstance(msg, dict): - messages.append(Message(role=msg.get('role'), content=msg.get('content', ''))) - return messages - - # Check for prompt field - if self.prompt_key in row and row[self.prompt_key]: - prompt = row[self.prompt_key] - if isinstance(prompt, str): - messages.append(Message(role='user', content=prompt)) - elif isinstance(prompt, list): - for msg in prompt: - if isinstance(msg, Message): - messages.append(msg) - elif isinstance(msg, dict): - messages.append(Message(role=msg.get('role'), content=msg.get('content', ''))) - - return messages - - def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: - """Process a single row into positive/negative Trajectories. - - Returns: - Dict with 'positive' and 'negative' Trajectory. - """ - # Build prompt messages - prompt_messages = self._build_prompt_messages(row) - - # Get chosen response - chosen_raw = row.get(self.chosen_key, '') - chosen_response = self._parse_response(chosen_raw) - - # Get rejected response - rejected_raw = row.get(self.rejected_key, '') - rejected_response = self._parse_response(rejected_raw) - - # Build full message lists - chosen_messages = prompt_messages + chosen_response - rejected_messages = prompt_messages + rejected_response - - return { - 'positive': Trajectory(messages=chosen_messages), - 'negative': Trajectory(messages=rejected_messages), - } - - def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - """Process batched data into DPO format.""" - rows = self.map_col_to_row(rows) - results = [self.preprocess(row) for row in rows] - # Collect all positive and negative trajectories - positive_list = [r['positive'] for r in results] - negative_list = [r['negative'] for r in results] - return { - 'positive': positive_list, - 'negative': negative_list, - } - - -class HHRLHFProcessor(Preprocessor): - """Preprocessor for Anthropic HH-RLHF dataset format. - - HH-RLHF format: - {'chosen': "Human: ... Assistant: ...", 'rejected': "Human: ... Assistant: ..."} - - The conversations use "Human:" and "Assistant:" prefixes. - """ - - def __init__(self, system: Optional[str] = None): - self.system = system - - def _parse_hh_conversation(self, text: str) -> List[Message]: - """Parse HH-RLHF style conversation text into Messages.""" - messages = [] - - if self.system: - messages.append(Message(role='system', content=self.system)) - - # Split by Human/Assistant markers - parts = text.split('\n\nHuman: ') - for i, part in enumerate(parts): - if i == 0 and not part.startswith('Human: '): - if part.strip(): - if part.startswith('Human: '): - part = part[7:] - messages.append(Message(role='user', content=part.strip())) - continue - - # Split Human and Assistant parts - if '\n\nAssistant: ' in part: - human_part, assistant_part = part.split('\n\nAssistant: ', 1) - messages.append(Message(role='user', content=human_part.strip())) - messages.append(Message(role='assistant', content=assistant_part.strip())) - else: - messages.append(Message(role='user', content=part.strip())) - - return messages - - def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: - """Process HH-RLHF format row.""" - chosen_text = row.get('chosen', '') - rejected_text = row.get('rejected', '') - - chosen_messages = self._parse_hh_conversation(chosen_text) - rejected_messages = self._parse_hh_conversation(rejected_text) - - return { - 'positive': Trajectory(messages=chosen_messages), - 'negative': Trajectory(messages=rejected_messages), - } - - def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - rows = self.map_col_to_row(rows) - results = [self.preprocess(row) for row in rows] - return { - 'positive': [r['positive'] for r in results], - 'negative': [r['negative'] for r in results], - } - - -class UltraFeedbackProcessor(Preprocessor): - """Preprocessor for UltraFeedback dataset format. - - UltraFeedback format: - { - 'instruction': str, - 'completions': [ - {'response': str, 'overall_score': float, ...}, - ... - ] - } - - Selects highest and lowest scored completions as chosen/rejected. - """ - - def __init__( - self, - system: Optional[str] = None, - instruction_key: str = 'instruction', - completions_key: str = 'completions', - response_key: str = 'response', - score_key: str = 'overall_score', - ): - self.system = system - self.instruction_key = instruction_key - self.completions_key = completions_key - self.response_key = response_key - self.score_key = score_key - - def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Trajectory]]: - """Process UltraFeedback format row.""" - instruction = row.get(self.instruction_key, '') - completions = row.get(self.completions_key, []) - - if len(completions) < 2: - return None - - # Sort by score - scored_completions = [ - (c.get(self.score_key, 0), c.get(self.response_key, '')) - for c in completions - if c.get(self.response_key) - ] - - if len(scored_completions) < 2: - return None - - scored_completions.sort(key=lambda x: x[0], reverse=True) - chosen_response = scored_completions[0][1] - rejected_response = scored_completions[-1][1] - - # Build messages - prompt_messages = [] - if self.system: - prompt_messages.append(Message(role='system', content=self.system)) - prompt_messages.append(Message(role='user', content=instruction)) - - chosen_messages = prompt_messages + [Message(role='assistant', content=chosen_response)] - rejected_messages = prompt_messages + [Message(role='assistant', content=rejected_response)] - - return { - 'positive': Trajectory(messages=chosen_messages), - 'negative': Trajectory(messages=rejected_messages), - } - - def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - rows = self.map_col_to_row(rows) - results = [self.preprocess(row) for row in rows] - results = [r for r in results if r is not None] - if not results: - return {} - return { - 'positive': [r['positive'] for r in results], - 'negative': [r['negative'] for r in results], - } - - -class ShareGPTDPOProcessor(Preprocessor): - """Preprocessor for ShareGPT-style DPO datasets. - - Expected format: - { - 'conversations': [ - {'from': 'human', 'value': '...'}, - {'from': 'gpt', 'value': '...'}, - ... - ], - 'chosen': {'from': 'gpt', 'value': '...'}, - 'rejected': {'from': 'gpt', 'value': '...'} - } - """ - - ROLE_MAPPING = { - 'human': 'user', - 'gpt': 'assistant', - 'system': 'system', - 'user': 'user', - 'assistant': 'assistant', - } - - def __init__(self, system: Optional[str] = None): - self.system = system - - def _parse_sharegpt_message(self, msg: Dict) -> Message: - """Parse ShareGPT format message.""" - role = self.ROLE_MAPPING.get(msg.get('from', ''), 'user') - content = msg.get('value', '') or msg.get('content', '') - return Message(role=role, content=content) - - def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: - """Process ShareGPT DPO format row.""" - conversations = row.get('conversations', []) - - # Build prompt messages - prompt_messages = [] - if self.system: - prompt_messages.append(Message(role='system', content=self.system)) - - for msg in conversations: - prompt_messages.append(self._parse_sharegpt_message(msg)) - - # Remove last message if it's assistant (will be replaced) - if prompt_messages and prompt_messages[-1]['role'] == 'assistant': - prompt_messages = prompt_messages[:-1] - - # Get chosen and rejected - chosen_msg = row.get('chosen', {}) - rejected_msg = row.get('rejected', {}) - - if isinstance(chosen_msg, dict): - chosen_content = chosen_msg.get('value', '') or chosen_msg.get('content', '') - else: - chosen_content = str(chosen_msg) - - if isinstance(rejected_msg, dict): - rejected_content = rejected_msg.get('value', '') or rejected_msg.get('content', '') - else: - rejected_content = str(rejected_msg) - - chosen_messages = prompt_messages + [Message(role='assistant', content=chosen_content)] - rejected_messages = prompt_messages + [Message(role='assistant', content=rejected_content)] - - return { - 'positive': Trajectory(messages=chosen_messages), - 'negative': Trajectory(messages=rejected_messages), - } - - def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - rows = self.map_col_to_row(rows) - results = [self.preprocess(row) for row in rows] - return { - 'positive': [r['positive'] for r in results], - 'negative': [r['negative'] for r in results], - } - - -class IntelOrcaDPOProcessor(Preprocessor): - """Preprocessor for Intel ORCA DPO dataset format. - - Expected format: - { - 'system': str, - 'question': str, - 'chosen': str, - 'rejected': str - } - """ - - def __init__(self, default_system: Optional[str] = None): - self.default_system = default_system - - def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: - """Process Intel ORCA DPO format row.""" - system = row.get('system', self.default_system) - question = row.get('question', '') - chosen = row.get('chosen', '') - rejected = row.get('rejected', '') - - prompt_messages = [] - if system: - prompt_messages.append(Message(role='system', content=system)) - prompt_messages.append(Message(role='user', content=question)) - - chosen_messages = prompt_messages + [Message(role='assistant', content=chosen)] - rejected_messages = prompt_messages + [Message(role='assistant', content=rejected)] - - return { - 'positive': Trajectory(messages=chosen_messages), - 'negative': Trajectory(messages=rejected_messages), - } - - def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - rows = self.map_col_to_row(rows) - results = [self.preprocess(row) for row in rows] - return { - 'positive': [r['positive'] for r in results], - 'negative': [r['negative'] for r in results], - } - - class EmojiDPOProcessor(Preprocessor): """Preprocessor for shareAI/DPO-zh-en-emoji dataset format. @@ -458,47 +74,3 @@ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: 'positive': [r['positive'] for r in results], 'negative': [r['negative'] for r in results], } - - -class UltraFeedbackKTOProcessor(Preprocessor): - """Preprocessor for ultrafeedback-binarized-preferences-cleaned-kto dataset. - - Dataset format: - { - 'prompt': str, - 'completion': str, - 'label': bool, # True for chosen, False for rejected - } - - For KTO training, we need (prompt, completion, label) format. - The label is stored in user_data. - - Args: - system: Optional system prompt. - """ - - def __init__(self, system: Optional[str] = None): - self.system = system - - def preprocess(self, row: Dict[str, Any]) -> Trajectory: - """Process a single row for KTO.""" - prompt = row.get('prompt', '') - completion = row.get('completion', '') - label = row.get('label', True) - - messages = [] - if self.system: - messages.append(Message(role='system', content=self.system)) - messages.append(Message(role='user', content=prompt)) - messages.append(Message(role='assistant', content=completion)) - - return Trajectory( - messages=messages, - user_data=[('kto_label', label)] - ) - - def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - rows = self.map_col_to_row(rows) - rows = [self.preprocess(row) for row in rows] - rows = self.map_row_to_col(rows) - return rows From 91cad8062e688c95b149ac2a886873096c2d30b8 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 29 Mar 2026 18:26:27 +0800 Subject: [PATCH 23/31] wip --- cookbook/rl/{dpo.py => dpo_full.py} | 84 +++++++++++++------------ cookbook/rl/dpo_lora.py | 86 +++++++++++++++----------- src/twinkle/model/megatron/megatron.py | 3 +- 3 files changed, 96 insertions(+), 77 deletions(-) rename cookbook/rl/{dpo.py => dpo_full.py} (81%) diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo_full.py similarity index 81% rename from cookbook/rl/dpo.py rename to cookbook/rl/dpo_full.py index 6fa90e91..c0ed0f13 100644 --- a/cookbook/rl/dpo.py +++ b/cookbook/rl/dpo_full.py @@ -1,13 +1,15 @@ -"""DPO (Direct Preference Optimization) Training via Ray. +"""DPO (Direct Preference Optimization) Full-Parameter Training via Ray. Off-policy preference alignment: trains the model to prefer chosen responses over rejected responses using preference data, without explicit reward modeling. +Supports both Transformers (FSDP) and Megatron backends via USE_MEGATRON flag. + Pipeline: 1. Load preference dataset with chosen/rejected pairs. 2. Encode positive and negative separately. 3. Compute reference model log probabilities (frozen). - 4. Train policy model using DPO loss. + 4. Train policy model using DPO loss (full-parameter, no LoRA). Architecture (Ray): ┌─────────────────────────────────────────────────────────────────┐ @@ -28,14 +30,15 @@ set REF_MODEL_GPUS=0 to skip reference model computation. Environment variables (all optional): - MODEL_ID – (default: ms://Qwen/Qwen3.5-4B) + USE_MEGATRON – Use Megatron backend (default: 0, use Transformers) + MODEL_ID – (default: ms://Qwen/Qwen3-4B) DATASET_ID – (default: ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji) MODEL_GPUS – GPUs for policy model (default: 4) REF_MODEL_GPUS – GPUs for reference model (default: 4, 0 to disable) BATCH_SIZE – global batch size (preference pairs) (default: 8) MICRO_BATCH_SIZE – per-device micro batch size (default: 2) MAX_STEPS – total optimization steps (default: 1000) - LR – learning rate (default: 5e-6) + LR – learning rate (default: 1e-5) DPO_BETA – DPO temperature parameter (default: 0.1) LOSS_TYPE – DPO variant (sigmoid/hinge/ipo/simpo/orpo/cpo) (default: sigmoid) SAVE_STEPS – checkpoint save interval (default: 100) @@ -49,9 +52,7 @@ """ import os -from typing import Any, Dict, List, Optional - -from peft import LoraConfig +from typing import Any, Dict, List import twinkle from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger @@ -60,30 +61,29 @@ from twinkle.dataset import Dataset, DatasetMeta from twinkle.loss import CPOLoss, DPOLoss, ORPOLoss, SimPOLoss from twinkle.metric import DPOMetric -from twinkle.model import TransformersModel from twinkle.preprocessor import EmojiDPOProcessor from twinkle.processor import InputProcessor logger = get_logger() # ── Configuration ───────────────────────────────────────────────────────────── -MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen2.5-7B-Instruct') +USE_MEGATRON = int(os.environ.get('USE_MEGATRON', 0)) +MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3-4B') DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji') -MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 2)) -REF_MODEL_GPUS = int(os.environ.get('REF_MODEL_GPUS', 2)) +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) +REF_MODEL_GPUS = int(os.environ.get('REF_MODEL_GPUS', 4)) NUM_GPUS = MODEL_GPUS + REF_MODEL_GPUS -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 2)) # Number of preference pairs +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 2)) -LEARNING_RATE = float(os.environ.get('LR', 1e-4)) # TRL default for DPO is 5e-7 to 5e-6 +LEARNING_RATE = float(os.environ.get('LR', 1e-5)) DPO_BETA = float(os.environ.get('DPO_BETA', 0.1)) SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 1.0)) # SFT loss weight for regularization LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo, simpo, orpo, cpo SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 100)) MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048)) -ADAPTER_NAME = 'default' SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', 'You are a helpful assistant.') @@ -162,7 +162,20 @@ def main(): DeviceGroup(name='reference', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU'), ] - policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) + # Configure device mesh based on backend + if USE_MEGATRON: + # Megatron: dp=2, pp=2 for each model + from twinkle.model import MegatronModel + policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=2, pp_size=2) + ref_mesh = DeviceMesh.from_sizes(world_size=REF_MODEL_GPUS, dp_size=2, pp_size=2) + ModelClass = MegatronModel + else: + # Transformers: fsdp=2, dp=2 for each model + from twinkle.model import TransformersModel + policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, fsdp_size=2, dp_size=2) + ref_mesh = DeviceMesh.from_sizes(world_size=REF_MODEL_GPUS, fsdp_size=2, dp_size=2) + ModelClass = TransformersModel + twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups) # ── DataLoader Setup ────────────────────────────────────────────────────── @@ -172,31 +185,29 @@ def main(): min_batch_size=BATCH_SIZE, device_mesh=policy_mesh, ) - length = len(dataloader) # ── Policy Model Setup ──────────────────────────────────────────────────── - lora_config = LoraConfig( - target_modules='all-linear', - r=16, - lora_alpha=32, - lora_dropout=0.05, - ) - - policy_model = TransformersModel( + policy_model = ModelClass( model_id=MODEL_ID, device_mesh=policy_mesh, remote_group='policy', ) MAX_STEPS = len(dataloader) - policy_model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) - policy_model.set_optimizer('AdamW', lr=LEARNING_RATE, weight_decay=0.01) - policy_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=LEARNING_RATE * 0.1) # Determine if we need reference model based on loss type reference_free = LOSS_TYPE in ['simpo', 'orpo', 'cpo'] # Set up loss function and metrics loss_fn = create_loss(LOSS_TYPE, DPO_BETA, sft_weight=SFT_WEIGHT, reference_free=False) + + # Configure optimizer based on backend (full-parameter training) + if USE_MEGATRON: + policy_model.set_optimizer('default', lr=LEARNING_RATE, weight_decay=0.01) + policy_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS) + else: + policy_model.set_optimizer('AdamW', lr=LEARNING_RATE, weight_decay=0.01) + policy_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=LEARNING_RATE * 0.1) + policy_model.set_loss(loss_fn) policy_model.add_metric(DPOMetric, beta=DPO_BETA) policy_model.set_processor(InputProcessor) @@ -205,8 +216,7 @@ def main(): # ── Reference Model Setup ───────────────────────────────────────────────── ref_model = None if not reference_free: - ref_mesh = DeviceMesh.from_sizes(world_size=REF_MODEL_GPUS, dp_size=REF_MODEL_GPUS) - ref_model = TransformersModel( + ref_model = ModelClass( model_id=MODEL_ID, device_mesh=ref_mesh, remote_group='reference', @@ -218,8 +228,9 @@ def main(): logger.info(f'Training without reference model (loss_type={LOSS_TYPE})') optim_step = 0 + backend_name = 'Megatron' if USE_MEGATRON else 'Transformers' logger.info(get_device_placement()) - logger.info(f'Starting DPO training: loss_type={LOSS_TYPE}, beta={DPO_BETA}') + logger.info(f'Starting DPO training ({backend_name}): loss_type={LOSS_TYPE}, beta={DPO_BETA}') # ── Training Loop ───────────────────────────────────────────────────────── for batch in dataloader: @@ -232,20 +243,15 @@ def main(): ref_outputs = ref_model.forward_only(inputs=dpo_batch) # Forward-backward pass with DPO loss - # ref_outputs is passed to loss which extracts logps internally - policy_model.forward_backward( - inputs=dpo_batch, - ref_outputs=ref_outputs, - ) - - # Gradient clipping and optimizer step + policy_model.forward_backward(inputs=dpo_batch, ref_outputs=ref_outputs) policy_model.clip_grad_and_step() + optim_step += 1 # Logging - if optim_step % 1 == 0: + if optim_step % GRADIENT_ACCUMULATION_STEPS == 0: metrics = policy_model.calculate_metric(is_training=True) - logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metrics}') + logger.info(f'[Step {optim_step // GRADIENT_ACCUMULATION_STEPS}/{MAX_STEPS}] {metrics}') # Checkpointing if optim_step % SAVE_STEPS == 0: diff --git a/cookbook/rl/dpo_lora.py b/cookbook/rl/dpo_lora.py index 513f7093..e10559a4 100644 --- a/cookbook/rl/dpo_lora.py +++ b/cookbook/rl/dpo_lora.py @@ -1,25 +1,27 @@ """DPO (Direct Preference Optimization) Training with LoRA (Single GPU Group). LoRA-based DPO training: uses the base model (without LoRA adapter) as reference -model by calling forward_only with adapter_name=''. This eliminates the need for +model by calling forward_only with disable_lora=True. This eliminates the need for a separate reference model GPU group. +Supports both Transformers (FSDP) and Megatron backends via USE_MEGATRON flag. + Pipeline: 1. Load preference dataset with chosen/rejected pairs. 2. Encode positive and negative separately. - 3. Compute reference model log probabilities using base model (adapter_name=''). + 3. Compute reference model log probabilities using base model (disable_lora=True). 4. Train policy model (with LoRA adapter) using DPO loss. Architecture (Ray - Single Group): ┌─────────────────────────────────────────────────────────────────┐ │ Driver (CPU) │ │ dataloader ──► batched preference pairs │ - │ policy_model.forward_only(adapter_name='') ──► reference logps│ + │ policy_model.forward_only(disable_lora=True) ──► ref logps │ │ policy_model.forward_backward() ──► DPO loss + gradient │ └─────────────────────────────────────────────────────────────────┘ │ PolicyModel (with LoRA adapter) - - forward_only(adapter_name='') → base model inference (reference) + - forward_only(disable_lora=True) → base model inference (reference) - forward_backward() → LoRA adapter training (policy) DPO data format (after preprocessing): @@ -27,9 +29,10 @@ - negative: List[Trajectory] - rejected responses Environment variables (all optional): - MODEL_ID – (default: ms://Qwen/Qwen2.5-7B-Instruct) + USE_MEGATRON – Use Megatron backend (default: 0, use Transformers) + MODEL_ID – (default: ms://Qwen/Qwen3-4B) DATASET_ID – (default: ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji) - MODEL_GPUS – GPUs for policy model (default: 4) + MODEL_GPUS – GPUs for policy model (default: 8) BATCH_SIZE – global batch size (preference pairs) (default: 8) MICRO_BATCH_SIZE – per-device micro batch size (default: 2) MAX_STEPS – total optimization steps (default: 1000) @@ -58,20 +61,20 @@ from twinkle.dataset import Dataset, DatasetMeta from twinkle.loss import DPOLoss from twinkle.metric import DPOMetric -from twinkle.model import MegatronModel from twinkle.preprocessor import EmojiDPOProcessor from twinkle.processor import InputProcessor logger = get_logger() # ── Configuration ───────────────────────────────────────────────────────────── +USE_MEGATRON = int(os.environ.get('USE_MEGATRON', 0)) MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3-4B') DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji') MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 8)) BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs -MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 8)) +MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 2)) LEARNING_RATE = float(os.environ.get('LR', 1e-4)) # LoRA DPO requires higher LR (1e-4 to 3e-4) DPO_BETA = float(os.environ.get('DPO_BETA', 0.1)) @@ -137,8 +140,19 @@ def main(): DeviceGroup(name='policy', ranks=list(range(MODEL_GPUS)), device_type='GPU'), ] - policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=4, pp_size=2) - twinkle.initialize(mode='ray', nproc_per_node=8, groups=device_groups) + # Configure device mesh based on backend + if USE_MEGATRON: + # Megatron: dp=4, pp=2 + from twinkle.model import MegatronModel + policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=4, pp_size=2) + ModelClass = MegatronModel + else: + # Transformers: fsdp=4, dp=2 + from twinkle.model import TransformersModel + policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, fsdp_size=4, dp_size=2) + ModelClass = TransformersModel + + twinkle.initialize(mode='ray', nproc_per_node=MODEL_GPUS, groups=device_groups) # ── DataLoader Setup ────────────────────────────────────────────────────── dataloader = DataLoader( @@ -147,7 +161,6 @@ def main(): min_batch_size=BATCH_SIZE, device_mesh=policy_mesh, ) - length = len(dataloader) # ── Policy Model Setup with LoRA ────────────────────────────────────────── lora_config = LoraConfig( @@ -157,15 +170,21 @@ def main(): lora_dropout=0.05, ) - policy_model = MegatronModel( + policy_model = ModelClass( model_id=MODEL_ID, device_mesh=policy_mesh, remote_group='policy', ) MAX_STEPS = len(dataloader) policy_model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) - policy_model.set_optimizer('default', lr=LEARNING_RATE, weight_decay=0.01, adapter_name=ADAPTER_NAME) - policy_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, adapter_name=ADAPTER_NAME) + + # Configure optimizer based on backend + if USE_MEGATRON: + policy_model.set_optimizer('default', lr=LEARNING_RATE, weight_decay=0.01, adapter_name=ADAPTER_NAME) + policy_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, adapter_name=ADAPTER_NAME) + else: + policy_model.set_optimizer('AdamW', lr=LEARNING_RATE, weight_decay=0.01) + policy_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=LEARNING_RATE * 0.1) # Set up loss function and metrics loss_fn = DPOLoss( @@ -174,14 +193,16 @@ def main(): reference_free=False, # We use base model as reference via disable_lora=True sft_weight=SFT_WEIGHT, ) - policy_model.set_loss(loss_fn, adapter_name=ADAPTER_NAME) - policy_model.add_metric(DPOMetric, beta=DPO_BETA, adapter_name=ADAPTER_NAME) - policy_model.set_processor(InputProcessor, adapter_name=ADAPTER_NAME) - policy_model.set_template('Template', model_id=MODEL_ID, adapter_name=ADAPTER_NAME) + + policy_model.set_loss(loss_fn) + policy_model.add_metric(DPOMetric, beta=DPO_BETA) + policy_model.set_processor(InputProcessor) + policy_model.set_template('Template', model_id=MODEL_ID) optim_step = 0 + backend_name = 'Megatron' if USE_MEGATRON else 'Transformers' logger.info(get_device_placement()) - logger.info(f'Starting LoRA DPO training: loss_type={LOSS_TYPE}, beta={DPO_BETA}, lr={LEARNING_RATE}') + logger.info(f'Starting LoRA DPO training ({backend_name}): loss_type={LOSS_TYPE}, beta={DPO_BETA}, lr={LEARNING_RATE}') logger.info(f'Using base model (disable_lora=True) as reference model') # ── Training Loop ───────────────────────────────────────────────────────── @@ -191,33 +212,24 @@ def main(): # Get reference outputs using base model (without LoRA adapter) # disable_lora=True tells the model to skip LoRA and use base weights - ref_outputs = policy_model.forward_only(inputs=dpo_batch, micro_batch_size=2, disable_lora=True, adapter_name=ADAPTER_NAME) - # Forward-backward pass with DPO loss (using LoRA adapter) - # ref_outputs is passed to loss which extracts logps internally - policy_model.forward_backward( - inputs=dpo_batch, - ref_outputs=ref_outputs, - micro_batch_size=2, - adapter_name=ADAPTER_NAME - ) - - # Gradient clipping and optimizer step - policy_model.clip_grad_and_step(adapter_name=ADAPTER_NAME) + ref_outputs = policy_model.forward_only(inputs=dpo_batch, disable_lora=True) + policy_model.forward_backward(inputs=dpo_batch, ref_outputs=ref_outputs) + policy_model.clip_grad_and_step() + + optim_step += 1 # Logging - if optim_step % 16 == 0: - metrics = policy_model.calculate_metric(is_training=True, adapter_name=ADAPTER_NAME) + if optim_step % GRADIENT_ACCUMULATION_STEPS == 0: + metrics = policy_model.calculate_metric(is_training=True) logger.info(f'[Step {optim_step // GRADIENT_ACCUMULATION_STEPS}/{MAX_STEPS}] {metrics}') # Checkpointing if optim_step % SAVE_STEPS == 0: - policy_model.save(f'dpo-lora-checkpoint-{optim_step}', adapter_name=ADAPTER_NAME) - - optim_step += 1 + policy_model.save(f'dpo-lora-checkpoint-{optim_step}') # ── Save Final Checkpoint ───────────────────────────────────────────────── logger.info(f'Training completed. Total steps: {optim_step}') - policy_model.save('dpo-lora-final-checkpoint', adapter_name=ADAPTER_NAME) + policy_model.save('dpo-lora-final-checkpoint') if __name__ == '__main__': diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index cd6c9f28..73341028 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -410,7 +410,8 @@ def forward_backward(self, assert isinstance(processor, InputProcessor), 'Set InputProcessor correctly before forwarding' if micro_batch_size is None: - micro_batch_size = 1 + # Compatible with DPO + micro_batch_size = 2 inputs = processor(inputs, micro_batch_size=micro_batch_size, variable_seq_lengths=self.variable_seq_lengths) # Get parallelism settings for sequence padding and splitting From c75e43f20da273c0ce5442dd5b0bf47ba0deba7c Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 29 Mar 2026 18:33:53 +0800 Subject: [PATCH 24/31] wip --- README.md | 1 + README_ZH.md | 1 + cookbook/rl/dpo_lora.py | 2 +- cookbook/transformers/fsdp2.py | 4 ++-- src/twinkle/preprocessor/__init__.py | 3 +-- 5 files changed, 6 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 046425d7..460ba92e 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,7 @@ Or use ModelScope's [official image](https://www.modelscope.cn/docs/intro/enviro ## Changelog +- 🎉2026-03-28 Support DPO training with both Transformers and Megatron backends. See [dpo_full.py](cookbook/rl/dpo_full.py) and [dpo_lora.py](cookbook/rl/dpo_lora.py). - 🎉2026-03-24 Twinkle Web site is now live at https://modelscope.github.io/twinkle-web/ - 🎉2026-03-19 Support GKD training ,please refer to this [cookbook](cookbook/rl/gkd_on_policy.py). - 🎉2026-02-13 Initial version of Twinkle✨ released, including SFT/PT/RL support for text models. diff --git a/README_ZH.md b/README_ZH.md index 5b4d1191..3f817272 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -91,6 +91,7 @@ Twinkle✨支持相同的算法接口运行在单GPU、torchrun多机、Ray、Cl ## 更新日志 +🎉2026-03-28 支持 DPO 训练,同时支持 Transformers 和 Megatron 后端。参考 [dpo_full.py](cookbook/rl/dpo_full.py) 和 [dpo_lora.py](cookbook/rl/dpo_lora.py)。 🎉2026-03-24 Twinkle 站点上线,访问地址 https://modelscope.github.io/twinkle-web/ 🎉2026-03-19 支持GKD蒸馏能力,参考[cookbook](cookbook/rl/gkd_on_policy.py)。 🎉2026-02-13 Twinkle✨ 初始版本发布,支持文本模型的SFT/PT/RL训练。我们还通过兼容Tinker的API,在魔搭社区上提供了无服务器训练功能。 diff --git a/cookbook/rl/dpo_lora.py b/cookbook/rl/dpo_lora.py index e10559a4..31ee8b9d 100644 --- a/cookbook/rl/dpo_lora.py +++ b/cookbook/rl/dpo_lora.py @@ -149,7 +149,7 @@ def main(): else: # Transformers: fsdp=4, dp=2 from twinkle.model import TransformersModel - policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, fsdp_size=4, dp_size=2) + policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=4, fsdp_size=2) ModelClass = TransformersModel twinkle.initialize(mode='ray', nproc_per_node=MODEL_GPUS, groups=device_groups) diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index 10d75df6..2ad55139 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -9,8 +9,8 @@ from twinkle.model import TransformersModel from twinkle.preprocessor import SelfCognitionProcessor -# Construct a device_mesh, dp=2 -device_mesh = DeviceMesh.from_sizes(dp_size=2) +# Construct a device_mesh, fsdp_size=2, dp=4 +device_mesh = DeviceMesh.from_sizes(fsdp_size=2, dp_size=4) # use torchrun mode twinkle.initialize(mode='local', global_device_mesh=device_mesh) diff --git a/src/twinkle/preprocessor/__init__.py b/src/twinkle/preprocessor/__init__.py index 6d9f6dd7..3294e623 100644 --- a/src/twinkle/preprocessor/__init__.py +++ b/src/twinkle/preprocessor/__init__.py @@ -1,6 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .base import DataFilter, Preprocessor -from .dpo import (DPOProcessor, EmojiDPOProcessor, HHRLHFProcessor, IntelOrcaDPOProcessor, ShareGPTDPOProcessor, - UltraFeedbackKTOProcessor, UltraFeedbackProcessor) +from .dpo import EmojiDPOProcessor from .llm import (AlpacaProcessor, CompetitionMathGRPOProcessor, CompetitionMathProcessor, CountdownProcessor, GSM8KProcessor, SelfCognitionProcessor) From c2cb1dd5212333913f737b5f6dee1ba1e54b00a4 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 29 Mar 2026 19:02:48 +0800 Subject: [PATCH 25/31] wip --- cookbook/rl/dpo_full.py | 8 ++++---- cookbook/transformers/fsdp2.py | 6 +++--- src/twinkle/model/megatron/megatron.py | 2 ++ 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/cookbook/rl/dpo_full.py b/cookbook/rl/dpo_full.py index c0ed0f13..34cf1392 100644 --- a/cookbook/rl/dpo_full.py +++ b/cookbook/rl/dpo_full.py @@ -67,7 +67,7 @@ logger = get_logger() # ── Configuration ───────────────────────────────────────────────────────────── -USE_MEGATRON = int(os.environ.get('USE_MEGATRON', 0)) +USE_MEGATRON = int(os.environ.get('USE_MEGATRON', 1)) MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3-4B') DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji') @@ -89,7 +89,7 @@ def create_dpo_dataset(): """Create DPO dataset with positive/negative format.""" - dataset = Dataset(DatasetMeta(DATASET_ID, data_slice=range(30000))) + dataset = Dataset(DatasetMeta(DATASET_ID, data_slice=range(6000))) dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH) dataset.map( EmojiDPOProcessor, @@ -167,13 +167,13 @@ def main(): # Megatron: dp=2, pp=2 for each model from twinkle.model import MegatronModel policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=2, pp_size=2) - ref_mesh = DeviceMesh.from_sizes(world_size=REF_MODEL_GPUS, dp_size=2, pp_size=2) + ref_mesh = DeviceMesh.from_sizes(world_size=REF_MODEL_GPUS, dp_size=4) ModelClass = MegatronModel else: # Transformers: fsdp=2, dp=2 for each model from twinkle.model import TransformersModel policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, fsdp_size=2, dp_size=2) - ref_mesh = DeviceMesh.from_sizes(world_size=REF_MODEL_GPUS, fsdp_size=2, dp_size=2) + ref_mesh = DeviceMesh.from_sizes(world_size=REF_MODEL_GPUS, dp_size=4) ModelClass = TransformersModel twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups) diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index 2ad55139..025c1743 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -20,7 +20,7 @@ def eval(model): # 100 Samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100))) - dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') + dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) dataset.encode() dataloader = DataLoader(dataset=dataset, batch_size=8) @@ -35,7 +35,7 @@ def train(): # 1000 samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) # Set template to prepare encoding - dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') + dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') # Preprocess the dataset to standard format dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) # Encode dataset @@ -68,7 +68,7 @@ def train(): model.forward_backward(inputs=batch) # Step model.clip_grad_and_step() - if step % 20 == 0: + if step % 1 == 0: # Print metric metric = model.calculate_metric(is_training=True) logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 73341028..8d8e8627 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -64,6 +64,8 @@ def __post_init__(self): def _get_lr(self): _lrs = [] + if self.optimizer is None: + return _lrs _default_lr = self.optimizer.chained_optimizers[0].config.lr for param_group in self.optimizer.param_groups: _lrs.append(param_group.get('lr', _default_lr)) From f2e26dd22532017484f3490d902fadb41437d855 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 29 Mar 2026 19:43:28 +0800 Subject: [PATCH 26/31] fix --- cookbook/rl/dpo_full.py | 2 +- cookbook/rl/dpo_lora.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/cookbook/rl/dpo_full.py b/cookbook/rl/dpo_full.py index 34cf1392..cf55079c 100644 --- a/cookbook/rl/dpo_full.py +++ b/cookbook/rl/dpo_full.py @@ -67,7 +67,7 @@ logger = get_logger() # ── Configuration ───────────────────────────────────────────────────────────── -USE_MEGATRON = int(os.environ.get('USE_MEGATRON', 1)) +USE_MEGATRON = int(os.environ.get('USE_MEGATRON', 0)) MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3-4B') DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji') diff --git a/cookbook/rl/dpo_lora.py b/cookbook/rl/dpo_lora.py index 31ee8b9d..eea3b4fb 100644 --- a/cookbook/rl/dpo_lora.py +++ b/cookbook/rl/dpo_lora.py @@ -147,9 +147,10 @@ def main(): policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=4, pp_size=2) ModelClass = MegatronModel else: - # Transformers: fsdp=4, dp=2 + # Transformers: dp_size=8 + # FSDP2 forward_only & forward has problems with `with unwrapped_model.disable_adapter()` from twinkle.model import TransformersModel - policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=4, fsdp_size=2) + policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=8) ModelClass = TransformersModel twinkle.initialize(mode='ray', nproc_per_node=MODEL_GPUS, groups=device_groups) From 92896001bc516fe71bc4a8588fb63e383a48d436 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 29 Mar 2026 19:47:00 +0800 Subject: [PATCH 27/31] lint --- src/twinkle/dataset/base.py | 6 ++-- src/twinkle/infra/collectors.py | 5 +-- src/twinkle/loss/dpo.py | 32 ++++++++----------- src/twinkle/metric/dpo.py | 12 +++---- src/twinkle/model/megatron/megatron.py | 10 +++--- src/twinkle/model/optimizer_group.py | 4 +-- .../transformers/multi_lora_transformers.py | 2 +- .../model/transformers/transformers.py | 9 ++++-- src/twinkle/template/base.py | 6 ++-- src/twinkle/utils/__init__.py | 3 +- src/twinkle/utils/device_mesh.py | 3 +- src/twinkle/utils/torch_utils.py | 2 +- 12 files changed, 41 insertions(+), 53 deletions(-) diff --git a/src/twinkle/dataset/base.py b/src/twinkle/dataset/base.py index fc14a030..6448cc2b 100644 --- a/src/twinkle/dataset/base.py +++ b/src/twinkle/dataset/base.py @@ -86,9 +86,9 @@ def encode(self, add_generation_prompt: bool = False, **kwargs): encode_fn = partial(self.template.batch_encode, add_generation_prompt=add_generation_prompt) with processing_lock('dataset'): # use a default lock because encode is to all datasets - self.dataset = self.dataset.map(encode_fn, - **kwargs).filter(lambda batch: [True] * len(next(iter(batch.values()))) if 'input_ids' not in batch else [len(x) > 0 for x in batch['input_ids']], - **kwargs) + self.dataset = self.dataset.map(encode_fn, **kwargs).filter( + lambda batch: [True] * len(next(iter(batch.values()))) + if 'input_ids' not in batch else [len(x) > 0 for x in batch['input_ids']], **kwargs) @remote_function() def check(self, **kwargs): diff --git a/src/twinkle/infra/collectors.py b/src/twinkle/infra/collectors.py index fbe4f17f..aaa60819 100644 --- a/src/twinkle/infra/collectors.py +++ b/src/twinkle/infra/collectors.py @@ -1,12 +1,9 @@ import numpy as np -from typing import TYPE_CHECKING, Any, Dict, List +from typing import Any, Dict, List from twinkle import DeviceMesh from twinkle.utils import pad_and_stack_tensors -if TYPE_CHECKING: - import torch - def collect_tensor_dict(outputs: List[Dict[str, Any]], device_mesh: DeviceMesh) -> Dict[str, Any]: import torch diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py index ada318c6..44d81f82 100644 --- a/src/twinkle/loss/dpo.py +++ b/src/twinkle/loss/dpo.py @@ -9,8 +9,8 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union from twinkle.data_format import LossOutput -from twinkle.utils.torch_utils import selective_log_softmax from twinkle.loss.base import Loss +from twinkle.utils.torch_utils import selective_log_softmax if TYPE_CHECKING: import torch @@ -176,14 +176,10 @@ def _align_logps( # Truncate right (keep left part) - may happen in Ray result merging return logps[:, :target_seq_len] else: - raise ValueError( - f'ref_logps seq_len ({src_seq_len}) < target seq_len ({target_seq_len}). ' - f'This should not happen when both models process the same batch.' - ) + raise ValueError(f'ref_logps seq_len ({src_seq_len}) < target seq_len ({target_seq_len}). ' + f'This should not happen when both models process the same batch.') - raise ValueError( - f'Cannot align ref_logps shape {logps.shape} to target shape {target_shape}' - ) + raise ValueError(f'Cannot align ref_logps shape {logps.shape} to target shape {target_shape}') def _compute_dpo_loss( self, @@ -227,7 +223,7 @@ def _compute_dpo_loss( elif self.loss_type == 'ipo': # IPO (Identity Preference Optimization) loss # Reference: "A General Theoretical Paradigm to Understand Learning from Human Feedback" - losses = (logits - 1 / (2 * self.beta)) ** 2 + losses = (logits - 1 / (2 * self.beta))**2 elif self.loss_type == 'kto_pair': # KTO pair loss (simplified version) chosen_logratios_scaled = self.beta * chosen_logratios @@ -236,7 +232,7 @@ def _compute_dpo_loss( rejected_losses = F.sigmoid(rejected_logratios_scaled) losses = chosen_losses + rejected_losses else: - raise ValueError(f"Unknown loss_type: {self.loss_type}") + raise ValueError(f'Unknown loss_type: {self.loss_type}') # Apply label smoothing if specified if self.label_smoothing > 0: @@ -292,7 +288,7 @@ def __call__( labels = labels.unsqueeze(0) batch_size = labels.shape[0] - assert batch_size % 2 == 0, "Batch size must be even (chosen + rejected pairs)" + assert batch_size % 2 == 0, 'Batch size must be even (chosen + rejected pairs)' # Get log probabilities from outputs logps = self._get_logps_from_outputs(outputs, labels) @@ -314,9 +310,7 @@ def __call__( reference_rejected_logps = ref_rejected_logps.to(device=device, dtype=dtype) elif ref_logps is not None: # Per-token reference log probs provided, need to align and sum - ref_logps_aligned = self._align_logps( - ref_logps, labels.shape, device, dtype - ) + ref_logps_aligned = self._align_logps(ref_logps, labels.shape, device, dtype) ref_chosen, ref_rejected = self._split_chosen_rejected(ref_logps_aligned) reference_chosen_logps = self._compute_sequence_logps(ref_chosen, chosen_labels) reference_rejected_logps = self._compute_sequence_logps(ref_rejected, rejected_labels) @@ -392,7 +386,7 @@ def __call__( if labels.dim() == 1: labels = labels.unsqueeze(0) - assert labels.shape[0] % 2 == 0, "Batch size must be even (chosen + rejected pairs)" + assert labels.shape[0] % 2 == 0, 'Batch size must be even (chosen + rejected pairs)' # Get log probabilities logps = self._get_logps_from_outputs(outputs, labels) @@ -455,7 +449,7 @@ def __call__( if labels.dim() == 1: labels = labels.unsqueeze(0) - assert labels.shape[0] % 2 == 0, "Batch size must be even" + assert labels.shape[0] % 2 == 0, 'Batch size must be even' # Get log probabilities logps = self._get_logps_from_outputs(outputs, labels) @@ -521,7 +515,7 @@ def __call__( if labels.dim() == 1: labels = labels.unsqueeze(0) - assert labels.shape[0] % 2 == 0, "Batch size must be even" + assert labels.shape[0] % 2 == 0, 'Batch size must be even' # Get log probabilities logps = self._get_logps_from_outputs(outputs, labels) @@ -540,8 +534,8 @@ def __call__( # Odds ratio: log(odds_chosen / odds_rejected) # log_odds = log(p/(1-p)) = log(p) - log(1-p) # Use numerically stable computation - prob_chosen = torch.exp(chosen_avg_logps).clamp(min=1e-7, max=1-1e-7) - prob_rejected = torch.exp(rejected_avg_logps).clamp(min=1e-7, max=1-1e-7) + prob_chosen = torch.exp(chosen_avg_logps).clamp(min=1e-7, max=1 - 1e-7) + prob_rejected = torch.exp(rejected_avg_logps).clamp(min=1e-7, max=1 - 1e-7) log_odds_chosen = torch.log(prob_chosen) - torch.log(1 - prob_chosen) log_odds_rejected = torch.log(prob_rejected) - torch.log(1 - prob_rejected) diff --git a/src/twinkle/metric/dpo.py b/src/twinkle/metric/dpo.py index 4d8513d0..5ce61410 100644 --- a/src/twinkle/metric/dpo.py +++ b/src/twinkle/metric/dpo.py @@ -57,10 +57,8 @@ def _align_logps(self, logps, target_shape, device, dtype): if src_len == target_len: return logps elif src_len < target_len: - raise ValueError( - f'ref_logps seq_len ({src_len}) < target seq_len ({target_len}). ' - f'This should not happen when both models process the same batch.' - ) + raise ValueError(f'ref_logps seq_len ({src_len}) < target seq_len ({target_len}). ' + f'This should not happen when both models process the same batch.') else: return logps[:, :target_len] @@ -84,7 +82,7 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M logps = outputs.get('logps') if logps is None or len(logps) == 0: return - + if isinstance(logps, list) and logps: logps = pad_and_stack_tensors(logps) @@ -128,9 +126,7 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M ref_logps = ref_outputs.get('logps') if ref_logps is not None: # Align ref_logps to match labels shape (handles different seq lengths) - ref_logps = self._align_logps( - ref_logps, labels.shape, labels.device, logps.dtype - ) + ref_logps = self._align_logps(ref_logps, labels.shape, labels.device, logps.dtype) ref_seq_logps = self._compute_sequence_logps(ref_logps, labels) ref_chosen_logps, ref_rejected_logps = self._split_chosen_rejected(ref_seq_logps) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 8d8e8627..36da85c1 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -25,12 +25,12 @@ from twinkle import DeviceMesh, Platform, remote_class, remote_function, requires, torch_util from twinkle.checkpoint_engine.mixin import CheckpointEngineMixin from twinkle.data_format import InputFeature, ModelOutput, Trajectory -from twinkle.model.optimizer_group import BaseOptimizerGroup, TrainStatus from twinkle.hub import HubOperation from twinkle.infra import collect_tensor_dict from twinkle.loss import CrossEntropyLoss, Loss from twinkle.metric import LossMetric, Metric, TrainMetric from twinkle.model.base import TwinkleModel +from twinkle.model.optimizer_group import BaseOptimizerGroup, TrainStatus from twinkle.patch import Patch, apply_patch from twinkle.processor import InputProcessor from twinkle.template import Template @@ -435,7 +435,7 @@ def forward_backward(self, seq_length = original_seq_length + (divisor - original_seq_length % divisor) else: seq_length = original_seq_length - + num_microbatches = len(inputs) loss_extra_kwargs_per_mb = [] if num_microbatches <= 1: @@ -463,10 +463,12 @@ def post_loss_function(output_tensor, inputs, logps): if not counts: # Later will gather this value, so it becomes: # 1. SUM loss: gather_sum(local_num_tokens) = global_num_tokens - # 2. PER TOKEN MEAN loss: gather_sum(1 * gradient_accumulation_steps ) = gradient_accumulation_steps * world_size + # 2. PER TOKEN MEAN loss: gather_sum(1 * gradient_accumulation_steps ) + # = gradient_accumulation_steps * world_size # Then, grad will divided by this value: # 1. SUM loss: (global_sum_grad) / (global_num_tokens) = global_sum_grad/global_num_tokens - # 2. PER TOKEN MEAN loss: (gather_sum(per_token_grad * gradient_accumulation_steps)) / (gradient_accumulation_steps * world_size ) = avg_per_token_grad + # 2. PER TOKEN MEAN loss: (gather_sum(per_token_grad * gradient_accumulation_steps)) + # / (gradient_accumulation_steps * world_size ) = avg_per_token_grad counts = torch.tensor(1, device=losses.device) return self.strategy.reduce_loss(losses, counts, output_tensor, logps) diff --git a/src/twinkle/model/optimizer_group.py b/src/twinkle/model/optimizer_group.py index 1a23d568..2052f7d6 100644 --- a/src/twinkle/model/optimizer_group.py +++ b/src/twinkle/model/optimizer_group.py @@ -1,9 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional - from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler +from typing import Any, Dict, List, Optional from twinkle import DeviceMesh from twinkle.data_format import InputFeature, ModelOutput @@ -83,4 +82,3 @@ def calculate_metrics(self, is_training): status.inputs = None status.outputs = None return results - diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index 9fe52c8d..77a25b18 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -8,9 +8,9 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union from twinkle import DeviceMesh, remote_class, remote_function, template -from twinkle.infra import collect_tensor_dict from twinkle.data_format import InputFeature, Trajectory from twinkle.hub import HubOperation +from twinkle.infra import collect_tensor_dict from twinkle.loss import Loss from twinkle.metric import Metric from twinkle.processor import InputProcessor diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 22e3538d..6097ffe2 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -474,13 +474,16 @@ def calculate_loss(self, **kwargs): counts = torch.tensor(1, device=loss_value.device) # Later will gather this value, so it becomes: # 1. SUM loss: gather_sum(local_num_tokens / dp_world_size) = global_num_tokens / dp_world_size - # 2. PER TOKEN MEAN loss: gather_sum(1 * gradient_accumulation_steps / dp_world_size ) = gradient_accumulation_steps + # 2. PER TOKEN MEAN loss: gather_sum(1 * gradient_accumulation_steps / dp_world_size ) + # = gradient_accumulation_steps # Then, grad will divided by this value: # 1. SUM loss: gather_mean(local_sum_grad) / (global_num_tokens / dp_world_size) # = (global_sum_grad / dp_world_size) / (global_num_tokens / dp_world_size) # = global_sum_grad/global_num_tokens - # 2. PER TOKEN MEAN loss: gather_mean(per_token_grad * gradient_accumulation_steps) / gradient_accumulation_steps - # = (global_per_token_grad * gradient_accumulation_steps / dp_world_size ) / gradient_accumulation_steps + # 2. PER TOKEN MEAN loss: gather_mean(per_token_grad * gradient_accumulation_steps) + # / gradient_accumulation_steps + # = (global_per_token_grad * gradient_accumulation_steps / dp_world_size ) + # / gradient_accumulation_steps # = global_per_token_grad / dp_world_size = avg_per_token_grad counts = counts / self.device_mesh.data_world_size optimizer_config = self.optimizer_group[adapter_name] diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index a3e53056..acc17d2e 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -320,7 +320,7 @@ def _encode_messages(self, trajectory: Trajectory, add_generation_prompt: bool = input_ids = self.tokenizer.encode(text) encoded = {} labels = deepcopy(input_ids) - + input_feature = InputFeature( input_ids=np.array(input_ids), labels=np.array(labels), @@ -398,9 +398,7 @@ def batch_encode( for key in trajectories: if key in traj_keys: # Encode this trajectory list - result[key] = self.batch_encode( - trajectories[key], add_generation_prompt=add_generation_prompt - ) + result[key] = self.batch_encode(trajectories[key], add_generation_prompt=add_generation_prompt) else: # Keep non-trajectory columns as-is result[key] = trajectories[key] diff --git a/src/twinkle/utils/__init__.py b/src/twinkle/utils/__init__.py index 5d543371..ce91cbf8 100644 --- a/src/twinkle/utils/__init__.py +++ b/src/twinkle/utils/__init__.py @@ -10,7 +10,8 @@ from .parallel import processing_lock from .platforms import GPU, NPU, Platform, ensure_hccl_socket_env, ensure_npu_backend from .safetensors import LazyTensor, SafetensorLazyLoader, StreamingSafetensorSaver -from .torch_utils import pad_sequence_to_length, selective_log_softmax, stateless_init_process_group, to_device, pad_and_stack_tensors +from .torch_utils import (pad_and_stack_tensors, pad_sequence_to_length, selective_log_softmax, + stateless_init_process_group, to_device) from .transformers_utils import find_all_linears, find_layers, get_modules_to_not_convert from .unsafe import check_unsafe, trust_remote_code from .utils import copy_files_by_pattern, deep_getattr diff --git a/src/twinkle/utils/device_mesh.py b/src/twinkle/utils/device_mesh.py index d95d5da8..4393fed7 100644 --- a/src/twinkle/utils/device_mesh.py +++ b/src/twinkle/utils/device_mesh.py @@ -478,8 +478,7 @@ def get_collect_ranks(self) -> list[int]: for data_rank in range(data_ws): # Find all global ranks that map to this data_rank candidates = [ - r for r in self.mesh.flatten().tolist() - if self.get_data_rank_from_global_rank(r) == data_rank + r for r in self.mesh.flatten().tolist() if self.get_data_rank_from_global_rank(r) == data_rank ] if not candidates: continue diff --git a/src/twinkle/utils/torch_utils.py b/src/twinkle/utils/torch_utils.py index 1c4d8462..34506d6f 100644 --- a/src/twinkle/utils/torch_utils.py +++ b/src/twinkle/utils/torch_utils.py @@ -222,4 +222,4 @@ def pad_and_stack_tensors(tensors: List['torch.Tensor'], pad_value: float = -200 padded = torch.nn.functional.pad(t, pad_params, value=pad_value) padded_tensors.append(padded) - return torch.cat(padded_tensors, dim=0) \ No newline at end of file + return torch.cat(padded_tensors, dim=0) From 78bd1473e003298fbae21ffa82809ebb4dbc6328 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 29 Mar 2026 19:58:57 +0800 Subject: [PATCH 28/31] fix --- cookbook/rl/dpo_full.py | 10 +--------- cookbook/rl/dpo_lora.py | 8 -------- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/cookbook/rl/dpo_full.py b/cookbook/rl/dpo_full.py index cf55079c..8c3e5a6f 100644 --- a/cookbook/rl/dpo_full.py +++ b/cookbook/rl/dpo_full.py @@ -36,19 +36,12 @@ MODEL_GPUS – GPUs for policy model (default: 4) REF_MODEL_GPUS – GPUs for reference model (default: 4, 0 to disable) BATCH_SIZE – global batch size (preference pairs) (default: 8) - MICRO_BATCH_SIZE – per-device micro batch size (default: 2) MAX_STEPS – total optimization steps (default: 1000) LR – learning rate (default: 1e-5) DPO_BETA – DPO temperature parameter (default: 0.1) LOSS_TYPE – DPO variant (sigmoid/hinge/ipo/simpo/orpo/cpo) (default: sigmoid) SAVE_STEPS – checkpoint save interval (default: 100) MAX_LENGTH – max sequence length (default: 2048) - - Dataset field mapping (for custom datasets): - PROMPT_KEY – key for prompt field (default: 'prompt') - CHOSEN_KEY – key for chosen response (default: 'answer_zh') - REJECTED_KEY – key for rejected response (default: 'answer_en') - SYSTEM_PROMPT – system prompt to prepend (default: 'You are a helpful assistant.') """ import os @@ -76,7 +69,6 @@ NUM_GPUS = MODEL_GPUS + REF_MODEL_GPUS BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs -MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 2)) LEARNING_RATE = float(os.environ.get('LR', 1e-5)) DPO_BETA = float(os.environ.get('DPO_BETA', 0.1)) @@ -198,7 +190,7 @@ def main(): reference_free = LOSS_TYPE in ['simpo', 'orpo', 'cpo'] # Set up loss function and metrics - loss_fn = create_loss(LOSS_TYPE, DPO_BETA, sft_weight=SFT_WEIGHT, reference_free=False) + loss_fn = create_loss(LOSS_TYPE, DPO_BETA, sft_weight=SFT_WEIGHT, reference_free=reference_free) # Configure optimizer based on backend (full-parameter training) if USE_MEGATRON: diff --git a/cookbook/rl/dpo_lora.py b/cookbook/rl/dpo_lora.py index eea3b4fb..861e72a4 100644 --- a/cookbook/rl/dpo_lora.py +++ b/cookbook/rl/dpo_lora.py @@ -34,19 +34,12 @@ DATASET_ID – (default: ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji) MODEL_GPUS – GPUs for policy model (default: 8) BATCH_SIZE – global batch size (preference pairs) (default: 8) - MICRO_BATCH_SIZE – per-device micro batch size (default: 2) MAX_STEPS – total optimization steps (default: 1000) LR – learning rate (default: 1e-4) DPO_BETA – DPO temperature parameter (default: 0.1) LOSS_TYPE – DPO variant (sigmoid/hinge/ipo) (default: sigmoid) SAVE_STEPS – checkpoint save interval (default: 100) MAX_LENGTH – max sequence length (default: 2048) - - Dataset field mapping (for custom datasets): - PROMPT_KEY – key for prompt field (default: 'prompt') - CHOSEN_KEY – key for chosen response (default: 'answer_zh') - REJECTED_KEY – key for rejected response (default: 'answer_en') - SYSTEM_PROMPT – system prompt to prepend (default: 'You are a helpful assistant.') """ import os @@ -74,7 +67,6 @@ MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 8)) BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs -MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 2)) LEARNING_RATE = float(os.environ.get('LR', 1e-4)) # LoRA DPO requires higher LR (1e-4 to 3e-4) DPO_BETA = float(os.environ.get('DPO_BETA', 0.1)) From deaf96a2fc54cb2a4c7fbaa13a008aedddb7f7bc Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 29 Mar 2026 20:00:18 +0800 Subject: [PATCH 29/31] fix --- cookbook/transformers/fsdp2.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index 025c1743..7b6bd2a8 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -1,9 +1,8 @@ -import os from peft import LoraConfig from tqdm import tqdm import twinkle -from twinkle import DeviceMesh, Platform, get_device_placement, get_logger +from twinkle import DeviceMesh, get_device_placement, get_logger from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import TransformersModel @@ -68,7 +67,7 @@ def train(): model.forward_backward(inputs=batch) # Step model.clip_grad_and_step() - if step % 1 == 0: + if step % 20 == 0: # Print metric metric = model.calculate_metric(is_training=True) logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') From 30b6411bef7a134f0fb7530285550d51efb1f250 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 29 Mar 2026 20:03:45 +0800 Subject: [PATCH 30/31] fix ga step --- src/twinkle/model/optimizer_group.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/twinkle/model/optimizer_group.py b/src/twinkle/model/optimizer_group.py index 2052f7d6..5fdb89f7 100644 --- a/src/twinkle/model/optimizer_group.py +++ b/src/twinkle/model/optimizer_group.py @@ -52,7 +52,8 @@ def do_grad_sync(self, gradient_accumulation_steps: Optional[int] = None) -> boo gradient_accumulation_steps = self.gradient_accumulation_steps else: self.gradient_accumulation_steps = gradient_accumulation_steps - return (self.cur_step - 1) % gradient_accumulation_steps == 0 and self.cur_step > 1 + return gradient_accumulation_steps == 1 or ((self.cur_step - 1) % gradient_accumulation_steps == 0 + and self.cur_step > 1) def _get_lr(self): """Get learning rates from optimizer. Override in subclass.""" From fb873912d2377a179bd27200a8ec921991849c0d Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Mon, 30 Mar 2026 13:55:37 +0800 Subject: [PATCH 31/31] fix --- README.md | 4 ++-- README_ZH.md | 4 ++-- src/twinkle/model/megatron/megatron.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 460ba92e..cd7eccfd 100644 --- a/README.md +++ b/README.md @@ -138,8 +138,8 @@ supported on Twinkle✨ framework. | qwen3 series | [Qwen/Qwen3-14B-Base](https://modelscope.cn/models/Qwen/Qwen3-14B-Base) | 0.6B/1.7B/4B/8B/14B | transformers>=4.51 | ✔ | [Qwen/Qwen3-14B-Base](https://huggingface.co/Qwen/Qwen3-14B-Base) | | | [Qwen/Qwen3-32B](https://modelscope.cn/models/Qwen/Qwen3-32B) | 0.6B/1.7B/4B/8B/14B/32B | transformers>=4.51 | ✔ | [Qwen/Qwen3-32B](https://huggingface.co/Qwen/Qwen3-32B) | | qwen3_moe series | [Qwen/Qwen3-30B-A3B-Base](https://modelscope.cn/models/Qwen/Qwen3-30B-A3B-Base) | 30B-A3B/A3B-Base,235B-A22B | transformers>=4.51 | ✔ | [Qwen/Qwen3-30B-A3B-Base](https://huggingface.co/Qwen/Qwen3-30B-A3B-Base) | -| qwen3.5 moe series | [Qwen/Qwen3.5-35B-A3B](https://www.modelscope.cn/models/Qwen/Qwen3.5-35B-A3B) | 35B-A3B,122B-A10B, etc. | transformers>=5.20 | ✔ | [Qwen/Qwen3.5-35B-A3B](https://huggingface.co/Qwen/Qwen3.5-35B-A3B) | -| qwen3.5 series | [Qwen/Qwen3.5-9B](https://www.modelscope.cn/models/Qwen/Qwen3.5-9B) | 2B ~ 27B | transformers>=5.20 | ✔ | [Qwen/Qwen3.5-9B](https://huggingface.co/Qwen/Qwen3.5-9B) | +| qwen3.5 moe series | [Qwen/Qwen3.5-35B-A3B](https://www.modelscope.cn/models/Qwen/Qwen3.5-35B-A3B) | 35B-A3B,122B-A10B, etc. | transformers>=5.2.0 | ✔ | [Qwen/Qwen3.5-35B-A3B](https://huggingface.co/Qwen/Qwen3.5-35B-A3B) | +| qwen3.5 series | [Qwen/Qwen3.5-9B](https://www.modelscope.cn/models/Qwen/Qwen3.5-9B) | 2B ~ 27B | transformers>=5.2.0 | ✔ | [Qwen/Qwen3.5-9B](https://huggingface.co/Qwen/Qwen3.5-9B) | | qwen2 series | [Qwen/Qwen2-0.5B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-0.5B-Instruct) | 0.5B/1.5B/7B/72B | transformers>=4.37 | ✔ | [Qwen/Qwen2-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) | | | [Qwen/Qwen2-1.5B](https://modelscope.cn/models/Qwen/Qwen2-1.5B) | 0.5B/1.5B/7B/72B | transformers>=4.37 | ✔ | [Qwen/Qwen2-1.5B](https://huggingface.co/Qwen/Qwen2-1.5B) | | | [Qwen/Qwen2.5-1.5B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-1.5B-Instruct) | 0.5B/1.5B/3B/7B/14B/32B/72B | transformers>=4.37 | ✔ | [Qwen/Qwen2.5-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct) | diff --git a/README_ZH.md b/README_ZH.md index 3f817272..e404508a 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -121,8 +121,8 @@ Twinkle✨支持相同的算法接口运行在单GPU、torchrun多机、Ray、Cl | qwen3 全系列 | [Qwen/Qwen3-14B-Base](https://modelscope.cn/models/Qwen/Qwen3-14B-Base) | 0.6B/1.7B/4B/8B/14B | transformers>=4.51 | ✔ | [Qwen/Qwen3-14B-Base](https://huggingface.co/Qwen/Qwen3-14B-Base) | | | [Qwen/Qwen3-32B](https://modelscope.cn/models/Qwen/Qwen3-32B) | 0.6B/1.7B/4B/8B/14B/32B | transformers>=4.51 | ✔ | [Qwen/Qwen3-32B](https://huggingface.co/Qwen/Qwen3-32B) | | qwen3_moe 全系列 | [Qwen/Qwen3-30B-A3B-Base](https://modelscope.cn/models/Qwen/Qwen3-30B-A3B-Base) | 30B-A3B/A3B-Base,235B-A22B | transformers>=4.51 | ✔ | [Qwen/Qwen3-30B-A3B-Base](https://huggingface.co/Qwen/Qwen3-30B-A3B-Base) | -| qwen3.5 moe 全系列 | [Qwen/Qwen3.5-35B-A3B](https://www.modelscope.cn/models/Qwen/Qwen3.5-35B-A3B) | 35B-A3B,122B-A10B, etc. | transformers>=5.20 | ✔ | [Qwen/Qwen3.5-35B-A3B](https://huggingface.co/Qwen/Qwen3.5-35B-A3B) | -| qwen3.5 全系列 | [Qwen/Qwen3.5-9B](https://www.modelscope.cn/models/Qwen/Qwen3.5-9B) | 2B ~ 27B | transformers>=5.20 | ✔ | [Qwen/Qwen3.5-9B](https://huggingface.co/Qwen/Qwen3.5-9B) | +| qwen3.5 moe 全系列 | [Qwen/Qwen3.5-35B-A3B](https://www.modelscope.cn/models/Qwen/Qwen3.5-35B-A3B) | 35B-A3B,122B-A10B, etc. | transformers>=5.2.0 | ✔ | [Qwen/Qwen3.5-35B-A3B](https://huggingface.co/Qwen/Qwen3.5-35B-A3B) | +| qwen3.5 全系列 | [Qwen/Qwen3.5-9B](https://www.modelscope.cn/models/Qwen/Qwen3.5-9B) | 2B ~ 27B | transformers>=5.2.0 | ✔ | [Qwen/Qwen3.5-9B](https://huggingface.co/Qwen/Qwen3.5-9B) | | qwen2 全系列 | [Qwen/Qwen2-0.5B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-0.5B-Instruct) | 0.5B/1.5B/7B/72B | transformers>=4.37 | ✔ | [Qwen/Qwen2-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) | | | [Qwen/Qwen2-1.5B](https://modelscope.cn/models/Qwen/Qwen2-1.5B) | 0.5B/1.5B/7B/72B | transformers>=4.37 | ✔ | [Qwen/Qwen2-1.5B](https://huggingface.co/Qwen/Qwen2-1.5B) | | | [Qwen/Qwen2.5-1.5B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-1.5B-Instruct) | 0.5B/1.5B/3B/7B/14B/32B/72B | transformers>=4.37 | ✔ | [Qwen/Qwen2.5-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct) | diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 36da85c1..d5923f86 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -413,7 +413,7 @@ def forward_backward(self, if micro_batch_size is None: # Compatible with DPO - micro_batch_size = 2 + micro_batch_size = min(2, len(inputs)) inputs = processor(inputs, micro_batch_size=micro_batch_size, variable_seq_lengths=self.variable_seq_lengths) # Get parallelism settings for sequence padding and splitting