From 8fb999179f490efa8185e749c01e8048c2636a50 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Mon, 30 Mar 2026 16:43:44 +0800 Subject: [PATCH 01/18] wip --- pyproject.toml | 2 +- src/twinkle/model/megatron/args.py | 692 ------- src/twinkle/model/megatron/megatron.py | 150 +- src/twinkle/model/megatron/model/__init__.py | 4 - src/twinkle/model/megatron/model/constant.py | 39 - .../model/megatron/model/gpt_bridge.py | 1651 ----------------- src/twinkle/model/megatron/model/gpt_model.py | 465 ----- .../model/megatron/model/gpts/__init__.py | 14 - .../model/megatron/model/gpts/qwen3_next.py | 512 ----- .../model/megatron/model/mm_gpt_model.py | 136 -- .../model/megatron/model/mm_gpts/__init__.py | 2 - .../model/megatron/model/mm_gpts/qwen.py | 121 -- .../model/megatron/model/mm_gpts/qwen3_5.py | 174 -- .../model/megatron/model/mm_gpts/qwen3_vl.py | 450 ----- .../model/megatron/model/mm_gpts/utils.py | 83 - src/twinkle/model/megatron/model/register.py | 98 - src/twinkle/model/megatron/model/rope.py | 175 -- .../model/megatron/strategy/megatron.py | 77 +- src/twinkle/model/megatron/tuners/__init__.py | 8 - src/twinkle/model/megatron/tuners/lora.py | 583 ------ src/twinkle/model/megatron/tuners/utils.py | 206 -- src/twinkle/model/megatron/utils/__init__.py | 2 - src/twinkle/model/megatron/utils/config.py | 201 -- src/twinkle/model/megatron/utils/utils.py | 32 - src/twinkle/utils/__init__.py | 2 +- src/twinkle/utils/torch_utils.py | 31 +- 26 files changed, 102 insertions(+), 5808 deletions(-) delete mode 100644 src/twinkle/model/megatron/args.py delete mode 100644 src/twinkle/model/megatron/model/__init__.py delete mode 100644 src/twinkle/model/megatron/model/constant.py delete mode 100644 src/twinkle/model/megatron/model/gpt_bridge.py delete mode 100644 src/twinkle/model/megatron/model/gpt_model.py delete mode 100644 src/twinkle/model/megatron/model/gpts/__init__.py delete mode 100644 src/twinkle/model/megatron/model/gpts/qwen3_next.py delete mode 100644 src/twinkle/model/megatron/model/mm_gpt_model.py delete mode 100644 src/twinkle/model/megatron/model/mm_gpts/__init__.py delete mode 100644 src/twinkle/model/megatron/model/mm_gpts/qwen.py delete mode 100644 src/twinkle/model/megatron/model/mm_gpts/qwen3_5.py delete mode 100644 src/twinkle/model/megatron/model/mm_gpts/qwen3_vl.py delete mode 100644 src/twinkle/model/megatron/model/mm_gpts/utils.py delete mode 100644 src/twinkle/model/megatron/model/register.py delete mode 100644 src/twinkle/model/megatron/model/rope.py delete mode 100644 src/twinkle/model/megatron/tuners/__init__.py delete mode 100644 src/twinkle/model/megatron/tuners/lora.py delete mode 100644 src/twinkle/model/megatron/tuners/utils.py delete mode 100644 src/twinkle/model/megatron/utils/__init__.py delete mode 100644 src/twinkle/model/megatron/utils/config.py delete mode 100644 src/twinkle/model/megatron/utils/utils.py diff --git a/pyproject.toml b/pyproject.toml index 0e9640c2..9e6a321e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ transformers = [ "torchvision", ] kernels = ["kernels"] -megatron = ["megatron-core>=0.12.0", "transformer-engine[pytorch]"] +megatron = ["megatron-core>=0.12.0", "transformer-engine[pytorch]", "mcore_bridge"] vllm = ["vllm>=0.11"] ray = ["ray[serve]"] tinker = ["tinker==0.14.0"] diff --git a/src/twinkle/model/megatron/args.py b/src/twinkle/model/megatron/args.py deleted file mode 100644 index eacc10db..00000000 --- a/src/twinkle/model/megatron/args.py +++ /dev/null @@ -1,692 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -import inspect -import torch -import torch.nn as nn -from dataclasses import dataclass, field -from types import SimpleNamespace -from typing import Any, Dict, List, Literal, Optional - -from twinkle import DeviceMesh -from twinkle.utils import exists -from .utils import convert_hf_config - -# Global args storage -_GLOBAL_ARGS: Optional['TwinkleMegatronArgs'] = None - - -def get_args() -> 'TwinkleMegatronArgs': - """Get the global TwinkleMegatronArgs instance. - - This function is designed to be a drop-in replacement for megatron's get_args(). - If TwinkleMegatronArgs has not been set, it will try to use megatron's get_args() as fallback. - - Returns: - TwinkleMegatronArgs instance or megatron args. - - Raises: - RuntimeError: If args have not been initialized. - """ - if _GLOBAL_ARGS is not None: - return _GLOBAL_ARGS - - raise RuntimeError('Twinkle args have not been initialized. ') - - -def set_args(args: 'TwinkleMegatronArgs') -> None: - """Set the global TwinkleMegatronArgs instance.""" - global _GLOBAL_ARGS - _GLOBAL_ARGS = args - - -def clear_args() -> None: - """Clear the global args.""" - global _GLOBAL_ARGS - _GLOBAL_ARGS = None - - -@dataclass -class TwinkleMegatronArgs: - """Lightweight args class compatible with Megatron's args. - - This class provides a unified configuration system for both model creation - and weight conversion. It stores a reference to the original HuggingFace config - and implements __getattr__ to fallback to hf_config for missing attributes. - - Attributes: - _hf_config: The original HuggingFace config object (stored but not a dataclass field). - """ - _model: Optional[List[nn.Module]] = None - # ========================================================================= - # Model architecture (from HF config) - # ========================================================================= - hidden_size: int = 4096 - num_attention_heads: int = 32 - num_key_value_heads: Optional[int] = None - num_layers: int = 32 - ffn_hidden_size: int = 11008 - vocab_size: Optional[int] = None - padded_vocab_size: Optional[int] = None - kv_channels: Optional[int] = None # head_dim - variable_seq_lengths: bool = True - - # ========================================================================= - # Parallelism settings - # ========================================================================= - device_mesh: DeviceMesh = None - sequence_parallel: bool = False - - # ========================================================================= - # RoPE settings - # ========================================================================= - rotary_base: int = 10000 # rope_theta in HF config - rotary_percent: float = 1.0 - max_position_embeddings: int = 4096 - original_max_position_embeddings: Optional[int] = None - rope_scaling: Optional[Dict[str, Any]] = None - partial_rotary_factor: Optional[float] = None # For partial RoPE - rope_interleaved: bool = False # mrope_interleaved in Swift - - # ========================================================================= - # Model settings - # ========================================================================= - model_dir: str = '' - hf_model_type: str = 'qwen2' - is_multimodal: bool = False - - # ========================================================================= - # Bias settings (used by bridge for weight conversion) - # ========================================================================= - add_qkv_bias: bool = False - add_bias_linear: bool = False - qk_layernorm: bool = False - tie_word_embeddings: bool = False - - # ========================================================================= - # MoE settings (used by bridge for weight conversion) - # ========================================================================= - num_experts: int = 0 - num_experts_per_tok: int = 2 - shared_expert_intermediate_size: int = 0 - moe_router_enable_expert_bias: bool = False - - # ========================================================================= - # Training/inference settings - # ========================================================================= - params_dtype: torch.dtype = torch.bfloat16 - task_type: str = 'causal_lm' # not used for now - num_labels: int = 2 - - # ========================================================================= - # Attention settings - # ========================================================================= - attn_impl: str = 'flash_attn' - attention_backend: str = 'flash' - - # ========================================================================= - # MTP (Multi-Token Prediction) settings - # ========================================================================= - mtp_num_layers: int = 0 - - # ========================================================================= - # MLA (Multi-Latent Attention) settings - for DeepSeek-V2/V3 style models - # ========================================================================= - multi_latent_attention: bool = False - q_lora_rank: Optional[int] = None - - # ========================================================================= - # LoRA/PEFT settings - # ========================================================================= - merge_lora: bool = False - target_modules: List[str] = field(default_factory=list) - - # ========================================================================= - # FP8 quantization settings - # ========================================================================= - fp8: Optional[str] = None - fp8_recipe: str = 'delayed' - fp8_param_gather: bool = False - - # ========================================================================= - # Activation checkpointing settings - # ========================================================================= - recompute_granularity: Literal['selective', 'full', 'none'] = 'selective' - recompute_modules: List[str] = field(default_factory=lambda: ['core_attn']) - recompute_method: Optional[Literal['uniform', 'block']] = None - recompute_num_layers: Optional[int] = None - # ========================================================================= - # Additional settings - # ========================================================================= - untie_embeddings_and_output_weights: bool = True - max_shard_size: str = '5GB' - use_cpu_initialization: bool = False - - def __post_init__(self): - # Initialize _hf_config as None (will be set by from_hf_config) - object.__setattr__(self, '_hf_config', None) - object.__setattr__(self, '_text_config', None) - - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads - if self.kv_channels is None: - self.kv_channels = self.hidden_size // self.num_attention_heads - if self.attention_backend is None: - self.attention_backend = SimpleNamespace(name='flash') - - def __getattr__(self, name: str) -> Any: - """Fallback to hf_config for missing attributes. - - This allows seamless access to HuggingFace config attributes that - weren't explicitly copied to TwinkleMegatronArgs. - """ - # Avoid infinite recursion for special attributes - if name.startswith('_'): - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") - - # Try to get from hf_config - hf_config = object.__getattribute__(self, '_hf_config') - if hf_config is not None: - # First try direct access - if hasattr(hf_config, name): - return getattr(hf_config, name) - - # For multimodal models, try text_config - text_config = object.__getattribute__(self, '_text_config') - if text_config is not None and hasattr(text_config, name): - return getattr(text_config, name) - - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}' " - f'and it was not found in hf_config either.') - - @property - def tensor_model_parallel_size(self) -> int: - return self.device_mesh.tp_world_size or 1 - - @property - def tp_size(self) -> int: - return self.device_mesh.tp_world_size or 1 - - @property - def pipeline_model_parallel_size(self) -> int: - return self.device_mesh.pp_world_size or 1 - - @property - def pp_size(self) -> int: - return self.device_mesh.pp_world_size or 1 - - @property - def context_parallel_size(self) -> int: - return self.device_mesh.cp_world_size or 1 - - @property - def cp_size(self) -> int: - return self.device_mesh.cp_world_size or 1 - - @property - def expert_model_parallel_size(self) -> int: - return self.device_mesh.ep_size or 1 - - @property - def ep_size(self) -> int: - return self.device_mesh.ep_size or 1 - - @property - def expert_tensor_parallel_size(self) -> int: - return self.device_mesh.etp_world_size - - @property - def etp_size(self) -> int: - return self.expert_tensor_parallel_size - - @property - def virtual_pipeline_model_parallel_size(self) -> int: - return self.device_mesh.vpp_size - - @property - def vpp_size(self) -> int: - return self.device_mesh.vpp_size - - @property - def order(self) -> str: - return self.device_mesh.order - - @property - def head_dim(self) -> int: - return self.kv_channels - - @property - def intermediate_size(self) -> int: - return self.ffn_hidden_size - - @property - def moe_shared_expert_intermediate_size(self) -> int: - return self.shared_expert_intermediate_size - - @property - def num_query_groups(self) -> int: - """Alias for num_key_value_heads (Megatron naming).""" - return self.num_key_value_heads - - @property - def group_query_attention(self) -> bool: - """Whether the model uses grouped query attention (GQA).""" - return self.num_key_value_heads != self.num_attention_heads - - @property - def torch_dtype(self) -> torch.dtype: - return self.params_dtype - - @property - def hf_config(self) -> Any: - """Get the original HuggingFace config.""" - return object.__getattribute__(self, '_hf_config') - - @property - def text_config(self) -> Any: - """Get the text config (for multimodal models).""" - return object.__getattribute__(self, '_text_config') - - @classmethod - def from_hf_config( - cls, - hf_config: Any, - model_dir: str = '', - device_mesh: DeviceMesh = None, - params_dtype: torch.dtype = torch.bfloat16, - sequence_parallel: bool = False, - task_type: str = 'causal_lm', - padded_vocab_size: Optional[int] = None, - **kwargs, - ) -> 'TwinkleMegatronArgs': - """Create TwinkleMegatronArgs from a HuggingFace model config. - - This method handles both regular LLM configs and multimodal configs - where parameters may be in nested sub-configs (e.g., text_config). - - The original hf_config is stored and can be accessed via args.hf_config - or through attribute fallback (__getattr__). - """ - # Handle multimodal configs with nested text_config - text_config = hf_config - if hasattr(hf_config, 'text_config') and hf_config.text_config is not None: - text_config = hf_config.text_config - - vocab_size = getattr(text_config, 'vocab_size') - assert vocab_size is not None, 'detect vocab_size in hf config failed' - if padded_vocab_size is None: - if device_mesh.tp_world_size > 1: - divisor = device_mesh.tp_world_size * 128 - padded_vocab_size = ((vocab_size + divisor - 1) // divisor) * divisor - else: - padded_vocab_size = vocab_size - - num_attention_heads = getattr(text_config, 'num_attention_heads', 32) - num_key_value_heads = getattr(text_config, 'num_key_value_heads', num_attention_heads) - hidden_size = getattr(text_config, 'hidden_size', 4096) - - # Get kv_channels (head_dim) - kv_channels = getattr(text_config, 'head_dim', None) - if kv_channels is None: - kv_channels = hidden_size // num_attention_heads - - # Get rope_scaling - rope_scaling = getattr(text_config, 'rope_scaling', None) - - model_type = getattr(hf_config, 'model_type', 'qwen2') - - # Detect multimodal model from the registered MegatronModelMeta - from .model.register import get_megatron_model_meta - model_meta = get_megatron_model_meta(model_type) - is_multimodal = model_meta.is_multimodal if model_meta is not None else False - - # Determine QKV bias - if hasattr(text_config, 'attention_bias'): - add_qkv_bias = text_config.attention_bias - elif model_type in ('qwen2', 'qwen2_5', 'qwen2_vl', 'qwen2_5_vl'): - add_qkv_bias = True - else: - add_qkv_bias = False - - # Determine QK layernorm - qk_layernorm = (getattr(text_config, 'qk_layernorm', False) or getattr(text_config, 'use_qk_norm', False)) - # MoE config - num_experts = ( - getattr(text_config, 'num_experts', 0) or getattr(text_config, 'n_routed_experts', 0) - or getattr(text_config, 'num_local_experts', 0) or 0) - num_experts_per_tok = ( - getattr(text_config, 'num_experts_per_tok', 2) or getattr(text_config, 'moe_topk', 2) or 2) - shared_expert_size = getattr(text_config, 'shared_expert_intermediate_size', 0) or 0 - - # MLA config (for DeepSeek-V2/V3 style models) - q_lora_rank = getattr(text_config, 'q_lora_rank', None) - multi_latent_attention = q_lora_rank is not None - - # Create instance - instance = cls( - # Model architecture - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, - num_layers=getattr(text_config, 'num_hidden_layers', 32), - ffn_hidden_size=getattr(text_config, 'intermediate_size', 11008), - vocab_size=vocab_size, - padded_vocab_size=padded_vocab_size, - kv_channels=kv_channels, - # Parallelism - device_mesh=device_mesh, - sequence_parallel=sequence_parallel, - # RoPE - rotary_base=int(getattr(text_config, 'rope_theta', 10000)), - rotary_percent=1.0, - max_position_embeddings=getattr(text_config, 'max_position_embeddings', 4096), - original_max_position_embeddings=getattr(text_config, 'original_max_position_embeddings', None), - rope_scaling=rope_scaling, - # Model settings - model_dir=model_dir, - hf_model_type=model_type, - is_multimodal=is_multimodal, - # Bias settings - add_qkv_bias=add_qkv_bias, - add_bias_linear=getattr(text_config, 'mlp_bias', False), - qk_layernorm=qk_layernorm, - tie_word_embeddings=getattr(hf_config, 'tie_word_embeddings', False), - # MoE settings - num_experts=num_experts, - num_experts_per_tok=num_experts_per_tok, - shared_expert_intermediate_size=shared_expert_size, - # MLA settings - multi_latent_attention=multi_latent_attention, - q_lora_rank=q_lora_rank, - # Training - params_dtype=params_dtype, - task_type=task_type, - # Attention - attn_impl='flash_attn', - attention_backend='flash', - # Other - untie_embeddings_and_output_weights=not getattr(hf_config, 'tie_word_embeddings', False), - **kwargs, - ) - - # Store the original hf_config for attribute fallback - object.__setattr__(instance, '_hf_config', hf_config) - object.__setattr__(instance, '_text_config', text_config if text_config is not hf_config else None) - - # Apply convert_hf_config results to instance (like swift's init_model_args) - # This ensures derived values like qk_layernorm are correctly set - mg_config = convert_hf_config(hf_config) - for k, v in mg_config.items(): - if not hasattr(instance, k): - continue - current_value = getattr(instance, k) - if current_value is None: - object.__setattr__(instance, k, v) - elif current_value is False and isinstance(v, bool) and v: - # update false - object.__setattr__(instance, k, v) - - return instance - - def create_model(self, ) -> List[nn.Module]: - """Create Megatron GPT model from HuggingFace config. - - Args: - hf_config: HuggingFace model configuration. - padded_vocab_size: Padded vocabulary size. - - Returns: - Megatron GPT model. - """ - if self._model is not None: - return self._model - from megatron.core import parallel_state as mpu - from megatron.core.transformer import TransformerConfig - from megatron.core.transformer.enums import AttnBackend - - from .model.gpt_model import GPTModel - from .model.register import get_megatron_model_meta - hf_config = self.hf_config - padded_vocab_size = self.padded_vocab_size - # Convert HF config to Megatron config - mg_config_dict = convert_hf_config(hf_config) - - # Get registered model class (for multimodal models like Qwen3-VL) - model_meta = get_megatron_model_meta(self.hf_model_type) - ModelClass = model_meta.model_cls if model_meta else GPTModel - - # Build TransformerConfig - num_attention_heads = mg_config_dict['num_attention_heads'] - num_query_groups = mg_config_dict.get('num_query_groups', num_attention_heads) - num_layers = mg_config_dict['num_layers'] - - # Configure activation recomputation - recompute_method = self.recompute_method - recompute_num_layers = self.recompute_num_layers - - # Auto-configure for 'full' recomputation if not specified - if self.recompute_granularity == 'full': - if recompute_method is None: - recompute_method = 'uniform' - if recompute_num_layers is None: - # Recompute all layers for maximum memory savings - recompute_num_layers = num_layers // self.pp_size - - # Create finalize_model_grads function for DP gradient synchronization - # Megatron's native finalize_model_grads requires DDP-wrapped models with ddp_config. - # For PEFT/LoRA models, we use a custom implementation that handles non-DDP models. - from megatron.core.distributed import finalize_model_grads as _native_finalize_model_grads - - def finalize_model_grads_for_lora(model, *args, **kwargs): - from megatron.core.distributed import DistributedDataParallel as MegatronDDP - from peft import PeftModel as _PeftModel - - # Check if model is DDP-wrapped (has ddp_config) - # Need to unwrap PeftModel to check the underlying model - def _get_base_model(m): - if isinstance(m, _PeftModel): - return _get_base_model(m.base_model.model) - return m - - base_model = _get_base_model(model[0]) - if isinstance(base_model, MegatronDDP) or hasattr(base_model, 'ddp_config'): - # Use native implementation for DDP models - return _native_finalize_model_grads(model, *args, **kwargs) - - return - - # MoE configuration - num_experts = mg_config_dict.get('num_experts', 0) or 0 - moe_ffn_hidden_size = mg_config_dict.get('moe_ffn_hidden_size') - moe_router_topk = mg_config_dict.get('moe_router_topk', 2) or 2 - moe_shared_expert_intermediate_size = mg_config_dict.get('moe_shared_expert_intermediate_size') - - # Build MoE-related kwargs - moe_kwargs = {} - if num_experts > 0: - moe_kwargs.update({ - 'num_moe_experts': - num_experts, - 'moe_router_topk': - moe_router_topk, - 'moe_router_load_balancing_type': - mg_config_dict.get('moe_router_load_balancing_type', 'aux_loss'), - # MoE performance optimizations - 'moe_token_dispatcher_type': - mg_config_dict.get('moe_token_dispatcher_type', - 'alltoall'), # 'alltoall' is more efficient than 'allgather' - 'moe_grouped_gemm': - mg_config_dict.get('moe_grouped_gemm', - True), # Enable for better performance (requires grouped_gemm package) - 'moe_aux_loss_coeff': - mg_config_dict.get('moe_aux_loss_coeff', 0.0), # Auxiliary load balancing loss coefficient - }) - - # FFN hidden size for MoE - if moe_ffn_hidden_size: - moe_kwargs['moe_ffn_hidden_size'] = moe_ffn_hidden_size - - # Shared expert configuration - if moe_shared_expert_intermediate_size: - moe_kwargs['moe_shared_expert_intermediate_size'] = moe_shared_expert_intermediate_size - - # Router score function (sigmoid for Qwen3, softmax for others) - if mg_config_dict.get('moe_router_score_function'): - moe_kwargs['moe_router_score_function'] = mg_config_dict['moe_router_score_function'] - - # Expert bias for sigmoid router - if mg_config_dict.get('moe_router_enable_expert_bias'): - moe_kwargs['moe_router_enable_expert_bias'] = mg_config_dict['moe_router_enable_expert_bias'] - - # Sequence parallel requires TP > 1 - # Auto-enable for MoE with TP > 1 (required by Megatron) - use_sequence_parallel = self.sequence_parallel and self.tp_size > 1 - if num_experts > 0 and self.tp_size > 1 and not use_sequence_parallel: - use_sequence_parallel = True - # Sync the flag back so that callers (e.g. padding logic in - # megatron.py) see the auto-enabled value. - self.sequence_parallel = True - if self.device_mesh is not None: - self.device_mesh.sequence_parallel = True - - # For MoE models, ffn_hidden_size should be moe_ffn_hidden_size if not specified - ffn_hidden_size = mg_config_dict.get('ffn_hidden_size') - if ffn_hidden_size is None: - ffn_hidden_size = moe_ffn_hidden_size or (4 * mg_config_dict['hidden_size']) - - # For models with non-standard head dimensions (like Qwen3-30B-A3B) - kv_channels = mg_config_dict.get('kv_channels') - - # Activation function for SwiGLU (required by Megatron when gated_linear_unit=True) - use_swiglu = mg_config_dict.get('swiglu', True) - activation_func = torch.nn.functional.silu if use_swiglu else torch.nn.functional.gelu - - # Enable bias_activation_fusion for SwiGLU - # Note: Only works with TransformerEngine and no bias in linear layers - has_bias = not mg_config_dict.get('disable_bias_linear', True) - bias_activation_fusion = use_swiglu and not has_bias - if 'moe_token_dispatcher_type' not in moe_kwargs: - moe_kwargs['moe_token_dispatcher_type'] = 'alltoall' if self.variable_seq_lengths else 'allgather' - config = TransformerConfig( - num_layers=num_layers, - hidden_size=mg_config_dict['hidden_size'], - num_attention_heads=num_attention_heads, - num_query_groups=num_query_groups, - kv_channels=kv_channels, - ffn_hidden_size=ffn_hidden_size, - tensor_model_parallel_size=self.tp_size, - pipeline_model_parallel_size=self.pp_size, - context_parallel_size=self.cp_size, - expert_model_parallel_size=self.ep_size, - virtual_pipeline_model_parallel_size=self.vpp_size, - sequence_parallel=use_sequence_parallel, - params_dtype=self.params_dtype, - fp16=self.params_dtype == torch.float16, - bf16=self.params_dtype == torch.bfloat16, - pipeline_dtype=self.params_dtype, # Required when using pipeline parallelism - use_cpu_initialization=self.use_cpu_initialization, - add_qkv_bias=self.add_qkv_bias, - variable_seq_lengths=self.variable_seq_lengths, - add_bias_linear=not mg_config_dict.get('disable_bias_linear', True), - gated_linear_unit=use_swiglu, - activation_func=activation_func, # SiLU for SwiGLU, GELU otherwise - bias_activation_fusion=bias_activation_fusion, # Fused SwiGLU for performance - normalization='RMSNorm', - layernorm_epsilon=mg_config_dict.get('norm_epsilon', 1e-6), - qk_layernorm=mg_config_dict.get('qk_layernorm', False), - hidden_dropout=0.0, - attention_dropout=0.0, - # Performance optimizations - masked_softmax_fusion=True, # Fused attention softmax - bias_dropout_fusion=True, # Fused bias + dropout - apply_rope_fusion=True, # Fused RoPE application - attention_softmax_in_fp32=True, # Numerical stability - attention_backend=AttnBackend.flash, # FlashAttention for speed - # Activation recomputation for memory efficiency - recompute_granularity=self.recompute_granularity, - recompute_modules=self.recompute_modules if self.recompute_granularity == 'selective' else None, - recompute_method=recompute_method, - recompute_num_layers=recompute_num_layers, - # Critical: Set finalize_model_grads_func for DP gradient synchronization - # Uses custom wrapper that handles both DDP and PEFT/LoRA models - finalize_model_grads_func=finalize_model_grads_for_lora, - calculate_per_token_loss=True, - # MoE configuration - **moe_kwargs, - ) - if exists('megatron_core>=0.13'): - config.expert_tensor_parallel_size = self.etp_size - - if mg_config_dict.get('use_shared_expert_gate'): - config.moe_use_shared_expert_gate = True - if mg_config_dict.get('rotary_interleaved'): - config.rotary_interleaved = True - partial_rotary_factor = mg_config_dict.get('partial_rotary_factor') - if partial_rotary_factor is not None: - config.rotary_percent = partial_rotary_factor - config.apply_rope_fusion = False - mrope_section = mg_config_dict.get('mrope_section') - if mrope_section is not None: - config.mrope_section = mrope_section - if mg_config_dict.get('mrope_interleaved'): - config.mrope_interleaved = True - - self.config = config - - # Delegate model-specific config & layer spec construction to the loader - loader = model_meta.loader() if model_meta else None - if loader is not None: - loader.post_config(config, self, mg_config_dict) - layer_spec = loader.get_layer_spec(config, self, mg_config_dict) - else: - from .model.register import MegatronModelLoader - default_loader = MegatronModelLoader() - default_loader.post_config(config, self, mg_config_dict) - layer_spec = default_loader.get_layer_spec(config, self, mg_config_dict) - - # Create model - max_seq_length = getattr(hf_config, 'max_position_embeddings', 4096) - rotary_base = mg_config_dict.get('rotary_base', 10000) - position_embedding_type = mg_config_dict.get('position_embedding_type', 'rope') - extra_init_args = {} - if hasattr(hf_config, - 'rope_scaling') and hf_config.rope_scaling is not None and 'factor' in hf_config.rope_scaling: - extra_init_args = {'seq_len_interpolation_factor': hf_config.rope_scaling['factor']} - vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() - if vpp_size is not None and vpp_size > 1: - model = [] - has_vp_stage = inspect.signature(mpu.is_pipeline_first_stage).parameters.get('vp_stage', None) is not None - for i in range(vpp_size): - mpu.set_virtual_pipeline_model_parallel_rank(i) - extra_kwargs = {} if not has_vp_stage else {'ignore_virtual': False, 'vp_stage': i} - if has_vp_stage: - extra_init_args['vp_stage'] = i - _model = ModelClass( - config=config, - transformer_layer_spec=layer_spec, - vocab_size=padded_vocab_size, - max_sequence_length=max_seq_length, - pre_process=mpu.is_pipeline_first_stage(**extra_kwargs), - post_process=mpu.is_pipeline_last_stage(**extra_kwargs), - parallel_output=True, - share_embeddings_and_output_weights=getattr(hf_config, 'tie_word_embeddings', False), - position_embedding_type=position_embedding_type, - rotary_base=rotary_base, - **extra_init_args) - model.append(_model) - mpu.set_virtual_pipeline_model_parallel_rank(0) - else: - model = ModelClass( - config=config, - transformer_layer_spec=layer_spec, - vocab_size=padded_vocab_size, - max_sequence_length=max_seq_length, - pre_process=mpu.is_pipeline_first_stage(), - post_process=mpu.is_pipeline_last_stage(), - parallel_output=True, - share_embeddings_and_output_weights=getattr(hf_config, 'tie_word_embeddings', False), - position_embedding_type=position_embedding_type, - rotary_base=rotary_base, - **extra_init_args, - ) - model = [model] - self._model = model - return model diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index d5923f86..cdc45246 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -1,23 +1,23 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import asyncio -import inspect import json import logging -import numpy as np import os import random import re import threading +from dataclasses import dataclass +from typing import Any, Callable, Dict, Generator, List, Literal, Optional, Tuple, Type, Union + +import numpy as np import torch import torch.distributed as dist import torch.nn as nn -from dataclasses import dataclass, field from peft import LoraConfig, PeftConfig, PeftModel, get_peft_model from peft.tuners.lora import Linear as LoraLinear from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler -from transformers import AutoConfig, PretrainedConfig -from typing import Any, Callable, Dict, Generator, List, Literal, Optional, Tuple, Type, Union +from transformers import PretrainedConfig import twinkle import twinkle.metric @@ -34,7 +34,7 @@ from twinkle.patch import Patch, apply_patch from twinkle.processor import InputProcessor from twinkle.template import Template -from twinkle.utils import construct_class, exists, selective_log_softmax +from twinkle.utils import construct_class, selective_log_softmax from .strategy import MegatronStrategy @@ -74,29 +74,6 @@ def _get_lr(self): _default_adapter_name = '' -_BASE_LAYER_SUFFIXES = [ - '.q_proj.weight', - '.q_proj.bias', - '.k_proj.weight', - '.k_proj.bias', - '.v_proj.weight', - '.v_proj.bias', - '.o_proj.weight', - '.o_proj.bias', - '.gate_proj.weight', - '.up_proj.weight', - '.down_proj.weight', - '.mlp.gate.weight', - '.mlp.gate.bias', - '.mlp.gate.e_score_correction_bias', - '.in_proj_qkv.weight', - '.in_proj_z.weight', - '.in_proj_a.weight', - '.in_proj_b.weight', - '.out_proj.weight', - '.conv1d.weight', -] - def _add_base_layer_suffix(params): """Insert ``.base_layer.`` before the final attribute for LoRA-target modules. @@ -141,7 +118,7 @@ def __init__( **kwargs, ): requires('megatron_core') - from .args import TwinkleMegatronArgs, get_args, set_args + requires('mcore_bridge') os.environ['TOKENIZERS_PARALLELISM'] = 'true' nn.Module.__init__(self) from twinkle.patch.megatron_peft import MegatronPeft @@ -149,45 +126,24 @@ def __init__( self.model_id = model_id self.device_mesh = device_mesh self.mixed_precision = mixed_precision - self._model_path = HubOperation.download_model(model_id) - self.hf_config = config or AutoConfig.from_pretrained(self._model_path) self.tokenizer_id = kwargs.get('tokenizer_id', self.model_id) - - self._seed = kwargs.pop('seed', None) or int(os.environ.get('TWINKLE_SEED', 42)) self._default_tokenizer = None self.use_distributed_optimizer = kwargs.get('use_distributed_optimizer', True) self.variable_seq_lengths = kwargs.get('variable_seq_lengths', False) torch_util.set_device() + self._try_init_process_group() - self.strategy = MegatronStrategy(self.device_mesh, mixed_precision=mixed_precision, **kwargs) - - # Determine params_dtype and activation checkpointing kwargs - params_dtype = torch.bfloat16 - if self.mixed_precision == 'fp16': - params_dtype = torch.float16 - elif self.mixed_precision == 'no': - params_dtype = torch.float32 - - ac_kwargs = { + kwargs.update({ 'recompute_granularity': recompute_granularity, 'recompute_modules': recompute_modules, 'recompute_method': recompute_method, 'recompute_num_layers': recompute_num_layers, - } - - # Initialize TwinkleMegatronArgs BEFORE creating the model - args = TwinkleMegatronArgs.from_hf_config( - self.hf_config, - model_dir=self._model_path, - device_mesh=self.device_mesh, - params_dtype=params_dtype, - sequence_parallel=self.strategy.sequence_parallel, - **ac_kwargs, - ) - set_args(args) - self._initialized = False - self.model: List[nn.Module] = self._create_megatron_model(load_weights, **kwargs) + }) + seed = kwargs.pop('seed', None) or int(os.environ.get('TWINKLE_SEED', 42)) + self.strategy = MegatronStrategy(self._model_path, self.device_mesh, mixed_precision=mixed_precision, + seed=seed, **kwargs) + self.model: List[nn.Module] = self.strategy.create_megatron_model(load_weights) self._model_wrapped = False # This correctly handles vocab sharding in Tensor Parallelism @@ -206,30 +162,6 @@ def _construct_default_optimizer_group(self): _device_mesh=self.device_mesh, ) - def _create_megatron_model( - self, - load_weights: bool = True, - **kwargs, - ) -> List[nn.Module]: - from .args import get_args - args = get_args() - self.initialize(**kwargs) - - model = args.create_model() - if load_weights: - bridge = self._bridge - for _model in model: - bridge.load_weights(_model, args.model_dir) - - if dist.is_initialized(): - dist.barrier() - - _models = [] - for _model in model: - _model = self._move_model_to_gpu(_model) - _models.append(_model) - return _models - @staticmethod def _move_model_to_gpu(model: nn.Module) -> nn.Module: model = model.to(Platform.get_local_device()) @@ -809,7 +741,6 @@ def _create_megatron_optimizer(self, **kwargs): ) # Ensure each model chunk has ddp_config attached (required by Megatron optimizer) - from megatron.core.distributed import DistributedDataParallelConfig model_chunks = self.model for model_chunk in model_chunks: assert hasattr(model_chunk, 'ddp_config') @@ -1241,7 +1172,7 @@ def _read_iteration(tracker_path: str) -> int: def _merge_lora_adapters(self, adapter_name: str = 'default'): """Merge LoRA adapters into base model weights.""" - from .tuners.lora import LoraParallelLinear + from mcore_bridge import LoraParallelLinear with torch.no_grad(): for model in self.strategy.unwrap_model(self.model): for module in model.modules(): @@ -1250,7 +1181,7 @@ def _merge_lora_adapters(self, adapter_name: str = 'default'): def _unmerge_lora_adapters(self): """Unmerge LoRA adapters to restore training state.""" - from .tuners.lora import LoraParallelLinear + from mcore_bridge import LoraParallelLinear with torch.no_grad(): for model in self.strategy.unwrap_model(self.model): for module in model.modules(): @@ -1506,52 +1437,9 @@ def get_train_configs(self, **kwargs): return expr - def initialize(self, **kwargs) -> None: - if self._initialized: - return - - from megatron.core import parallel_state - from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed - - from .args import get_args - self._try_init_process_group() - args = get_args() - init_kwargs = { - 'tensor_model_parallel_size': args.tensor_model_parallel_size, - 'pipeline_model_parallel_size': args.pipeline_model_parallel_size, - 'context_parallel_size': args.context_parallel_size, - 'virtual_pipeline_model_parallel_size': args.virtual_pipeline_model_parallel_size, - 'expert_model_parallel_size': args.expert_model_parallel_size, - } - - if args.order: - init_kwargs['order'] = args.order - - if exists('megatron_core>=0.13'): - init_kwargs['expert_tensor_parallel_size'] = args.expert_tensor_parallel_size - - # Filter out kwargs that are not valid for initialize_model_parallel - # Dynamically check the signature to exclude unsupported parameters - valid_params = set(inspect.signature(parallel_state.initialize_model_parallel).parameters.keys()) - filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} - init_kwargs.update(filtered_kwargs) - parallel_state.initialize_model_parallel(**init_kwargs) - model_parallel_cuda_manual_seed(self._seed) - - self._parallel_state = parallel_state - self._initialized = True - @property def _bridge(self) -> 'GPTBridge': - if not hasattr(self, '_bridge_instance'): - from .args import get_args - from .model import get_megatron_model_meta - args = get_args() - megatron_model_meta = get_megatron_model_meta(args.hf_model_type) - assert megatron_model_meta is not None, f'Model: {args.hf_model_type} is not supported.' - self._bridge_instance = megatron_model_meta.bridge_cls() - - return self._bridge_instance + return self.strategy.config.bridge # ── Checkpoint Engine (from CheckpointEngineMixin) ────────────────── # prepare_checkpoint_engine, init_checkpoint_process_group, and @@ -1583,9 +1471,7 @@ def send_weights( # Trim any tensor whose dim-0 equals padded_vocab_size back to # org_vocab_size — this is shape-based, not name-based, so it works # regardless of the model architecture's naming convention. - from .args import get_args - args = get_args() - org_vocab_size = getattr(self.hf_config, 'vocab_size', args.padded_vocab_size) + org_vocab_size = getattr(self.hf_config, 'vocab_size', self.strategy.config.padded_vocab_size) _padded_vocab_size = args.padded_vocab_size def _trim_vocab(name, tensor): diff --git a/src/twinkle/model/megatron/model/__init__.py b/src/twinkle/model/megatron/model/__init__.py deleted file mode 100644 index c61acef9..00000000 --- a/src/twinkle/model/megatron/model/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from . import gpts, mm_gpts -from .constant import MegatronModelType -from .gpt_bridge import GPTBridge -from .register import MegatronModelLoader, MegatronModelMeta, get_megatron_model_meta, register_megatron_model diff --git a/src/twinkle/model/megatron/model/constant.py b/src/twinkle/model/megatron/model/constant.py deleted file mode 100644 index 968186ac..00000000 --- a/src/twinkle/model/megatron/model/constant.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. - - -# LLMModelType/MLLMModelType: model_type attribute in model config -class LLMModelType: - qwen2 = 'qwen2' - qwen2_moe = 'qwen2_moe' - qwen3 = 'qwen3' - qwen3_moe = 'qwen3_moe' - - -class MLLMModelType: - qwen2_vl = 'qwen2_vl' - qwen2_5_vl = 'qwen2_5_vl' - qwen3_vl = 'qwen3_vl' - qwen3_vl_moe = 'qwen3_vl_moe' - qwen3_5 = 'qwen3_5' - qwen3_5_moe = 'qwen3_5_moe' - - -class ModelType(LLMModelType, MLLMModelType): - pass - - -# LLMMegatronModelType/MLLMMegatronModelType: megatron model architecture type -class LLMMegatronModelType: - gpt = 'gpt' - - -class MLLMMegatronModelType: - qwen2_vl = 'qwen2_vl' - qwen2_5_vl = 'qwen2_5_vl' - qwen3_vl = 'qwen3_vl' - qwen3_5 = 'qwen3_5' - qwen3_5_moe = 'qwen3_5_moe' - - -class MegatronModelType(LLMMegatronModelType, MLLMMegatronModelType): - pass diff --git a/src/twinkle/model/megatron/model/gpt_bridge.py b/src/twinkle/model/megatron/model/gpt_bridge.py deleted file mode 100644 index daea5c90..00000000 --- a/src/twinkle/model/megatron/model/gpt_bridge.py +++ /dev/null @@ -1,1651 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -# Reference: swift/swift/megatron/model/gpt_bridge.py - -import math -import os -import re -import shutil -import torch -import torch.distributed as dist -import torch.nn.functional as F -import transformers -from copy import copy -from packaging import version -from peft.utils import ModulesToSaveWrapper -from tqdm import tqdm -from transformers import AutoConfig, AutoProcessor, AutoTokenizer -from transformers.modeling_utils import PreTrainedModel, custom_object_save -from typing import Callable, List, Optional, Union - -from twinkle.hub import HubOperation -from twinkle.model.megatron.args import get_args # Use twinkle's get_args -from twinkle.utils import (MxFp4Dequantizer, SafetensorLazyLoader, StreamingSafetensorSaver, deep_getattr, get_logger, - get_modules_to_not_convert, is_last_rank, requires) - -logger = get_logger() - - -# Some ideas for LoRA conversion are referenced from: https://github.com/modelscope/ms-swift/pull/6225 -class GPTBridge: - fp8_block_size = 128 - hf_layers_prefix = 'model.layers' - hf_mtp_prefix = 'model.layers' - hf_embed_key = 'model.embed_tokens.weight' - hf_final_layernorm_key = 'model.norm.weight' - hf_lm_head_key = 'lm_head.weight' - hf_score_key = 'score.weight' - hf_state_dict_mapping = {} - - def __init__(self, disable_tqmd: bool = False): - from .register import get_megatron_model_meta - requires('megatron_core') - import megatron.core as megatron_core - from megatron.core import mpu - - from ..tuners import LoraParallelLinear - self.megatron_core = megatron_core - self.mpu = mpu - self.LoraParallelLinear = LoraParallelLinear - self.args = get_args() - self.disable_tqmd = disable_tqmd or not is_last_rank() - self._target_device = None - self._only_last_rank = False - self._peft_target_modules = set() - self._peft_modules_to_save = set() - self._is_peft_format = False - self._adapter_name = 'default' - self._init_meta_hf_model() - # Get HF layers if model was loaded, otherwise None - self.hf_layers = deep_getattr(self.hf_model, self.hf_layers_prefix) if self.hf_model is not None else None - self.module_mapping = {} - self.mcore_013 = version.parse(self.megatron_core.__version__) >= version.parse('0.13.0rc0') - self.mcore_014 = version.parse(self.megatron_core.__version__) >= version.parse('0.14.0rc0') - megatron_model_meta = get_megatron_model_meta(self.args.hf_model_type) - if self.args.is_multimodal and megatron_model_meta.visual_cls is not None: - self.module_mapping = megatron_model_meta.visual_cls.module_mapping - self.tp_size = self.args.tensor_model_parallel_size - self.pp_size = self.args.pipeline_model_parallel_size - self.etp_size = self.args.expert_tensor_parallel_size - self.ep_size = self.args.expert_model_parallel_size - - self.tp_group = self.mpu.get_tensor_model_parallel_group() - self.pp_group = self.mpu.get_pipeline_model_parallel_group() - self.etp_group = self.mpu.get_expert_tensor_parallel_group() - self.ep_group = self.mpu.get_expert_model_parallel_group() - self.is_transformers_5 = version.parse(transformers.__version__) >= version.parse('5.0.0.dev') - self.tp_rank = self.mpu.get_tensor_model_parallel_rank() - self.pp_rank = self.mpu.get_pipeline_model_parallel_rank() - self.etp_rank = self.mpu.get_expert_tensor_parallel_rank() - self.ep_rank = self.mpu.get_expert_model_parallel_rank() - - self._fp8_quantizer = None - self.mxfp4_quantizer = MxFp4Dequantizer() - - dp_size = dist.get_world_size() // self.etp_size // self.ep_size // self.pp_size - expert_decoder_rank_generator = self.mpu.RankGenerator( - tp=self.etp_size, - ep=self.ep_size, - dp=dp_size, - pp=self.pp_size, - cp=1, - order='tp-cp-ep-dp-pp', - rank_offset=0, - ) - rank = dist.get_rank() - for ranks in expert_decoder_rank_generator.get_ranks('ep-pp'): - group = self.mpu.create_group( - ranks, - group_desc='EP-PP-GROUP', - ) - if rank in ranks: - self.ep_pp_size = self.ep_size * self.pp_size - self.ep_pp_group = group - self.ep_pp_rank = dist.get_rank(group) - - def get_hf_mlp_prefix(self, layer_idx): - if hasattr(self.hf_layers[layer_idx], 'feed_forward'): - return 'feed_forward' - else: - return 'mlp' - - def _get_hf_mlp(self, layer_idx): - return getattr(self.hf_layers[layer_idx], self.get_hf_mlp_prefix(layer_idx)) - - def _get_transpose(self): - if self.args.hf_model_type in {'qwen3_vl_moe', 'gpt_oss', 'llama4'}: - return True - else: - return False - - def _init_meta_hf_model(self): - import copy - - from .register import get_megatron_model_meta - - model_dir = self.args.model_dir - model_type = self.args.hf_model_type - - # Get the correct AutoModel class from MegatronModelMeta - megatron_model_meta = get_megatron_model_meta(model_type) - auto_model_cls = megatron_model_meta.auto_model_cls if megatron_model_meta else None - if auto_model_cls is None: - from transformers import AutoModelForCausalLM - auto_model_cls = AutoModelForCausalLM - - # Load config first - config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) - config.torch_dtype = self.args.params_dtype - - with torch.device('meta'): - origin_dtype = torch.get_default_dtype() - torch.set_default_dtype(self.args.params_dtype) - config_copy = copy.deepcopy(config) - # Auto classes have from_config, concrete model classes have _from_config - if hasattr(auto_model_cls, 'from_config'): - self.hf_model = auto_model_cls.from_config(config_copy, trust_remote_code=True) - else: - self.hf_model = auto_model_cls._from_config(config_copy) - torch.set_default_dtype(origin_dtype) - - if os.path.exists(os.path.join(model_dir, 'preprocessor_config.json')): - auto_tokenizer_cls = AutoProcessor - else: - auto_tokenizer_cls = AutoTokenizer - - self.processor = auto_tokenizer_cls.from_pretrained(model_dir, trust_remote_code=True) - - def _get_tp_split_dim(self, mg_key: Optional[str]) -> Optional[int]: - if mg_key is None: - return - # ColumnLinear - dim0_keys = { - 'word_embeddings', - 'linear_qkv', - # mla - 'linear_q_proj', - 'linear_q_up_proj', - 'linear_kv_up_proj', - # mtp - 'eh_proj', - } - if self.args.task_type == 'causal_lm': - dim0_keys.add('output_layer') - if not self.mcore_014: - # https://github.com/NVIDIA/Megatron-LM/commit/720c8b40d8e7e2de1dd303d792f29093101c5e72 - dim0_keys.update({'linear_q_down_proj', 'linear_kv_down_proj'}) - # RowLinear - dim1_keys = {'linear_proj', 'linear_fc2'} - if 'lora_A' not in mg_key and 'lora_B' not in mg_key: - key, suffix = mg_key.rsplit('.', 2)[-2:] - if suffix == 'layer_norm_weight': - return - elif mg_key == 'core_attention.softmax_offset': - return 0 - elif key in dim0_keys: - return 0 - elif key in {'linear_fc1'} | dim1_keys and suffix != 'bias': - # linear_fc1 shape [2, X, Y] - return 1 - else: - mg_key_splited = mg_key.rsplit('.', 3) - key, lora_name = mg_key_splited[:2] - if lora_name == 'lora_A': - if key in dim1_keys: - return 1 - elif lora_name == 'lora_B': - if key in dim0_keys: - return 0 - elif key in {'linear_fc1'}: - return 1 - - def _split_tp(self, hf_weight, tp_dim, is_expert): - tp_size = self.etp_size if is_expert else self.tp_size - tp_rank = self.etp_rank if is_expert else self.tp_rank - if tp_dim is not None and tp_size > 1: - tensor = hf_weight.chunk(tp_size, dim=tp_dim)[tp_rank] - else: - tensor = hf_weight - return tensor - - def _set_weight( - self, - mg_param: Union[torch.Tensor, List[torch.Tensor]], - hf_weight: torch.Tensor, - mg_key: str, - offset: float = 0, - is_expert: bool = False, - *, - hf_scale_inv: Optional[torch.Tensor] = None, - ): - # tp/etp - tp_dim = self._get_tp_split_dim(mg_key) - tensor = self._split_tp(hf_weight, tp_dim, is_expert) - del hf_weight - if not isinstance(mg_param, (list, tuple)): - mg_param = [mg_param] - if hf_scale_inv is not None: - hf_scale_inv = self._split_tp(hf_scale_inv, tp_dim, is_expert) - hf_scale_inv = hf_scale_inv.chunk(len(mg_param), dim=0) - if offset: - assert hf_scale_inv is None, f'mg_key: {mg_key}' - tensor = tensor + offset - tensor_list = tensor.chunk(len(mg_param), dim=0) - for i, param in enumerate(mg_param): - tensor = tensor_list[i].reshape(*param.shape) - if self._is_fp8_param(param): - if hf_scale_inv is None: - param.data.copy_(tensor) - param._high_precision_init_val.copy_(tensor) - else: - tensor = tensor.view(torch.uint8) - param._rowwise_data.data.copy_(tensor) - self._copy_scale_inv(param, hf_scale_inv[i]) - del param.get_high_precision_init_val - else: - if hf_scale_inv is not None: - fp8_tensor = self.fp8_quantizer.make_empty(tensor.shape) - fp8_tensor._rowwise_data.copy_(tensor.view(torch.uint8)) - self._copy_scale_inv(fp8_tensor, hf_scale_inv[i]) - tensor = fp8_tensor - param.data.copy_(tensor) - - @staticmethod - def _copy_scale_inv(tensor, scale_inv): - scale_inv = scale_inv.reshape(-1, scale_inv.shape[-1]) - if scale_inv.shape[-1] < tensor._rowwise_scale_inv.shape[-1]: - scale_inv = torch.concat([ - scale_inv, - scale_inv.new_zeros((scale_inv.shape[0], tensor._rowwise_scale_inv.shape[-1] - scale_inv.shape[1])) - ], - dim=-1) - tensor._rowwise_scale_inv.data.copy_(scale_inv) - - @property - def fp8_quantizer(self): - if self._fp8_quantizer is None: - from transformer_engine.pytorch import Float8BlockQuantizer - from transformer_engine_torch import DType as TE_DType - self._fp8_quantizer = Float8BlockQuantizer(TE_DType.kFloat8E4M3, rowwise=True, columnwise=True) - return self._fp8_quantizer - - @staticmethod - def _is_fp8_param(param): - try: - from transformer_engine.pytorch import Float8BlockwiseQTensor - return isinstance(param, Float8BlockwiseQTensor) - except ImportError: - return False - - def _set_module(self, mg_module, hf_state_dict, hf_prefix: str, to_mcore: bool): - if to_mcore: - if mg_module is None: - return {} - hf_state_dict = {k: v.load() for k, v in self._remove_prefix(hf_state_dict, hf_prefix).items()} - if self._is_peft_format: - new_state_dict = {} - for k, v in hf_state_dict.items(): - k = k.replace('.lora_A.', f'.lora_A.{self._adapter_name}.') - k = k.replace('.lora_B.', f'.lora_B.{self._adapter_name}.') - k = k.replace('.modules_to_save.', f'.modules_to_save.{self._adapter_name}.') - new_state_dict[k] = v - hf_state_dict = new_state_dict - incompatible_keys = mg_module.load_state_dict(hf_state_dict, strict=False) - missing_keys = incompatible_keys.missing_keys - if self._is_peft_format: - missing_keys = [ - k for k in incompatible_keys.missing_keys - if '.lora_A.' in k or '.lora_B.' in k or '.modules_to_save.' in k - ] - assert len(missing_keys) == 0, f'incompatible_keys.missing_keys: {missing_keys}' - return {} - else: - hf_state_dict = None if mg_module is None else mg_module.state_dict() - if hf_state_dict is not None: - new_state_dict = {} - for k, v in hf_state_dict.items(): - if self._is_peft_format: - if '.lora_A.' in k or '.lora_B.' in k or '.modules_to_save.' in k: - k = k.replace(f'{self._adapter_name}.', '') - if '.lora_A.' in k: - module_name = k.split('.lora_A.')[0].rsplit('.', 1)[-1] - self._peft_target_modules.add(module_name) - new_state_dict[k] = v - else: - if '.lora_A.' in k or '.lora_B.' in k or 'original_module.' in k: - continue - k = k.replace('base_layer.', '') - k = k.replace(f'modules_to_save.{self._adapter_name}.', '') - new_state_dict[k] = v - hf_state_dict = new_state_dict - if self.pp_size > 1: - src_rank = torch.tensor([0 if hf_state_dict is None else self.pp_rank], - dtype=torch.int64, - device='cuda') - dist.all_reduce(src_rank, group=self.pp_group) - src_rank = dist.get_global_rank(self.pp_group, src_rank.item()) - meta_data = [None] if hf_state_dict is None else [list(hf_state_dict.keys())] - dist.broadcast_object_list(meta_data, src=src_rank, group=self.pp_group) - if meta_data[0] is None: - return {} - hf_state_dict = hf_state_dict or {k: None for k in meta_data[0]} - for k, v in hf_state_dict.items(): - v, _ = self._get_weight(v, None) - hf_state_dict[k] = v - elif hf_state_dict is None: - return {} - else: - if self._target_device is not None: - for k, v in hf_state_dict.items(): - hf_state_dict[k] = v.to(self._target_device) - return self._add_prefix(hf_state_dict, hf_prefix) - - def _all_gather_tp(self, tensor, tp_dim, is_expert): - tensor = None if tensor is None else tensor.to('cuda') - tp_size = self.etp_size if is_expert else self.tp_size - tp_group = self.etp_group if is_expert else self.tp_group - if tensor is not None and tp_dim is not None and tp_size > 1: - if tp_dim == 0: - # save memory - tensor_shape = list(tensor.shape) - tensor_shape[0] *= tp_size - output = tensor.new_empty(tensor_shape) - dist.all_gather_into_tensor( - output, - tensor, - group=tp_group, - ) - tensor = output - else: - output = [torch.empty_like(tensor) for _ in range(tp_size)] - dist.all_gather( - output, - tensor, - group=tp_group, - ) - tensor = torch.cat(output, dim=tp_dim) - del output - return tensor - - def _broadcast_ep_pp(self, tensor, is_expert): - pp_group = self.ep_pp_group if is_expert else self.pp_group - pp_size = self.ep_pp_size if is_expert else self.pp_size - pp_rank = self.ep_pp_rank if is_expert else self.pp_rank - # pp/ep - if pp_size > 1: - src_rank = torch.tensor([0 if tensor is None else pp_rank], dtype=torch.int64, device='cuda') - dist.all_reduce(src_rank, group=pp_group) - src_rank = dist.get_global_rank(pp_group, src_rank.item()) - meta_data = torch.zeros(10, dtype=torch.int64, device='cuda') - dtype_mapping = {torch.float64: 0, torch.float32: 1, torch.float16: 2, torch.bfloat16: 3, torch.uint8: 4} - dtype_mapping_r = {v: k for k, v in dtype_mapping.items()} - if tensor is None: - dist.broadcast(meta_data, src=src_rank, group=pp_group) - assert meta_data[0].item() > 0, f'meta_data: {meta_data}' - shape = meta_data[1:1 + meta_data[0]].tolist() - dtype = dtype_mapping_r[meta_data[-1].item()] - tensor = torch.empty(shape, device='cuda', dtype=dtype) - dist.broadcast(tensor, src=src_rank, group=pp_group) - else: - meta_data[0] = tensor.ndim - meta_data[1:1 + tensor.ndim] = torch.tensor(tensor.shape, dtype=torch.int64, device='cuda') - meta_data[-1] = dtype_mapping[tensor.dtype] - dist.broadcast(meta_data, src=src_rank, group=pp_group) - dist.broadcast(tensor, src=src_rank, group=pp_group) - return tensor - - def _get_weight( - self, - mg_weight: Union[torch.Tensor, List[torch.Tensor]], - mg_key: Optional[str], - offset: float = 0, - is_expert: bool = False, - ): - # tp/etp - mg_scale_inv = None - tensor = mg_weight - if tensor is not None: - if not isinstance(tensor, (list, tuple)): - tensor = [tensor] - if self._is_fp8_param(tensor[0]): - mg_scale_inv = [t._rowwise_scale_inv for t in tensor] - tensor = [t._rowwise_data for t in tensor] - del mg_weight - if tensor is not None: - assert isinstance(tensor, (list, tuple)), f'mg_key: {mg_key}' - tensor = torch.concat(tensor, dim=0) - if mg_scale_inv is not None: - mg_scale_inv = torch.concat(mg_scale_inv, dim=0) - num_local_experts = self.args.num_experts // self.ep_size if is_expert else 1 - tp_dim = self._get_tp_split_dim(mg_key) - is_linear_fc1 = (mg_key is not None and mg_key.split('.', 1)[0] == 'linear_fc1' and tp_dim is not None) - if tensor is not None and is_linear_fc1: - tensor = tensor.view(num_local_experts * 2, -1, tensor.shape[-1]) - if mg_scale_inv is not None: - mg_scale_inv = mg_scale_inv.view(num_local_experts * 2, -1, mg_scale_inv.shape[-1]) - - tensor = self._all_gather_tp(tensor, tp_dim, is_expert) - tensor = self._broadcast_ep_pp(tensor, is_expert) - if tensor.dtype == torch.uint8: - mg_scale_inv = self._all_gather_tp(mg_scale_inv, tp_dim, is_expert) - mg_scale_inv = self._broadcast_ep_pp(mg_scale_inv, is_expert) - tensor = tensor.view(torch.float8_e4m3fn) - mg_scale_inv = mg_scale_inv[..., :math.ceil(tensor.shape[-1] / self.fp8_block_size)].contiguous() - assert tensor is not None, f'mg_key: {mg_key}' - if offset: - assert mg_scale_inv is None, f'mg_key: {mg_key}' - tensor = tensor + offset - if self._target_device is not None: - tensor = tensor.to(device=self._target_device) - if mg_scale_inv is not None: - mg_scale_inv = mg_scale_inv.to(device=self._target_device) - if self._only_last_rank and not is_last_rank(): - tensor = None - mg_scale_inv = None - if is_expert and tensor is not None: - if mg_key.endswith('bias'): - tensor = tensor.view(num_local_experts, -1) - else: - tensor = tensor.view(num_local_experts, -1, tensor.shape[-1]) - if mg_scale_inv is not None: - mg_scale_inv = mg_scale_inv.view(num_local_experts, -1, mg_scale_inv.shape[-1]) - return tensor, mg_scale_inv - - def _set_state_dict(self, - mg_module, - mg_key: str, - hf_state_dict, - hf_key: str, - to_mcore: bool, - *, - offset: float = 0, - is_expert: bool = False): - module_key, param_key = mg_key.rsplit('.', 1) - if '.' in hf_key: - hf_module_key, hf_param_key = hf_key.rsplit('.', 1) - else: - hf_module_key, hf_param_key = hf_key, None - sub_module = deep_getattr(mg_module, module_key) - is_lora = isinstance(sub_module, self.LoraParallelLinear) - is_modules_to_save = isinstance(sub_module, ModulesToSaveWrapper) - if not to_mcore: - state = torch.tensor([is_lora, is_modules_to_save], dtype=torch.bool, device='cuda') - if is_expert and self.ep_pp_size > 1: - dist.all_reduce(state, group=self.ep_pp_group) - elif not is_expert and self.pp_size > 1: - dist.all_reduce(state, group=self.pp_group) - is_lora, is_modules_to_save = state - if is_lora and self._is_peft_format and param_key != 'layer_norm_weight': - if to_mcore: - lora_A_key = f'{module_key}.lora_A.{self._adapter_name}.{param_key}' - lora_B_key = f'{module_key}.lora_B.{self._adapter_name}.{param_key}' - mg_lora_A = deep_getattr(mg_module, f'{lora_A_key}') - mg_lora_B = deep_getattr(mg_module, f'{lora_B_key}') - hf_lora_A = hf_state_dict[f'{hf_module_key}.lora_A.{hf_param_key}'].load() - hf_lora_B = hf_state_dict[f'{hf_module_key}.lora_B.{hf_param_key}'].load() - self._set_weight(mg_lora_A, hf_lora_A, lora_A_key, offset, is_expert) - self._set_weight(mg_lora_B, hf_lora_B, lora_B_key, offset, is_expert) - else: - lora_A_key = f'{module_key}.lora_A.{self._adapter_name}.{param_key}' - lora_B_key = f'{module_key}.lora_B.{self._adapter_name}.{param_key}' - lora_A_tensor = deep_getattr(mg_module, f'{lora_A_key}.data') - lora_B_tensor = deep_getattr(mg_module, f'{lora_B_key}.data') - hf_lora_A_key = f'{hf_module_key}.lora_A.{hf_param_key}' - hf_lora_B_key = f'{hf_module_key}.lora_B.{hf_param_key}' - lora_A, _ = self._get_weight(lora_A_tensor, lora_A_key, offset, is_expert) - lora_B, _ = self._get_weight(lora_B_tensor, lora_B_key, offset, is_expert) - if lora_A is not None: - self._peft_target_modules.add(hf_module_key) - hf_state_dict[hf_lora_A_key] = lora_A - hf_state_dict[hf_lora_B_key] = lora_B - elif not self._is_peft_format or is_modules_to_save: - if is_lora: - mg_param = deep_getattr(sub_module, f'base_layer.{param_key}') - else: - mg_param = deep_getattr(sub_module, param_key) - if to_mcore: - assert mg_param is not None, f'mg_module: {mg_module}, mg_key: {mg_key}' - hf_weight = hf_state_dict[hf_key].load() - if module_key in {'embedding.word_embeddings', 'output_layer' - } and hf_weight.shape[0] < self.args.padded_vocab_size: - hf_weight = F.pad(hf_weight, (0, 0, 0, self.args.padded_vocab_size - hf_weight.shape[0])) - hf_scale_inv = None - if f'{hf_key}_scale_inv' in hf_state_dict: - hf_scale_inv = hf_state_dict[f'{hf_key}_scale_inv'].load() - self._set_weight(mg_param, hf_weight, mg_key, offset, is_expert, hf_scale_inv=hf_scale_inv) - else: - if is_modules_to_save: - self._peft_modules_to_save.add(hf_module_key) - weight, scale_inv = self._get_weight(None if mg_param is None else mg_param.data, mg_key, offset, - is_expert) - if weight is not None: - hf_state_dict[hf_key] = weight - if scale_inv is not None: - hf_state_dict[f'{hf_key}_scale_inv'] = scale_inv - - @staticmethod - def _remove_prefix(state_dict, prefix: str): - if not prefix: - return state_dict - return {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} - - @staticmethod - def _add_prefix(state_dict, prefix: str): - if not prefix: - return state_dict - return {f'{prefix}{k}': v for k, v in state_dict.items()} - - @staticmethod - def _filter_prefix(state_dict, prefix: str): - if not prefix: - return state_dict - return {k: v for k, v in state_dict.items() if k.startswith(prefix)} - - @staticmethod - def _is_moe(state_dict): - for k, v in state_dict.items(): - if 'experts.' in k: - return True - return False - - def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): - if to_mcore: - hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) - else: - hf_state_dict = {} - hf_attn = self.hf_layers[layer_idx].self_attn - args = self.args - num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) - hidden_size_block = args.hidden_size // self.fp8_block_size - if to_mcore: - if isinstance(mg_attn.linear_qkv, self.LoraParallelLinear): - lora_A = hf_state_dict['q_proj.lora_A.weight'].load() - assert (lora_A == hf_state_dict['k_proj.lora_A.weight'].load()).all() and ( - lora_A == hf_state_dict['v_proj.lora_A.weight'].load() - ).all(), 'Need to ensure QKV\'s lora_A are consistent' - q_lora_B = hf_state_dict['q_proj.lora_B.weight'].load() - lora_B = torch.cat([ - q_lora_B.reshape((num_query_groups, -1, q_lora_B.shape[-1])), - hf_state_dict['k_proj.lora_B.weight'].load().reshape((num_query_groups, -1, q_lora_B.shape[-1])), - hf_state_dict['v_proj.lora_B.weight'].load().reshape((num_query_groups, -1, q_lora_B.shape[-1])), - ], - dim=1).reshape((-1, q_lora_B.shape[-1])) - self._set_weight(mg_attn.linear_qkv.lora_A[self._adapter_name].weight, lora_A, - 'linear_qkv.lora_A.weight') - self._set_weight(mg_attn.linear_qkv.lora_B[self._adapter_name].weight, lora_B, - 'linear_qkv.lora_B.weight') - elif not self._is_peft_format: - linear_qkv_weight = torch.cat([ - hf_state_dict['q_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)), - hf_state_dict['k_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)), - hf_state_dict['v_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)), - ], - dim=1).reshape((-1, args.hidden_size)) - qkv_scale_inv = None - if 'q_proj.weight_scale_inv' in hf_state_dict: - qkv_scale_inv = torch.cat([ - hf_state_dict['q_proj.weight_scale_inv'].load().reshape( - (num_query_groups, -1, hidden_size_block)), - hf_state_dict['k_proj.weight_scale_inv'].load().reshape( - (num_query_groups, -1, hidden_size_block)), - hf_state_dict['v_proj.weight_scale_inv'].load().reshape( - (num_query_groups, -1, hidden_size_block)), - ], - dim=1).reshape((-1, hidden_size_block)) - self._set_weight( - mg_attn.linear_qkv.weight, linear_qkv_weight, 'linear_qkv.weight', hf_scale_inv=qkv_scale_inv) - else: - q_dim, kv_dim = hf_attn.q_proj.weight.shape[0] // num_query_groups, hf_attn.k_proj.weight.shape[ - 0] // num_query_groups - q_block = q_dim // self.fp8_block_size - kv_block = kv_dim // self.fp8_block_size - is_lora = False if mg_attn is None else isinstance(mg_attn.linear_qkv, - self.LoraParallelLinear) and self._is_peft_format - is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda') - if self.pp_size > 1: - dist.all_reduce(is_lora, group=self.pp_group) - if is_lora: - lora_A, _ = self._get_weight( - None if mg_attn is None else mg_attn.linear_qkv.lora_A[self._adapter_name].weight.data, - f'linear_qkv.lora_A.{self._adapter_name}.weight') - lora_B, _ = self._get_weight( - None if mg_attn is None else mg_attn.linear_qkv.lora_B[self._adapter_name].weight.data, - f'linear_qkv.lora_B.{self._adapter_name}.weight') - if lora_A is not None: - self._peft_target_modules.update({'q_proj', 'k_proj', 'v_proj'}) - for key in ['q_proj', 'k_proj', 'v_proj']: - hf_state_dict[f'{key}.lora_A.weight'] = lora_A.clone() - lora_B = lora_B.reshape((num_query_groups, -1, lora_B.shape[-1])) - hf_state_dict['q_proj.lora_B.weight'] = lora_B[:, :q_dim, :].reshape(-1, lora_B.shape[-1]).clone() - hf_state_dict['k_proj.lora_B.weight'] = lora_B[:, - q_dim:-kv_dim, :].reshape(-1, - lora_B.shape[-1]).clone() - hf_state_dict['v_proj.lora_B.weight'] = lora_B[:, -kv_dim:, :].reshape(-1, lora_B.shape[-1]).clone() - elif not self._is_peft_format: - mg_attn_weight, scale_inv = self._get_weight( - None if mg_attn is None else mg_attn.linear_qkv.weight.data, 'linear_qkv.weight') - if mg_attn_weight is not None: - mg_attn_weight = mg_attn_weight.reshape((num_query_groups, -1, args.hidden_size)) - hf_state_dict['q_proj.weight'] = mg_attn_weight[:, :q_dim, :].reshape(-1, args.hidden_size).clone() - hf_state_dict['k_proj.weight'] = mg_attn_weight[:, - q_dim:-kv_dim, :].reshape(-1, - args.hidden_size).clone() - hf_state_dict['v_proj.weight'] = mg_attn_weight[:, -kv_dim:, :].reshape(-1, - args.hidden_size).clone() - if scale_inv is not None: - scale_inv = scale_inv.reshape((num_query_groups, -1, hidden_size_block)) - hf_state_dict['q_proj.weight_scale_inv'] = scale_inv[:, :q_block, :].reshape( - -1, hidden_size_block).clone() - hf_state_dict['k_proj.weight_scale_inv'] = scale_inv[:, q_block:-kv_block, :].reshape( - -1, hidden_size_block).clone() - hf_state_dict['v_proj.weight_scale_inv'] = scale_inv[:, -kv_block:, :].reshape( - -1, hidden_size_block).clone() - del mg_attn_weight - self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, 'o_proj.weight', to_mcore) - if args.add_bias_linear: - self._set_state_dict(mg_attn, 'linear_proj.bias', hf_state_dict, 'o_proj.bias', to_mcore) - - # Copy bias - if (args.add_bias_linear or args.add_qkv_bias) and not self._is_peft_format: - if to_mcore: - linear_qkv_bias = torch.cat([ - hf_state_dict['q_proj.bias'].load().reshape((num_query_groups, -1)), - hf_state_dict['k_proj.bias'].load().reshape((num_query_groups, -1)), - hf_state_dict['v_proj.bias'].load().reshape((num_query_groups, -1)), - ], - dim=1).reshape(-1) - self._set_weight(mg_attn.linear_qkv.bias, linear_qkv_bias, 'linear_qkv.bias') - else: - mg_attn_bias, _ = self._get_weight(None if mg_attn is None else mg_attn.linear_qkv.bias.data, - 'linear_qkv.bias') - if mg_attn_bias is not None: - mg_attn_bias = mg_attn_bias.reshape((num_query_groups, -1)) - hf_state_dict['q_proj.bias'] = mg_attn_bias[:, :q_dim].reshape(-1).clone() - hf_state_dict['k_proj.bias'] = mg_attn_bias[:, q_dim:-kv_dim].reshape(-1).clone() - hf_state_dict['v_proj.bias'] = mg_attn_bias[:, -kv_dim:].reshape(-1).clone() - if getattr(args, 'softmax_type', 'vanilla') == 'learnable': - self._set_state_dict(mg_attn, 'core_attention.softmax_offset', hf_state_dict, 'sinks', to_mcore) - if args.qk_layernorm: - self._set_qk_layernorm(mg_attn, hf_attn, hf_state_dict, to_mcore) - if to_mcore: - hf_state_dict = {} - else: - hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) - return hf_state_dict - - def _set_qk_layernorm(self, mg_attn, hf_attn, hf_state_dict, to_mcore): - hf_q_norm_key = 'q_norm.weight' if hasattr(hf_attn, 'q_norm') else 'query_layernorm.weight' - hf_k_norm_key = 'k_norm.weight' if hasattr(hf_attn, 'k_norm') else 'key_layernorm.weight' - self._set_state_dict(mg_attn, 'q_layernorm.weight', hf_state_dict, hf_q_norm_key, to_mcore) - self._set_state_dict(mg_attn, 'k_layernorm.weight', hf_state_dict, hf_k_norm_key, to_mcore) - - def get_e_score_correction_bias_key(self, hf_mlp): - if hasattr(hf_mlp, 'moe_statics'): - hf_bias_key = 'moe_statics.e_score_correction_bias' - else: - hf_bias_key = 'gate.e_score_correction_bias' - return hf_bias_key - - def _set_moe_state( - self, - mg_mlp, - hf_state_dict, - hf_prefix: str, - layer_idx: int, - to_mcore: bool, - ): - if to_mcore: - hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) - else: - hf_state_dict = {} - args = self.args - hf_mlp = self._get_hf_mlp(layer_idx) - if hasattr(hf_mlp, 'router'): - hf_gate_key = 'router.weight' - elif hasattr(hf_mlp.gate, 'wg'): - hf_gate_key = 'gate.wg.weight' - else: - hf_gate_key = 'gate.weight' - self._set_state_dict(mg_mlp, 'router.weight', hf_state_dict, hf_gate_key, to_mcore) - if args.add_bias_linear: - self._set_state_dict(mg_mlp, 'router.bias', hf_state_dict, hf_gate_key.replace('weight', 'bias'), to_mcore) - if getattr(args, 'moe_router_enable_expert_bias', False): - hf_bias_key = self.get_e_score_correction_bias_key(hf_mlp) - self._set_state_dict(mg_mlp, 'router.expert_bias', hf_state_dict, hf_bias_key, to_mcore) - - if getattr(args, 'moe_shared_expert_intermediate_size', False): - for key in ['shared_expert', 'shared_experts', 'shared_mlp']: - if hasattr(hf_mlp, key): - hf_shared_expert_prefix = f'{key}.' - shared_expert = getattr(hf_mlp, key) - hf_state_dict.update( - self._set_mlp_state( - None if mg_mlp is None else mg_mlp.shared_experts, - hf_state_dict, - hf_shared_expert_prefix, - layer_idx, - to_mcore, - hf_mlp=shared_expert)) - if hasattr(hf_mlp, 'shared_expert_gate'): - self._set_state_dict(mg_mlp, 'shared_experts.gate_weight', hf_state_dict, 'shared_expert_gate.weight', - to_mcore) - for ep_rank in range(self.ep_size): - mg_experts = None if mg_mlp is None else mg_mlp.experts - expert_available = ep_rank == self.ep_rank - if not expert_available: - if to_mcore: - continue - else: - mg_experts = None - hf_state_dict.update( - self._set_mlp_state(mg_experts, hf_state_dict, 'experts.', layer_idx, to_mcore, ep_rank=ep_rank)) - if to_mcore: - hf_state_dict = {} - else: - hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) - return hf_state_dict - - def _get_hf_grouped(self): - if self.args.hf_model_type in { - 'qwen2_moe', 'qwen3_moe', 'deepseek_v2', 'deepseek_v3', 'dots1', 'ernie4_5_moe', 'glm4_moe', - 'glm4_moe_lite', 'glm4v_moe', 'minimax_m2', 'olmoe', 'qwen3_next', 'kimi_vl', 'qwen3_omni_moe', - 'qwen3_5_moe' - }: - return False, False - return None, None - - def _set_mlp_state(self, - mg_mlp, - hf_state_dict, - hf_prefix: str, - layer_idx: int, - to_mcore: bool, - ep_rank: Optional[int] = None, - hf_mlp=None): - if to_mcore: - hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) - if hf_mlp is None: - hf_mlp = self._get_hf_mlp(layer_idx) - is_expert = ep_rank is not None - num_local_experts = 1 - hf_grouped = False - args = self.args - if is_expert: - hf_mlp = hf_mlp.experts - # When converting to_mcore, hf_grouped is determined by default from the hf_state_dict condition. - # When converting to_hf, it is determined by default from the hf_mlp condition. - if to_mcore: - pattern = r'\d+\.down_proj' - hf_grouped = not any(re.match(pattern, k) is not None for k in hf_state_dict.keys()) - else: - hf_grouped = not hasattr(hf_mlp, '__len__') - if hasattr(hf_mlp, '__len__'): - hf_mlp = hf_mlp[0] - num_local_experts = args.num_experts // self.ep_size - if to_mcore: - is_gate_up = any('gate_up_proj' in k for k in hf_state_dict.keys()) - else: - is_gate_up = hasattr(hf_mlp, 'gate_up_proj') - # transformers 5.0 compatibility - if self.is_transformers_5 and not to_mcore and is_expert: - _hf_grouped, _is_gate_up = self._get_hf_grouped() - if _hf_grouped is not None: - hf_grouped = _hf_grouped - if _is_gate_up is not None: - is_gate_up = _is_gate_up - - need_transpose = True - if self.is_transformers_5 and hf_grouped: - need_transpose = self._get_transpose() - - if hf_grouped and not to_mcore: - hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) - elif not to_mcore: - hf_state_dict = {} - - # linear_fc1 - if to_mcore: - has_scale_inv = any('_scale_inv' in k for k in hf_state_dict.keys()) - if isinstance(mg_mlp.linear_fc1, self.LoraParallelLinear): - mg_lora_B = mg_mlp.linear_fc1.lora_B[self._adapter_name] - mg_lora_B = [getattr(mg_lora_B, f'weight{i}') - for i in range(num_local_experts)] if is_expert else mg_lora_B.weight - if is_gate_up: - if is_expert: - lora_A = torch.stack([ - hf_state_dict[f'{i + ep_rank * num_local_experts}.gate_up_proj.lora_A.weight'].load() - for i in range(num_local_experts) - ]) - lora_B = torch.concat([ - hf_state_dict[f'{i + ep_rank * num_local_experts}.gate_up_proj.lora_B.weight'].load() - for i in range(num_local_experts) - ]) - else: - lora_A = hf_state_dict['gate_up_proj.lora_A.weight'].load() - lora_B = hf_state_dict['gate_up_proj.lora_B.weight'].load() - else: - if is_expert: - lora_A = torch.concat([ - hf_state_dict[f'{i + ep_rank * num_local_experts}.gate_proj.lora_A.weight'].load() - for i in range(num_local_experts) - ]) - up_lora_A = torch.concat([ - hf_state_dict[f'{i + ep_rank * num_local_experts}.up_proj.lora_A.weight'].load() - for i in range(num_local_experts) - ]) - weight_list = [] - for i in range(num_local_experts): - gate_lora_B = hf_state_dict[ - f'{i + ep_rank * num_local_experts}.gate_proj.lora_B.weight'].load() - up_lora_B = hf_state_dict[f'{i + ep_rank * num_local_experts}.up_proj.lora_B.weight'].load() - weight_list.append(torch.stack([gate_lora_B, up_lora_B], dim=0)) - lora_B = torch.concat(weight_list, dim=0) - else: - lora_A = hf_state_dict['gate_proj.lora_A.weight'].load() - up_lora_A = hf_state_dict['up_proj.lora_A.weight'].load() - gate_lora_B = hf_state_dict['gate_proj.lora_B.weight'].load() - up_lora_B = hf_state_dict['up_proj.lora_B.weight'].load() - lora_B = torch.stack([gate_lora_B, up_lora_B], dim=0) - assert ( - lora_A == up_lora_A).all(), 'Need to ensure lora_A consistency between gate_proj and up_proj' - mg_lora_A = mg_mlp.linear_fc1.lora_A[self._adapter_name] - mg_lora_A = [getattr(mg_lora_A, f'weight{i}') - for i in range(num_local_experts)] if is_expert else mg_lora_A.weight - self._set_weight( - mg_lora_A, lora_A, f'linear_fc1.lora_A.{self._adapter_name}.weight', is_expert=is_expert) - self._set_weight( - mg_lora_B, lora_B, f'linear_fc1.lora_B.{self._adapter_name}.weight', is_expert=is_expert) - elif not self._is_peft_format: - fc1_weight = [getattr(mg_mlp.linear_fc1, f'weight{i}') - for i in range(num_local_experts)] if is_expert else mg_mlp.linear_fc1.weight - fc1_bias = None - if args.add_bias_linear: - assert is_expert and not has_scale_inv, 'not support' # TODO - fc1_bias = [getattr(mg_mlp.linear_fc1, f'bias{i}') for i in range(num_local_experts)] - gate_up_scale_inv = None - if is_gate_up: - if is_expert: - if hf_grouped: - if 'gate_up_proj_blocks' in hf_state_dict: - blocks = hf_state_dict['gate_up_proj_blocks'].load() - scales = hf_state_dict['gate_up_proj_scales'].load() - gate_up_proj_weight = self.mxfp4_quantizer.convert(blocks, scales) - else: - gate_up_proj_weight = hf_state_dict['gate_up_proj'].load() - if need_transpose: - gate_up_proj_weight = gate_up_proj_weight.transpose(1, 2) - - gate_up_proj_weight = gate_up_proj_weight[ep_rank * num_local_experts:(ep_rank + 1) - * num_local_experts] - if has_scale_inv: - gate_up_scale_inv = hf_state_dict['gate_up_proj_scale_inv'].load() - if need_transpose: - gate_up_scale_inv = gate_up_scale_inv.transpose(1, 2) - gate_up_scale_inv = gate_up_scale_inv[ep_rank * num_local_experts:(ep_rank + 1) - * num_local_experts] - if fc1_bias is not None: - gate_up_proj_bias = hf_state_dict['gate_up_proj_bias'].load() - gate_up_proj_bias = gate_up_proj_bias[ep_rank * num_local_experts:(ep_rank + 1) - * num_local_experts] - if args.hf_model_type == 'gpt_oss': - gate_proj_weight = gate_up_proj_weight[:, ::2] - up_proj_weight = gate_up_proj_weight[:, 1::2] - gate_proj_bias, up_proj_bias = gate_up_proj_bias[:, ::2], gate_up_proj_bias[:, 1::2] - gate_up_proj_weight = torch.concat([gate_proj_weight, up_proj_weight], dim=1) - gate_up_proj_bias = torch.concat([gate_proj_bias, up_proj_bias], dim=1) - del gate_proj_weight, up_proj_weight, gate_proj_bias, up_proj_bias - else: - gate_up_proj_weight = torch.concat([ - hf_state_dict[f'{i + ep_rank * num_local_experts}.gate_up_proj.weight'].load() - for i in range(num_local_experts) - ], - dim=0) - if has_scale_inv: - gate_up_scale_inv = torch.concat([ - hf_state_dict[f'{i + ep_rank * num_local_experts}.gate_up_proj.weight_scale_inv']. - load() for i in range(num_local_experts) - ], - dim=0) - - gate_up_proj_weight = gate_up_proj_weight.reshape(num_local_experts * 2, -1, - gate_up_proj_weight.shape[-1]) - if has_scale_inv: - gate_up_scale_inv = gate_up_scale_inv.reshape(num_local_experts * 2, -1, - gate_up_scale_inv.shape[-1]) - else: - gate_up_proj_weight = hf_state_dict['gate_up_proj.weight'].load() - gate_up_proj_weight = gate_up_proj_weight.view(2, -1, gate_up_proj_weight.shape[-1]) - if has_scale_inv: - gate_up_scale_inv = hf_state_dict['gate_up_proj.weight_scale_inv'].load() - gate_up_scale_inv = gate_up_scale_inv.view(2, -1, gate_up_scale_inv.shape[-1]) - else: - if is_expert: - weight_list = [] - start_idx = ep_rank * num_local_experts - for i in range(num_local_experts): - gate_proj_weight = hf_state_dict[f'{start_idx + i}.gate_proj.weight'].load() - up_proj_weight = hf_state_dict[f'{start_idx + i}.up_proj.weight'].load() - weight_list.append(torch.stack([gate_proj_weight, up_proj_weight], dim=0)) - gate_up_proj_weight = torch.concat(weight_list, dim=0) - if has_scale_inv: - scale_inv_list = [] - for i in range(num_local_experts): - gate_scale_inv = hf_state_dict[f'{start_idx + i}.gate_proj.weight_scale_inv'].load() - up_scale_inv = hf_state_dict[f'{start_idx + i}.up_proj.weight_scale_inv'].load() - scale_inv_list.append(torch.stack([gate_scale_inv, up_scale_inv], dim=0)) - gate_up_scale_inv = torch.concat(scale_inv_list, dim=0) - del weight_list - else: - gate_proj_weight = hf_state_dict['gate_proj.weight'].load() - up_proj_weight = hf_state_dict['up_proj.weight'].load() - gate_up_proj_weight = torch.stack([gate_proj_weight, up_proj_weight], dim=0) - if has_scale_inv: - gate_scale_inv = hf_state_dict['gate_proj.weight_scale_inv'].load() - up_scale_inv = hf_state_dict['up_proj.weight_scale_inv'].load() - gate_up_scale_inv = torch.stack([gate_scale_inv, up_scale_inv], dim=0) - self._set_weight( - fc1_weight, - gate_up_proj_weight, - 'linear_fc1.weight', - is_expert=is_expert, - hf_scale_inv=gate_up_scale_inv) - if fc1_bias is not None: - self._set_weight( - fc1_bias, gate_up_proj_bias, 'linear_fc1.bias', is_expert=is_expert, hf_scale_inv=None) - else: - is_lora = False if mg_mlp is None else isinstance(mg_mlp.linear_fc1, - self.LoraParallelLinear) and self._is_peft_format - is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda') - if is_expert and self.ep_pp_size > 1: - dist.all_reduce(is_lora, group=self.ep_pp_group) - elif not is_expert and self.pp_size > 1: - dist.all_reduce(is_lora, group=self.pp_group) - if is_lora: - if hf_grouped: - raise ValueError('Since this model\'s transformers and megatron have different expert ' - 'weight organization methods, LoRA weight conversion is not supported. ' - 'You can solve this issue by setting `--merge_lora true`.') - if mg_mlp is None: - lora_A = None - lora_B = None - else: - if is_expert: - lora_A = [ - getattr(mg_mlp.linear_fc1.lora_A[self._adapter_name], f'weight{i}') - for i in range(num_local_experts) - ] - lora_B = [ - getattr(mg_mlp.linear_fc1.lora_B[self._adapter_name], f'weight{i}') - for i in range(num_local_experts) - ] - else: - lora_A = mg_mlp.linear_fc1.lora_A[self._adapter_name].weight - lora_B = mg_mlp.linear_fc1.lora_B[self._adapter_name].weight - lora_A, _ = self._get_weight( - lora_A, f'linear_fc1.lora_A.{self._adapter_name}.weight', is_expert=is_expert) - lora_B, _ = self._get_weight( - lora_B, f'linear_fc1.lora_B.{self._adapter_name}.weight', is_expert=is_expert) - if lora_A is not None: - if is_gate_up: - self._peft_target_modules.update({'gate_up_proj'}) - if is_expert: - for i in range(num_local_experts): - hf_i = i + ep_rank * num_local_experts - hf_state_dict[f'{hf_i}.gate_up_proj.lora_A.weight'] = lora_A[i].clone() - hf_state_dict[f'{hf_i}.gate_up_proj.lora_B.weight'] = lora_B[i].clone() - - else: - hf_state_dict['gate_up_proj.lora_A.weight'] = lora_A.clone() - hf_state_dict['gate_up_proj.lora_B.weight'] = lora_B.view(-1, lora_B.shape[-1]).clone() - else: - self._peft_target_modules.update({'gate_proj', 'up_proj'}) - if is_expert: - lora_B = lora_B.view(num_local_experts, 2, -1, lora_B.shape[-1]) - for i in range(num_local_experts): - hf_i = i + ep_rank * num_local_experts - hf_state_dict[f'{hf_i}.gate_proj.lora_A.weight'] = lora_A[i].clone() - hf_state_dict[f'{hf_i}.up_proj.lora_A.weight'] = lora_A[i].clone() - hf_state_dict[f'{hf_i}.gate_proj.lora_B.weight'] = lora_B[i][0].clone() - hf_state_dict[f'{hf_i}.up_proj.lora_B.weight'] = lora_B[i][1].clone() - else: - lora_B = lora_B.view(2, -1, lora_B.shape[-1]) - hf_state_dict['gate_proj.lora_A.weight'] = lora_A.clone() - hf_state_dict['up_proj.lora_A.weight'] = lora_A.clone() - hf_state_dict['gate_proj.lora_B.weight'] = lora_B[0].clone() - hf_state_dict['up_proj.lora_B.weight'] = lora_B[1].clone() - elif not self._is_peft_format: - fc1_bias = None - if mg_mlp is None: - fc1_weight = None - else: - if is_expert: - linear_fc1 = mg_mlp.linear_fc1 - if isinstance(linear_fc1, self.LoraParallelLinear): - linear_fc1 = linear_fc1.base_layer - fc1_weight = [getattr(linear_fc1, f'weight{i}') for i in range(num_local_experts)] - if args.add_bias_linear: - fc1_bias = [getattr(linear_fc1, f'bias{i}') for i in range(num_local_experts)] - else: - fc1_weight = mg_mlp.linear_fc1.weight - gate_up_proj_weight, scale_inv = self._get_weight(fc1_weight, 'linear_fc1.weight', is_expert=is_expert) - gate_up_proj_bias = None - if args.add_bias_linear: - gate_up_proj_bias, _ = self._get_weight(fc1_bias, 'linear_fc1.bias', is_expert=is_expert) - del fc1_weight - if gate_up_proj_weight is not None: - if is_gate_up: - if is_expert: - if hf_grouped: - if need_transpose: - gate_up_proj_weight = gate_up_proj_weight.transpose(1, 2) - if 'gate_up_proj' in hf_state_dict: - gate_up_proj_weight = torch.concat( - [hf_state_dict['gate_up_proj'], gate_up_proj_weight], dim=0) - is_last_ckpt = gate_up_proj_weight.shape[0] == args.num_experts - if args.hf_model_type == 'gpt_oss' and is_last_ckpt: - gate_proj_weight, up_proj_weight = gate_up_proj_weight.chunk(2, dim=2) - new_gate_up_proj_weight = torch.empty_like(gate_up_proj_weight) - new_gate_up_proj_weight[..., ::2] = gate_proj_weight - new_gate_up_proj_weight[..., 1::2] = up_proj_weight - gate_up_proj_weight = new_gate_up_proj_weight - del new_gate_up_proj_weight, gate_proj_weight, up_proj_weight - hf_state_dict['gate_up_proj'] = gate_up_proj_weight.clone() - if scale_inv is not None: - if need_transpose: - scale_inv = scale_inv.transpose(1, 2) - if 'gate_up_proj_scale_inv' in hf_state_dict: - scale_inv = torch.concat([hf_state_dict['gate_up_proj_scale_inv'], scale_inv], - dim=0) - hf_state_dict['gate_up_proj_scale_inv'] = scale_inv.clone() - - if gate_up_proj_bias is not None: - if 'gate_up_proj_bias' in hf_state_dict: - gate_up_proj_bias = torch.concat( - [hf_state_dict['gate_up_proj_bias'], gate_up_proj_bias], dim=0) - if args.hf_model_type == 'gpt_oss' and is_last_ckpt: - gate_proj_bias, up_proj_bias = gate_up_proj_bias.chunk(2, dim=1) - new_gate_up_proj_bias = torch.empty_like(gate_up_proj_bias) - new_gate_up_proj_bias[:, ::2] = gate_proj_bias - new_gate_up_proj_bias[:, 1::2] = up_proj_bias - gate_up_proj_bias = new_gate_up_proj_bias - del new_gate_up_proj_bias, gate_proj_bias, up_proj_bias - hf_state_dict['gate_up_proj_bias'] = gate_up_proj_bias.clone() - else: - for i in range(num_local_experts): - hf_i = i + ep_rank * num_local_experts - hf_state_dict[f'{hf_i}.gate_up_proj.weight'] = gate_up_proj_weight[i].clone() - if scale_inv is not None: - hf_state_dict[f'{hf_i}.gate_up_proj.weight_scale_inv'] = scale_inv[i].clone() - del gate_up_proj_weight - else: - gate_up_proj_weight = gate_up_proj_weight.view(-1, gate_up_proj_weight.shape[-1]) - hf_state_dict['gate_up_proj.weight'] = gate_up_proj_weight.clone() - if scale_inv is not None: - scale_inv = scale_inv.view(-1, scale_inv.shape[-1]) - hf_state_dict['gate_up_proj.weight_scale_inv'] = scale_inv.clone() - else: - if is_expert: - gate_up_proj_weight = gate_up_proj_weight.view(num_local_experts, 2, -1, - gate_up_proj_weight.shape[-1]) - if scale_inv is not None: - scale_inv = scale_inv.view(num_local_experts, 2, -1, scale_inv.shape[-1]) - for i in range(num_local_experts): - hf_i = i + ep_rank * num_local_experts - hf_state_dict[f'{hf_i}.gate_proj.weight'] = gate_up_proj_weight[i][0].clone() - hf_state_dict[f'{hf_i}.up_proj.weight'] = gate_up_proj_weight[i][1].clone() - if scale_inv is not None: - hf_state_dict[f'{hf_i}.gate_proj.weight_scale_inv'] = scale_inv[i][0].clone() - hf_state_dict[f'{hf_i}.up_proj.weight_scale_inv'] = scale_inv[i][1].clone() - del gate_up_proj_weight - else: - gate_up_proj_weight = gate_up_proj_weight.view(2, -1, gate_up_proj_weight.shape[-1]) - hf_state_dict['gate_proj.weight'] = gate_up_proj_weight[0].clone() - hf_state_dict['up_proj.weight'] = gate_up_proj_weight[1].clone() - if scale_inv is not None: - scale_inv = scale_inv.view(2, -1, scale_inv.shape[-1]) - hf_state_dict['gate_proj.weight_scale_inv'] = scale_inv[0].clone() - hf_state_dict['up_proj.weight_scale_inv'] = scale_inv[1].clone() - - # linear_fc2 - if is_expert: - if to_mcore: - if isinstance(mg_mlp.linear_fc2, self.LoraParallelLinear): - mg_lora_A = mg_mlp.linear_fc2.lora_A[self._adapter_name] - mg_lora_A = [getattr(mg_lora_A, f'weight{i}') - for i in range(num_local_experts)] if is_expert else mg_lora_A.weight - mg_lora_B = mg_mlp.linear_fc2.lora_B[self._adapter_name] - mg_lora_B = [getattr(mg_lora_B, f'weight{i}') - for i in range(num_local_experts)] if is_expert else mg_lora_B.weight - lora_A = torch.concat([ - hf_state_dict[f'{i + ep_rank * num_local_experts}.down_proj.lora_A.weight'].load() - for i in range(num_local_experts) - ], - dim=0) - lora_B = torch.concat([ - hf_state_dict[f'{i + ep_rank * num_local_experts}.down_proj.lora_B.weight'].load() - for i in range(num_local_experts) - ], - dim=0) - self._set_weight( - mg_lora_A, lora_A, f'linear_fc2.lora_A.{self._adapter_name}.weight', is_expert=is_expert) - self._set_weight( - mg_lora_B, lora_B, f'linear_fc2.lora_B.{self._adapter_name}.weight', is_expert=is_expert) - elif not self._is_peft_format: - fc2_weight = [getattr(mg_mlp.linear_fc2, f'weight{i}') - for i in range(num_local_experts)] if is_expert else mg_mlp.linear_fc2.weight - fc2_bias = None - if args.add_bias_linear: - fc2_bias = [getattr(mg_mlp.linear_fc2, f'bias{i}') for i in range(num_local_experts)] - down_scale_inv = None - if hf_grouped: - if 'down_proj_blocks' in hf_state_dict: - blocks = hf_state_dict['down_proj_blocks'].load() - scales = hf_state_dict['down_proj_scales'].load() - down_proj_weight = self.mxfp4_quantizer.convert(blocks, scales) - else: - down_proj_weight = hf_state_dict['down_proj'].load() - if need_transpose: - down_proj_weight = down_proj_weight.transpose(1, 2) - down_proj_weight = down_proj_weight[ep_rank * num_local_experts:(ep_rank + 1) - * num_local_experts].reshape( - -1, down_proj_weight.shape[-1]) - if has_scale_inv: - down_scale_inv = hf_state_dict['down_proj_scale_inv'].load() - if need_transpose: - down_scale_inv = down_scale_inv.transpose(1, 2) - down_scale_inv = down_scale_inv[ep_rank * num_local_experts:(ep_rank + 1) - * num_local_experts].reshape(-1, down_scale_inv.shape[-1]) - if fc2_bias is not None: - down_proj_bias = hf_state_dict['down_proj_bias'].load() - down_proj_bias = down_proj_bias[ep_rank * num_local_experts:(ep_rank + 1) - * num_local_experts] - else: - down_proj_weight = torch.concat([ - hf_state_dict[f'{i + ep_rank * num_local_experts}.down_proj.weight'].load() - for i in range(num_local_experts) - ], - dim=0) - if has_scale_inv: - down_scale_inv = torch.concat([ - hf_state_dict[f'{i + ep_rank * num_local_experts}.down_proj.weight_scale_inv'].load() - for i in range(num_local_experts) - ], - dim=0) - self._set_weight( - fc2_weight, - down_proj_weight, - 'linear_fc2.weight', - is_expert=is_expert, - hf_scale_inv=down_scale_inv) - if fc2_bias is not None: - self._set_weight( - fc2_bias, down_proj_bias, 'linear_fc2.bias', is_expert=is_expert, hf_scale_inv=None) - else: - is_lora = False if mg_mlp is None else isinstance(mg_mlp.linear_fc2, - self.LoraParallelLinear) and self._is_peft_format - is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda') - if is_expert and self.ep_pp_size > 1: - dist.all_reduce(is_lora, group=self.ep_pp_group) - elif not is_expert and self.pp_size > 1: - dist.all_reduce(is_lora, group=self.pp_group) - if is_lora: - if hf_grouped: - raise ValueError('Since this model\'s transformers and megatron have different expert ' - 'weight organization methods, LoRA weight conversion is not supported. ' - 'You can solve this issue by setting `--merge_lora true`.') - if mg_mlp is None: - lora_A = None - lora_B = None - else: - lora_A = [ - getattr(mg_mlp.linear_fc2.lora_A[self._adapter_name], f'weight{i}') - for i in range(num_local_experts) - ] - lora_B = [ - getattr(mg_mlp.linear_fc2.lora_B[self._adapter_name], f'weight{i}') - for i in range(num_local_experts) - ] - lora_A, _ = self._get_weight( - lora_A, f'linear_fc2.lora_A.{self._adapter_name}.weight', is_expert=is_expert) - lora_B, _ = self._get_weight( - lora_B, f'linear_fc2.lora_B.{self._adapter_name}.weight', is_expert=is_expert) - if lora_A is not None: - self._peft_target_modules.update({'down_proj'}) - for i in range(num_local_experts): - hf_i = i + ep_rank * num_local_experts - hf_state_dict[f'{hf_i}.down_proj.lora_A.weight'] = lora_A[i].clone() - hf_state_dict[f'{hf_i}.down_proj.lora_B.weight'] = lora_B[i].clone() - elif not self._is_peft_format: - fc2_bias = None - if mg_mlp is None: - fc2_weight = None - else: - linear_fc2 = mg_mlp.linear_fc2 - if isinstance(linear_fc2, self.LoraParallelLinear): - linear_fc2 = linear_fc2.base_layer - fc2_weight = [getattr(linear_fc2, f'weight{i}') for i in range(num_local_experts)] - if args.add_bias_linear: - fc2_bias = [getattr(linear_fc2, f'bias{i}') for i in range(num_local_experts)] - down_proj_weight, scale_inv = self._get_weight(fc2_weight, 'linear_fc2.weight', is_expert=is_expert) - if args.add_bias_linear: - down_proj_bias, _ = self._get_weight(fc2_bias, 'linear_fc2.bias', is_expert=is_expert) - del fc2_weight, fc2_bias - if down_proj_weight is not None: - if hf_grouped: - if need_transpose: - down_proj_weight = down_proj_weight.transpose(1, 2) - if 'down_proj' in hf_state_dict: - down_proj_weight = torch.concat([hf_state_dict['down_proj'], down_proj_weight], dim=0) - hf_state_dict['down_proj'] = down_proj_weight.clone() - if scale_inv is not None: - if need_transpose: - scale_inv = scale_inv.transpose(1, 2) - if 'down_proj_scale_inv' in hf_state_dict: - scale_inv = torch.concat([hf_state_dict['down_proj_scale_inv'], scale_inv], dim=0) - hf_state_dict['down_proj_scale_inv'] = scale_inv.clone() - if args.add_bias_linear: - if 'down_proj_bias' in hf_state_dict: - down_proj_bias = torch.concat([hf_state_dict['down_proj_bias'], down_proj_bias], - dim=0) - hf_state_dict['down_proj_bias'] = down_proj_bias.clone() - else: - for i in range(num_local_experts): - hf_i = i + ep_rank * num_local_experts - hf_state_dict[f'{hf_i}.down_proj.weight'] = down_proj_weight[i].clone() - if scale_inv is not None: - hf_state_dict[f'{hf_i}.down_proj.weight_scale_inv'] = scale_inv[i].clone() - else: - self._set_state_dict( - mg_mlp, 'linear_fc2.weight', hf_state_dict, 'down_proj.weight', to_mcore, is_expert=is_expert) - if to_mcore: - hf_state_dict = {} - else: - hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) - return hf_state_dict - - def _set_mla_attn_state( - self, - mg_attn, - hf_state_dict, - hf_prefix: str, - layer_idx: int, - to_mcore: bool, - ): - if to_mcore: - hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) - else: - hf_state_dict = {} - self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, 'o_proj.weight', to_mcore) - if self.args.q_lora_rank is None: - self._set_state_dict(mg_attn, 'linear_q_proj.weight', hf_state_dict, 'q_proj.weight', to_mcore) - else: - self._set_state_dict(mg_attn, 'linear_q_down_proj.weight', hf_state_dict, 'q_a_proj.weight', to_mcore) - self._set_state_dict(mg_attn, 'linear_q_up_proj.weight', hf_state_dict, 'q_b_proj.weight', to_mcore) - self._set_state_dict(mg_attn, 'linear_kv_down_proj.weight', hf_state_dict, 'kv_a_proj_with_mqa.weight', - to_mcore) - self._set_state_dict(mg_attn, 'linear_kv_up_proj.weight', hf_state_dict, 'kv_b_proj.weight', to_mcore) - if self.args.qk_layernorm: - if self.args.q_lora_rank is not None: - self._set_state_dict(mg_attn, 'linear_q_up_proj.layer_norm_weight', hf_state_dict, - 'q_a_layernorm.weight', to_mcore) - self._set_state_dict(mg_attn, 'linear_kv_up_proj.layer_norm_weight', hf_state_dict, 'kv_a_layernorm.weight', - to_mcore) - if to_mcore: - hf_state_dict = {} - else: - hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) - return hf_state_dict - - def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool): - mg_attn = None if mg_layer is None else mg_layer.self_attention - if self.args.multi_latent_attention: - hf_state_dict.update(self._set_mla_attn_state(mg_attn, hf_state_dict, 'self_attn.', layer_idx, to_mcore)) - self._set_state_dict(mg_layer, 'input_layernorm.weight', hf_state_dict, 'input_layernorm.weight', to_mcore) - else: - hf_state_dict.update(self._set_attn_state(mg_attn, hf_state_dict, 'self_attn.', layer_idx, to_mcore)) - self._set_state_dict(mg_layer, 'self_attention.linear_qkv.layer_norm_weight', hf_state_dict, - 'input_layernorm.weight', to_mcore) - return hf_state_dict - - def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool): - hf_mlp_prefix = self.get_hf_mlp_prefix(layer_idx) - hf_mlp = self._get_hf_mlp(layer_idx) - is_moe = self._is_moe(hf_mlp.state_dict()) - mg_mlp = None if mg_layer is None else mg_layer.mlp - if is_moe: - hf_state_dict.update(self._set_moe_state(mg_mlp, hf_state_dict, f'{hf_mlp_prefix}.', layer_idx, to_mcore)) - self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, 'post_attention_layernorm.weight', - to_mcore) - else: - hf_state_dict.update(self._set_mlp_state(mg_mlp, hf_state_dict, f'{hf_mlp_prefix}.', layer_idx, to_mcore)) - if self.args.hf_model_type == 'qwen3_5': - self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, - 'post_attention_layernorm.weight', to_mcore) - else: - self._set_state_dict(mg_layer, 'mlp.linear_fc1.layer_norm_weight', hf_state_dict, - 'post_attention_layernorm.weight', to_mcore) - return hf_state_dict - - def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): - hf_prefix = f'{hf_prefix}{layer_idx}.' - if to_mcore: - hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) - else: - hf_state_dict = {} - hf_state_dict.update(self._set_layer_attn(mg_layer, hf_state_dict, layer_idx, to_mcore)) - hf_state_dict.update(self._set_layer_mlp(mg_layer, hf_state_dict, layer_idx, to_mcore)) - if to_mcore: - hf_state_dict = {} - else: - hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) - return hf_state_dict - - def _convert_pre_process(self, mg_model, hf_state_dict, hf_prefix: str, to_mcore): - if to_mcore: - hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) - else: - hf_state_dict = {} - lm_model = getattr(mg_model, 'language_model') if self.args.is_multimodal else mg_model - self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, self.hf_embed_key, to_mcore) - if self.args.is_multimodal: - for prefix, mg_prefix in self.module_mapping.items(): - mg_module = deep_getattr(mg_model, f'visual.{mg_prefix}') - hf_state_dict.update(self._set_module(mg_module, hf_state_dict, f'{hf_prefix}{prefix}.', to_mcore)) - if to_mcore: - hf_state_dict = {} - else: - hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) - return hf_state_dict - - def _convert_post_process(self, mg_model, hf_state_dict, hf_prefix: str, to_mcore): - if to_mcore: - hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) - else: - hf_state_dict = {} - lm_model = getattr(mg_model, 'language_model') if self.args.is_multimodal else mg_model - if self.args.untie_embeddings_and_output_weights: - if not to_mcore or self.args.task_type == 'causal_lm': - hf_lm_head_key = self.hf_lm_head_key - if not to_mcore and self.args.task_type == 'seq_cls': - hf_lm_head_key = self.hf_score_key - self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, hf_lm_head_key, to_mcore) - elif to_mcore and lm_model.output_layer.weight is not None: - self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, self.hf_embed_key, to_mcore) - self._set_state_dict(lm_model, 'decoder.final_layernorm.weight', hf_state_dict, self.hf_final_layernorm_key, - to_mcore) - if to_mcore: - hf_state_dict = {} - else: - hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) - return hf_state_dict - - def _convert_hf_state_dict(self, hf_state_dict, to_mcore): - res = {} - for k, v in hf_state_dict.items(): - for old_key, new_key in self.hf_state_dict_mapping.items(): - if not to_mcore: - old_key, new_key = new_key, old_key - if k.startswith(old_key): - k = k.replace(old_key, new_key) - break - res[k] = v - return res - - def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqdm_desc: str = 'Converting: '): - if to_mcore: - hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) - hf_state_dict = self._convert_hf_state_dict(hf_state_dict, to_mcore) - else: - hf_state_dict = {} - mg_models = iter(mg_models) - mg_model = next(mg_models) - if self.mcore_013: - is_pp_first_stage = self.mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage) - is_pp_last_stage = self.mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage) - else: - is_pp_first_stage = self.mpu.is_pipeline_first_stage() - is_pp_last_stage = self.mpu.is_pipeline_last_stage() - if not to_mcore or is_pp_first_stage: - hf_state_dict.update(self._convert_pre_process(mg_model, hf_state_dict, '', to_mcore)) - if to_mcore: - yield - else: - yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) - hf_state_dict = {} - layer_idx = 0 - prog_bar = tqdm(range(self.args.num_layers), dynamic_ncols=True, desc=tqdm_desc, disable=self.disable_tqmd) - while layer_idx < self.args.num_layers: - lm_model = getattr(mg_model, 'language_model') if self.args.is_multimodal else mg_model - if len(lm_model.decoder.layers) > 0: - start_idx = lm_model.decoder.layers[0].layer_number - 1 - mg_layer_available = (start_idx <= layer_idx < lm_model.decoder.layers[-1].layer_number) - else: - mg_layer_available = False - if mg_layer_available: - mg_layer = lm_model.decoder.layers[layer_idx - start_idx] - else: - if to_mcore: - layer_idx += 1 - prog_bar.update() - continue - else: - mg_layer = None - if not to_mcore and self.pp_size > 1: - has_model = torch.tensor([mg_layer is not None], dtype=torch.bool, device='cuda') - dist.all_reduce(has_model, group=self.pp_group) - if not has_model: - mg_model = next(mg_models) # compat vpp - continue - res = self._set_layer_state(mg_layer, hf_state_dict, f'{self.hf_layers_prefix}.', layer_idx, to_mcore) - layer_idx += 1 - prog_bar.update() - if to_mcore: - yield - else: - yield from list(self._add_prefix(res, hf_prefix).items()) - hf_state_dict = {} - - if (not to_mcore or is_pp_last_stage) and self.args.mtp_num_layers: - lm_model = getattr(mg_model, 'language_model') if self.args.is_multimodal else mg_model - if to_mcore and self.pp_rank > 0: - self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, self.hf_embed_key, - to_mcore) - layer_idx = 0 - while layer_idx < self.args.mtp_num_layers: - res = self._convert_mtp_layer(lm_model, hf_state_dict, f'{self.hf_mtp_prefix}.', layer_idx, to_mcore) - layer_idx += 1 - if to_mcore: - yield - else: - yield from list(self._add_prefix(res, hf_prefix).items()) - hf_state_dict = {} - if not to_mcore or is_pp_last_stage: - hf_state_dict.update(self._convert_post_process(mg_model, hf_state_dict, '', to_mcore)) - if to_mcore: - yield - else: - hf_state_dict = self._convert_hf_state_dict(hf_state_dict, to_mcore) - yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) - - def _convert_mtp_extra(self, mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict): - for key in ['enorm.weight', 'hnorm.weight', 'eh_proj.weight']: - self._set_state_dict(mtp_layer, key, hf_state_dict, key, to_mcore) - self._set_state_dict(mtp_layer, 'final_layernorm.weight', hf_state_dict, 'shared_head.norm.weight', to_mcore) - - def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): - mtp_layer = lm_model.mtp.layers[layer_idx] if hasattr(lm_model, 'mtp') else None - if self.hf_mtp_prefix == self.hf_layers_prefix: - hf_layer_idx = layer_idx + self.args.num_layers - else: - hf_layer_idx = layer_idx - hf_prefix = f'{hf_prefix}{hf_layer_idx}.' - if to_mcore: - origin_hf_state_dict = hf_state_dict - hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) - if len(hf_state_dict) == 0: - logger.info_if( - f'MTP Layer {mtp_layer.layer_number} safetensors weights not found, ' - 'this part will be randomly initialized.', - cond=is_last_rank()) - for param in mtp_layer.parameters(): - if param.ndim == 2: - mtp_layer.config.init_method(param.data) - return {} - else: - origin_hf_state_dict = {} - hf_state_dict = {} - self._convert_mtp_extra(mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict) - transformer_layer = None if mtp_layer is None else mtp_layer.transformer_layer - if not to_mcore and not self.args.hf_model_type.startswith(('qwen3_next', 'qwen3_5')): - self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, 'embed_tokens.weight', - to_mcore) - if self.args.untie_embeddings_and_output_weights: - self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, 'shared_head.head.weight', - to_mcore) - hf_state_dict.update(self._set_layer_attn(transformer_layer, hf_state_dict, -1, to_mcore)) - hf_state_dict.update(self._set_layer_mlp(transformer_layer, hf_state_dict, -1, to_mcore)) - if to_mcore: - hf_state_dict = {} - else: - hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) - hf_state_dict.update(origin_hf_state_dict) - return hf_state_dict - - def load_weights(self, - mg_model, - hf_model_dir: str, - is_peft_format: bool = False, - adapter_name: str = 'default', - lora_converter=None): - self._is_peft_format = is_peft_format - self._adapter_name = adapter_name - hf_model_dir = HubOperation.download_model(hf_model_dir) - with torch.no_grad(), SafetensorLazyLoader(hf_model_dir, is_peft_format=is_peft_format) as loader: - state_dict = loader.get_state_dict() - _state_dict = {} - for key, value in state_dict.items(): - if lora_converter is not None: - key, value = lora_converter(key, value) - _state_dict[key] = value - state_dict = _state_dict - hf_prefix = 'base_model.model.' if is_peft_format else '' - list(self._convert([mg_model], state_dict, hf_prefix, True, 'Loading: ')) - - def export_weights(self, - mg_models, - target_device=None, - only_last_rank: bool = False, - is_peft_format: bool = False, - adapter_name: str = 'default', - tqdm_desc: str = 'Exporting: '): - self._target_device = target_device - self._only_last_rank = only_last_rank - self._is_peft_format = is_peft_format - self._adapter_name = adapter_name - self._peft_target_modules = set() - self._peft_modules_to_save = set() - hf_prefix = 'base_model.model.' if is_peft_format else '' - with torch.no_grad(): - yield from self._convert(mg_models, {}, hf_prefix, False, tqdm_desc=tqdm_desc) - - def save_weights(self, - mg_models, - output_dir: str, - is_peft_format: bool = False, - adapter_name: str = 'default', - lora_converter: Callable = None) -> None: - """Save the mg_model checkpoint in HF format""" - torch.cuda.empty_cache() - saver = StreamingSafetensorSaver( - save_dir=output_dir, max_shard_size=self.args.max_shard_size, is_peft_format=is_peft_format) - for k, v in self.export_weights( - mg_models, - target_device='cpu', - only_last_rank=True, - is_peft_format=is_peft_format, - adapter_name=adapter_name, - tqdm_desc='Saving: '): - if lora_converter is not None: - k, v = lora_converter(k, v, adapter_name) - if k is not None and v is not None: - saver.add_tensor(k, v) - saver.finalize() - args = self.args - if is_last_rank(): - if is_peft_format: - peft_config = copy(mg_models[0].peft_config[self._adapter_name]) - if args.task_type == 'seq_cls': - peft_config.task_type = 'SEQ_CLS' - peft_config.target_modules = self._peft_target_modules - peft_config.modules_to_save = self._peft_modules_to_save - peft_config.save_pretrained(output_dir) - else: - if args.mtp_num_layers: - self.hf_model.config.num_nextn_predict_layers = args.mtp_num_layers - self.hf_model.config.vocab_size = args.padded_vocab_size - if args.fp8 is not None and args.fp8_recipe == 'blockwise' and args.fp8_param_gather: - if getattr(self.hf_model.config, 'quantization_config', None) is None: - from transformers.utils.quantization_config import FineGrainedFP8Config - modules_to_not_convert = get_modules_to_not_convert(self.hf_model) - self.hf_model.config.quantization_config = FineGrainedFP8Config( - modules_to_not_convert=modules_to_not_convert) - elif hasattr(self.hf_model.config, 'quantization_config'): - del self.hf_model.config.quantization_config - self.hf_model.config.save_pretrained(output_dir) - if getattr(self.hf_model, '_auto_class') is not None: - try: - custom_object_save(self.hf_model, output_dir, config=self.hf_model.config) - except FileNotFoundError as e: - logger.error(f'custom_object_save Error: {e}') - GPTBridge.save_checkpoint( - None, - self.processor, - output_dir, - model_dirs=[args.model_dir], - ) - logger.info_if(f'Successfully saved `safetensors` model weights in `{output_dir}`.', cond=is_last_rank()) - dist.barrier() # Ensure all weights are saved completely - - @staticmethod - def save_checkpoint(model: Optional[PreTrainedModel], - processor, - output_dir: str, - *, - safe_serialization: bool = True, - max_shard_size: Union[int, str] = '5GB', - model_dirs: List[str] = None, - additional_saved_files: Optional[List[str]] = None) -> None: - if model is not None: - if model.__class__.__name__ != 'SentenceTransformer': - model.save_pretrained(output_dir, safe_serialization=safe_serialization, max_shard_size=max_shard_size) - else: - model.save_pretrained(output_dir, safe_serialization=safe_serialization) - # copy sentencetransformers files - from twinkle.utils import copy_files_by_pattern - copy_files_by_pattern(model.model_dir, output_dir, '*.py') - copy_files_by_pattern(model.model_dir, output_dir, '*.json') - processor.save_pretrained(output_dir) - - if model_dirs is None: - model_dirs = [] - else: - model_dirs = model_dirs.copy() - if model and model.model_dir and model.model_dir not in model_dirs: - model_dirs.append(model.model_dir) - for src_file in (additional_saved_files or []) + ['preprocessor_config.json', 'args.json']: - tgt_path = os.path.join(output_dir, src_file) - if os.path.exists(tgt_path) and src_file == 'args.json': - continue - for model_dir in model_dirs: - src_path: str = os.path.join(model_dir, src_file) - if os.path.isfile(src_path): - shutil.copy(src_path, tgt_path) - break - elif os.path.isdir(src_path): - shutil.copytree(src_path, tgt_path) - break - - -class MultimodalGPTBridge(GPTBridge): - hf_layers_prefix = 'model.language_model.layers' - hf_embed_key = 'model.language_model.embed_tokens.weight' - hf_final_layernorm_key = 'model.language_model.norm.weight' diff --git a/src/twinkle/model/megatron/model/gpt_model.py b/src/twinkle/model/megatron/model/gpt_model.py deleted file mode 100644 index 85e3f251..00000000 --- a/src/twinkle/model/megatron/model/gpt_model.py +++ /dev/null @@ -1,465 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -import megatron.core -import torch -from collections import OrderedDict -from copy import deepcopy -from megatron.core import mpu -from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk -from megatron.core.dist_checkpointing.mapping import ShardedStateDict -from megatron.core.extensions.transformer_engine import TELinear -from megatron.core.inference.contexts import BaseInferenceContext -from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding -from megatron.core.models.gpt import GPTModel as McoreGPTModel -from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region -from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler, MTPLossLoggingHelper, roll_tensor -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import WrappedTensor, deprecate_inference_params -from packaging import version -from typing import Any, Dict, Literal, Optional, Tuple - -from twinkle import get_logger -from twinkle.model.megatron.args import get_args -from twinkle.model.megatron.utils import split_cp_inputs -from .rope import dynamic_rope_update, get_rope_inv_freq - -logger = get_logger() - -mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') - - -class OutputLayerLinear(TELinear): - - def forward(self, hidden_states, *args, **kwargs): - args = get_args() - if args.sequence_parallel and args.tensor_model_parallel_size > 1: - hidden_states = gather_from_sequence_parallel_region(hidden_states) - return super().forward(hidden_states) - - def sharded_state_dict( - self, - prefix: str = '', - sharded_offsets: Tuple[Tuple[int, int, int]] = (), - metadata: Optional[dict] = None, - ) -> ShardedStateDict: - res = super().sharded_state_dict(prefix, sharded_offsets, metadata) - for k, v in res.items(): - if k.endswith('._extra_state'): - if v.data is not None and v.data.numel() == 0: - v.data = None - return res - - -class GPTModel(McoreGPTModel): - - def __init__( - self, - config: TransformerConfig, - transformer_layer_spec: ModuleSpec, - vocab_size: int, - max_sequence_length: int, - pre_process: bool = True, - post_process: bool = True, - fp16_lm_cross_entropy: bool = False, - parallel_output: bool = True, - share_embeddings_and_output_weights: bool = False, - position_embedding_type: Literal['learned_absolute', 'rope', 'mrope', 'none'] = 'learned_absolute', - rotary_percent: float = 1.0, - rotary_base: int = 10000, - hf_rope_scaling: Dict[str, Any] = None, - rope_scaling: bool = False, - rope_scaling_factor: float = 8.0, - scatter_embedding_sequence_parallel: bool = True, - seq_len_interpolation_factor: Optional[float] = None, - mtp_block_spec: Optional[ModuleSpec] = None, - vp_stage: Optional[int] = None, - ): - if config.multi_latent_attention and config.rope_type == 'yarn': - config.rope_type = 'rope' # use transformers implementation - if hf_rope_scaling and hf_rope_scaling['rope_type'] == 'yarn': - # softmax_scale - config.mscale = hf_rope_scaling['mscale'] - config.mscale_all_dim = hf_rope_scaling['mscale_all_dim'] - config.rotary_scaling_factor = hf_rope_scaling['factor'] - self.hf_rope_scaling = hf_rope_scaling - if mcore_013: - kwargs = {'vp_stage': vp_stage} - else: - self.vp_stage = vp_stage - assert vp_stage is None, 'megatron-core==0.12 does not support vp_stage' - kwargs = {} - super().__init__( - config, - transformer_layer_spec, - vocab_size, - max_sequence_length, - pre_process=pre_process, - post_process=post_process, - fp16_lm_cross_entropy=fp16_lm_cross_entropy, - parallel_output=parallel_output, - share_embeddings_and_output_weights=share_embeddings_and_output_weights, - position_embedding_type=position_embedding_type, - rotary_percent=rotary_percent, - rotary_base=rotary_base, - rope_scaling=rope_scaling, - rope_scaling_factor=rope_scaling_factor, - scatter_embedding_sequence_parallel=scatter_embedding_sequence_parallel, - seq_len_interpolation_factor=seq_len_interpolation_factor, - mtp_block_spec=mtp_block_spec, - **kwargs, - ) - if config.multi_latent_attention: - self.rotary_pos_emb = RotaryEmbedding( - kv_channels=config.qk_pos_emb_head_dim, - rotary_percent=rotary_percent, - rotary_interleaved=config.rotary_interleaved, - seq_len_interpolation_factor=seq_len_interpolation_factor, - rotary_base=rotary_base, - rope_scaling=rope_scaling, - rope_scaling_factor=rope_scaling_factor, - use_cpu_initialization=config.use_cpu_initialization, - ) - # save memory - for i in range(len(self.decoder.layers)): - if hasattr(self.decoder.layers[i].self_attention, 'rotary_pos_emb'): - del self.decoder.layers[i].self_attention.rotary_pos_emb - self.attention_scaling = 1. - new_inv_freq, self.attention_scaling = get_rope_inv_freq() - self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device) - # remove seq_cls here - - if (self.attention_scaling != 1 or position_embedding_type == 'mrope') and config.apply_rope_fusion: - config.apply_rope_fusion = False - if self.attention_scaling != 1: - warning_string = 'attention_scaling' - else: - warning_string = 'mrope' - logger.warning(f'`apply_rope_fusion` does not support `{warning_string}`. ' - f'Setting `config.apply_rope_fusion`: {config.apply_rope_fusion}') - if self.attention_scaling != 1: - self._patch_apply_rotary_pos_emb() - if getattr(self, 'mtp', None) is not None: - for layer in self.mtp.layers: - attention = layer.transformer_layer.self_attention - attention.config = deepcopy(attention.config) - attention.config.apply_rope_fusion = False - - def _patch_apply_rotary_pos_emb(self): - from megatron.core.transformer import attention - origin_apply_rotary_pos_emb = attention.apply_rotary_pos_emb - - def apply_rotary_pos_emb(*args, **kwargs): - kwargs['mscale'] = self.attention_scaling - return origin_apply_rotary_pos_emb(*args, **kwargs) - - attention.apply_rotary_pos_emb = apply_rotary_pos_emb - attention.origin_apply_rotary_pos_emb = origin_apply_rotary_pos_emb - - def _preprocess( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - decoder_input: torch.Tensor = None, - inference_context: BaseInferenceContext = None, - packed_seq_params: PackedSeqParams = None, - ): - """Preprocesses inputs for the transformer decoder. - - Applies embeddings to input tokens, or uses `decoder_input` from a previous - pipeline stage. Also sets up rotary positional embeddings. - """ - # If decoder_input is provided (not None), then input_ids and position_ids are ignored. - # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. - in_inference_mode = inference_context is not None and not self.training - - # Decoder embedding. - if decoder_input is not None: - pass - elif self.pre_process: - decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) - else: - # intermediate stage of pipeline - # decoder will get hidden_states from encoder.input_tensor - decoder_input = None - - if decoder_input is not None and self.training and torch.is_grad_enabled() and not decoder_input.requires_grad: - # fix LoRA incompatibility with gradient checkpointing - decoder_input = decoder_input.requires_grad_(True) - - # Rotary positional embeddings (embedding is None for PP intermediate devices) - rotary_pos_emb = None - rotary_pos_cos = None - rotary_pos_sin = None - if self.position_embedding_type in {'rope', 'mrope'}: - if not self.training and self.config.flash_decode and inference_context: - assert (inference_context.is_static_batching() - ), 'GPTModel currently only supports static inference batching.' - # Flash decoding uses precomputed cos and sin for RoPE - rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault( - inference_context.max_sequence_length, - self.rotary_pos_emb.get_cos_sin(inference_context.max_sequence_length), - ) - else: - rotary_seq_len = RotaryEmbedding.get_rotary_seq_len(self, inference_context, self.decoder, - decoder_input, self.config, packed_seq_params) - if self.hf_rope_scaling is not None: - attention_scaling = dynamic_rope_update(self, self.rotary_pos_emb.inv_freq, rotary_seq_len) - if attention_scaling is not None and attention_scaling != self.attention_scaling: - raise ValueError('Currently does not support changing attention_scaling during training. ' - f'args.attention_scaling: {self.attention_scaling}, ' - f'current_attention_scaling: {attention_scaling}.') - packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' - if self.position_embedding_type == 'mrope': - mrope_position_ids = position_ids - if mrope_position_ids.dim() == 2: - mrope_position_ids = mrope_position_ids.unsqueeze(0).expand(3, -1, -1) - rotary_pos_emb = self.rotary_pos_emb( - mrope_position_ids, - mrope_section=self.mrope_section, - ) - else: - rotary_pos_emb = self.rotary_pos_emb( - rotary_seq_len, - packed_seq=packed_seq, - ) - if packed_seq and not self.config.apply_rope_fusion: - assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' - rotary_pos_emb = rotary_pos_emb[position_ids[0]] - - if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration') - or self.config.flash_decode) and rotary_pos_cos is not None - and inference_context.is_static_batching()): - current_batch_size = input_ids.shape[0] - sequence_len_offset = torch.tensor( - [inference_context.sequence_len_offset] * current_batch_size, - dtype=torch.int32, - device=rotary_pos_cos.device, # Co-locate this with the rotary tensors - ) - else: - sequence_len_offset = None - - # Wrap decoder_input to allow the decoder (TransformerBlock) to delete the - # reference held by this caller function, enabling early garbage collection for - # inference. Skip wrapping if decoder_input is logged after decoder completion. - if in_inference_mode and not has_config_logger_enabled(self.config): - decoder_input = WrappedTensor(decoder_input) - - return decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset - - # Code borrowed from NVIDIA/Megatron-LM - def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - attention_mask: torch.Tensor = None, - decoder_input: torch.Tensor = None, - labels: torch.Tensor = None, - inference_context: BaseInferenceContext = None, - packed_seq_params: PackedSeqParams = None, - extra_block_kwargs: dict = None, - runtime_gather_output: Optional[bool] = None, - *, - inference_params: Optional[BaseInferenceContext] = None, - loss_mask: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - """Forward function of the GPT Model This function passes the input tensors - through the embedding layer, and then the decoeder and finally into the post - processing layer (optional). - - It either returns the Loss values if labels are given or the final hidden units - - Args: - runtime_gather_output (bool): Gather output at runtime. Default None means - `parallel_output` arg in the constructor will be used. - """ - - inference_context = deprecate_inference_params(inference_context, inference_params) - - decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = ( - self._preprocess( - input_ids=input_ids, - position_ids=position_ids, - decoder_input=decoder_input, - inference_context=inference_context, - packed_seq_params=packed_seq_params, - )) - # Run decoder. - hidden_states = self.decoder( - hidden_states=decoder_input, - attention_mask=attention_mask, - inference_context=inference_context, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - packed_seq_params=packed_seq_params, - sequence_len_offset=sequence_len_offset, - **(extra_block_kwargs or {}), - **kwargs, - ) - - # MTP: https://github.com/NVIDIA/Megatron-LM/issues/1661 - return self._postprocess( - hidden_states=hidden_states, - input_ids=input_ids, - position_ids=position_ids, - labels=labels, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - loss_mask=loss_mask, - decoder_input=decoder_input, - attention_mask=attention_mask, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - sequence_len_offset=sequence_len_offset, - runtime_gather_output=runtime_gather_output, - extra_block_kwargs=extra_block_kwargs, - inference_context=inference_context, - ) - - def _postprocess( - self, - hidden_states, - input_ids, - position_ids, - labels, - rotary_pos_emb, - rotary_pos_cos, - rotary_pos_sin, - loss_mask=None, - decoder_input=None, - attention_mask=None, - inference_params=None, - packed_seq_params=None, - sequence_len_offset=None, - runtime_gather_output=None, - extra_block_kwargs=None, - inference_context=None, - ): - """Postprocesses decoder hidden states to generate logits or compute loss. - - Applies Multi-Token Prediction if enabled, generates output logits through - the output layer, and computes language model loss when labels are provided. - """ - if not self.post_process: - return hidden_states - in_inference_mode = inference_context is not None and not self.training - if in_inference_mode: - assert runtime_gather_output, 'Inference must always gather TP logits' - - # logits and loss - output_weight = None - if self.share_embeddings_and_output_weights: - output_weight = self.shared_embedding_or_output_weight() - - if self.mtp_process: - hidden_states = self.mtp( - input_ids=input_ids, - position_ids=position_ids, - hidden_states=hidden_states, - attention_mask=attention_mask, - inference_params=inference_params, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - packed_seq_params=packed_seq_params, - sequence_len_offset=sequence_len_offset, - embedding=self.embedding, - **(extra_block_kwargs or {}), - ) - hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) - hidden_states = hidden_states_list[0] - - if labels is not None: - mtp_labels = labels.clone() - if loss_mask is None: - # if loss_mask is not provided, use all ones as loss_mask - if packed_seq_params is None: - loss_mask = torch.ones_like(mtp_labels) - else: - loss_mask = mtp_labels.new_ones((1, packed_seq_params.cu_seqlens_q[-1])) - cu_seqlens = packed_seq_params.cu_seqlens_q if packed_seq_params is not None else None - for mtp_layer_number in range(self.config.mtp_num_layers): - # output - mtp_logits, _ = self.output_layer( - hidden_states_list[mtp_layer_number + 1], - weight=output_weight, - runtime_gather_output=runtime_gather_output, - ) - # Calc loss for the current Multi-Token Prediction (MTP) layers. - mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group) - if cu_seqlens is None: - loss_mask_, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group) - else: - loss_mask[:, cu_seqlens[:-1]] = 0 - loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1) - if mpu.get_context_parallel_world_size() > 1: - loss_mask_ = split_cp_inputs(loss_mask, cu_seqlens, dim=1) - else: - loss_mask_ = loss_mask.clone() - mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits) - mtp_loss = loss_mask_ * mtp_loss - num_tokens = loss_mask_.sum() - if self.training: - # after moving loss logging to loss_func in pretrain_gpt.py - MTPLossLoggingHelper.save_loss_to_tracker( - torch.sum(mtp_loss) / num_tokens, - mtp_layer_number, - self.config.mtp_num_layers, - avg_group=mpu.get_data_parallel_group(with_context_parallel=True), - ) - mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers - if self.config.calculate_per_token_loss: - hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss) - else: - hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens) - sequence_parallel_override = False - if in_inference_mode and inference_context.materialize_only_last_token_logits: - if inference_context.is_static_batching(): - hidden_states = hidden_states[-1:, :, :] - else: - if self.output_layer.sequence_parallel: - # Perform the sequence parallel gather here instead of after the output layer - # because we need to slice the last token logits from the full view of the - # packed logits across all requests. - # TODO(ksanthanam): Make the equivalent change in the `MambaModel` code after - # merging in !3722. - hidden_states = gather_from_sequence_parallel_region(hidden_states, group=self.pg_collection.tp) - self.output_layer.sequence_parallel = False - sequence_parallel_override = True - - # Reshape [B, 1, H] to [1, B, H] → extract each sample’s true last‐token hidden - # state ([B, H]) → unsqueeze back to [1, B, H] - # (so that the output layer, which expects S×B×H, receives only the final token) - hidden_states = inference_context.last_token_logits(hidden_states.squeeze(1).unsqueeze(0)).unsqueeze(1) - - logits, _ = self.output_layer(hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output) - - # Restore sequence parallel execution to the output layer if necessary. - if sequence_parallel_override: - assert (in_inference_mode and inference_context.is_dynamic_batching() - and inference_context.materialize_only_last_token_logits) - self.output_layer.sequence_parallel = True - - if has_config_logger_enabled(self.config): - payload = OrderedDict({ - 'input_ids': input_ids, - 'position_ids': position_ids, - 'attention_mask': attention_mask, - 'decoder_input': decoder_input, - 'logits': logits, - }) - log_config_to_disk(self.config, payload, prefix='input_and_logits') - - if labels is None: - # [s b h] => [b s h] - return logits.transpose(0, 1).contiguous() - - loss = self.compute_language_model_loss(labels, logits) - - return loss - - def get_input_tensor(self): - return self.decoder.input_tensor diff --git a/src/twinkle/model/megatron/model/gpts/__init__.py b/src/twinkle/model/megatron/model/gpts/__init__.py deleted file mode 100644 index 6c11171b..00000000 --- a/src/twinkle/model/megatron/model/gpts/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -from ..constant import MegatronModelType, ModelType -from ..register import MegatronModelMeta, register_megatron_model - -register_megatron_model( - MegatronModelMeta( - MegatronModelType.gpt, - [ - ModelType.qwen2, - ModelType.qwen3, - ModelType.qwen2_moe, - ModelType.qwen3_moe, - ], - )) diff --git a/src/twinkle/model/megatron/model/gpts/qwen3_next.py b/src/twinkle/model/megatron/model/gpts/qwen3_next.py deleted file mode 100644 index 7ae0c943..00000000 --- a/src/twinkle/model/megatron/model/gpts/qwen3_next.py +++ /dev/null @@ -1,512 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -# Reference: swift/swift/megatron/model/gpts/qwen3_next.py -# Qwen3-Next / Qwen3.5 series model support for Megatron - -import megatron.core -import torch -from copy import deepcopy -from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, _get_extra_te_kwargs -from megatron.core.inference.contexts import BaseInferenceContext -from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec, get_gpt_mtp_block_spec -from megatron.core.models.huggingface import HuggingFaceModule as _HuggingFaceModule -from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.tensor_parallel import (gather_from_sequence_parallel_region, - reduce_scatter_to_sequence_parallel_region) -from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules -from megatron.core.transformer.identity_op import IdentityOp -from megatron.core.transformer.spec_utils import build_module -from megatron.core.transformer.transformer_block import TransformerBlockSubmodules -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import deprecate_inference_params, is_fa_min_version -from packaging import version -from typing import List, Optional, Tuple, Union - -from twinkle import get_logger -from twinkle.model.megatron.args import get_args -from twinkle.model.megatron.model.register import MegatronModelLoader - -mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') -mcore_015 = version.parse(megatron.core.__version__) >= version.parse('0.15.0rc0') -try: - from flashattn_hopper.flash_attn_interface import _flash_attn_forward - from flashattn_hopper.flash_attn_interface import flash_attn_with_kvcache as flash_attn3_with_kvcache - HAVE_FA3 = True -except Exception: - HAVE_FA3 = False - -try: - from einops import rearrange -except ImportError: - rearrange = None - -try: - import transformer_engine # pylint: disable=unused-import - HAVE_TE = True - from megatron.core.extensions.transformer_engine import SplitAlongDim -except ImportError: - HAVE_TE = False - SplitAlongDim = None - -logger = get_logger() - - -class Qwen3NextRMSNorm(torch.nn.Module): - """ - Zero-Centered RMSNorm for Qwen3-Next/Qwen3.5. - Uses (1 + weight) scaling to match HuggingFace implementation exactly. - This eliminates the need for +1/-1 offset during weight conversion. - """ - - def __init__(self, config, hidden_size: int, eps: float = 1e-5): - super().__init__() - self.config = config - self.eps = eps - self.weight = torch.nn.Parameter(torch.zeros(hidden_size)) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, hidden_states): - output = self._norm(hidden_states.float()) - output = output * (1.0 + self.weight.float()) - return output.type_as(hidden_states) - - -class Qwen3NextSelfAttention(SelfAttention): - """Full attention with output gate for Qwen3-Next/Qwen3.5 models. - - QKV projection produces [Q_heads, gate_heads, K_heads, V_heads] where - Q and gate are interleaved: Q0, gate0, Q1, gate1, ... - """ - - def __init__(self, config, submodules: SelfAttentionSubmodules, *args, **kwargs): - super(SelfAttention, self).__init__(config, submodules, *args, attention_type='self', **kwargs) - kwargs_pg = {} - if mcore_015: - kwargs_pg['tp_group'] = self.pg_collection.tp - elif mcore_013: - kwargs_pg['tp_group'] = self.model_comm_pgs.tp - self.linear_qkv = build_module( - submodules.linear_qkv, - self.config.hidden_size, - 2 * self.query_projection_size + 2 * self.kv_projection_size, - config=self.config, - init_method=self.config.init_method, - gather_output=False, - bias=self.config.add_bias_linear or self.config.add_qkv_bias, - skip_bias_add=False, - is_expert=False, - tp_comm_buffer_name='qkv', - **kwargs_pg, - ) - - if submodules.q_layernorm is not None: - self.q_layernorm = build_module( - submodules.q_layernorm, - hidden_size=self.hidden_size_per_attention_head, - config=self.config, - eps=self.config.layernorm_epsilon, - ) - else: - self.q_layernorm = None - - if submodules.k_layernorm is not None: - self.k_layernorm = build_module( - submodules.k_layernorm, - hidden_size=self.hidden_size_per_attention_head, - config=self.config, - eps=self.config.layernorm_epsilon, - ) - else: - self.k_layernorm = None - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - inference_context: Optional[BaseInferenceContext] = None, - rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, - rotary_pos_cos: Optional[torch.Tensor] = None, - rotary_pos_sin: Optional[torch.Tensor] = None, - attention_bias: Optional[torch.Tensor] = None, - packed_seq_params: Optional[PackedSeqParams] = None, - sequence_len_offset: Optional[int] = None, - *, - inference_params: Optional[BaseInferenceContext] = None, - **kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor]: - try: - from megatron.core.utils import nvtx_range_pop, nvtx_range_push - except ImportError: - - def nvtx_range_pop(*args, **kwargs): - return - - def nvtx_range_push(*args, **kwargs): - return - - if hasattr(self.config, 'no_rope_freq'): - no_rope = (self.config.no_rope_freq[self.layer_number - 1] if self.config.no_rope_freq else False) - if no_rope: - rotary_pos_emb = None - - inference_context = deprecate_inference_params(inference_context, inference_params) - - if inference_context and inference_context.is_dynamic_batching(): - assert HAVE_FA3 or is_fa_min_version( - '2.7.3'), 'flash attn verion v2.7.3 and above is required for dynamic batching.' - - if self.config.flash_decode and not self.training and inference_context is not None: - rotary_pos_emb = None - else: - assert rotary_pos_cos is None and rotary_pos_sin is None - - if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): - rotary_pos_emb = (rotary_pos_emb, ) * 2 - - nvtx_range_push(suffix='qkv') - query, key, value, gate = self.get_query_key_value_tensors(hidden_states, key_value_states) - nvtx_range_pop(suffix='qkv') - - in_decode_mode = (inference_context is not None and inference_context.is_decode_only() and not self.training) - - nvtx_range_push(suffix='adjust_key_value') - if in_decode_mode and self.config.flash_decode: - assert self.layer_number in inference_context.key_value_memory_dict - assert inference_context.sequence_len_offset is not None - inference_key_memory, inference_value_memory = inference_context.key_value_memory_dict[self.layer_number] - output = self.flash_decode( - sequence_len_offset=sequence_len_offset, - query_layer=query, - key_layer=key, - value_layer=value, - inference_key_memory=inference_key_memory, - inference_value_memory=inference_value_memory, - rotary_cos=rotary_pos_cos, - rotary_sin=rotary_pos_sin, - rotary_interleaved=self.config.rotary_interleaved, - ) - out = output.transpose(0, 1).contiguous() - context_layer = out.view(out.size(0), out.size(1), -1) - output, bias = self.linear_proj(context_layer) - return output, bias - - if (in_decode_mode and self.config.enable_cuda_graph and inference_context.is_static_batching()): - raise ValueError('CUDA graphs must use flash decode with static batching!') - - result = self._adjust_key_value_for_inference( - inference_context, - query, - key, - value, - rotary_pos_emb, - rotary_pos_cos, - rotary_pos_sin, - sequence_len_offset, - ) - if mcore_013: - query, key, value, rotary_pos_emb, attn_mask_type, block_table = result - else: - query, key, value, rotary_pos_emb, attn_mask_type = result - - if packed_seq_params is not None: - query = query.squeeze(1) - key = key.squeeze(1) - value = value.squeeze(1) - nvtx_range_pop(suffix='adjust_key_value') - - kwargs_cp = {} - if mcore_015: - kwargs_cp['cp_group'] = self.pg_collection.cp - elif mcore_013: - kwargs_cp['cp_group'] = self.model_comm_pgs.cp - nvtx_range_push(suffix='rotary_pos_emb') - if rotary_pos_emb is not None and not self.config.flash_decode: - q_pos_emb, k_pos_emb = rotary_pos_emb - - if packed_seq_params is not None: - cu_seqlens_q = ( - packed_seq_params.cu_seqlens_q_padded - if packed_seq_params.cu_seqlens_q_padded is not None else packed_seq_params.cu_seqlens_q) - cu_seqlens_kv = ( - packed_seq_params.cu_seqlens_kv_padded - if packed_seq_params.cu_seqlens_kv_padded is not None else packed_seq_params.cu_seqlens_kv) - else: - cu_seqlens_q = cu_seqlens_kv = None - - if q_pos_emb is not None: - if inference_context is None or inference_context.is_static_batching(): - query = apply_rotary_pos_emb( - query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q, **kwargs_cp) - else: - query = inference_context.apply_rotary_emb_query(query, q_pos_emb, self.config, cu_seqlens_q, - **kwargs_cp) - if k_pos_emb is not None: - key = apply_rotary_pos_emb(key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv, **kwargs_cp) - nvtx_range_pop(suffix='rotary_pos_emb') - - nvtx_range_push(suffix='core_attention') - if self.checkpoint_core_attention and self.training: - core_attn_out = self._checkpointed_attention_forward( - query, - key, - value, - attention_mask, - attn_mask_type=attn_mask_type, - attention_bias=attention_bias, - packed_seq_params=packed_seq_params, - ) - else: - if inference_context is None or inference_context.is_static_batching(): - core_attn_out = self.core_attention( - query, - key, - value, - attention_mask, - attn_mask_type=attn_mask_type, - attention_bias=attention_bias, - packed_seq_params=packed_seq_params, - ) - else: - q, k, v = (query, key, value) - cu_query_lengths, max_seqlen_q = inference_context.cu_query_lengths() - cu_kv_lengths, kv_lengths, kv_lengths_decode_only, max_seqlen_k = (inference_context.cu_kv_lengths()) - core_attn_out = self.flash_decode_and_prefill( - q, - k, - v, - max_seqlen_q, - max_seqlen_k, - cu_query_lengths, - cu_kv_lengths, - kv_lengths, - kv_lengths_decode_only, - block_table, - ) - core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)') - - if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': - core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) - nvtx_range_pop(suffix='core_attention') - - core_attn_out = core_attn_out * torch.sigmoid(gate.reshape_as(core_attn_out)) - nvtx_range_push(suffix='linear_proj') - output, bias = self.linear_proj(core_attn_out) - nvtx_range_pop(suffix='linear_proj') - - return output, bias - - def get_query_key_value_tensors(self, hidden_states, key_value_states=None): - mixed_qkv, _ = self.linear_qkv(hidden_states) - - new_tensor_shape = mixed_qkv.size()[:-1] + ( - self.num_query_groups_per_partition, - ((self.num_attention_heads_per_partition // self.num_query_groups_per_partition * 2 + 2) - * self.hidden_size_per_attention_head), - ) - mixed_qkv = mixed_qkv.view(*new_tensor_shape) - split_arg_list = [ - (self.num_attention_heads_per_partition // self.num_query_groups_per_partition - * self.hidden_size_per_attention_head * 2), - self.hidden_size_per_attention_head, - self.hidden_size_per_attention_head, - ] - - if SplitAlongDim is not None: - (query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list) - else: - (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3) - - query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) - query, gate = query[:, :, ::2], query[:, :, 1::2] - if self.q_layernorm is not None: - query = self.q_layernorm(query) - if self.k_layernorm is not None: - key = self.k_layernorm(key) - - if self.config.test_mode: - self.run_realtime_tests() - - return query, key, value, gate - - -def _gated_delta_net_forward(self, hidden_states: torch.Tensor, **kwargs): - """Shared forward logic for all GatedDeltaNet variants.""" - args = get_args() - if args.sequence_parallel and args.tensor_model_parallel_size > 1: - hidden_states = gather_from_sequence_parallel_region(hidden_states) - seq_len = hidden_states.shape[0] - packed_seq_params = kwargs.get('packed_seq_params') - thd_format = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' - if thd_format and not getattr(args, 'packing', False): - new_hidden_states = hidden_states.new_zeros( - (packed_seq_params.num_samples, packed_seq_params.max_seqlen_q.item(), hidden_states.shape[-1])) - attention_mask = hidden_states.new_zeros((packed_seq_params.num_samples, packed_seq_params.max_seqlen_q.item()), - dtype=torch.bool) - cu_seqlens_q = packed_seq_params.cu_seqlens_q - for i in range(packed_seq_params.num_samples): - start, end = cu_seqlens_q[i], cu_seqlens_q[i + 1] - attention_mask[i, :end - start] = True - new_hidden_states[i, :end - start] = hidden_states[start:end, 0] - hidden_states = new_hidden_states - else: - hidden_states = hidden_states.transpose(0, 1) - attention_mask = kwargs.get('attention_mask') - if attention_mask is not None: - attention_mask = (~attention_mask).sum(dim=(1, 2)) > 0 - res = super(type(self), self).forward(hidden_states=hidden_states, attention_mask=attention_mask) - if thd_format and not getattr(args, 'packing', False): - res = res[attention_mask][:, None] - res = torch.concat([res, res.new_zeros(seq_len - res.shape[0], 1, res.shape[2])]) - else: - res = res.transpose(0, 1).contiguous() - if args.sequence_parallel and args.tensor_model_parallel_size > 1: - res = reduce_scatter_to_sequence_parallel_region(res) / args.tensor_model_parallel_size - return res, None - - -def _gated_delta_net_init(self, hf_cls, config, submodules, layer_number, **kwargs): - """Shared __init__ logic for all GatedDeltaNet variants.""" - assert config.context_parallel_size == 1, 'Qwen3-Next/Qwen3.5 currently does not support context parallel.' - hf_cls.__init__(self, config, layer_number) - self.config = config - extra_kwargs = _get_extra_te_kwargs(config) - self.to(dtype=extra_kwargs['params_dtype'], device=extra_kwargs['device']) - - -try: - from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeGatedDeltaNet as _Qwen3_5MoeGatedDeltaNet -except ImportError: - _Qwen3_5MoeGatedDeltaNet = object - -try: - from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextGatedDeltaNet as _Qwen3NextGatedDeltaNet -except ImportError: - _Qwen3NextGatedDeltaNet = object - - -class Qwen3NextGatedDeltaNet(_HuggingFaceModule, _Qwen3NextGatedDeltaNet): - """GatedDeltaNet for linear attention layers in Qwen3-Next models.""" - - def __init__(self, config, submodules: SelfAttentionSubmodules, layer_number: int, **kwargs): - assert _Qwen3NextGatedDeltaNet is not object, 'please update the `transformers` version.' - _gated_delta_net_init(self, _Qwen3NextGatedDeltaNet, config, submodules, layer_number, **kwargs) - - def forward(self, hidden_states: torch.Tensor, **kwargs): - return _gated_delta_net_forward(self, hidden_states, **kwargs) - - -class Qwen3_5MoeGatedDeltaNet(_HuggingFaceModule, _Qwen3_5MoeGatedDeltaNet): - """GatedDeltaNet for Qwen3.5-MoE linear attention layers.""" - - def __init__(self, config, submodules: SelfAttentionSubmodules, layer_number: int, **kwargs): - assert _Qwen3_5MoeGatedDeltaNet is not object, 'please update the `transformers` version.' - _gated_delta_net_init(self, _Qwen3_5MoeGatedDeltaNet, config, submodules, layer_number, **kwargs) - - def forward(self, hidden_states: torch.Tensor, **kwargs): - return _gated_delta_net_forward(self, hidden_states, **kwargs) - - -def get_local_layer_specs(config, layer_specs, vp_stage=None): - """Get the layer specs for layers assigned to this pipeline stage. - - Mirrors swift.megatron.utils.get_local_layer_specs for distributing - heterogeneous layer specs across pipeline stages. - """ - from megatron.core import mpu - pp_size = mpu.get_pipeline_model_parallel_world_size() - pp_rank = mpu.get_pipeline_model_parallel_rank() - if pp_size <= 1: - return layer_specs - num_layers = len(layer_specs) - layers_per_stage = num_layers // pp_size - remainder = num_layers % pp_size - start = pp_rank * layers_per_stage + min(pp_rank, remainder) - if pp_rank < remainder: - layers_per_stage += 1 - return layer_specs[start:start + layers_per_stage] - - -def get_qwen3_next_layer_spec(config, args, gated_delta_net_cls): - """Build the heterogeneous transformer layer specs for Qwen3-Next/Qwen3.5. - - Returns a TransformerBlockSubmodules with per-layer specs matching - the model's layer_types (linear_attention / full_attention). - """ - config.hetereogenous_dist_checkpoint = True - config.hidden_act = 'silu' - config.rms_norm_eps = config.layernorm_epsilon - config.dtype = args.params_dtype - - layer_norm_impl = Qwen3NextRMSNorm - kwargs = {'use_kitchen': config.use_kitchen} if hasattr(config, 'use_kitchen') and mcore_013 else {} - moe_layer_spec = get_gpt_layer_with_transformer_engine_spec( - num_experts=config.num_moe_experts, - moe_grouped_gemm=getattr(config, 'moe_grouped_gemm', True), - qk_layernorm=config.qk_layernorm, - multi_latent_attention=config.multi_latent_attention, - moe_use_legacy_grouped_gemm=getattr(config, 'moe_use_legacy_grouped_gemm', False), - **kwargs, - ) - layer_specs = [] - for layer_type in config.layer_types: - layer_spec = deepcopy(moe_layer_spec) - if layer_type == 'linear_attention': - layer_spec.submodules.self_attention.module = gated_delta_net_cls - elif layer_type == 'full_attention': - layer_spec.submodules.self_attention.submodules.linear_qkv = TEColumnParallelLinear - layer_spec.submodules.self_attention.module = Qwen3NextSelfAttention - # Replace ALL layernorms with Qwen3NextRMSNorm (Zero-Centered) - layer_spec.submodules.input_layernorm = layer_norm_impl - if hasattr(layer_spec.submodules, 'pre_mlp_layernorm'): - layer_spec.submodules.pre_mlp_layernorm = layer_norm_impl - # qwen3.5 dense - if args.hf_model_type == 'qwen3_5': - layer_spec.submodules.mlp.submodules.linear_fc1 = TEColumnParallelLinear - # Replace qk_layernorm if present - if hasattr(layer_spec.submodules.self_attention.submodules, 'q_layernorm'): - layer_spec.submodules.self_attention.submodules.q_layernorm = layer_norm_impl - if hasattr(layer_spec.submodules.self_attention.submodules, 'k_layernorm'): - layer_spec.submodules.self_attention.submodules.k_layernorm = layer_norm_impl - if (getattr(config, 'moe_use_shared_expert_gate', False) and hasattr(layer_spec.submodules, 'mlp') - and hasattr(layer_spec.submodules.mlp.submodules, 'shared_experts')): - layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True} - layer_specs.append(layer_spec) - - local_layer_specs = get_local_layer_specs(config, layer_specs) - block_spec = TransformerBlockSubmodules(layer_specs=local_layer_specs, layer_norm=layer_norm_impl) - - return block_spec - - -def get_qwen3_next_mtp_block_spec(config, transformer_layer_spec, **kwargs): - """Build MTP block spec with Qwen3NextRMSNorm.""" - mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec, use_transformer_engine=True, **kwargs) - for layer_spec in mtp_block_spec.layer_specs: - layer_spec.submodules.enorm = Qwen3NextRMSNorm - layer_spec.submodules.hnorm = Qwen3NextRMSNorm - layer_spec.submodules.layer_norm = Qwen3NextRMSNorm - return mtp_block_spec - - -class Qwen3NextLoader(MegatronModelLoader): - """Loader for Qwen3-Next models with heterogeneous linear/full attention layers.""" - gated_delta_net = Qwen3NextGatedDeltaNet - - def post_config(self, config, args, mg_config_dict): - layer_types = mg_config_dict.get('layer_types') - if layer_types is not None: - config.layer_types = layer_types - for attr in ('linear_num_value_heads', 'linear_num_key_heads', 'linear_key_head_dim', - 'linear_value_head_dim', 'linear_conv_kernel_dim'): - val = mg_config_dict.get(attr) - if val is not None: - setattr(config, attr, val) - - def get_layer_spec(self, config, args, mg_config_dict): - return get_qwen3_next_layer_spec(config, args, self.gated_delta_net) - - def get_mtp_block_spec(self, config, layer_spec, **kwargs): - return get_qwen3_next_mtp_block_spec(config, layer_spec, **kwargs) diff --git a/src/twinkle/model/megatron/model/mm_gpt_model.py b/src/twinkle/model/megatron/model/mm_gpt_model.py deleted file mode 100644 index 83a86ef5..00000000 --- a/src/twinkle/model/megatron/model/mm_gpt_model.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -import megatron.core -import torch -from contextlib import contextmanager -from megatron.core import InferenceParams, mpu -from megatron.core.enums import ModelType -from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.tensor_parallel import VocabParallelEmbedding, reduce_scatter_to_sequence_parallel_region -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_config import TransformerConfig -from packaging import version - -from twinkle.model.megatron.args import get_args -from .gpt_model import GPTModel - -mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') - - -class MultimodalGPTModel(MegatronModule): - - def __init__(self, - config: TransformerConfig, - transformer_layer_spec: ModuleSpec, - vocab_size: int, - max_sequence_length: int, - pre_process: bool = True, - post_process: bool = True, - *args, - **kwargs): - from .register import get_megatron_model_meta - super().__init__(config) - # Required by Megatron's forward_backward scheduling - self.model_type = ModelType.encoder_or_decoder - self.pre_process = pre_process - self.post_process = post_process - self.language_model = GPTModel(config, transformer_layer_spec, vocab_size, max_sequence_length, pre_process, - post_process, *args, **kwargs) - self.vp_stage = self.language_model.vp_stage - self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights - args = get_args() - self.megatron_model_meta = get_megatron_model_meta(args.hf_model_type) - self.visual = None - if args.mtp_num_layers: - raise ValueError('MTP currently does not support multimodal models.') - if pre_process and self.megatron_model_meta.visual_cls is not None: - self.visual = self.megatron_model_meta.visual_cls(config) - - @contextmanager - def _patch_word_embeddings(self, kwargs): - origin_forward = VocabParallelEmbedding.forward - - def forward(_self, input_): - from twinkle.model.megatron.utils import split_cp_inputs - reduce_scatter_embeddings = _self.reduce_scatter_embeddings - _self.reduce_scatter_embeddings = False - input_ = torch.masked_fill(input_, input_ < 0, 0) - res = origin_forward(_self, input_) - _self.reduce_scatter_embeddings = reduce_scatter_embeddings - packed_seq_params = kwargs.get('packed_seq_params') - if self.visual is not None: - res = self.visual.get_inputs_embeds(res, **kwargs) - kwargs.clear() - if isinstance(res, dict): - # compat dict - inputs_embeds = res.pop('inputs_embeds') - kwargs.update(res) - res = inputs_embeds - cp_size = mpu.get_context_parallel_world_size() - if cp_size > 1: - # Pad embedding sequence to be divisible by 2 * cp_size - # This is required for the load-balanced CP split algorithm - seq_dim = 1 # res shape: [batch, seq, hidden] - seq_len = res.shape[seq_dim] - divisor = 2 * cp_size - if seq_len % divisor != 0: - pad_len = divisor - (seq_len % divisor) - # Pad with zeros on the sequence dimension - # res shape: [batch, seq, hidden], pad the seq dimension - res = torch.nn.functional.pad(res, (0, 0, 0, pad_len), value=0) - res = split_cp_inputs(res, getattr(packed_seq_params, 'cu_seqlens_q', None), seq_dim) - if reduce_scatter_embeddings: - res = res.transpose(0, 1).contiguous() - group_kwargs = {'group': _self.tp_group} if mcore_013 else {} - tp_size = mpu.get_tensor_model_parallel_world_size() - res = reduce_scatter_to_sequence_parallel_region(res, **group_kwargs) / tp_size - return res - - VocabParallelEmbedding.forward = forward - try: - yield - finally: - VocabParallelEmbedding.forward = origin_forward - - # Code borrowed from NVIDIA/Megatron-LM - def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - attention_mask: torch.Tensor = None, - decoder_input: torch.Tensor = None, - labels: torch.Tensor = None, - inference_params: InferenceParams = None, - packed_seq_params: PackedSeqParams = None, - **kwargs, - ) -> torch.Tensor: - if decoder_input is not None: - pass - elif self.pre_process: - kwargs.update({'input_ids': input_ids, 'packed_seq_params': packed_seq_params}) - with self._patch_word_embeddings(kwargs): - decoder_input = self.language_model.embedding(input_ids=input_ids, position_ids=position_ids) - else: - # intermediate stage of pipeline - # decoder will get hidden_states from encoder.input_tensor - decoder_input = None - kwargs = {} - return self.language_model( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - decoder_input=decoder_input, - labels=labels, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - **kwargs, - ) - - def set_input_tensor(self, input_tensor: torch.Tensor) -> None: - return self.language_model.set_input_tensor(input_tensor) - - def get_input_tensor(self): - return self.language_model.get_input_tensor() - - def shared_embedding_or_output_weight(self) -> torch.Tensor: - return self.language_model.shared_embedding_or_output_weight() diff --git a/src/twinkle/model/megatron/model/mm_gpts/__init__.py b/src/twinkle/model/megatron/model/mm_gpts/__init__.py deleted file mode 100644 index 30f10d89..00000000 --- a/src/twinkle/model/megatron/model/mm_gpts/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -from . import qwen, qwen3_5, qwen3_vl, utils diff --git a/src/twinkle/model/megatron/model/mm_gpts/qwen.py b/src/twinkle/model/megatron/model/mm_gpts/qwen.py deleted file mode 100644 index 267a1216..00000000 --- a/src/twinkle/model/megatron/model/mm_gpts/qwen.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -import torch -from PIL import Image -from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2VLForConditionalGeneration - -from twinkle.utils.torch_utils import to_device -from ..constant import MegatronModelType, ModelType -from ..gpt_bridge import MultimodalGPTBridge -from ..register import MegatronModelMeta, register_megatron_model -from .utils import HuggingFaceModule - - -class Qwen2_5VL_Vit(HuggingFaceModule): - module_mapping = {'model.visual': 'visual'} - _vision_tower = ['visual'] - _aligner = ['visual.merger'] - version = 'v2_5' - - def __init__(self, config): - if self.version == 'v2_5': - try: - from transformers.models.qwen2_5_vl import Qwen2_5_VLTextModel - except ImportError: - from transformers.models.qwen2_5_vl import Qwen2_5_VLModel as Qwen2_5_VLTextModel - ignore_init_model_cls = Qwen2_5_VLTextModel - elif self.version == 'v2': - try: - from transformers.models.qwen2_vl import Qwen2VLTextModel - except ImportError: - from transformers.models.qwen2_vl import Qwen2VLModel as Qwen2VLTextModel - ignore_init_model_cls = Qwen2VLTextModel - super().__init__(config, ignore_init_model_cls) - - def get_inputs_embeds(self, inputs_embeds, **kwargs): - return self._get_inputs_embeds_hf(inputs_embeds, kwargs, self.visual, self.processor, self.model_config) - - def _get_inputs_embeds_hf(self, inputs_embeds, inputs, visual, processor, config): - # mimic the behavior of Template._get_inputs_embeds_hf in swift - input_ids = inputs['input_ids'] - pixel_values = inputs.get('pixel_values') - pixel_values_videos = inputs.get('pixel_values_videos') - image_grid_thw = inputs.get('image_grid_thw') - video_grid_thw = inputs.get('video_grid_thw') - dtype = visual.dtype - if pixel_values is None and pixel_values_videos is None: # plain-text - images = [Image.new('RGB', (32, 32), (0, 0, 0))] - media_inputs = processor.image_processor(images=images, return_tensors='pt') - media_inputs = to_device(media_inputs, input_ids.device) - pixel_values = media_inputs['pixel_values'].type(dtype) - image_embeds = visual(pixel_values, grid_thw=media_inputs['image_grid_thw']) - inputs_embeds = inputs_embeds + image_embeds.mean().to(device=inputs_embeds.device) * 0. - else: - if pixel_values is None: - pixel_values_mixed = pixel_values_videos - grid_thw = video_grid_thw - elif pixel_values_videos is None: - pixel_values_mixed = pixel_values - grid_thw = image_grid_thw - else: - pixel_values_mixed = torch.concat([pixel_values, pixel_values_videos], dim=0) - grid_thw = torch.concat([image_grid_thw, video_grid_thw], dim=0) - pixel_values_mixed = pixel_values_mixed.type(dtype) - mixed_embeds = visual(pixel_values_mixed, grid_thw=grid_thw) - if pixel_values is None: - image_embeds = None - video_embeds = mixed_embeds - elif pixel_values_videos is None: - image_embeds = mixed_embeds - video_embeds = None - else: - merge_length = processor.image_processor.merge_size**2 - image_tokens = (image_grid_thw.prod(dim=-1) // merge_length).sum() - image_embeds = mixed_embeds[:image_tokens] - video_embeds = mixed_embeds[image_tokens:] - - if image_embeds is not None: - image_mask = (input_ids == config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - image_mask = image_mask.to(inputs_embeds.device) - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) - - if video_embeds is not None: - video_mask = (input_ids == config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - video_mask = video_mask.to(inputs_embeds.device) - inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) - return inputs_embeds - - -class Qwen2_5VLBridge(MultimodalGPTBridge): - # Compatible with older versions of transformers - hf_state_dict_mapping = { - 'model.layers': 'model.language_model.layers', - 'model.embed_tokens': 'model.language_model.embed_tokens', - 'model.norm': 'model.language_model.norm', - 'visual': 'model.visual', - } - - -register_megatron_model( - MegatronModelMeta( - MegatronModelType.qwen2_5_vl, [ - ModelType.qwen2_5_vl, - ], - bridge_cls=Qwen2_5VLBridge, - visual_cls=Qwen2_5VL_Vit, - auto_model_cls=Qwen2_5_VLForConditionalGeneration)) - - -class Qwen2VL_Vit(Qwen2_5VL_Vit): - version = 'v2' - - -register_megatron_model( - MegatronModelMeta( - MegatronModelType.qwen2_vl, [ - ModelType.qwen2_vl, - ], - bridge_cls=Qwen2_5VLBridge, - visual_cls=Qwen2VL_Vit, - auto_model_cls=Qwen2VLForConditionalGeneration)) diff --git a/src/twinkle/model/megatron/model/mm_gpts/qwen3_5.py b/src/twinkle/model/megatron/model/mm_gpts/qwen3_5.py deleted file mode 100644 index dbffd992..00000000 --- a/src/twinkle/model/megatron/model/mm_gpts/qwen3_5.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -# Reference: swift/swift/megatron/model/mm_gpts/qwen3_5.py -# Qwen3.5 / Qwen3.5-MoE multimodal model support for Megatron - -import torch -from PIL import Image - -from twinkle.model.megatron.args import get_args -from twinkle.utils.torch_utils import to_device -from ..constant import MegatronModelType, ModelType -from ..gpt_bridge import GPTBridge, MultimodalGPTBridge -from ..gpts.qwen3_next import Qwen3_5MoeGatedDeltaNet, Qwen3NextLoader -from ..register import MegatronModelMeta, register_megatron_model -from .utils import HuggingFaceModule - - -class Qwen3_5Vit(HuggingFaceModule): - """Vision module for Qwen3.5 / Qwen3.5-MoE models. - - Maps 'model.visual' from HF model to 'visual' in Megatron, - with merger as aligner. - """ - module_mapping = {'model.visual': 'visual'} - _vision_tower = ['visual'] - _aligner = ['visual.merger'] - - def __init__(self, config): - try: - from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5TextModel - except ImportError: - Qwen3_5TextModel = None - try: - from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeTextModel - except ImportError: - Qwen3_5MoeTextModel = None - ignore_cls = [c for c in [Qwen3_5TextModel, Qwen3_5MoeTextModel] if c is not None] - super().__init__(config, ignore_cls) - - def get_inputs_embeds(self, inputs_embeds, **kwargs): - return self._get_inputs_embeds_hf(inputs_embeds, kwargs, self.visual, self.processor, self.model_config) - - def _get_inputs_embeds_hf(self, inputs_embeds, inputs, visual, processor, config): - input_ids = inputs['input_ids'] - pixel_values = inputs.get('pixel_values') - pixel_values_videos = inputs.get('pixel_values_videos') - image_grid_thw = inputs.get('image_grid_thw') - video_grid_thw = inputs.get('video_grid_thw') - dtype = visual.dtype - if pixel_values is None and pixel_values_videos is None: - images = [Image.new('RGB', (32, 32), (0, 0, 0))] - media_inputs = processor.image_processor(images=images, return_tensors='pt') - media_inputs = to_device(media_inputs, input_ids.device) - pixel_values = media_inputs['pixel_values'].type(dtype) - image_embeds = visual(pixel_values, grid_thw=media_inputs['image_grid_thw']) - if hasattr(image_embeds, 'pooler_output'): - image_embeds = image_embeds.pooler_output - inputs_embeds = inputs_embeds + image_embeds.mean().to(device=inputs_embeds.device) * 0. - else: - if pixel_values is None: - pixel_values_mixed = pixel_values_videos - grid_thw = video_grid_thw - elif pixel_values_videos is None: - pixel_values_mixed = pixel_values - grid_thw = image_grid_thw - else: - pixel_values_mixed = torch.concat([pixel_values, pixel_values_videos], dim=0) - grid_thw = torch.concat([image_grid_thw, video_grid_thw], dim=0) - pixel_values_mixed = pixel_values_mixed.type(dtype) - mixed_embeds = visual(pixel_values_mixed, grid_thw=grid_thw) - if hasattr(mixed_embeds, 'pooler_output'): - mixed_embeds = mixed_embeds.pooler_output - if pixel_values is None: - image_embeds = None - video_embeds = mixed_embeds - elif pixel_values_videos is None: - image_embeds = mixed_embeds - video_embeds = None - else: - merge_length = processor.image_processor.merge_size**2 - image_tokens = (image_grid_thw.prod(dim=-1) // merge_length).sum() - image_embeds = mixed_embeds[:image_tokens] - video_embeds = mixed_embeds[image_tokens:] - - if image_embeds is not None: - image_mask = (input_ids == config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - image_mask = image_mask.to(inputs_embeds.device) - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) - - if video_embeds is not None: - video_mask = (input_ids == config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - video_mask = video_mask.to(inputs_embeds.device) - inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) - return inputs_embeds - - -class Qwen3_5Bridge(MultimodalGPTBridge): - """Bridge for Qwen3.5 multimodal models. - - Uses language_model prefix for the LLM backbone since Qwen3.5 has a - multimodal architecture with model.language_model.layers structure. - - Overrides _set_layer_attn to handle the mixed linear/full attention - architecture specific to Qwen3-Next/Qwen3.5. - """ - hf_layers_prefix = 'model.language_model.layers' - hf_embed_key = 'model.language_model.embed_tokens.weight' - hf_final_layernorm_key = 'model.language_model.norm.weight' - hf_mtp_prefix = 'mtp.layers' - - def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool): - args = self.args - layer_types = getattr(args, 'layer_types', None) - if layer_types is None: - return super()._set_layer_attn(mg_layer, hf_state_dict, layer_idx, to_mcore) - - layer_type = layer_types[layer_idx] if 0 <= layer_idx < len(layer_types) else 'full_attention' - mg_attn = None if mg_layer is None else mg_layer.self_attention - if layer_type == 'linear_attention': - hf_state_dict.update(self._set_module(mg_attn, hf_state_dict, 'linear_attn.', to_mcore)) - elif layer_type == 'full_attention': - hf_state_dict.update(self._set_attn_state(mg_attn, hf_state_dict, 'self_attn.', layer_idx, to_mcore)) - self._set_state_dict(mg_layer, 'input_layernorm.weight', hf_state_dict, 'input_layernorm.weight', to_mcore) - return hf_state_dict - - def _convert_mtp_extra(self, mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict): - hf_state_dict = self._remove_prefix(origin_hf_state_dict, 'mtp.') - for mg_key, key in zip(['enorm.weight', 'hnorm.weight', 'eh_proj.weight'], - ['pre_fc_norm_embedding.weight', 'pre_fc_norm_hidden.weight', 'fc.weight']): - self._set_state_dict(mtp_layer, mg_key, hf_state_dict, key, to_mcore) - self._set_state_dict(mtp_layer, 'final_layernorm.weight', hf_state_dict, 'norm.weight', to_mcore) - if not to_mcore: - origin_hf_state_dict.update(self._add_prefix(hf_state_dict, 'mtp.')) - - -try: - from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeForConditionalGeneration -except ImportError: - Qwen3_5MoeForConditionalGeneration = None - -try: - from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForConditionalGeneration -except ImportError: - Qwen3_5ForConditionalGeneration = None - - -class Qwen3_5MoeLoader(Qwen3NextLoader): - gated_delta_net = Qwen3_5MoeGatedDeltaNet - - -register_megatron_model( - MegatronModelMeta( - MegatronModelType.qwen3_5_moe, - [ - ModelType.qwen3_5_moe, - ], - bridge_cls=Qwen3_5Bridge, - visual_cls=Qwen3_5Vit, - auto_model_cls=Qwen3_5MoeForConditionalGeneration, - loader=Qwen3_5MoeLoader, - )) - -register_megatron_model( - MegatronModelMeta( - MegatronModelType.qwen3_5, - [ - ModelType.qwen3_5, - ], - bridge_cls=Qwen3_5Bridge, - visual_cls=Qwen3_5Vit, - auto_model_cls=Qwen3_5ForConditionalGeneration, - loader=Qwen3_5MoeLoader, - )) diff --git a/src/twinkle/model/megatron/model/mm_gpts/qwen3_vl.py b/src/twinkle/model/megatron/model/mm_gpts/qwen3_vl.py deleted file mode 100644 index 365f4e8f..00000000 --- a/src/twinkle/model/megatron/model/mm_gpts/qwen3_vl.py +++ /dev/null @@ -1,450 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -# Reference: swift/swift/megatron/model/mm_gpts/qwen3_vl.py - -import torch -from contextlib import nullcontext -from megatron.core import parallel_state, tensor_parallel -from megatron.core.enums import Fp8Recipe -from megatron.core.fp8_utils import get_fp8_context -from megatron.core.inference.contexts import BaseInferenceContext -from megatron.core.models.gpt import gpt_model -from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.utils import WrappedTensor, deprecate_inference_params, make_viewless_tensor -from PIL import Image -from transformers.models.qwen3_vl import Qwen3VLForConditionalGeneration -from typing import List, Optional, Union - -from twinkle.model.megatron.args import get_args -from twinkle.model.megatron.model.constant import MegatronModelType, ModelType -from twinkle.model.megatron.model.gpt_bridge import GPTBridge, MultimodalGPTBridge -from twinkle.model.megatron.model.mm_gpt_model import MultimodalGPTModel -from twinkle.utils import to_device -from ..register import MegatronModelMeta, register_megatron_model -from .utils import HuggingFaceModule - -te_checkpoint = None - -try: - import transformer_engine.pytorch as te # pylint: disable=unused-import - HAVE_TE = True -except ImportError: - HAVE_TE = False - -if HAVE_TE: - from megatron.core.extensions.transformer_engine import te_checkpoint - - -class Qwen3Omni_Vit(HuggingFaceModule): - module_mapping = {'thinker': 'thinker', 'talker': 'talker', 'code2wav': 'code2wav'} - _vision_tower = ['thinker.audio_tower', 'thinker.visual'] - _aligner = [ - 'thinker.audio_tower.proj1', 'thinker.audio_tower.proj2', 'thinker.visual.merger', 'thinker.visual.merger_list' - ] - _generator = ['talker', 'code2wav'] - - def __init__(self, config): - from transformers.models.qwen3_omni_moe import Qwen3OmniMoeThinkerTextModel - super().__init__(config, [Qwen3OmniMoeThinkerTextModel]) - - def prepare_model(self, hf_model): - del self.thinker.model - del self.thinker.lm_head - - @staticmethod - def _get_inputs_embeds(inputs_embeds, inputs, visual, processor, config): - from twinkle.model.megatron.utils import split_cp_inputs - input_ids = inputs['input_ids'] - packed_seq_params = inputs.get('packed_seq_params') - pixel_values = inputs.get('pixel_values') - pixel_values_videos = inputs.get('pixel_values_videos') - image_grid_thw = inputs.get('image_grid_thw') - video_grid_thw = inputs.get('video_grid_thw') - dtype = visual.dtype - if pixel_values is None and pixel_values_videos is None: # plain-text - images = [Image.new('RGB', (32, 32), (0, 0, 0))] - media_inputs = processor.image_processor(images=images, return_tensors='pt') - media_inputs = to_device(media_inputs, input_ids.device) - pixel_values = media_inputs['pixel_values'].type(dtype) - image_embeds, deepstack_visual_embeds = visual(pixel_values, grid_thw=media_inputs['image_grid_thw']) - deepstack_visual_embeds = torch.stack(deepstack_visual_embeds, dim=0) - inputs_embeds = inputs_embeds + image_embeds.mean().to(device=inputs_embeds.device) * 0. - visual_pos_masks = None - else: - if pixel_values is None: - pixel_values_mixed = pixel_values_videos - grid_thw = video_grid_thw - elif pixel_values_videos is None: - pixel_values_mixed = pixel_values - grid_thw = image_grid_thw - else: - pixel_values_mixed = torch.concat([pixel_values, pixel_values_videos], dim=0) - grid_thw = torch.concat([image_grid_thw, video_grid_thw], dim=0) - pixel_values_mixed = pixel_values_mixed.type(dtype) - mixed_embeds, deepstack_visual_embeds = visual(pixel_values_mixed, grid_thw=grid_thw) - if pixel_values is None: - image_embeds = None - video_embeds = mixed_embeds - elif pixel_values_videos is None: - image_embeds = mixed_embeds - video_embeds = None - else: - merge_length = processor.image_processor.merge_size**2 - image_tokens = (image_grid_thw.prod(dim=-1) // merge_length).sum() - image_embeds = mixed_embeds[:image_tokens] - video_embeds = mixed_embeds[image_tokens:] - - image_mask = (input_ids == config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) - video_mask = (input_ids == config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) - if image_embeds is not None: - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - image_mask = image_mask.to(inputs_embeds.device) - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) - - if video_embeds is not None: - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - video_mask = video_mask.to(inputs_embeds.device) - inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) - image_mask, video_mask = image_mask[..., 0], video_mask[..., 0] - visual_pos_masks = image_mask | video_mask - if image_embeds is not None and video_embeds is not None: - deepstack_image_embeds = [tensor[:image_tokens] for tensor in deepstack_visual_embeds] - deepstack_video_embeds = [tensor[image_tokens:] for tensor in deepstack_visual_embeds] - deepstack_visual_embeds = [] - image_mask_joint = image_mask[visual_pos_masks] - video_mask_joint = video_mask[visual_pos_masks] - for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds): - embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device) - embed_joint[image_mask_joint, :] = img_embed - embed_joint[video_mask_joint, :] = vid_embed - deepstack_visual_embeds.append(embed_joint) - - deepstack_visual_embeds = torch.stack(deepstack_visual_embeds, dim=0) - visual_pos_masks = visual_pos_masks.transpose(0, 1) - # compat cp - args = get_args() - if args.context_parallel_size > 1: - device = visual_pos_masks.device - cp_mask = torch.full(visual_pos_masks.shape[:1], -1, dtype=torch.long, device=device) - cp_mask[visual_pos_masks[:, 0]] = torch.arange(visual_pos_masks.sum(), device=device) - cu_seqlens = getattr(packed_seq_params, 'cu_seqlens_q', None) - cp_mask = split_cp_inputs(cp_mask, cu_seqlens, 0) - visual_pos_masks = split_cp_inputs(visual_pos_masks, cu_seqlens, 0) - deepstack_visual_embeds = deepstack_visual_embeds[:, cp_mask[(cp_mask != -1)]] - # compat sp - tp_world_size = parallel_state.get_tensor_model_parallel_world_size() - tp_rank = parallel_state.get_tensor_model_parallel_rank() - if args.sequence_parallel and tp_world_size > 1: - visual_pos_masks = visual_pos_masks.view(tp_world_size, -1, *visual_pos_masks.shape[1:]) - mask_tokens = visual_pos_masks.sum(dim=(1, 2)).tolist() - visual_start = 0 if tp_rank == 0 else sum(mask_tokens[:tp_rank]) - visual_end = visual_start + mask_tokens[tp_rank] - visual_pos_masks = visual_pos_masks[tp_rank] - deepstack_visual_embeds = deepstack_visual_embeds[:, visual_start:visual_end] - return { - 'inputs_embeds': inputs_embeds, - 'visual_pos_masks': visual_pos_masks, - 'deepstack_visual_embeds': deepstack_visual_embeds - } - - def get_inputs_embeds(self, inputs_embeds, **kwargs): - """Merge Qwen-Omni vision features into embeddings with audio support. - - Reference: swift/swift/megatron/model/mm_gpts/qwen3_vl.py:149-169 - """ - input_ids = kwargs['input_ids'] - visual = self.thinker.visual - config = self.model_config.thinker_config - res = self._get_inputs_embeds(inputs_embeds, kwargs, visual, self.processor, config) - inputs_embeds = res['inputs_embeds'] - input_features = kwargs.get('input_features') - feature_attention_mask = kwargs.get('feature_attention_mask') - - if input_features is None: - input_features = input_ids.new_zeros([1, 128, 128], dtype=self.thinker.audio_tower.dtype) - feature_attention_mask = input_ids.new_ones([1, 128], dtype=torch.bool) - audio_embeds = self.thinker.get_audio_features(input_features, feature_attention_mask) - inputs_embeds = inputs_embeds + audio_embeds.mean() * 0. - else: - audio_embeds = self.thinker.get_audio_features(input_features, feature_attention_mask) - audio_mask = (input_ids == config.audio_token_id).unsqueeze(-1).expand_as(inputs_embeds) - audio_embeds = audio_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_embeds) - res['inputs_embeds'] = inputs_embeds - return res - - -class Qwen3VLTransformerBlock(gpt_model.TransformerBlock): - """TransformerBlock with deepstack visual feature injection for Qwen3-VL. - - Reference: swift/swift/megatron/model/mm_gpts/qwen3_vl.py:172-444 - """ - - def _checkpointed_forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - context: torch.Tensor, - context_mask: torch.Tensor, - rotary_pos_emb: torch.Tensor, - attention_bias: torch.Tensor, - packed_seq_params: PackedSeqParams, - use_inner_fp8_context: bool, - # args for deepstack - visual_pos_masks: Optional[torch.Tensor] = None, - deepstack_visual_embeds: Optional[List[torch.Tensor]] = None, - ): - """Forward method with activation checkpointing.""" - - def custom(start: int, end: int): - - def custom_forward(hidden_states, attention_mask, context, context_mask, rotary_pos_emb, visual_pos_masks, - deepstack_visual_embeds): - for index in range(start, end): - layer = self._get_layer(index) - inner_fp8_context = ( - get_fp8_context(self.config, layer.layer_number - - 1) if use_inner_fp8_context else nullcontext()) - with inner_fp8_context: - hidden_states, context = layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - context=context, - context_mask=context_mask, - rotary_pos_emb=rotary_pos_emb, - attention_bias=attention_bias, - inference_context=None, - packed_seq_params=packed_seq_params, - ) - # Add visual features to the hidden states of first several layers - layer_number = layer.layer_number - 1 - if deepstack_visual_embeds is not None and layer_number in range(len(deepstack_visual_embeds)): - hidden_states = self._deepstack_process( - hidden_states, - visual_pos_masks, - deepstack_visual_embeds[layer_number], - ) - return hidden_states, context - - return custom_forward - - def checkpoint_handler(forward_func): - """Determines whether to use te_checkpoint or tensor_parallel.checkpoint.""" - if self.config.fp8: - return te_checkpoint( - forward_func, - self.config.distribute_saved_activations, - tensor_parallel.random.get_cuda_rng_tracker, - parallel_state.get_tensor_model_parallel_group(), - hidden_states, - attention_mask, - context, - context_mask, - rotary_pos_emb, - visual_pos_masks, - deepstack_visual_embeds, - ) - else: - return tensor_parallel.checkpoint( - forward_func, - self.config.distribute_saved_activations, - hidden_states, - attention_mask, - context, - context_mask, - rotary_pos_emb, - visual_pos_masks, - deepstack_visual_embeds, - ) - - if self.config.recompute_method == 'uniform': - layer_idx = 0 - while layer_idx < self.num_layers_per_pipeline_rank: - hidden_states, context = checkpoint_handler( - custom(layer_idx, layer_idx + self.config.recompute_num_layers)) - layer_idx += self.config.recompute_num_layers - - elif self.config.recompute_method == 'block': - recompute_skip_num_layers = 0 - for layer_idx in range(self.num_layers_per_pipeline_rank): - if self.config.fp8 and not hidden_states.requires_grad: - recompute_skip_num_layers += 1 - if (layer_idx >= recompute_skip_num_layers - and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers): - hidden_states, context = checkpoint_handler(custom(layer_idx, layer_idx + 1)) - else: - hidden_states, context = custom(layer_idx, layer_idx + 1)(hidden_states, attention_mask, context, - context_mask, rotary_pos_emb, - visual_pos_masks, deepstack_visual_embeds) - else: - raise ValueError('Invalid activation recompute method.') - - return hidden_states - - def forward( - self, - hidden_states: Union[torch.Tensor, WrappedTensor], - attention_mask: Optional[torch.Tensor], - context: Optional[torch.Tensor] = None, - context_mask: Optional[torch.Tensor] = None, - rotary_pos_emb: Optional[torch.Tensor] = None, - rotary_pos_cos: Optional[torch.Tensor] = None, - rotary_pos_sin: Optional[torch.Tensor] = None, - attention_bias: Optional[torch.Tensor] = None, - inference_context: Optional[BaseInferenceContext] = None, - packed_seq_params: Optional[PackedSeqParams] = None, - sequence_len_offset: Optional[torch.Tensor] = None, - *, - inference_params: Optional[BaseInferenceContext] = None, - # args for deepstack - visual_pos_masks: Optional[torch.Tensor] = None, - deepstack_visual_embeds: Optional[List[torch.Tensor]] = None, - ): - """Forward pass through the transformer block with deepstack support. - - Reference: swift/swift/megatron/model/mm_gpts/qwen3_vl.py:285-434 - """ - if deepstack_visual_embeds is not None: - assert len(deepstack_visual_embeds) <= len( - self.layers), (f'len(deepstack_visual_embeds): {len(deepstack_visual_embeds)}, ' - f'len(self.layers): {len(self.layers)}.') - inference_context = deprecate_inference_params(inference_context, inference_params) - - # Delete the obsolete reference to the initial input tensor if necessary - if isinstance(hidden_states, WrappedTensor): - hidden_states = hidden_states.unwrap() - - if not self.pre_process: - hidden_states = self.input_tensor - - hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) - - if self.config.sequence_parallel: - rng_context = tensor_parallel.get_cuda_rng_tracker().fork() - else: - rng_context = nullcontext() - - use_outer_fp8_context = self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed - use_inner_fp8_context = self.config.fp8 and self.config.fp8_recipe != Fp8Recipe.delayed - outer_fp8_context = get_fp8_context(self.config) if use_outer_fp8_context else nullcontext() - - with rng_context, outer_fp8_context: - if self.config.recompute_granularity == 'full' and self.training: - hidden_states = self._checkpointed_forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - context=context, - context_mask=context_mask, - rotary_pos_emb=rotary_pos_emb, - attention_bias=attention_bias, - packed_seq_params=packed_seq_params, - use_inner_fp8_context=use_inner_fp8_context, - visual_pos_masks=visual_pos_masks, - deepstack_visual_embeds=deepstack_visual_embeds, - ) - else: - for l_no, layer in enumerate(self.layers): - inner_fp8_context = ( - get_fp8_context(self.config, layer.layer_number - - 1) if use_inner_fp8_context else nullcontext()) - with self.offload_context, inner_fp8_context: - hidden_states, context = layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - context=context, - context_mask=context_mask, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - attention_bias=attention_bias, - inference_context=inference_context, - packed_seq_params=packed_seq_params, - sequence_len_offset=sequence_len_offset, - ) - # Add visual features to the hidden states of first several layers - layer_number = layer.layer_number - 1 - if deepstack_visual_embeds is not None and layer_number in range(len(deepstack_visual_embeds)): - hidden_states = self._deepstack_process( - hidden_states, - visual_pos_masks, - deepstack_visual_embeds[layer_number], - ) - - if (torch.is_grad_enabled() and self.config.cpu_offloading - and self.group_prefetch_offload_commit_async is not None): - hidden_states = self.group_prefetch_offload_commit_async(hidden_states) - - # Final layer norm - if self.final_layernorm is not None: - hidden_states = self.final_layernorm(hidden_states) - hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) - - if not self.pre_process and len(self.layers) == 0 and not self.final_layernorm: - hidden_states = hidden_states.clone() - - return hidden_states - - def _deepstack_process(self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, - visual_embeds: torch.Tensor): - """Inject visual features into hidden states at visual token positions. - - Reference: swift/swift/megatron/model/mm_gpts/qwen3_vl.py:436-444 - """ - if visual_pos_masks is None: - return hidden_states + visual_embeds.mean() * 0 - visual_pos_masks = visual_pos_masks.to(hidden_states.device) - visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) - local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds - hidden_states[visual_pos_masks, :] = local_this - return hidden_states - - -class Qwen3VLGPTModel(MultimodalGPTModel): - """Qwen3-VL GPT model with deepstack visual feature injection. - - Reference: swift/swift/megatron/model/mm_gpts/qwen3_vl.py:447-457 - """ - - def _patch_transformer_block(self): - if hasattr(gpt_model, 'OriginTransformerBlock'): - return - gpt_model.OriginTransformerBlock = gpt_model.TransformerBlock - gpt_model.TransformerBlock = Qwen3VLTransformerBlock - - def __init__(self, *args, **kwargs): - self._patch_transformer_block() - super().__init__(*args, **kwargs) - - -class Qwen3OmniBridge(GPTBridge): - # TODO: qwen3-omni support - hf_layers_prefix = 'thinker.model.layers' - hf_embed_key = 'thinker.model.embed_tokens.weight' - hf_final_layernorm_key = 'thinker.model.norm.weight' - hf_lm_head_key = 'thinker.lm_head.weight' - hf_score_key = 'thinker.score.weight' - - -class Qwen3VL_Vit(HuggingFaceModule): - module_mapping = {'model.visual': 'visual'} - _vision_tower = ['visual'] - _aligner = ['visual.merger', 'visual.deepstack_merger_list'] - - def __init__(self, config): - from transformers.models.qwen3_vl import Qwen3VLTextModel - from transformers.models.qwen3_vl_moe import Qwen3VLMoeTextModel - super().__init__(config, [Qwen3VLTextModel, Qwen3VLMoeTextModel]) - - def get_inputs_embeds(self, inputs_embeds, **kwargs): - return Qwen3Omni_Vit._get_inputs_embeds(inputs_embeds, kwargs, self.visual, self.processor, self.model_config) - - -register_megatron_model( - MegatronModelMeta( - MegatronModelType.qwen3_vl, [ - ModelType.qwen3_vl, - ModelType.qwen3_vl_moe, - ], - model_cls=Qwen3VLGPTModel, - bridge_cls=MultimodalGPTBridge, - visual_cls=Qwen3VL_Vit, - auto_model_cls=Qwen3VLForConditionalGeneration)) diff --git a/src/twinkle/model/megatron/model/mm_gpts/utils.py b/src/twinkle/model/megatron/model/mm_gpts/utils.py deleted file mode 100644 index 96f689a4..00000000 --- a/src/twinkle/model/megatron/model/mm_gpts/utils.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -# Reference: swift/swift/megatron/model/mm_gpts/utils.py -import torch -from abc import ABC, abstractmethod -from contextlib import contextmanager -from megatron.core.models.huggingface import HuggingFaceModule as _HuggingFaceModule -from transformers import PreTrainedModel -from transformers.utils import ContextManagers - -from twinkle.model.megatron.args import get_args -from twinkle.utils import deep_getattr - - -@contextmanager -def patch_hf_initialize_weight(): - - _origin_initialize_weight = PreTrainedModel._initialize_weights - - def _initialize_weight(self, *args, **kwargs): - return - - PreTrainedModel._initialize_weights = _initialize_weight - try: - yield - finally: - PreTrainedModel._initialize_weights = _origin_initialize_weight - - -@contextmanager -def patch_device_map_meta(model_cls): - __origin_init__ = model_cls.__init__ - - def __init__(self, *args, **kwargs): - with torch.device('meta'): - __origin_init__(self, *args, **kwargs) - - model_cls.__init__ = __init__ - - try: - yield - finally: - model_cls.__init__ = __origin_init__ - - -class HuggingFaceModule(_HuggingFaceModule, ABC): - module_mapping = {} # hf -> mcore - - def __init__(self, config, ignore_init_model_cls=None): - super().__init__(config) - args = get_args() - attn_impl = getattr(args, 'attn_impl', None) or 'flash_attn' - # Handle both enum and string attention_backend - attn_backend = args.attention_backend - is_flash = (getattr(attn_backend, 'name', attn_backend) == 'flash' if attn_backend else False) - kwargs = {'attn_impl': attn_impl} if is_flash else {} - ignore_init_model_cls = ignore_init_model_cls or [] - if not isinstance(ignore_init_model_cls, list): - ignore_init_model_cls = [ignore_init_model_cls] - context_list = [patch_device_map_meta(model_cls) for model_cls in ignore_init_model_cls] - context_list.append(patch_hf_initialize_weight()) - kwargs['model_type'] = args.hf_model_type - from transformers import AutoModel, AutoProcessor - - from ..register import get_megatron_model_meta - megatron_model_meta = get_megatron_model_meta(args.hf_model_type) - auto_model_cls = megatron_model_meta.auto_model_cls if megatron_model_meta else AutoModel - with ContextManagers(context_list): - model = auto_model_cls.from_pretrained(args.model_dir, torch_dtype=args.torch_dtype, trust_remote_code=True) - self.processor = AutoProcessor.from_pretrained(args.model_dir, trust_remote_code=True) - - self.model_config = model.config - for hf_prefix, mg_prefix in self.module_mapping.items(): - setattr(self, mg_prefix, deep_getattr(model, hf_prefix)) - self._hf_model = [model] - self.prepare_model(model) - self.to('cuda') - - def prepare_model(self, hf_model): - pass - - @abstractmethod - def get_inputs_embeds(self, inputs_embeds, **kwargs): - pass diff --git a/src/twinkle/model/megatron/model/register.py b/src/twinkle/model/megatron/model/register.py deleted file mode 100644 index 07dfd82a..00000000 --- a/src/twinkle/model/megatron/model/register.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -import torch.nn as nn -from dataclasses import dataclass -from typing import List, Optional, Type - -from .constant import MLLMMegatronModelType - -MEGATRON_MODEL_MAPPING = {} - - -@dataclass -class MegatronModelMeta: - megatron_model_type: str - model_types: List[str] - - is_multimodal: bool = False - bridge_cls: Optional[Type] = None - model_cls: Optional[Type[nn.Module]] = None - visual_cls: Optional[Type[nn.Module]] = None - auto_model_cls: Optional[Type] = None - loader: Optional[Type['MegatronModelLoader']] = None - - def __post_init__(self): - if self.megatron_model_type in MLLMMegatronModelType.__dict__: - self.is_multimodal = True - if self.bridge_cls is None: - from .gpt_bridge import GPTBridge, MultimodalGPTBridge - self.bridge_cls = MultimodalGPTBridge if self.is_multimodal else GPTBridge - if self.model_cls is None: - from .gpt_model import GPTModel - from .mm_gpt_model import MultimodalGPTModel - self.model_cls = MultimodalGPTModel if self.is_multimodal else GPTModel - if self.auto_model_cls is None: - from transformers import AutoModel, AutoModelForCausalLM - self.auto_model_cls = AutoModel if self.is_multimodal else AutoModelForCausalLM - if self.loader is None: - self.loader = MegatronModelLoader - - -class MegatronModelLoader: - """Default loader that builds TransformerConfig + layer specs for a model. - - Subclass this to customize layer spec construction (e.g. heterogeneous - attention types, custom layer norms). Register the subclass via - ``MegatronModelMeta(loader=MyLoader)``. - """ - - def get_layer_spec(self, config, args, mg_config_dict): - """Build a transformer layer spec from *config* (``TransformerConfig``). - - The default implementation delegates to Megatron-Core's - ``get_gpt_layer_with_transformer_engine_spec``. - - Returns: - A ``ModuleSpec`` or ``TransformerBlockSubmodules`` instance. - """ - from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec - num_experts = mg_config_dict.get('num_experts') or None - return get_gpt_layer_with_transformer_engine_spec( - num_experts=num_experts, - moe_grouped_gemm=num_experts is not None, - qk_layernorm=mg_config_dict.get('qk_layernorm', False), - ) - - def get_mtp_block_spec(self, config, layer_spec, **kwargs): - """Build MTP block spec. Override for custom layer norms etc.""" - from megatron.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec - return get_gpt_mtp_block_spec(config, layer_spec, use_transformer_engine=True, **kwargs) - - def post_config(self, config, args, mg_config_dict): - """Hook called after TransformerConfig is created but before layer specs. - - Use this to set model-specific config attributes (e.g. ``layer_types``, - ``moe_use_shared_expert_gate``). - """ - pass - - -def register_megatron_model(megatron_model_meta: MegatronModelMeta, *, exist_ok: bool = False): - megatron_model_type = megatron_model_meta.megatron_model_type - if not exist_ok and megatron_model_type in MEGATRON_MODEL_MAPPING: - raise ValueError(f'The `{megatron_model_type}` has already been registered in the MEGATRON_MODEL_MAPPING.') - MEGATRON_MODEL_MAPPING[megatron_model_type] = megatron_model_meta - - -_MODEL_META_MAPPING = None - - -def get_megatron_model_meta(model_type: str) -> Optional[MegatronModelMeta]: - global _MODEL_META_MAPPING - if _MODEL_META_MAPPING is None: - _MODEL_META_MAPPING = {} - for k, megatron_model_meta in MEGATRON_MODEL_MAPPING.items(): - for _model_type in megatron_model_meta.model_types: - _MODEL_META_MAPPING[_model_type] = k - if model_type not in _MODEL_META_MAPPING: - return - return MEGATRON_MODEL_MAPPING[_MODEL_META_MAPPING[model_type]] diff --git a/src/twinkle/model/megatron/model/rope.py b/src/twinkle/model/megatron/model/rope.py deleted file mode 100644 index d23759c9..00000000 --- a/src/twinkle/model/megatron/model/rope.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -import torch -import transformers -from packaging import version -from transformers import PretrainedConfig -from typing import Any, Dict, Optional, Tuple - -from twinkle.model.megatron.args import get_args - - -class DummyConfig: - - def __init__(self, **kwargs): - for k, v in kwargs.items(): - setattr(self, k, v) - - -def _get_dummy_config(args): - dummy_config = DummyConfig( - rope_scaling=args.rope_scaling, - rope_theta=args.rotary_base, - max_position_embeddings=args.max_position_embeddings, - head_dim=args.qk_pos_emb_head_dim if args.multi_latent_attention else args.kv_channels, - hidden_size=args.hidden_size, - num_attention_heads=args.num_attention_heads, - ) - original_max_position_embeddings = args.original_max_position_embeddings or ( - args.rope_scaling or {}).get('original_max_position_embeddings') - if original_max_position_embeddings is not None: - dummy_config.original_max_position_embeddings = original_max_position_embeddings - if args.partial_rotary_factor is not None: - dummy_config.partial_rotary_factor = args.partial_rotary_factor - return dummy_config - - -EXTENDED_ROPE_INIT_FUNCTIONS = {} - - -# copy from transformers # compat transformers==5.0 -def _compute_default_rope_parameters( - config: Optional[PretrainedConfig] = None, - device: Optional['torch.device'] = None, - seq_len: Optional[int] = None, -) -> Tuple['torch.Tensor', float]: - """ - Computes the inverse frequencies according to the original RoPE implementation - Args: - config ([`~transformers.PretrainedConfig`]): - The model configuration. This function assumes that the config will provide at least the following - properties: - - * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived. - * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly. - * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly. - - Additionally, this function will make use of the following properties if they are found in the config: - - * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be - derived as hidden_size // num_attention_heads. - * partial_rotary_factor (`float`, *optional*): If less than 1.0, inverse frequencies will be returned for - the first fraction of the head_dim. Defaults to 1.0. - device (`torch.device`): - The device to use for initialization of the inverse frequencies. - seq_len (`int`, *optional*): - The current sequence length. Unused for this type of RoPE. - - Returns: - Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the - post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). - """ - base = config.rope_theta - partial_rotary_factor = getattr(config, 'partial_rotary_factor', 1.0) - head_dim = getattr(config, 'head_dim', None) or config.hidden_size // config.num_attention_heads - dim = int(head_dim * partial_rotary_factor) - - attention_factor = 1.0 # Unused in this type of RoPE - - # Compute the inverse frequencies - inv_freq = 1.0 / (base**(torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)) - return inv_freq, attention_factor - - -if version.parse(transformers.__version__) >= version.parse('5.0.0.dev'): - EXTENDED_ROPE_INIT_FUNCTIONS['default'] = _compute_default_rope_parameters - - -def _get_rope_type(rope_scaling: Optional[Dict[str, Any]]): - if rope_scaling is None: - return 'default' - rope_type = rope_scaling['rope_type'] - if rope_type == 'dynamic' and rope_scaling.get('alpha') is not None: - rope_type = 'dynamic_alpha' - return rope_type - - -def get_rope_inv_freq(seq_len=None): - from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS - args = get_args() - ROPE_INIT_FUNCTIONS.update(EXTENDED_ROPE_INIT_FUNCTIONS) - dummy_config = _get_dummy_config(args) - rope_init_fn = ROPE_INIT_FUNCTIONS[_get_rope_type(args.rope_scaling)] - inv_freq, attention_scaling = rope_init_fn(dummy_config, 'cpu', seq_len=seq_len) - if attention_scaling is None: - attention_scaling = 1. - return inv_freq, attention_scaling - - -# borrowed from huggingface/transformers -def longrope_frequency_update(args, model, inv_freq, seq_len: int): - if args.original_max_position_embeddings is not None: - original_max_position_embeddings = args.original_max_position_embeddings - else: - original_max_position_embeddings = args.max_position_embeddings - - if not hasattr(model, 'long_inv_freq'): - model.long_inv_freq, _ = get_rope_inv_freq(seq_len=original_max_position_embeddings + 1) - model.original_inv_freq = inv_freq.clone() - - if seq_len > original_max_position_embeddings: - inv_freq.data.copy_(model.long_inv_freq) - else: - inv_freq.data.copy_(model.original_inv_freq) - - -# borrowed from huggingface/transformers -def dynamic_frequency_update(args, model, inv_freq, seq_len: int): - if not hasattr(model, 'max_seq_len_cached'): - model.max_seq_len_cached = args.max_position_embeddings - model.original_max_seq_len = args.max_position_embeddings - model.original_inv_freq = inv_freq.clone() - attention_scaling = None - if seq_len > model.max_seq_len_cached: # growth - new_inv_freq, attention_scaling = get_rope_inv_freq(seq_len=seq_len) - inv_freq.data.copy_(new_inv_freq) - model.max_seq_len_cached = seq_len - - if seq_len < model.original_max_seq_len and model.max_seq_len_cached > model.original_max_seq_len: # reset - inv_freq.data.copy_(model.original_inv_freq) - model.max_seq_len_cached = model.original_max_seq_len - return attention_scaling - - -def dynamic_rope_update(model, inv_freq, seq_len: int): - args = get_args() - rope_type = _get_rope_type(args.rope_scaling) - attention_scaling = None - if rope_type == 'dynamic': - attention_scaling = dynamic_frequency_update(args, model, inv_freq, seq_len) - elif rope_type == 'longrope': - attention_scaling = longrope_frequency_update(args, model, inv_freq, seq_len) - return attention_scaling - - -def _compute_dynamic_alpha_ntk_parameters( - config: Optional[PretrainedConfig] = None, - device: Optional['torch.device'] = None, - seq_len: Optional[int] = None, - **rope_kwargs, -) -> tuple['torch.Tensor', float]: - # Code borrowed from Tencent-Hunyuan/Hunyuan-A13B-Instruct - base = config.rope_theta - partial_rotary_factor = config.partial_rotary_factor if hasattr(config, 'partial_rotary_factor') else 1.0 - head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads) - dim = int(head_dim * partial_rotary_factor) - alpha = config.rope_scaling['alpha'] - - attention_factor = 1.0 # Unused in this type of RoPE - - # Compute the inverse frequencies - base = base * alpha**(dim / (dim - 2)) - inv_freq = 1.0 / (base**(torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)) - return inv_freq, attention_factor - - -EXTENDED_ROPE_INIT_FUNCTIONS['dynamic_alpha'] = _compute_dynamic_alpha_ntk_parameters diff --git a/src/twinkle/model/megatron/strategy/megatron.py b/src/twinkle/model/megatron/strategy/megatron.py index 362e65ae..017d3cf5 100644 --- a/src/twinkle/model/megatron/strategy/megatron.py +++ b/src/twinkle/model/megatron/strategy/megatron.py @@ -1,7 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from typing import List, Literal, Optional, Dict, Any + import torch import torch.nn as nn -from typing import List, Literal, Optional from twinkle import DeviceMesh @@ -10,16 +11,43 @@ class MegatronStrategy: def __init__( self, + model_dir, device_mesh: Optional[DeviceMesh] = None, use_distributed_optimizer: bool = True, mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16', - params_dtype: Optional[str] = None, + seed: int = 42, **kwargs, ): + from megatron.core import mpu self.device_mesh = device_mesh self.use_distributed_optimizer = use_distributed_optimizer self.mixed_precision = mixed_precision + self.model_dir = model_dir + self._seed = seed + # Determine params_dtype and activation checkpointing kwargs + params_dtype = torch.bfloat16 + if self.mixed_precision == 'fp16': + params_dtype = torch.float16 + elif self.mixed_precision == 'no': + params_dtype = torch.float32 self._params_dtype = params_dtype + from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed + model_parallel_cuda_manual_seed(self._seed) + + parallel_kwargs = { + 'tensor_model_parallel_size': self.device_mesh.tp_world_size or 1, + 'pipeline_model_parallel_size': self.device_mesh.pp_world_size or 1, + 'context_parallel_size': self.device_mesh.cp_world_size or 1, + 'expert_model_parallel_size': self.device_mesh.ep_size or 1, + 'expert_tensor_parallel_size': self.device_mesh.etp_world_size or 1, + 'virtual_pipeline_model_parallel_size': self.device_mesh.vpp_size or 1, + } + self._initialized = True + mpu.initialize_model_parallel( + order=self.device_mesh.order, + **parallel_kwargs, + ) + self.config = self.get_model_config(model_dir, parallel_kwargs, **kwargs) @property def sequence_parallel(self) -> bool: @@ -144,33 +172,32 @@ def reduce_loss(self, local_loss, local_count, logits, logps): def get_model_config( self, - hidden_size: int, - num_attention_heads: int, - num_layers: int, - ffn_hidden_size: Optional[int] = None, - num_query_groups: Optional[int] = None, - num_experts: Optional[int] = None, - moe_router_topk: int = 2, + model_dir: str, + parallel_kwargs: Dict[str, Any], **kwargs, ): - from megatron.core.transformer import TransformerConfig - - config = TransformerConfig( - num_layers=num_layers, - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_query_groups=num_query_groups or num_attention_heads, - ffn_hidden_size=ffn_hidden_size or 4 * hidden_size, + from mcore_bridge import ModelConfig, hf_to_mcore_config + from transformers import AutoConfig + hf_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) + config_kwargs = hf_to_mcore_config(hf_config) + config_kwargs.update(kwargs) + config = ModelConfig( use_cpu_initialization=True, params_dtype=self.params_type, - tensor_model_parallel_size=self.device_mesh.tp_world_size or 1, - pipeline_model_parallel_size=self.device_mesh.pp_world_size or 1, - context_parallel_size=self.device_mesh.cp_world_size or 1, - expert_model_parallel_size=self.device_mesh.ep_size or 1, sequence_parallel=self.sequence_parallel, - num_moe_experts=num_experts, - moe_router_topk=moe_router_topk, - **kwargs, + **parallel_kwargs, + **config_kwargs, ) - return config + + def create_megatron_model( + self, + load_weights: bool = True, + ) -> List[nn.Module]: + from mcore_bridge import get_mcore_model + mg_models = get_mcore_model(self.config) + if load_weights: + # Load weights + bridge = self.config.bridge + bridge.load_weights(mg_models, self.model_dir) + return mg_models diff --git a/src/twinkle/model/megatron/tuners/__init__.py b/src/twinkle/model/megatron/tuners/__init__.py deleted file mode 100644 index 2112a613..00000000 --- a/src/twinkle/model/megatron/tuners/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. - -from .lora import LoraParallelLinear, dispatch_megatron - -__all__ = [ - 'LoraParallelLinear', - 'dispatch_megatron', -] diff --git a/src/twinkle/model/megatron/tuners/lora.py b/src/twinkle/model/megatron/tuners/lora.py deleted file mode 100644 index b80c906d..00000000 --- a/src/twinkle/model/megatron/tuners/lora.py +++ /dev/null @@ -1,583 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Megatron-compatible LoRA implementation with Tensor Parallel support.""" -import math -import torch -import torch.nn as nn -import torch.nn.functional as F -import warnings -from contextlib import contextmanager, nullcontext -from peft.tuners.lora import model -from peft.tuners.lora.layer import LoraLayer -from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge -from peft.utils.other import transpose -from transformers.utils import is_torch_npu_available -from typing import Any, List, Optional, Tuple - -from twinkle import Platform, exists, requires - -if exists('megatron_core'): - from megatron.core import parallel_state - from megatron.core.dist_checkpointing.mapping import ShardedStateDict - from megatron.core.extensions.transformer_engine import (TEColumnParallelGroupedLinear, TEColumnParallelLinear, - TEGroupedLinear, TELayerNormColumnParallelLinear, TELinear, - TERowParallelGroupedLinear, TERowParallelLinear) - from megatron.core.parallel_state import get_expert_tensor_parallel_world_size, get_tensor_model_parallel_world_size - from megatron.core.tensor_parallel import gather_from_sequence_parallel_region, scatter_to_sequence_parallel_region - from megatron.core.tensor_parallel.random import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name - from megatron.core.transformer.mlp import apply_swiglu_sharded_factory - from megatron.core.transformer.module import MegatronModule - from megatron.core.transformer.moe.router import TopKRouter -else: - # raise an error - requires('megatron_core') - - -class LoraParallelLinear(MegatronModule, LoraLayer): - """LoRA layer compatible with Megatron Tensor Parallel Linear layers. - - This class wraps Megatron's parallel linear layers (TELinear, TEColumnParallelLinear, - TERowParallelLinear, etc.) and adds LoRA adapters that are correctly sharded - across tensor parallel ranks. - """ - - def __init__( - self, - base_layer, - adapter_name: str, - r: int = 0, - lora_alpha: int = 1, - lora_dropout: float = 0.0, - fan_in_fan_out: bool = False, - init_lora_weights: bool = True, - use_rslora: bool = False, - use_dora: bool = False, - lora_bias: bool = False, - **kwargs, - ): - """Initialize LoraParallelLinear. - - Args: - base_layer: The Megatron parallel linear layer to wrap. - adapter_name: Name of the LoRA adapter. - r: LoRA rank. - lora_alpha: LoRA alpha scaling factor. - lora_dropout: Dropout probability for LoRA layers. - fan_in_fan_out: Whether the layer uses fan-in/fan-out convention. - init_lora_weights: Whether to initialize LoRA weights. - use_rslora: Use rank-stabilized LoRA scaling. - use_dora: Use DoRA (not supported yet). - lora_bias: Whether to add bias to LoRA layers. - """ - config = base_layer.config - super().__init__(config=config) - with warnings.catch_warnings(): - warnings.simplefilter('ignore') - LoraLayer.__init__(self, base_layer=base_layer) - - if use_dora: - raise ValueError(f'{self.__class__.__name__} does not support DoRA yet, please set it to False') - - self.is_parallel_a = isinstance(base_layer, (TERowParallelLinear, TERowParallelGroupedLinear)) - self.is_grouped = isinstance(base_layer, TEGroupedLinear) - self.fan_in_fan_out = fan_in_fan_out - self._active_adapter = adapter_name - self.is_expert = getattr(base_layer, 'is_expert', False) - self.sequence_parallel = getattr(base_layer, 'sequence_parallel', False) - - if self.is_expert: - self.tp_size = get_expert_tensor_parallel_world_size() - if self.tp_size > 1: - raise ValueError('Currently, LoRA does not support ETP.') # TODO: init/all-reduce - else: - self.tp_size = get_tensor_model_parallel_world_size() - - self.update_layer( - adapter_name, - r, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - init_lora_weights=init_lora_weights, - use_rslora=use_rslora, - lora_bias=lora_bias, - ) - - self.is_target_conv_1d_layer = False - - def update_layer(self, adapter_name: str, r: int, *, lora_alpha: int, lora_dropout: float, init_lora_weights: bool, - use_rslora: bool, lora_bias: bool, **kwargs): - """Update LoRA layer with new adapter configuration. - - Args: - adapter_name: Name of the adapter. - r: LoRA rank. - lora_alpha: LoRA alpha scaling factor. - lora_dropout: Dropout probability. - init_lora_weights: Whether to initialize weights. - use_rslora: Use rank-stabilized LoRA. - lora_bias: Whether to add bias. - """ - if r <= 0: - raise ValueError(f'`r` should be a positive integer value but the value passed is {r}') - - self.r[adapter_name] = r - self.lora_alpha[adapter_name] = lora_alpha - - if lora_dropout > 0.0: - lora_dropout_layer = nn.Dropout(p=lora_dropout) - else: - lora_dropout_layer = nn.Identity() - - self.lora_dropout[adapter_name] = lora_dropout_layer - - # Build LoRA A and B matrices with proper parallelism - kwargs = { - 'skip_bias_add': False, - 'init_method': self.config.init_method, - 'config': self.config, - 'is_expert': self.is_expert, - } - - if isinstance(self.base_layer, TopKRouter): - # Router layer - no parallelism needed - router_shape = self.base_layer.weight.shape - lora_a = TELinear( - input_size=router_shape[1], - output_size=r, - bias=lora_bias, - parallel_mode=None, - skip_weight_param_allocation=False, - **kwargs, - ) - lora_b = TELinear( - input_size=r, - output_size=router_shape[0], - bias=lora_bias, - parallel_mode=None, - skip_weight_param_allocation=False, - **kwargs, - ) - elif self.is_parallel_a: - # Row parallel layer - LoRA A is parallel, LoRA B is not - in_features = self.in_features * self.tp_size - if self.is_grouped: - lora_a = TERowParallelGroupedLinear( - num_gemms=self.base_layer.num_gemms, - input_size=in_features, - output_size=r, - bias=False, - **kwargs, - ) - lora_b = TEGroupedLinear( - num_gemms=self.base_layer.num_gemms, - input_size=r, - output_size=self.out_features, - bias=lora_bias, - parallel_mode=None, - **kwargs, - ) - else: - lora_a = TERowParallelLinear( - input_size=in_features, - output_size=r, - bias=False, - input_is_parallel=True, - **kwargs, - ) - lora_b = TELinear( - input_size=r, - output_size=self.out_features, - bias=lora_bias, - parallel_mode=None, - skip_weight_param_allocation=False, - **kwargs, - ) - lora_a.parallel_mode = self.base_layer.parallel_mode - else: - # Column parallel layer - LoRA A is not parallel, LoRA B is parallel - if is_torch_npu_available(): - out_features = self.out_features - else: - out_features = self.out_features * self.tp_size - if self.is_grouped: - lora_a = TEGroupedLinear( - num_gemms=self.base_layer.num_gemms, - input_size=self.in_features, - output_size=r, - bias=lora_bias, - parallel_mode=None, - **kwargs) - lora_b = TEColumnParallelGroupedLinear( - num_gemms=self.base_layer.num_gemms, - input_size=r, - output_size=out_features, - bias=lora_bias, - **kwargs, - ) - else: - if is_torch_npu_available(): - lora_a = nn.Linear( - in_features=self.in_features, - out_features=r, - bias=lora_bias, - ) - else: - lora_a = TELinear( - input_size=self.in_features, - output_size=r, - bias=lora_bias, - parallel_mode=None, - skip_weight_param_allocation=False, - **kwargs) - lora_b = TEColumnParallelLinear( - input_size=r, - output_size=out_features, - bias=lora_bias, - gather_output=False, - **kwargs, - ) - lora_b.parallel_mode = self.base_layer.parallel_mode - for lora in [lora_a, lora_b]: - if getattr(lora, 'parallel_mode', None) is None and hasattr(lora, 'weight'): # TODO: experts - if isinstance(self.base_layer, TopKRouter): - sequence_parallel = self.base_layer.weight.sequence_parallel - else: - sequence_parallel = self.sequence_parallel - lora.weight.sequence_parallel = sequence_parallel - self.lora_A[adapter_name] = lora_a - self.lora_B[adapter_name] = lora_b - - if hasattr(self, 'lora_bias'): - self.lora_bias[adapter_name] = lora_bias - - if use_rslora: - self.scaling[adapter_name] = lora_alpha / (r**0.5) - else: - self.scaling[adapter_name] = lora_alpha / r - - if init_lora_weights: - self.reset_lora_parameters(adapter_name, init_lora_weights) - - self._move_adapter_to_device_of_base_layer(adapter_name) - self.set_adapter(self.active_adapters) - - def _get_rng_context(self, lora): - if self.is_expert: - rng_context = get_cuda_rng_tracker().fork(get_expert_parallel_rng_tracker_name()) - elif getattr(lora, 'parallel_mode', None) is None: - rng_context = nullcontext() - else: - rng_context = get_cuda_rng_tracker().fork() - return rng_context - - def reset_lora_parameters(self, adapter_name: str, init_lora_weights: bool): - """Reset LoRA parameters to initial values. - - Args: - adapter_name: Name of the adapter. - init_lora_weights: Initialization method. - """ - if init_lora_weights is False: - return - - if adapter_name in self.lora_A.keys(): - lora_a = self.lora_A[adapter_name] - lora_b = self.lora_B[adapter_name] - - if isinstance(lora_a, TEGroupedLinear): - weights_a = [getattr(lora_a, f'weight{i}') for i in range(lora_a.num_gemms)] - else: - weights_a = [lora_a.weight] - - if isinstance(lora_b, TEGroupedLinear): - weights_b = [getattr(lora_b, f'weight{i}') for i in range(lora_b.num_gemms)] - else: - weights_b = [lora_b.weight] - - with self._get_rng_context(lora_a): - for weight_a in weights_a: - if init_lora_weights is True: - # initialize A the same way as the default for nn.Linear and B to zero - # https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124 - nn.init.kaiming_uniform_(weight_a, a=math.sqrt(5)) - elif init_lora_weights.lower() == 'gaussian': - nn.init.normal_(weight_a, std=1 / self.r[adapter_name]) - else: - raise ValueError(f'Unknown initialization {init_lora_weights=}') - - for weight_b in weights_b: - nn.init.zeros_(weight_b) - - if adapter_name in self.lora_embedding_A.keys(): - nn.init.zeros_(self.lora_embedding_A[adapter_name]) - nn.init.normal_(self.lora_embedding_B[adapter_name]) - - @contextmanager - def _patch_router_gating(self): - """Context manager to patch router gating with LoRA.""" - origin_gating = self.base_layer.__class__.gating - - def gating(_self, x): - result = origin_gating(_self, x) - for active_adapter in self.active_adapters: - if active_adapter not in self.lora_A.keys(): - continue - lora_A = self.lora_A[active_adapter] - lora_B = self.lora_B[active_adapter] - dropout = self.lora_dropout[active_adapter] - scaling = self.scaling[active_adapter] - x = x.to(result.dtype) - - lora_result = F.linear(dropout(x), lora_A.weight.to(result.dtype)) - if isinstance(lora_result, tuple): - lora_result = lora_result[0] - lora_result = F.linear(lora_result, lora_B.weight.to(result.dtype)) - if isinstance(lora_result, tuple): - lora_result = lora_result[0] - lora_result = lora_result * scaling - - result = result + lora_result - return result - - self.base_layer.__class__.gating = gating - try: - yield - finally: - self.base_layer.__class__.gating = origin_gating - - def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any): - """Forward pass with LoRA adaptation. - - Args: - x: Input tensor. - *args: Additional positional arguments. - **kwargs: Additional keyword arguments. - - Returns: - Tuple of (output tensor, bias). - """ - previous_dtype = x.dtype - if self.disable_adapters and self.merged: - self.unmerge() - - if isinstance(self.base_layer, TELayerNormColumnParallelLinear): - if self.disable_adapters or self.merged: - self.base_layer.return_layernorm_output = False - result, bias = self.base_layer(x, *args, **kwargs) - else: - self.base_layer.return_layernorm_output = True - if is_torch_npu_available(): - result, bias = self.base_layer(x, *args, **kwargs) - else: - (result, x), bias = self.base_layer(x, *args, **kwargs) - elif isinstance(self.base_layer, (TELinear, TEGroupedLinear)): - result, bias = self.base_layer(x, *args, **kwargs) - elif isinstance(self.base_layer, TopKRouter): - with self._patch_router_gating(): - result, bias = self.base_layer(x, *args, **kwargs) - else: - raise ValueError(f'Unsupported base layer type: {type(self.base_layer)}') - - if not isinstance(self.base_layer, TopKRouter) and not self.disable_adapters and not self.merged: - for active_adapter in self.active_adapters: - if active_adapter not in self.lora_A.keys(): - continue - - lora_A = self.lora_A[active_adapter] - lora_B = self.lora_B[active_adapter] - dropout = self.lora_dropout[active_adapter] - scaling = self.scaling[active_adapter] - dtype = lora_A.weight0.dtype if isinstance(lora_A, TEGroupedLinear) else lora_A.weight.dtype - x = x.to(dtype) - - lora_result = lora_A(dropout(x), *args, **kwargs) if isinstance(lora_A, TEGroupedLinear) else lora_A( - dropout(x)) - if isinstance(lora_result, tuple): - lora_result = lora_result[0] - - lora_result = lora_B(lora_result, *args, **kwargs) if isinstance( - lora_B, TEGroupedLinear) else lora_B(lora_result) - if isinstance(lora_result, tuple): - lora_result = lora_result[0] - lora_result = lora_result * scaling - result = result + lora_result - - result = result.to(previous_dtype) - return result, bias - - def sharded_state_dict( - self, - prefix: str = '', - sharded_offsets: Tuple[Tuple[int, int, int]] = (), - metadata: Optional[dict] = None, - ) -> ShardedStateDict: - """Get sharded state dict for distributed checkpointing. - - Args: - prefix: Key prefix. - sharded_offsets: Sharding offsets. - metadata: Additional metadata. - - Returns: - Sharded state dictionary. - """ - - from .multi_lora import tuners_sharded_state_dict - sharded_state_dict = tuners_sharded_state_dict(self, prefix, sharded_offsets, metadata) - - if prefix.endswith('linear_fc1.'): - if isinstance(self.base_layer, TEGroupedLinear) and self.config.gated_linear_unit: - num_global_experts = (parallel_state.get_expert_model_parallel_world_size() * self.base_layer.num_gemms) - local_expert_indices_offset = ( - parallel_state.get_expert_model_parallel_rank() * self.base_layer.num_gemms) - ep_axis = len(sharded_offsets) - for i in range(self.base_layer.num_gemms): - new_sharded_offsets = ( - *sharded_offsets, - (ep_axis, local_expert_indices_offset + i, num_global_experts), - ) - for k in (f'{prefix}base_layer.weight{i}', f'{prefix}base_layer.bias{i}'): - if k in sharded_state_dict: - sharded_state_dict[k] = apply_swiglu_sharded_factory(sharded_state_dict[k], - new_sharded_offsets) - else: - for k, v in sharded_state_dict.items(): - if k in [f'{prefix}base_layer.weight', f'{prefix}base_layer.bias']: - sharded_state_dict[k] = apply_swiglu_sharded_factory(sharded_state_dict[k], sharded_offsets) - return sharded_state_dict - - def get_delta_weights(self, adapter: str) -> List[torch.Tensor]: - """Compute the delta weight for the given adapter. - - Args: - adapter: The name of the adapter. - - Returns: - List of delta weight tensors. - """ - lora_A = self.lora_A[adapter] - lora_B = self.lora_B[adapter] - - if self.is_grouped: - weight_A = [getattr(lora_A, f'weight{i}') for i in range(lora_A.num_gemms)] - weight_B = [getattr(lora_B, f'weight{i}') for i in range(lora_B.num_gemms)] - else: - weight_A = [self.lora_A[adapter].weight] - weight_B = [self.lora_B[adapter].weight] - - output_tensor = [] - assert len(weight_A) == len(weight_B) - - for i in range(len(weight_B)): - output_tensor.append(transpose(weight_B[i] @ weight_A[i], self.fan_in_fan_out) * self.scaling[adapter]) - - return output_tensor - - def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: - """Merge the active adapter weights into the base weights. - - Args: - safe_merge: If True, check for NaNs before merging. - adapter_names: List of adapter names to merge. - """ - adapter_names = check_adapters_to_merge(self, adapter_names) - if not adapter_names: - return - - base_layer = self.get_base_layer() - origin_device = base_layer.weight0.device if self.is_grouped else base_layer.weight.device - - if origin_device.type == 'cpu': - self.to(device=Platform.get_local_device()) - - for active_adapter in adapter_names: - if active_adapter in self.lora_A.keys(): - if self.is_grouped: - orig_weights = [getattr(base_layer, f'weight{i}') for i in range(base_layer.num_gemms)] - else: - orig_weights = [base_layer.weight] - - if safe_merge: - orig_weights = [weight.data.clone() for weight in orig_weights] - delta_weights = self.get_delta_weights(active_adapter) - for orig_weight, delta_weight in zip(orig_weights, delta_weights): - orig_weight += delta_weight - if not all(torch.isfinite(orig_weights[i]).all() for i in range(len(orig_weights))): - raise ValueError( - f'NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken') - if self.is_grouped: - for i in range(base_layer.num_gemms): - weight = getattr(base_layer, f'weight{i}') - weight.data = orig_weights[i] - else: - base_layer.weight.data = orig_weights[0] - else: - delta_weights = self.get_delta_weights(active_adapter) - for orig_weight, delta_weight in zip(orig_weights, delta_weights): - orig_weight.data += delta_weight - - self.merged_adapters.append(active_adapter) - - if origin_device.type == 'cpu': - self.to(device=origin_device) - - def unmerge(self) -> None: - """Unmerge all merged adapter weights from the base weights.""" - if not self.merged: - return - - base_layer = self.get_base_layer() - origin_device = base_layer.weight0.device if self.is_grouped else base_layer.weight.device - - if origin_device.type == 'cpu': - self.to(device=Platform.get_local_device()) - - for active_adapter in self.merged_adapters: - if active_adapter in self.lora_A.keys(): - if self.is_grouped: - orig_weights = [getattr(base_layer, f'weight{i}') for i in range(base_layer.num_gemms)] - else: - orig_weights = [base_layer.weight] - - delta_weights = self.get_delta_weights(active_adapter) - for orig_weight, delta_weight in zip(orig_weights, delta_weights): - orig_weight.data -= delta_weight - - self.merged_adapters = [] - - if origin_device.type == 'cpu': - self.to(device=origin_device) - - -def dispatch_megatron( - target: torch.nn.Module, - adapter_name: str, - lora_config, - **kwargs: Any, -) -> Optional[torch.nn.Module]: - """Dispatch function to replace Megatron linear layers with LoRA layers. - - Args: - target: The target module to potentially replace. - adapter_name: Name of the LoRA adapter. - lora_config: LoRA configuration. - **kwargs: Additional arguments for LoraParallelLinear. - - Returns: - LoraParallelLinear if target is a compatible layer, None otherwise. - """ - new_module = None - - if isinstance(target, BaseTunerLayer): - target_base_layer = target.get_base_layer() - else: - target_base_layer = target - - linear_cls = (TELayerNormColumnParallelLinear, TELinear, TEGroupedLinear, TopKRouter) - if isinstance(target_base_layer, linear_cls): - new_module = LoraParallelLinear(base_layer=target, adapter_name=adapter_name, **kwargs) - - return new_module - - -# Register dispatch function with PEFT -model.dispatch_megatron = dispatch_megatron diff --git a/src/twinkle/model/megatron/tuners/utils.py b/src/twinkle/model/megatron/tuners/utils.py deleted file mode 100644 index e97ab462..00000000 --- a/src/twinkle/model/megatron/tuners/utils.py +++ /dev/null @@ -1,206 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Utility functions for Megatron-Core integration.""" -import torch.nn as nn -from contextlib import contextmanager -from typing import Any, Dict, List, Optional, Tuple - - -def find_layers(model: nn.Module, cond_fn) -> List[str]: - """Find all layers in model matching condition function. - - - - Args: - model: The model to search. - cond_fn: Callable(name, module) -> bool. - - Returns: - List of matching layer names. - """ - result = [] - for name, module in model.named_modules(): - if cond_fn(name, module): - result.append(name) - return result - - -def find_all_linears(model: nn.Module) -> List[str]: - """Find all linear layers suitable for LoRA in a Megatron model. - - - - Args: - model: The Megatron model. - - Returns: - List of layer names suitable for LoRA. - """ - from megatron.core.extensions.transformer_engine import TEGroupedLinear, TELayerNormColumnParallelLinear, TELinear - - def _cond(name: str, module: nn.Module) -> bool: - if name == 'output_layer' or 'lora' in name: - return False - if isinstance(module, (TELinear, TELayerNormColumnParallelLinear, TEGroupedLinear, nn.Linear)): - return True - return False - - return find_layers(model, _cond) - - -def find_router(model: nn.Module) -> List[str]: - """Find all MoE router layers in a Megatron model. - - - - Args: - model: The Megatron model. - - Returns: - List of router layer names. - """ - from megatron.core.transformer.moe.router import TopKRouter - return find_layers(model, lambda name, module: isinstance(module, TopKRouter) and 'lora' not in name) - - -def find_embedding(model: nn.Module) -> List[str]: - """Find all embedding layers in a Megatron model. - - - - Args: - model: The Megatron model. - - Returns: - List of embedding layer names. - """ - from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding - return find_layers(model, lambda name, module: isinstance(module, LanguageModelEmbedding) and 'lora' not in name) - - -def get_target_modules(model: nn.Module, target_modules: List[str]) -> List[str]: - """Expand target module specifications to actual module names. - - - - Args: - model: The Megatron model. - target_modules: List of target module specs, may include 'all-linear', etc. - - Returns: - Expanded list of target module names. - """ - result = target_modules.copy() - if 'all-linear' in result: - result.remove('all-linear') - result += find_all_linears(model) - if 'all-embedding' in result: - result.remove('all-embedding') - result += find_embedding(model) - if 'all-router' in result: - result.remove('all-router') - result += find_router(model) - return list(set(result)) - - -def set_linear_is_expert(model: nn.Module): - """Mark expert linear layers in MoE models. - - Args: - model: The Megatron model. - """ - from megatron.core.extensions.transformer_engine import TEGroupedLinear, TELayerNormColumnParallelLinear, TELinear - for name, module in model.named_modules(): - if '.local_experts.' in name and isinstance(module, (TELinear, TELayerNormColumnParallelLinear)): - module.is_expert = True - elif isinstance(module, TEGroupedLinear): - module.is_expert = True - - -@contextmanager -def patch_deepcopy(): - """Context manager to handle tp_group in deepcopy operations. - - - - WHY THIS IS NECESSARY: - ---------------------- - Megatron-Core's TransformerEngine linear layers (TELinear, TEColumnParallelLinear, etc.) - store a reference to their tensor parallel process group in the `tp_group` attribute. - - When PEFT's get_peft_model() is called, it internally uses copy.deepcopy() to create - copies of certain modules. However, torch.distributed.ProcessGroup objects cannot be - pickled or deepcopied because: - - 1. ProcessGroup objects contain native CUDA/NCCL handles that are process-specific - 2. These handles cannot be serialized and recreated in a different memory context - 3. Attempting to deepcopy them raises: "RuntimeError: Cannot pickle ProcessGroup" - - This patch temporarily sets tp_group to None during deepcopy, then restores it - after the copy is complete. This allows PEFT to work with Megatron modules while - preserving the correct process group references. - - USAGE: - ------ - ```python - with patch_deepcopy(): - model = get_peft_model(megatron_model, lora_config) - ``` - - Without this patch, the above code would fail with a pickling error. - """ - import copy - _origin_deepcopy = copy.deepcopy - - def new_deepcopy(x, *args, **kwargs): - if getattr(x, 'tp_group', None) is not None: - origin_tp_group = x.tp_group - x.tp_group = None - res = _origin_deepcopy(x, *args, **kwargs) - x.tp_group = origin_tp_group - res.tp_group = origin_tp_group - return res - else: - return _origin_deepcopy(x, *args, **kwargs) - - copy.deepcopy = new_deepcopy - try: - yield - finally: - copy.deepcopy = _origin_deepcopy - - -def tuners_sharded_state_dict( - module: nn.Module, - prefix: str = '', - sharded_offsets: Tuple[Tuple[int, int, int]] = (), - metadata: Optional[dict] = None, -) -> Dict[str, Any]: - """Generate sharded state dict for PEFT tuners. - - - - Args: - module: The module to generate state dict for. - prefix: Key prefix. - sharded_offsets: Sharding offsets for distributed checkpointing. - metadata: Additional metadata. - - Returns: - Sharded state dictionary. - """ - from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint, sharded_state_dict_default - sharded_state_dict = {} - # Save parameters - module._save_to_state_dict(sharded_state_dict, '', keep_vars=True) - sharded_state_dict = make_sharded_tensors_for_checkpoint( - sharded_state_dict, prefix, sharded_offsets=sharded_offsets) - # Recurse into submodules - for name, child in module.named_children(): - if 'Dict' in child.__class__.__name__: - modules = child.named_children() - else: - modules = [(None, child)] - for n, m in modules: - _prefix = f'{prefix}{name}.' if n is None else f'{prefix}{name}.{n}.' - sharded_state_dict.update(sharded_state_dict_default(m, _prefix, sharded_offsets, metadata)) - return sharded_state_dict diff --git a/src/twinkle/model/megatron/utils/__init__.py b/src/twinkle/model/megatron/utils/__init__.py deleted file mode 100644 index a81db2cc..00000000 --- a/src/twinkle/model/megatron/utils/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .config import convert_hf_config -from .utils import split_cp_inputs diff --git a/src/twinkle/model/megatron/utils/config.py b/src/twinkle/model/megatron/utils/config.py deleted file mode 100644 index eca0edbb..00000000 --- a/src/twinkle/model/megatron/utils/config.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -from typing import Any, Dict - -config_mapping = { - 'num_layers': ['num_hidden_layers'], - 'hidden_size': ['hidden_size'], - 'mlp_ffn_hidden_size': ['intermediate_size_mlp'], - 'ffn_hidden_size': ['intermediate_size'], - 'num_attention_heads': ['num_attention_heads'], - 'num_query_groups': ['num_key_value_heads'], - 'max_position_embeddings': ['max_position_embeddings'], - 'norm_epsilon': ['rms_norm_eps'], - 'rotary_base': ['rope_theta'], - 'padded_vocab_size': ['vocab_size'], - 'attention_dropout': ['attention_dropout'], - 'untie_embeddings_and_output_weights': ['tie_word_embeddings'], - 'swiglu': ['hidden_act'], - 'add_qkv_bias': ['attention_bias', 'qkv_bias', 'use_bias'], - 'disable_bias_linear': ['mlp_bias'], - 'kv_channels': ['head_dim', 'v_head_dim'], - 'architectures': ['architectures'], - 'hf_model_type': ['model_type'], # TODO: check - # moe - 'moe_ffn_hidden_size': ['moe_intermediate_size'], - 'moe_shared_expert_intermediate_size': ['shared_expert_intermediate_size'], - 'moe_router_topk': ['num_experts_per_tok', 'moe_topk', 'moe_k'], - 'moe_router_num_groups': ['n_group'], - 'moe_router_group_topk': ['topk_group'], - 'num_experts': ['num_experts', 'n_routed_experts', 'moe_num_experts', 'num_local_experts'], - 'moe_router_pre_softmax': ['norm_topk_prob'], - # deepseek - 'q_lora_rank': ['q_lora_rank'], - 'kv_lora_rank': ['kv_lora_rank'], - 'moe_router_score_function': ['scoring_func'], - 'moe_router_bias_update_rate': ['aux_loss_alpha'], - 'qk_head_dim': ['qk_nope_head_dim'], - 'qk_pos_emb_head_dim': ['qk_rope_head_dim'], - 'moe_router_topk_scaling_factor': ['routed_scaling_factor'], - 'qk_layernorm': ['use_qk_norm'], - # qwen3_next / qwen3_5 - 'linear_num_value_heads': ['linear_num_value_heads'], - 'linear_num_key_heads': ['linear_num_key_heads'], - 'linear_key_head_dim': ['linear_key_head_dim'], - 'linear_value_head_dim': ['linear_value_head_dim'], - 'linear_conv_kernel_dim': ['linear_conv_kernel_dim'], - 'full_attention_interval': ['full_attention_interval'], - # other - 'original_max_position_embeddings': ['original_max_position_embeddings'], - 'partial_rotary_factor': ['partial_rotary_factor'], - 'first_k_dense_replace': ['first_k_dense_replace', 'moe_layer_start_index'], - 'n_shared_experts': ['n_shared_experts', 'num_shared_expert', 'moe_num_shared_experts'], - 'window_size': ['sliding_window'], - 'layer_types': ['layer_types'], -} - - -def _convert_config(config, _internal_call=False) -> Dict[str, Any]: - megatron_config = {} - for k, hf_keys in config_mapping.items(): - for hf_k in hf_keys: - if hasattr(config, hf_k): - hf_v = getattr(config, hf_k) - if hf_v is None: - continue - if k == 'rotary_base': - megatron_config[k] = int(hf_v) - elif k in {'untie_embeddings_and_output_weights', 'disable_bias_linear', 'moe_router_pre_softmax'}: - megatron_config[k] = not hf_v - elif k == 'swiglu': - if hf_v == 'silu': - megatron_config[k] = True - else: - if k == 'kv_lora_rank': - megatron_config['multi_latent_attention'] = True - elif k == 'hf_model_type': - if _internal_call: - k = 'llm_model_type' - megatron_config[k] = hf_v - break - for key in ['text_config', 'llm_config', 'thinker_config']: - if hasattr(config, key): - megatron_config.update(_convert_config(getattr(config, key), _internal_call=True)) - # compat llama3 - if getattr(config, 'rope_scaling', None) is not None: - if isinstance(config.rope_scaling, int): - megatron_config['rope_scaling'] = {'factor': config.rope_scaling, 'type': 'linear'}, - elif isinstance(config.rope_scaling, dict): - megatron_config['rope_scaling'] = config.rope_scaling - return megatron_config - - -def convert_hf_config(config) -> Dict[str, Any]: - res = _convert_config(config) - hf_model_type = res.get('hf_model_type') - llm_model_type = res.get('llm_model_type') or hf_model_type - res['llm_model_type'] = llm_model_type - - first_k_dense_replace = res.pop('first_k_dense_replace', None) - n_shared_experts = res.pop('n_shared_experts', None) - layer_types = res.pop('layer_types', None) - mlp_ffn_hidden_size = res.pop('mlp_ffn_hidden_size', None) - interleave_moe_layer_step = res.pop('interleave_moe_layer_step', None) - window_size = res.pop('window_size', None) - rope_scaling = res.get('rope_scaling') or {} - if llm_model_type in {'qwen3', 'qwen3_moe', 'qwen3_next'} or hf_model_type in { - 'qwen3_omni_moe', 'qwen3_omni', 'qwen3_vl', 'qwen3_vl_moe', 'qwen3_5', 'qwen3_5_moe' - }: - res['qk_layernorm'] = True - if llm_model_type in {'qwen2_moe', 'qwen3_moe', 'qwen3_next' - } or hf_model_type in {'qwen3_omni_moe', 'qwen3_vl_moe', 'qwen3_5_moe'}: - res.pop('ffn_hidden_size', None) - if llm_model_type in {'qwen2_moe', 'qwen3_next'} or hf_model_type == 'qwen3_5_moe': - res['use_shared_expert_gate'] = True - if llm_model_type in { - 'deepseek', - 'deepseek_v2', - 'deepseek_v3', - 'dots1', - } or hf_model_type == 'kimi_vl': - if llm_model_type != 'deepseek': - res['qk_layernorm'] = True - res['moe_router_load_balancing_type'] = 'seq_aux_loss' - res.pop('num_query_groups', None) # https://github.com/NVIDIA/Megatron-LM/issues/1475 - if llm_model_type == 'dots1': - res['moe_router_score_function'] = 'sigmoid' - elif llm_model_type == 'hunyuan': - # Since HunYuan’s attention applies RoPE before using q/k_layernorm, - # which is incompatible with megatron-core, support is not provided here. - res['n_shared_experts'] = n_shared_experts - for key in ['moe_ffn_hidden_size', 'n_shared_experts', 'moe_router_topk']: - val = res.get(key) - if isinstance(val, list) and val and min(val) == max(val): - res[key] = val[0] - n_shared_experts = res.pop('n_shared_experts') - elif llm_model_type in {'ernie4_5', 'ernie4_5_moe', 'glm4'}: - res['rotary_interleaved'] = True - elif llm_model_type == 'gpt_oss': - res['disable_bias_linear'] = False - res['no_bias_dropout_fusion'] = True - res['softmax_type'] = 'learnable' - res['swiglu'] = False - res['quick_geglu'] = True - res['activation_func_clamp_value'] = 7 - res['glu_linear_offset'] = 1 - res['window_size'] = f'{window_size},0' - if layer_types is None: - res['window_attn_skip_freq'] = '2' - else: - window_attn_skip_freq = ','.join(['1' if lt == 'sliding_attention' else '0' for lt in layer_types]) - res['window_attn_skip_freq'] = f'[{window_attn_skip_freq}]' - elif llm_model_type in {'glm4_moe', 'glm4_moe_lite'} or hf_model_type == 'glm4v_moe': - res['moe_router_score_function'] = 'sigmoid' - if llm_model_type == 'glm4_moe_lite': - res['qk_layernorm'] = True - res.pop('num_query_groups', None) - elif llm_model_type == 'qwen3_next' or hf_model_type in {'qwen3_5', 'qwen3_5_moe'}: - full_attention_interval = res.pop('full_attention_interval', 4) - num_layers = res['num_layers'] - res['layer_types'] = [ - 'full_attention' if (i + 1) % full_attention_interval == 0 else 'linear_attention' - for i in range(num_layers) - ] - elif llm_model_type == 'minimax_m2': - res['add_qkv_bias'] = False - elif llm_model_type == 'llama4': - qk_layernorm = res.pop('qk_layernorm', False) - if qk_layernorm: - res['qk_l2_norm'] = True - res['no_rope_freq'] = 4 - res['moe_apply_probs_on_input'] = True - res['rotary_interleaved'] = True - res['moe_router_score_function'] = 'sigmoid' - res['moe_ffn_hidden_size'] = res['ffn_hidden_size'] - res['ffn_hidden_size'] = mlp_ffn_hidden_size - res['moe_router_enable_expert_bias'] = False - res['moe_shared_expert_intermediate_size'] = res['moe_ffn_hidden_size'] - if interleave_moe_layer_step > 1: - moe_layer_freq = [ - '1' if i % interleave_moe_layer_step == (interleave_moe_layer_step - 1) else '0' - for i in range(res['num_layers']) - ] - res['moe_layer_freq'] = f"[{','.join(moe_layer_freq)}]" - elif hf_model_type == 'glm4v': - res['rotary_interleaved'] = True - if 'partial_rotary_factor' not in res and 'partial_rotary_factor' in rope_scaling: - res['partial_rotary_factor'] = rope_scaling['partial_rotary_factor'] - if 'rotary_base' not in res and 'rope_theta' in rope_scaling: - res['rotary_base'] = rope_scaling['rope_theta'] - if rope_scaling.get('mrope_section') is not None: - res['position_embedding_type'] = 'mrope' - res['mrope_section'] = rope_scaling['mrope_section'] - mrope_interleaved = rope_scaling.get('mrope_interleaved', False) or rope_scaling.get('interleaved', False) - res['mrope_interleaved'] = mrope_interleaved - - if first_k_dense_replace is not None: - res['moe_layer_freq'] = f'[0]*{first_k_dense_replace}+[1]*{res["num_layers"] - first_k_dense_replace}' - if res.get('moe_router_score_function', 'softmax') == 'sigmoid' and 'moe_router_enable_expert_bias' not in res: - res['moe_router_enable_expert_bias'] = True - if n_shared_experts is not None and 'moe_shared_expert_intermediate_size' not in res: - res['moe_shared_expert_intermediate_size'] = n_shared_experts * res['moe_ffn_hidden_size'] - return res diff --git a/src/twinkle/model/megatron/utils/utils.py b/src/twinkle/model/megatron/utils/utils.py deleted file mode 100644 index 3d2b9b31..00000000 --- a/src/twinkle/model/megatron/utils/utils.py +++ /dev/null @@ -1,32 +0,0 @@ -""" -Reference: swift/swift/megatron/trainers/utils.py -""" -import torch -from typing import Optional - -from twinkle import requires - - -def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: Optional[torch.Tensor], dim: int): - requires('megatron_core') - from megatron.core import mpu - if dim < 0: - dim = (dim + inputs.ndim) % inputs.ndim - new_inputs = [] - cp_size = mpu.get_context_parallel_world_size() - cp_rank = mpu.get_context_parallel_rank() - for i in range(1 if cu_seqlens is None else (cu_seqlens.shape[0] - 1)): - if cu_seqlens is None: - val = inputs - else: - slices = [slice(None)] * inputs.ndim - slices[dim] = slice(cu_seqlens[i], cu_seqlens[i + 1]) - val = inputs[tuple(slices)] - view_shape = (*inputs.shape[:dim], 2 * cp_size, val.shape[dim] // (2 * cp_size), *inputs.shape[dim + 1:]) - val = val.view(view_shape) - index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device='cpu', - pin_memory=True).cuda(non_blocking=True) - val = val.index_select(dim, index) - view_shape = (*inputs.shape[:dim], -1, *inputs.shape[dim + 1:]) - new_inputs.append(val.view(view_shape)) - return torch.cat(new_inputs, dim=dim) diff --git a/src/twinkle/utils/__init__.py b/src/twinkle/utils/__init__.py index ce91cbf8..982707d9 100644 --- a/src/twinkle/utils/__init__.py +++ b/src/twinkle/utils/__init__.py @@ -11,7 +11,7 @@ from .platforms import GPU, NPU, Platform, ensure_hccl_socket_env, ensure_npu_backend from .safetensors import LazyTensor, SafetensorLazyLoader, StreamingSafetensorSaver from .torch_utils import (pad_and_stack_tensors, pad_sequence_to_length, selective_log_softmax, - stateless_init_process_group, to_device) + stateless_init_process_group, to_device, split_cp_inputs) from .transformers_utils import find_all_linears, find_layers, get_modules_to_not_convert from .unsafe import check_unsafe, trust_remote_code from .utils import copy_files_by_pattern, deep_getattr diff --git a/src/twinkle/utils/torch_utils.py b/src/twinkle/utils/torch_utils.py index 34506d6f..5d00cc7d 100644 --- a/src/twinkle/utils/torch_utils.py +++ b/src/twinkle/utils/torch_utils.py @@ -1,7 +1,11 @@ import socket from datetime import timedelta -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Union +from typing import Optional +from typing import TYPE_CHECKING, Any, List, Mapping, Union +import torch + +from twinkle import requires from .network import is_valid_ipv6_address if TYPE_CHECKING: @@ -223,3 +227,28 @@ def pad_and_stack_tensors(tensors: List['torch.Tensor'], pad_value: float = -200 padded_tensors.append(padded) return torch.cat(padded_tensors, dim=0) + + +def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: Optional[torch.Tensor], dim: int): + requires('megatron_core') + from megatron.core import mpu + if dim < 0: + dim = (dim + inputs.ndim) % inputs.ndim + new_inputs = [] + cp_size = mpu.get_context_parallel_world_size() + cp_rank = mpu.get_context_parallel_rank() + for i in range(1 if cu_seqlens is None else (cu_seqlens.shape[0] - 1)): + if cu_seqlens is None: + val = inputs + else: + slices = [slice(None)] * inputs.ndim + slices[dim] = slice(cu_seqlens[i], cu_seqlens[i + 1]) + val = inputs[tuple(slices)] + view_shape = (*inputs.shape[:dim], 2 * cp_size, val.shape[dim] // (2 * cp_size), *inputs.shape[dim + 1:]) + val = val.view(view_shape) + index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device='cpu', + pin_memory=True).cuda(non_blocking=True) + val = val.index_select(dim, index) + view_shape = (*inputs.shape[:dim], -1, *inputs.shape[dim + 1:]) + new_inputs.append(val.view(view_shape)) + return torch.cat(new_inputs, dim=dim) From 35efd818245bcd8a02521b27289af6b787e2f6cd Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Mon, 30 Mar 2026 17:15:48 +0800 Subject: [PATCH 02/18] wip --- src/twinkle/checkpoint_engine/base.py | 4 +- src/twinkle/model/megatron/megatron.py | 131 ++++++------------ .../model/megatron/strategy/megatron.py | 37 ++--- 3 files changed, 62 insertions(+), 110 deletions(-) diff --git a/src/twinkle/checkpoint_engine/base.py b/src/twinkle/checkpoint_engine/base.py index 9cb95e08..0760f2ae 100644 --- a/src/twinkle/checkpoint_engine/base.py +++ b/src/twinkle/checkpoint_engine/base.py @@ -1,7 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. # Adapted from https://github.com/volcengine/verl/blob/main/verl/checkpoint_engine/base.py from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, AsyncGenerator, Generator, TypedDict +from typing import TYPE_CHECKING, Any, AsyncGenerator, Generator, TypedDict, Optional if TYPE_CHECKING: import torch @@ -38,6 +38,8 @@ class CheckpointEngine(ABC): >>> engine.finalize() """ + rank: Optional[int] = None + @abstractmethod def prepare(self) -> dict[str, Any]: """Prepare the checkpoint engine before weight synchronization. diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index cdc45246..c8d04a60 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -13,6 +13,7 @@ import torch import torch.distributed as dist import torch.nn as nn +from Cython.Compiler.Code import contextmanager from peft import LoraConfig, PeftConfig, PeftModel, get_peft_model from peft.tuners.lora import Linear as LoraLinear from torch.optim import Optimizer @@ -75,32 +76,6 @@ def _get_lr(self): _default_adapter_name = '' -def _add_base_layer_suffix(params): - """Insert ``.base_layer.`` before the final attribute for LoRA-target modules. - - Converts plain HF names exported by the Megatron bridge into the format - expected by vLLM when ``enable_lora=True``:: - - model.layers.0.self_attn.q_proj.weight - -> model.layers.0.self_attn.q_proj.base_layer.weight - - Non-matching names are yielded unchanged. - - Args: - params: Iterable of ``(name, tensor)`` pairs. - - Yields: - ``(name, tensor)`` with ``.base_layer.`` inserted where needed. - """ - for name, param in params: - for suffix in _BASE_LAYER_SUFFIXES: - if name.endswith(suffix): - attr = suffix.rsplit('.', 1)[-1] # 'weight' or 'bias' - name = f'{name[:-len(attr)]}base_layer.{attr}' - break - yield name, param - - @remote_class(execute='all') class MegatronModel(TwinkleModel, nn.Module, CheckpointEngineMixin): @@ -1299,7 +1274,6 @@ def get_hf_state_dict(self, adapter_name: str = '') -> Generator[Tuple[str, torc ) def _patch_adapter(self, adapter_name: str, config_or_dir: Union[PeftConfig, str, Dict[str, Any]], **kwargs): - from .tuners.utils import get_target_modules, patch_deepcopy, set_linear_is_expert assert adapter_name, 'Use a non-empty adapter_name' model = self.strategy.unwrap_model(self.model) if isinstance(config_or_dir, str): @@ -1328,8 +1302,7 @@ def _patch_adapter(self, adapter_name: str, config_or_dir: Union[PeftConfig, str expanded_modules = get_target_modules(_model, target_modules) config.target_modules = expanded_modules - with patch_deepcopy(): - _model = get_peft_model(_model, config, adapter_name=adapter_name) + _model = get_peft_model(_model, config, adapter_name=adapter_name) # setting average_gradients_across_tp_domain for m in _model.modules(): if isinstance(m, LoraLinear): @@ -1437,10 +1410,6 @@ def get_train_configs(self, **kwargs): return expr - @property - def _bridge(self) -> 'GPTBridge': - return self.strategy.config.bridge - # ── Checkpoint Engine (from CheckpointEngineMixin) ────────────────── # prepare_checkpoint_engine, init_checkpoint_process_group, and # finalize_checkpoint_engine are inherited from CheckpointEngineMixin. @@ -1453,54 +1422,52 @@ def _bridge(self) -> 'GPTBridge': # via NCCL; others consume the generator silently (rank=-1). @remote_function(dispatch='all', lazy_collect=True) - def send_weights( + async def send_weights( self, adapter_name: str = None, base_sync_done: bool = False, merge_and_sync: bool = False, + add_base_layer_path: bool = True, ): if adapter_name is None: adapter_name = self._get_default_group() engine = self._get_or_create_checkpoint_engine() - is_peft_format = (adapter_name != _default_adapter_name) - - # Megatron uses padded_vocab_size for TP alignment (rounded up to - # TP * 128). vLLM creates its embedding / lm_head from the original - # HF vocab_size, so weight_loader asserts shape[0] == org_vocab_size. - # Trim any tensor whose dim-0 equals padded_vocab_size back to - # org_vocab_size — this is shape-based, not name-based, so it works - # regardless of the model architecture's naming convention. - org_vocab_size = getattr(self.hf_config, 'vocab_size', self.strategy.config.padded_vocab_size) - _padded_vocab_size = args.padded_vocab_size - - def _trim_vocab(name, tensor): - if _padded_vocab_size != org_vocab_size and tensor.shape[0] == _padded_vocab_size: - tensor = tensor[:org_vocab_size] - return name, tensor + @contextmanager + def merge_lora(): + for _model in self.strategy.unwrap_model(self.model): + if isinstance(_model, PeftModel): + _model.merge_adapter() + yield + for _model in self.strategy.unwrap_model(self.model): + if isinstance(_model, PeftModel): + _model.unmerge_adapter() + + def _add_base_layer_suffix(params): + _BASE_LAYER_SUFFIXES = ['weight', 'bias'] + for name, param in params: + for suffix in _BASE_LAYER_SUFFIXES: + if name.endswith(suffix): + attr = suffix.rsplit('.', 1)[-1] # 'weight' or 'bias' + name = f'{name[:-len(attr)]}base_layer.{attr}' + break + yield name, param + is_peft_format = (adapter_name != _default_adapter_name) if base_sync_done and adapter_name: + # The first base model synchronization finished, and is lora training if merge_and_sync: - # LoRA Training and sync full model(merge_adapter) def weight_generator(): - for _model in self.strategy.unwrap_model(self.model): - if isinstance(_model, PeftModel): - _model.merge_adapter() - for name, tensor in self.get_hf_state_dict(adapter_name=''): - if name is None or tensor is None: - continue - # Skip LoRA-specific weights for base model sync - if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name: - continue - yield _trim_vocab(name, tensor) - for _model in self.strategy.unwrap_model(self.model): - if isinstance(_model, PeftModel): - _model.unmerge_adapter() + with merge_lora(): + for name, tensor in self.get_hf_state_dict(adapter_name=''): + if name is None or tensor is None: + continue + # Skip LoRA-specific weights for base model sync + if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name: + continue + yield name, tensor + else: - # ── LoRA-only mode ──────────────────────────────────────────── - # Export only LoRA adapter weights via the bridge. - # The bridge may also yield non-LoRA weights (e.g. embed_tokens - # for modules_to_save), filter to only lora_A/lora_B tensors. def weight_generator(): for name, tensor in self.get_hf_state_dict(adapter_name=adapter_name): if name is None or tensor is None: @@ -1508,8 +1475,8 @@ def weight_generator(): if 'lora' not in name: continue yield name, tensor - else: + # Need to synchronize the base model # First full base-model sync. def _raw_weights(): for name, tensor in self.get_hf_state_dict(adapter_name=''): @@ -1518,10 +1485,10 @@ def _raw_weights(): # Skip LoRA-specific weights for base model sync if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name: continue - yield _trim_vocab(name, tensor) + yield name, tensor def weight_generator(): - if is_peft_format and not merge_and_sync: + if is_peft_format and (not merge_and_sync) and add_base_layer_path: yield from _add_base_layer_suffix(_raw_weights()) else: yield from _raw_weights() @@ -1533,30 +1500,10 @@ def weight_generator(): pass return - async def _send(): - await engine.send_weights(weight_generator()) - - result_container = {'error': None} - - def _run(): - try: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(_send()) - finally: - loop.close() - except Exception as e: - result_container['error'] = e - - thread = threading.Thread(target=_run) - thread.start() - thread.join() - if result_container['error'] is not None: - raise result_container['error'] + await engine.send_weights(weight_generator()) @remote_function(collect='first') - def get_peft_config_dict(self, adapter_name: str = None) -> dict: + def get_peft_config_dict(self, adapter_name: str = None) -> Optional[Dict[str, Any]]: """Return the PEFT config as a dict for vLLM's PEFTHelper. Used by CheckpointEngineManager for LoRA-only weight sync. diff --git a/src/twinkle/model/megatron/strategy/megatron.py b/src/twinkle/model/megatron/strategy/megatron.py index 017d3cf5..bcd5520f 100644 --- a/src/twinkle/model/megatron/strategy/megatron.py +++ b/src/twinkle/model/megatron/strategy/megatron.py @@ -42,7 +42,6 @@ def __init__( 'expert_tensor_parallel_size': self.device_mesh.etp_world_size or 1, 'virtual_pipeline_model_parallel_size': self.device_mesh.vpp_size or 1, } - self._initialized = True mpu.initialize_model_parallel( order=self.device_mesh.order, **parallel_kwargs, @@ -54,6 +53,26 @@ def sequence_parallel(self) -> bool: """Read from device_mesh so auto-enable in args.py is visible.""" return getattr(self.device_mesh, 'sequence_parallel', False) + @property + def bridge(self): + return self.config.bridge + + @property + def params_type(self) -> torch.dtype: + if self._params_dtype is not None: + dtype_map = { + 'fp32': torch.float32, + 'fp16': torch.float16, + 'bf16': torch.bfloat16, + } + return dtype_map.get(self._params_dtype, torch.bfloat16) + + if self.mixed_precision == 'bf16': + return torch.bfloat16 + elif self.mixed_precision == 'fp16': + return torch.float16 + return torch.float32 + def _check_device_mesh(self): from megatron.core import parallel_state as mpu @@ -78,22 +97,6 @@ def _check_device_mesh(self): if self.device_mesh.vpp_size is not None and self.device_mesh.vpp_size > 1: assert self.device_mesh.vpp_size == mpu.get_virtual_pipeline_model_parallel_world_size() - @property - def params_type(self) -> torch.dtype: - if self._params_dtype is not None: - dtype_map = { - 'fp32': torch.float32, - 'fp16': torch.float16, - 'bf16': torch.bfloat16, - } - return dtype_map.get(self._params_dtype, torch.bfloat16) - - if self.mixed_precision == 'bf16': - return torch.bfloat16 - elif self.mixed_precision == 'fp16': - return torch.float16 - return torch.float32 - def wrap_model( self, model: List[nn.Module], From 01e753539515cb87a6c54a7680e0f53e4f1e3b60 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Mon, 30 Mar 2026 17:17:22 +0800 Subject: [PATCH 03/18] wip --- src/twinkle/model/megatron/megatron.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index c8d04a60..a0bbf121 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -1281,8 +1281,6 @@ def _patch_adapter(self, adapter_name: str, config_or_dir: Union[PeftConfig, str _models = [] for _model in model: - # Mark expert layers for MoE models - set_linear_is_expert(_model) if isinstance(config_or_dir, str): _model = PeftModel.from_pretrained( _model, config_or_dir, adapter_name=adapter_name, is_trainable=kwargs.get('is_trainable', True)) From 399960a8ac7ca2eed2b879dfa02ed3fabbfd707d Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Mon, 30 Mar 2026 17:29:16 +0800 Subject: [PATCH 04/18] wip --- src/twinkle/model/base.py | 47 +++++++++++++++++++++++++- src/twinkle/model/megatron/megatron.py | 18 ++-------- src/twinkle/model/multi_lora.py | 41 ++++++++++------------ 3 files changed, 66 insertions(+), 40 deletions(-) diff --git a/src/twinkle/model/base.py b/src/twinkle/model/base.py index bee6c37d..53716d31 100644 --- a/src/twinkle/model/base.py +++ b/src/twinkle/model/base.py @@ -1,7 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union, List from twinkle import Platform, torch_util from twinkle.data_format import InputFeature, ModelOutput @@ -157,3 +157,48 @@ def _try_init_process_group(self): if backend in ('nccl', 'hccl'): init_kwargs['device_id'] = torch.device(Platform.get_local_device()) dist.init_process_group(**init_kwargs) + + @staticmethod + def get_target_modules(model: torch.nn.Module, target_modules: List[str]) -> List[str]: + import torch + + def find_layers(model: torch.nn.Module, cond_fn) -> List[str]: + result = [] + for name, module in model.named_modules(): + if cond_fn(name, module): + result.append(name) + return result + + def find_all_linears(model: torch.nn.Module) -> List[str]: + from megatron.core.extensions.transformer_engine import TEGroupedLinear, TELayerNormColumnParallelLinear, \ + TELinear + + def _cond(name: str, module: torch.nn.Module) -> bool: + if name == 'output_layer' or 'lora' in name: + return False + if isinstance(module, (TELinear, TELayerNormColumnParallelLinear, TEGroupedLinear, torch.nn.Linear)): + return True + return False + + return find_layers(model, _cond) + + def find_router(model: torch.nn.Module) -> List[str]: + from megatron.core.transformer.moe.router import TopKRouter + return find_layers(model, lambda name, module: isinstance(module, TopKRouter) and 'lora' not in name) + + def find_embedding(model: torch.nn.Module) -> List[str]: + from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding + return find_layers(model, + lambda name, module: isinstance(module, LanguageModelEmbedding) and 'lora' not in name) + + result = target_modules.copy() + if 'all-linear' in result: + result.remove('all-linear') + result += find_all_linears(model) + if 'all-embedding' in result: + result.remove('all-embedding') + result += find_embedding(model) + if 'all-router' in result: + result.remove('all-router') + result += find_router(model) + return list(set(result)) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index a0bbf121..e1ab1d4f 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -1297,29 +1297,17 @@ def _patch_adapter(self, adapter_name: str, config_or_dir: Union[PeftConfig, str else: target_modules = list(config.target_modules) - expanded_modules = get_target_modules(_model, target_modules) + expanded_modules = self.get_target_modules(_model, target_modules) config.target_modules = expanded_modules - _model = get_peft_model(_model, config, adapter_name=adapter_name) - # setting average_gradients_across_tp_domain - for m in _model.modules(): - if isinstance(m, LoraLinear): - # just check - # TODO untested code - from .args import get_args - args = get_args() - from .tuners import LoraParallelLinear - assert args.is_multimodal and not isinstance(m, LoraParallelLinear) - for p in m.parameters(): - if p.requires_grad: - p.average_gradients_across_tp_domain = True + _model = get_peft_model(_model, config, adapter_name=adapter_name) # noqa _models.append(_model) self.model = _models # Create optimizer group for adapter self.optimizer_group[adapter_name] = self._construct_default_optimizer_group() self.optimizer_group[adapter_name].adapter_name = adapter_name - self.optimizer_group[adapter_name].adapter_config = config + self.optimizer_group[adapter_name].adapter_config = config # noqa self.optimizer_group[adapter_name].gradient_accumulation_steps = kwargs.get('gradient_accumulation_steps', 1) # Fix: use .processor instead of .tokenizer - Template class uses self.processor self._default_tokenizer = self.optimizer_group[adapter_name].template.processor diff --git a/src/twinkle/model/multi_lora.py b/src/twinkle/model/multi_lora.py index cebde214..37b647e6 100644 --- a/src/twinkle/model/multi_lora.py +++ b/src/twinkle/model/multi_lora.py @@ -6,7 +6,7 @@ from peft import LoraConfig, PeftModel, get_peft_model from peft.tuners.lora import Embedding, Linear, LoraLayer from types import MethodType -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, Callable from twinkle import torch_util from twinkle.data_format import InputFeature @@ -191,7 +191,7 @@ def _patch_lora_forward(_self, name, base_layer: LoraLayer): # Megatron is an optional dependency; if megatron-core/megatron is missing, # we must not crash the entire service just because we try to import megatron modules. try: - from twinkle.model.megatron.tuners import LoraParallelLinear as _LoraParallelLinear + from mcore_bridge import LoraParallelLinear as _LoraParallelLinear except Exception: # noqa: broad-except _LoraParallelLinear = () @@ -394,32 +394,25 @@ def _patch_peft(_module): return _module def _patch_megatron(_module): - # Mark expert layers for MoE models - from .megatron.tuners.utils import set_linear_is_expert - set_linear_is_expert(_module) - # Expand target_modules (e.g., 'all-linear' -> actual module names) _config = deepcopy(config) + if isinstance(_module, PeftModel): + _module.add_adapter(lora_tenant.adapter_name, _config) + else: + # TODO first wrap needs parse target_modules, need to fix later + if _config.target_modules: + if isinstance(_config.target_modules, str): + target_modules = [_config.target_modules] + else: + target_modules = list(_config.target_modules) - from .megatron.tuners.utils import patch_deepcopy - with patch_deepcopy(): - if isinstance(_module, PeftModel): - _module.add_adapter(lora_tenant.adapter_name, _config) - else: - # TODO first wrap needs parse target_modules, need to fix later - if _config.target_modules: - if isinstance(_config.target_modules, str): - target_modules = [_config.target_modules] - else: - target_modules = list(_config.target_modules) - - from .megatron.tuners.utils import get_target_modules - _config.target_modules = get_target_modules(_module, target_modules) - _module = get_peft_model(_module, _config, lora_tenant.adapter_name) + from .base import TwinkleModel + _config.target_modules = TwinkleModel.get_target_modules(_module, target_modules) + _module = get_peft_model(_module, _config, lora_tenant.adapter_name) - for name, submodule in _module.named_modules(): - if isinstance(submodule, LoraLayer): - self._patch_lora_forward(name, submodule) + for name, submodule in _module.named_modules(): + if isinstance(submodule, LoraLayer): + self._patch_lora_forward(name, submodule) return _module if isinstance(module, list): From f2bd84633943c740bc960f3a3df28b60b3187719 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Mon, 30 Mar 2026 17:40:17 +0800 Subject: [PATCH 05/18] fix --- src/twinkle/sampler/torch_sampler/__init__.py | 1 - .../sampler/torch_sampler/torch_sampler.py | 157 --------- .../torch_sampler/transformers_engine.py | 298 ------------------ .../sampler/vllm_sampler/vllm_sampler.py | 4 +- 4 files changed, 2 insertions(+), 458 deletions(-) delete mode 100644 src/twinkle/sampler/torch_sampler/__init__.py delete mode 100644 src/twinkle/sampler/torch_sampler/torch_sampler.py delete mode 100644 src/twinkle/sampler/torch_sampler/transformers_engine.py diff --git a/src/twinkle/sampler/torch_sampler/__init__.py b/src/twinkle/sampler/torch_sampler/__init__.py deleted file mode 100644 index ac1e5df5..00000000 --- a/src/twinkle/sampler/torch_sampler/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .torch_sampler import TorchSampler diff --git a/src/twinkle/sampler/torch_sampler/torch_sampler.py b/src/twinkle/sampler/torch_sampler/torch_sampler.py deleted file mode 100644 index 8c7643ac..00000000 --- a/src/twinkle/sampler/torch_sampler/torch_sampler.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""PyTorch native sampler using TransformersEngine.""" -import torch -from transformers import AutoModelForCausalLM, PreTrainedModel -from transformers.models.auto.auto_factory import _BaseAutoModelClass -from typing import Any, Dict, List, Optional, Type, Union - -from twinkle import DeviceMesh, remote_class, remote_function -from twinkle.data_format import InputFeature, Trajectory -from twinkle.data_format.sampling import SampledSequence, SampleResponse, SamplingParams -from twinkle.hub import HubOperation -from twinkle.sampler.base import Sampler - - -@remote_class() -class TorchSampler(Sampler): - # not tested yet - """A PyTorch native sampler using TransformersEngine.""" - - def __init__(self, - model_id: str, - device_mesh: DeviceMesh = None, - torch_dtype: torch.dtype = torch.bfloat16, - trust_remote_code: bool = True, - model_cls: Optional[Union[Type[PreTrainedModel], str, - Type[_BaseAutoModelClass]]] = AutoModelForCausalLM, - **kwargs): - super().__init__() - model_id = HubOperation.download_model(model_id) - self.model_id = model_id - self.device_mesh = device_mesh - - if device_mesh is not None and getattr(device_mesh, 'device_type', None): - self.device = torch.device(device_mesh.device_type) - elif torch.cuda.is_available(): - self.device = torch.device('cuda') - elif hasattr(torch, 'npu') and torch.npu.is_available(): - self.device = torch.device('npu') - else: - self.device = torch.device('cpu') - - from .transformers_engine import TransformersEngine - self.engine = TransformersEngine( - model_id=model_id, - torch_dtype=torch_dtype, - trust_remote_code=trust_remote_code, - model_cls=model_cls, - **kwargs) - self.model = self.engine.model - self.tokenizer = self.engine.tokenizer - - @remote_function() - def sample( - self, - inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], - sampling_params: Optional[Union[SamplingParams, Dict[str, Any]]] = None, - adapter_name: str = '', - ) -> List[SampleResponse]: - """Sample responses for given inputs. - - Args: - inputs: Either InputFeature(s) or Trajectory(s). - - InputFeature: Must contain 'input_ids'. - - Trajectory: Must contain 'messages'. Requires template to be set. - sampling_params: Sampling parameters. - adapter_name: Optional LoRA adapter name. - - Returns: - SampleResponse containing sampled sequences. - """ - if sampling_params is None: - sampling_params = SamplingParams() - elif isinstance(sampling_params, dict): - sampling_params = SamplingParams.from_dict(sampling_params) - - inputs_list = self._normalize_inputs(inputs) - - # Check if inputs are Trajectory (not encoded) - aligned with Model.forward logic - is_trajectory = self._is_trajectory(inputs) - - if is_trajectory: - template = self.template - assert template is not None, \ - 'Use set_template to add a template when trying to input Trajectory' - encoded_inputs = [self.encode_trajectory(traj, adapter_name) for traj in inputs_list] - else: - encoded_inputs = inputs_list - - gen_kwargs = sampling_params.to_transformers(self.tokenizer) - gen_kwargs['return_dict_in_generate'] = True - gen_kwargs['output_scores'] = True - - all_sequences = [] - device = next(self.model.parameters()).device - - for feat in encoded_inputs: - input_ids = feat['input_ids'] - if hasattr(input_ids, 'tolist'): - input_ids = input_ids.tolist() - - input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device) - attention_mask = torch.ones_like(input_tensor) - - # Build model inputs including multimodal data - model_inputs = { - 'input_ids': input_tensor, - 'attention_mask': attention_mask, - } - - # Add extra inputs for multimodal models (pixel_values, image_grid_thw, etc.) - # These are typically produced by template.encode() for VLM models - extra_keys = [ - 'pixel_values', 'image_grid_thw', 'video_grid_thw', 'pixel_values_videos', 'second_per_grid_ts' - ] - for key in extra_keys: - if key in feat: - value = feat[key] - if hasattr(value, 'to'): - model_inputs[key] = value.to(device) - elif isinstance(value, (list, tuple)) and len(value) > 0: - # Handle list of tensors - if hasattr(value[0], 'to'): - model_inputs[key] = [v.to(device) for v in value] - else: - model_inputs[key] = value - else: - model_inputs[key] = value - - with torch.no_grad(): - outputs = self.model.generate(**model_inputs, **gen_kwargs) - - generated_ids = outputs.sequences - prompt_len = len(input_ids) - - gen_tokens = generated_ids[0][prompt_len:].tolist() - - seq_logprobs = None - # TODO: fix logprobs - if hasattr(outputs, 'scores') and outputs.scores: - seq_logprobs = [] - for k, score in enumerate(outputs.scores): - if k >= len(gen_tokens): - break - log_probs = torch.log_softmax(score[0], dim=-1) - seq_logprobs.append(log_probs[gen_tokens[k]].item()) - - stop_reason = 'length' - if gen_tokens and gen_tokens[-1] == self.tokenizer.eos_token_id: - stop_reason = 'stop' - - all_sequences.append(SampledSequence( - stop_reason=stop_reason, - tokens=gen_tokens, - logprobs=seq_logprobs, - )) - - return [SampleResponse(sequences=all_sequences)] diff --git a/src/twinkle/sampler/torch_sampler/transformers_engine.py b/src/twinkle/sampler/torch_sampler/transformers_engine.py deleted file mode 100644 index ee8ed97d..00000000 --- a/src/twinkle/sampler/torch_sampler/transformers_engine.py +++ /dev/null @@ -1,298 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -""" -TransformersEngine: A transformers-based inference engine. - -Uses HuggingFace transformers model.generate() for text generation. -Slower than vLLM but more compatible and easier to debug. -""" - -import hashlib -import json -import os -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel -from transformers.models.auto.auto_factory import _BaseAutoModelClass -from typing import Any, Dict, List, Optional, Tuple, Type, Union - -from twinkle import get_logger -from twinkle.data_format.sampling import SampledSequence, SampleResponse, SamplingParams -from twinkle.sampler.base_engine import BaseSamplerEngine - -logger = get_logger() - - -class TransformersEngine(BaseSamplerEngine): - # not tested yet - def __init__( - self, - model_id: str, - *, - torch_dtype: torch.dtype = torch.bfloat16, - device_map: str = 'auto', - trust_remote_code: bool = True, - enable_lora: bool = False, - max_lora_rank: int = 64, - model_kwargs: Optional[Dict[str, Any]] = None, - model_cls: Optional[Union[Type[PreTrainedModel], str, Type[_BaseAutoModelClass]]] = AutoModelForCausalLM, - ): - self._model_id = model_id - self.torch_dtype = torch_dtype - self.device_map = device_map - self.trust_remote_code = trust_remote_code - self.enable_lora = enable_lora - self.max_lora_rank = max_lora_rank - self._model_kwargs = model_kwargs or {} - - # Load model and tokenizer - self.model = model_cls.from_pretrained( - model_id, - torch_dtype=torch_dtype, - device_map=device_map, - trust_remote_code=trust_remote_code, - **self._model_kwargs, - ) - self.model.eval() - - self.tokenizer = AutoTokenizer.from_pretrained( - model_id, - trust_remote_code=trust_remote_code, - ) - if self.tokenizer.pad_token is None: - self.tokenizer.pad_token = self.tokenizer.eos_token - - # LoRA adapter management - self._adapters: Dict[str, Dict[str, Any]] = {} - self._lora_weights_dir = os.path.join('/tmp/twinkle_lora', hashlib.md5(model_id.encode()).hexdigest()) - os.makedirs(self._lora_weights_dir, exist_ok=True) - - # Track current adapter - self._current_adapter: Optional[str] = None - - logger.info(f'TransformersEngine initialized: model={model_id}') - - @property - def model_id(self) -> str: - return self._model_id - - async def get_tokenizer(self): - return self.tokenizer - - def _convert_params(self, params: Optional[SamplingParams]) -> Dict[str, Any]: - """Convert SamplingParams to transformers generate kwargs.""" - if params is None: - params = SamplingParams() - - gen_kwargs = { - 'do_sample': params.temperature > 0, - 'temperature': max(params.temperature, 1e-7), - 'top_p': params.top_p, - 'pad_token_id': self.tokenizer.pad_token_id, - 'eos_token_id': self.tokenizer.eos_token_id, - } - - if params.max_tokens is not None: - gen_kwargs['max_new_tokens'] = params.max_tokens - else: - gen_kwargs['max_new_tokens'] = 2048 - - if params.seed is not None: - torch.manual_seed(params.seed) - - if params.top_k > 0: - gen_kwargs['top_k'] = params.top_k - - if params.repetition_penalty != 1.0: - gen_kwargs['repetition_penalty'] = params.repetition_penalty - - # Handle stop sequences - if params.stop: - if isinstance(params.stop, str): - stop_token_ids = self.tokenizer.encode(params.stop, add_special_tokens=False) - if stop_token_ids: - gen_kwargs['eos_token_id'] = [self.tokenizer.eos_token_id] + stop_token_ids - elif isinstance(params.stop, (list, tuple)): - if params.stop and isinstance(params.stop[0], int): - gen_kwargs['eos_token_id'] = [self.tokenizer.eos_token_id] + list(params.stop) - else: - all_stop_ids = [self.tokenizer.eos_token_id] - for s in params.stop: - ids = self.tokenizer.encode(s, add_special_tokens=False) - if ids: - all_stop_ids.extend(ids) - gen_kwargs['eos_token_id'] = all_stop_ids - - return gen_kwargs - - async def sample( - self, - prompt_token_ids: List[int], - sampling_params: Optional[SamplingParams] = None, - *, - num_samples: int = 1, - logprobs: bool = True, - include_prompt_logprobs: bool = False, - topk_prompt_logprobs: int = 0, - adapter_uri: Optional[str] = None, - request_id: Optional[str] = None, - images: Optional[List[Any]] = None, - videos: Optional[List[Any]] = None, - extra_model_inputs: Optional[Dict[str, Any]] = None, - ) -> SampleResponse: - """Sample completions using transformers generate().""" - - # Switch adapter if needed - if adapter_uri and self.enable_lora: - await self._load_adapter(adapter_uri) - - # Convert params - gen_kwargs = self._convert_params(sampling_params) - gen_kwargs['num_return_sequences'] = num_samples - gen_kwargs['return_dict_in_generate'] = True - - if logprobs or include_prompt_logprobs: - gen_kwargs['output_scores'] = True - - # Prepare input - device = next(self.model.parameters()).device - input_ids = torch.tensor([prompt_token_ids], dtype=torch.long, device=device) - attention_mask = torch.ones_like(input_ids) - - # Repeat for num_samples - if num_samples > 1: - input_ids = input_ids.repeat(num_samples, 1) - attention_mask = attention_mask.repeat(num_samples, 1) - - # Build model inputs - model_inputs = { - 'input_ids': input_ids, - 'attention_mask': attention_mask, - } - - # Add extra model inputs for multimodal (pre-processed by template) - if extra_model_inputs: - for key, value in extra_model_inputs.items(): - if hasattr(value, 'to'): - model_inputs[key] = value.to(device) - else: - model_inputs[key] = value - - # Generate - with torch.no_grad(): - outputs = self.model.generate( - **model_inputs, - **gen_kwargs, - ) - - # Extract generated sequences - generated_ids = outputs.sequences - prompt_len = len(prompt_token_ids) - - sequences = [] - for i in range(num_samples): - gen_tokens = generated_ids[i][prompt_len:].tolist() - - # Compute logprobs if requested - seq_logprobs = None - if logprobs and hasattr(outputs, 'scores') and outputs.scores: - seq_logprobs = [] - for j, score in enumerate(outputs.scores): - if j >= len(gen_tokens): - break - log_probs = torch.log_softmax(score[i], dim=-1) - token_id = gen_tokens[j] - seq_logprobs.append(log_probs[token_id].item()) - - # Determine stop reason - stop_reason = 'length' - if gen_tokens and gen_tokens[-1] == self.tokenizer.eos_token_id: - stop_reason = 'stop' - - sequences.append(SampledSequence( - stop_reason=stop_reason, - tokens=gen_tokens, - logprobs=seq_logprobs, - )) - - # Compute prompt logprobs if requested - prompt_logprobs_result = None - topk_prompt_logprobs_result = None - if include_prompt_logprobs or topk_prompt_logprobs > 0: - prompt_logprobs_result, topk_prompt_logprobs_result = await self._compute_prompt_logprobs( - prompt_token_ids, - topk=topk_prompt_logprobs if topk_prompt_logprobs > 0 else 1, - ) - - return SampleResponse( - sequences=sequences, - prompt_logprobs=prompt_logprobs_result, - topk_prompt_logprobs=topk_prompt_logprobs_result if topk_prompt_logprobs > 0 else None, - ) - - async def _compute_prompt_logprobs( - self, - prompt_token_ids: List[int], - topk: int = 1, - ) -> Tuple[List[Optional[float]], List[Optional[List[Tuple[int, float]]]]]: - """Compute log probabilities for prompt tokens.""" - device = next(self.model.parameters()).device - input_ids = torch.tensor([prompt_token_ids], dtype=torch.long, device=device) - - with torch.no_grad(): - outputs = self.model(input_ids=input_ids) - logits = outputs.logits[0] # [seq_len, vocab] - - log_probs = torch.log_softmax(logits, dim=-1) - - prompt_logprobs: List[Optional[float]] = [None] # First token has no previous context - topk_logprobs: List[Optional[List[Tuple[int, float]]]] = [None] - - for i in range(1, len(prompt_token_ids)): - token_id = prompt_token_ids[i] - prev_logprobs = log_probs[i - 1] - - # Logprob for the actual token - prompt_logprobs.append(prev_logprobs[token_id].item()) - - # Top-k logprobs - topk_values, topk_indices = prev_logprobs.topk(topk) - topk_logprobs.append([(idx.item(), val.item()) for idx, val in zip(topk_indices, topk_values)]) - - return prompt_logprobs, topk_logprobs - - async def update_weights( - self, - weights: Dict[str, torch.Tensor], - adapter_name: Optional[str] = None, - ) -> None: - """Update model weights.""" - if adapter_name is None: - # Update base model weights - self.model.load_state_dict(weights, strict=False) - logger.info(f'Updated {len(weights)} base model weight tensors') - else: - # Update LoRA adapter weights - from peft import PeftModel - if isinstance(self.model, PeftModel): - adapter_state_dict = {} - for key, value in weights.items(): - if adapter_name in key: - adapter_state_dict[key] = value - if adapter_state_dict: - self.model.load_state_dict(adapter_state_dict, strict=False) - logger.info(f'Updated {len(adapter_state_dict)} adapter weights for {adapter_name}') - - async def save_weights_for_sampler( - self, - weights: Dict[str, torch.Tensor], - peft_config: Dict[str, Any], - ) -> str: - raise NotImplementedError - - async def _load_adapter(self, adapter_uri: str) -> None: - raise NotImplementedError - - async def sleep(self, **kwargs) -> None: - pass - - async def wake_up(self, **kwargs) -> None: - pass diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index 915e012f..04b8d99f 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -376,7 +376,7 @@ def reset_prefix_cache(self): self._run_in_loop(self.engine.reset_prefix_cache()) @remote_function(dispatch='all', lazy_collect=True) - def receive_weights( + async def receive_weights( self, base_sync_done: bool = False, peft_config: dict = None, @@ -423,7 +423,7 @@ async def _receive_and_load(): # Base-model sync invalidates any previously synced LoRA. self.engine.invalidate_synced_lora() - self._run_in_loop(_receive_and_load()) + await _receive_and_load() def shutdown(self): """Gracefully shutdown the vLLM engine and background event loop. From 096c19366d188e67b9b388c03ac7663022bd17fa Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Mon, 30 Mar 2026 18:18:12 +0800 Subject: [PATCH 06/18] wip --- cookbook/megatron/tp.py | 8 ++--- src/twinkle/model/base.py | 2 +- src/twinkle/model/megatron/megatron.py | 8 +---- .../model/megatron/strategy/megatron.py | 29 +++++++++++++++---- src/twinkle/utils/torch_utils.py | 8 ++--- 5 files changed, 32 insertions(+), 23 deletions(-) diff --git a/cookbook/megatron/tp.py b/cookbook/megatron/tp.py index ee457fe7..b09d1a60 100644 --- a/cookbook/megatron/tp.py +++ b/cookbook/megatron/tp.py @@ -9,7 +9,7 @@ from twinkle.model import MegatronModel from twinkle.preprocessor import SelfCognitionProcessor # Construct a device_mesh, tp=pp=cp=2, dp=1 -device_mesh = DeviceMesh.from_sizes(dp_size=2, tp_size=2, pp_size=2) +device_mesh = DeviceMesh.from_sizes(dp_size=1, tp_size=2, pp_size=2, cp_size=2) # use torchrun mode twinkle.initialize(mode='local', global_device_mesh=device_mesh) @@ -19,7 +19,7 @@ def eval(model): # 100 Samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100))) - dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') + dataset.set_template('Template', model_id='ms://Qwen/Qwen3-4B') dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) dataset.encode() dataloader = DataLoader(dataset=dataset, batch_size=16) @@ -33,7 +33,7 @@ def train(): # 1000 samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) # Set template to prepare encoding - dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') + dataset.set_template('Template', model_id='ms://Qwen/Qwen3-4B') # Preprocess the dataset to standard format dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) # Encode dataset @@ -41,7 +41,7 @@ def train(): # Global batch size = 1, dp_size = 1 dataloader = DataLoader(dataset=dataset, batch_size=16) # Use a MegatronModel - model = MegatronModel(model_id='ms://Qwen/Qwen3.5-4B') + model = MegatronModel(model_id='ms://Qwen/Qwen3-4B') lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') diff --git a/src/twinkle/model/base.py b/src/twinkle/model/base.py index 53716d31..19a05958 100644 --- a/src/twinkle/model/base.py +++ b/src/twinkle/model/base.py @@ -159,7 +159,7 @@ def _try_init_process_group(self): dist.init_process_group(**init_kwargs) @staticmethod - def get_target_modules(model: torch.nn.Module, target_modules: List[str]) -> List[str]: + def get_target_modules(model: 'torch.nn.Module', target_modules: List[str]) -> List[str]: import torch def find_layers(model: torch.nn.Module, cond_fn) -> List[str]: diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index e1ab1d4f..3d4e6e82 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -13,7 +13,7 @@ import torch import torch.distributed as dist import torch.nn as nn -from Cython.Compiler.Code import contextmanager +from contextlib import contextmanager from peft import LoraConfig, PeftConfig, PeftModel, get_peft_model from peft.tuners.lora import Linear as LoraLinear from torch.optim import Optimizer @@ -137,12 +137,6 @@ def _construct_default_optimizer_group(self): _device_mesh=self.device_mesh, ) - @staticmethod - def _move_model_to_gpu(model: nn.Module) -> nn.Module: - model = model.to(Platform.get_local_device()) - torch_util.synchronize() - return model - def _lazy_wrap_model(self): if not self._model_wrapped: self.model = self.strategy.wrap_model(self.model) diff --git a/src/twinkle/model/megatron/strategy/megatron.py b/src/twinkle/model/megatron/strategy/megatron.py index bcd5520f..919fbf0f 100644 --- a/src/twinkle/model/megatron/strategy/megatron.py +++ b/src/twinkle/model/megatron/strategy/megatron.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from twinkle import DeviceMesh +from twinkle import DeviceMesh, Platform, torch_util class MegatronStrategy: @@ -31,8 +31,6 @@ def __init__( elif self.mixed_precision == 'no': params_dtype = torch.float32 self._params_dtype = params_dtype - from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed - model_parallel_cuda_manual_seed(self._seed) parallel_kwargs = { 'tensor_model_parallel_size': self.device_mesh.tp_world_size or 1, @@ -40,12 +38,17 @@ def __init__( 'context_parallel_size': self.device_mesh.cp_world_size or 1, 'expert_model_parallel_size': self.device_mesh.ep_size or 1, 'expert_tensor_parallel_size': self.device_mesh.etp_world_size or 1, - 'virtual_pipeline_model_parallel_size': self.device_mesh.vpp_size or 1, + 'virtual_pipeline_model_parallel_size': self.device_mesh.vpp_size or None, } + if not self.device_mesh.vpp_size: + # non-interleave does not support overlap_p2p_comm + kwargs['overlap_p2p_comm'] = False mpu.initialize_model_parallel( order=self.device_mesh.order, **parallel_kwargs, ) + from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed + model_parallel_cuda_manual_seed(self._seed) self.config = self.get_model_config(model_dir, parallel_kwargs, **kwargs) @property @@ -198,9 +201,25 @@ def create_megatron_model( load_weights: bool = True, ) -> List[nn.Module]: from mcore_bridge import get_mcore_model + import torch.distributed as dist mg_models = get_mcore_model(self.config) + + if dist.is_initialized(): + dist.barrier() + + _models = [] + for _model in mg_models: + _model = self._move_model_to_gpu(_model) + _models.append(_model) + if load_weights: # Load weights bridge = self.config.bridge bridge.load_weights(mg_models, self.model_dir) - return mg_models + return _models + + @staticmethod + def _move_model_to_gpu(model: nn.Module) -> nn.Module: + model = model.to(Platform.get_local_device()) + torch_util.synchronize() + return model \ No newline at end of file diff --git a/src/twinkle/utils/torch_utils.py b/src/twinkle/utils/torch_utils.py index 5d00cc7d..07d1730f 100644 --- a/src/twinkle/utils/torch_utils.py +++ b/src/twinkle/utils/torch_utils.py @@ -2,10 +2,6 @@ from datetime import timedelta from typing import Optional from typing import TYPE_CHECKING, Any, List, Mapping, Union - -import torch - -from twinkle import requires from .network import is_valid_ipv6_address if TYPE_CHECKING: @@ -229,8 +225,8 @@ def pad_and_stack_tensors(tensors: List['torch.Tensor'], pad_value: float = -200 return torch.cat(padded_tensors, dim=0) -def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: Optional[torch.Tensor], dim: int): - requires('megatron_core') +def split_cp_inputs(inputs: 'torch.Tensor', cu_seqlens: Optional['torch.Tensor'], dim: int): + import torch from megatron.core import mpu if dim < 0: dim = (dim + inputs.ndim) % inputs.ndim From 53c19a73913bd98af880e3bcf74102b1d71966f8 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Mon, 30 Mar 2026 18:54:49 +0800 Subject: [PATCH 07/18] wip --- src/twinkle/model/megatron/megatron.py | 1 + .../model/megatron/strategy/megatron.py | 31 +++++++++++++++++-- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 3d4e6e82..f434df2d 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -114,6 +114,7 @@ def __init__( 'recompute_modules': recompute_modules, 'recompute_method': recompute_method, 'recompute_num_layers': recompute_num_layers, + 'variable_seq_lengths': self.variable_seq_lengths, }) seed = kwargs.pop('seed', None) or int(os.environ.get('TWINKLE_SEED', 42)) self.strategy = MegatronStrategy(self._model_path, self.device_mesh, mixed_precision=mixed_precision, diff --git a/src/twinkle/model/megatron/strategy/megatron.py b/src/twinkle/model/megatron/strategy/megatron.py index 919fbf0f..d4e6c914 100644 --- a/src/twinkle/model/megatron/strategy/megatron.py +++ b/src/twinkle/model/megatron/strategy/megatron.py @@ -16,6 +16,7 @@ def __init__( use_distributed_optimizer: bool = True, mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16', seed: int = 42, + variable_seq_lengths: bool = False, **kwargs, ): from megatron.core import mpu @@ -23,7 +24,8 @@ def __init__( self.use_distributed_optimizer = use_distributed_optimizer self.mixed_precision = mixed_precision self.model_dir = model_dir - self._seed = seed + self.seed = seed + self.variable_seq_lengths = variable_seq_lengths # Determine params_dtype and activation checkpointing kwargs params_dtype = torch.bfloat16 if self.mixed_precision == 'fp16': @@ -48,7 +50,7 @@ def __init__( **parallel_kwargs, ) from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed - model_parallel_cuda_manual_seed(self._seed) + model_parallel_cuda_manual_seed(self.seed) self.config = self.get_model_config(model_dir, parallel_kwargs, **kwargs) @property @@ -183,14 +185,39 @@ def get_model_config( **kwargs, ): from mcore_bridge import ModelConfig, hf_to_mcore_config + from megatron.core.distributed import finalize_model_grads as _native_finalize_model_grads from transformers import AutoConfig hf_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) config_kwargs = hf_to_mcore_config(hf_config) config_kwargs.update(kwargs) + if 'calculate_per_token_loss' not in config_kwargs: + config_kwargs['calculate_per_token_loss'] = True + + if 'moe_token_dispatcher_type' not in config_kwargs: + config_kwargs['moe_token_dispatcher_type'] = 'alltoall' if self.variable_seq_lengths else 'allgather' + + def finalize_model_grads_for_lora(model, *args, **kwargs): + from megatron.core.distributed import DistributedDataParallel as MegatronDDP + from peft import PeftModel as _PeftModel + + # Check if model is DDP-wrapped (has ddp_config) + # Need to unwrap PeftModel to check the underlying model + def _get_base_model(m): + if isinstance(m, _PeftModel): + return _get_base_model(m.base_model.model) + return m + + base_model = _get_base_model(model[0]) + if isinstance(base_model, MegatronDDP) or hasattr(base_model, 'ddp_config'): + # Use native implementation for DDP models + return _native_finalize_model_grads(model, *args, **kwargs) + config = ModelConfig( use_cpu_initialization=True, params_dtype=self.params_type, sequence_parallel=self.sequence_parallel, + finalize_model_grads_func=finalize_model_grads_for_lora, + variable_seq_lengths=self.variable_seq_lengths, **parallel_kwargs, **config_kwargs, ) From d951d4758164f0a80d0ef31b9a4a9fbca591b771 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 31 Mar 2026 10:58:54 +0800 Subject: [PATCH 08/18] wip --- src/twinkle/model/megatron/megatron.py | 14 +++++++------- src/twinkle/model/megatron/strategy/megatron.py | 11 +++++++++-- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index f434df2d..4daa319b 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -325,7 +325,7 @@ def forward_backward(self, if self.variable_seq_lengths: seq_length = 4096 else: - original_seq_length = inputs[0]['input_ids'].shape[1] + original_seq_length = inputs[0]['input_ids'].shape[1] * cp_size if cp_size > 1: divisor = 2 * cp_size elif self.strategy.sequence_parallel and self.device_mesh.tp_world_size > 1: @@ -388,7 +388,7 @@ def forward_step_func(data_iterator, model): output_tensor = model(**batch) batch['labels'] = labels logps = None - if labels is not None and mpu.is_pipeline_last_stage(): + if labels is not None and mpu.is_pipeline_last_stage(False, unwrapped_model.vp_stage): loss_mask = (labels != -100).bool() masked_labels = labels.clone() masked_labels[~loss_mask] = 0 @@ -877,7 +877,7 @@ def load(self, name: str, output_dir: Optional[str] = None, **kwargs): **kwargs, ) else: - bridge = self._bridge + bridge = self.self.strategy.bridge for _model in self.strategy.unwrap_model(self.model): bridge.load_weights( _model, @@ -1142,7 +1142,7 @@ def _read_iteration(tracker_path: str) -> int: def _merge_lora_adapters(self, adapter_name: str = 'default'): """Merge LoRA adapters into base model weights.""" - from mcore_bridge import LoraParallelLinear + from mcoreself.strategy.bridge import LoraParallelLinear with torch.no_grad(): for model in self.strategy.unwrap_model(self.model): for module in model.modules(): @@ -1151,7 +1151,7 @@ def _merge_lora_adapters(self, adapter_name: str = 'default'): def _unmerge_lora_adapters(self): """Unmerge LoRA adapters to restore training state.""" - from mcore_bridge import LoraParallelLinear + from mcoreself.strategy.bridge import LoraParallelLinear with torch.no_grad(): for model in self.strategy.unwrap_model(self.model): for module in model.modules(): @@ -1186,7 +1186,7 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non # Get the model (unwrap if DDP wrapped) model = self.strategy.unwrap_model(self.model) - self._bridge.save_weights( + self.self.strategy.bridge.save_weights( model, output_dir, is_peft_format=is_peft_format, adapter_name=adapter_name, lora_converter=lora_converter) # Save config on rank 0 only @@ -1259,7 +1259,7 @@ def get_hf_state_dict(self, adapter_name: str = '') -> Generator[Tuple[str, torc ... print(f"{name}: {tensor.shape}") """ model = self.strategy.unwrap_model(self.model) - yield from self._bridge.export_weights( + yield from self.self.strategy.bridge.export_weights( model, target_device=None, # Keep on current device for IPC transfer only_last_rank=False, # All ranks participate in weight sync diff --git a/src/twinkle/model/megatron/strategy/megatron.py b/src/twinkle/model/megatron/strategy/megatron.py index d4e6c914..595d47e0 100644 --- a/src/twinkle/model/megatron/strategy/megatron.py +++ b/src/twinkle/model/megatron/strategy/megatron.py @@ -34,17 +34,24 @@ def __init__( params_dtype = torch.float32 self._params_dtype = params_dtype + vpp_size = self.device_mesh.vpp_size + if vpp_size in (0, 1): + vpp_size = None + parallel_kwargs = { 'tensor_model_parallel_size': self.device_mesh.tp_world_size or 1, 'pipeline_model_parallel_size': self.device_mesh.pp_world_size or 1, 'context_parallel_size': self.device_mesh.cp_world_size or 1, 'expert_model_parallel_size': self.device_mesh.ep_size or 1, 'expert_tensor_parallel_size': self.device_mesh.etp_world_size or 1, - 'virtual_pipeline_model_parallel_size': self.device_mesh.vpp_size or None, + 'virtual_pipeline_model_parallel_size': vpp_size, } - if not self.device_mesh.vpp_size: + if not vpp_size: # non-interleave does not support overlap_p2p_comm kwargs['overlap_p2p_comm'] = False + if 'overlap_p2p_comm' not in kwargs: + kwargs['overlap_p2p_comm'] = True + kwargs['batch_p2p_comm'] = not kwargs['overlap_p2p_comm'] mpu.initialize_model_parallel( order=self.device_mesh.order, **parallel_kwargs, From ad82a77a45a6c253817748147b25ffc1dce43d30 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 31 Mar 2026 13:44:02 +0800 Subject: [PATCH 09/18] wip --- src/twinkle/infra/__init__.py | 150 ++++++++++++++----------- src/twinkle/model/megatron/megatron.py | 12 +- src/twinkle/sampler/__init__.py | 2 - 3 files changed, 91 insertions(+), 73 deletions(-) diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index f25f0fa3..30999406 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -619,81 +619,101 @@ def remote_function(dispatch: Union[Literal['slice', 'all', 'slice_dp'], Callabl def decorator(func: Callable[..., T1]) -> Callable[..., T1]: + def _handle_worker_redispatch(self, args, kwargs, device_mesh): + """Handle worker-side redispatch when args contain Ray ObjectRefs.""" + from ._ray import RayHelper + if RayHelper.has_ref(args, kwargs): + args, kwargs = RayHelper.do_get_and_collect(args, kwargs) + world_size = Platform.get_world_size() + rank = Platform.get_rank() + _workers_and_args = _dispatch_args( + _get_workers([None] * world_size, execute), dispatch, execute, device_mesh, args, kwargs) + _, args, kwargs = _workers_and_args[rank] + return args, kwargs + + def _handle_driver_dispatch(self, func_name, args, kwargs, device_mesh): + """Handle driver-side dispatch and result collection.""" + from ._ray import RayHelper + execute_method = RayHelper.execute_all_async if not sync else RayHelper.execute_all_sync + if RayHelper.has_ref(args, kwargs): + _workers_and_args = _dispatch_args( + _get_workers(self._actors, execute), 'all', execute, device_mesh, args, kwargs) + else: + _workers_and_args = _dispatch_args( + _get_workers(self._actors, execute), dispatch, execute, device_mesh, args, kwargs) + + result = execute_method(func_name, _workers_and_args) + result_func = RayHelper.do_get_and_collect_func(_collect_func, collect, result, device_mesh) + _local_lazy_collect = _lazy_collect + + if func_name == '__iter__': + return self + + if func_name == '__len__': + import ray + return ray.get(result[0]) + + if func_name == '__next__': + import ray + for _res in result: + stop = ray.get(_res[1]) + if stop: + raise StopIteration() + result = [_res[0] for _res in result] + result_func._futures = result + + if lazy_collect is not None: + _local_lazy_collect = lazy_collect + if hasattr(self, '_lazy_collect'): + _local_lazy_collect = self._lazy_collect + return result_func if _local_lazy_collect else result_func() + + def _set_wrapper_attrs(wrapper_func): + """Set common attributes on wrapper function.""" + wrapper_func._execute = execute + wrapper_func._collect = collect + wrapper_func._dispatch = dispatch + wrapper_func._lazy_collect = _lazy_collect + wrapper_func._sync = sync + return wrapper_func + + def _run_func(func, self, args, kwargs): + """Run function, handling both sync and async functions.""" + result = func(self, *args, **kwargs) + if inspect.iscoroutine(result): + import asyncio + # Run coroutine in event loop + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop is not None: + # Already in an event loop, create task + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, result) + return future.result() + else: + return asyncio.run(result) + return result + @functools.wraps(func) def wrapper(self, *args, **kwargs) -> T1: device_mesh = getattr(self, 'device_mesh', None) if _mode == 'local': - return func(self, *args, **kwargs) + return _run_func(func, self, args, kwargs) elif _mode == 'ray': check_unsafe(*args, **kwargs) if not hasattr(self, '_actors'): - # This is the worker - from ._ray import RayHelper - if RayHelper.has_ref(args, kwargs): - # In this case, driver dispatch is all, redispatch here - args, kwargs = RayHelper.do_get_and_collect(args, kwargs) - world_size = Platform.get_world_size() - rank = Platform.get_rank() - # Redispatch here - _workers_and_args = _dispatch_args( - _get_workers([None] * world_size, execute), dispatch, execute, device_mesh, args, kwargs) - _, args, kwargs = _workers_and_args[rank] - return func(self, *args, **kwargs) + # Worker side + args, kwargs = _handle_worker_redispatch(self, args, kwargs, device_mesh) + return _run_func(func, self, args, kwargs) else: - # This is the driver - from ._ray import RayHelper - execute_method = RayHelper.execute_all_async if not sync else RayHelper.execute_all_sync - if RayHelper.has_ref(args, kwargs): - # If has any object-ref, dispatch in worker, because we don't know the structure in the ref. - # for example, dataloader returns any data list. - _workers_and_args = _dispatch_args( - _get_workers(self._actors, execute), 'all', execute, device_mesh, args, kwargs) - else: - # dispatch now - _workers_and_args = _dispatch_args( - _get_workers(self._actors, execute), dispatch, execute, device_mesh, args, kwargs) - - result = execute_method(func.__name__, _workers_and_args) - # This is a result future, call it to get the actual result - result_func = RayHelper.do_get_and_collect_func(_collect_func, collect, result, device_mesh) - _local_lazy_collect = _lazy_collect - if func.__name__ == '__iter__': - # return self - return self - - if func.__name__ == '__len__': - # Get the first result and ignore the `lazy_collect` - import ray - return ray.get(result[0]) - - if func.__name__ == '__next__': - import ray - for _res in result: - # raise when any worker raises StopIteration - stop = ray.get(_res[1]) - if stop: - raise StopIteration() - result = [_res[0] for _res in result] - result_func._futures = result - - if lazy_collect is not None: - # Maybe this function returns a small object - _local_lazy_collect = lazy_collect - if hasattr(self, '_lazy_collect'): - # _lazy_collect in class has the highest priority - # This is the unique case that an object ref contains another - # And this is user independent, only decided by the code. - _local_lazy_collect = self._lazy_collect - result = result_func if _local_lazy_collect else result_func() - return result + # Driver side - Ray handles remote execution + return _handle_driver_dispatch(self, func.__name__, args, kwargs, device_mesh) else: raise NotImplementedError(f'Unsupported mode {_mode}') - wrapper._execute = execute - wrapper._collect = collect - wrapper._dispatch = dispatch - wrapper._lazy_collect = _lazy_collect - wrapper._sync = sync - return wrapper + return _set_wrapper_attrs(wrapper) return decorator diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 4daa319b..bbf817d2 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -877,7 +877,7 @@ def load(self, name: str, output_dir: Optional[str] = None, **kwargs): **kwargs, ) else: - bridge = self.self.strategy.bridge + bridge = self.strategy.bridge for _model in self.strategy.unwrap_model(self.model): bridge.load_weights( _model, @@ -1186,7 +1186,7 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non # Get the model (unwrap if DDP wrapped) model = self.strategy.unwrap_model(self.model) - self.self.strategy.bridge.save_weights( + self.strategy.bridge.save_weights( model, output_dir, is_peft_format=is_peft_format, adapter_name=adapter_name, lora_converter=lora_converter) # Save config on rank 0 only @@ -1259,12 +1259,12 @@ def get_hf_state_dict(self, adapter_name: str = '') -> Generator[Tuple[str, torc ... print(f"{name}: {tensor.shape}") """ model = self.strategy.unwrap_model(self.model) - yield from self.self.strategy.bridge.export_weights( + yield from self.strategy.bridge.export_weights( model, target_device=None, # Keep on current device for IPC transfer - only_last_rank=False, # All ranks participate in weight sync - is_peft_format=bool(adapter_name), - adapter_name=adapter_name if adapter_name else None, + only_master_rank=False, # All ranks participate in weight sync + peft_format=bool(adapter_name), + # adapter_name=adapter_name if adapter_name else None, tqdm_desc='Weight sync: ', ) diff --git a/src/twinkle/sampler/__init__.py b/src/twinkle/sampler/__init__.py index 6bd9532b..67b14026 100644 --- a/src/twinkle/sampler/__init__.py +++ b/src/twinkle/sampler/__init__.py @@ -1,7 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from twinkle.sampler.torch_sampler.transformers_engine import TransformersEngine from twinkle.sampler.vllm_sampler.vllm_engine import VLLMEngine from .base import Sampler from .base_engine import BaseSamplerEngine -from .torch_sampler import TorchSampler from .vllm_sampler import vLLMSampler From 48cbf1301bd925a88f53ba67b49bf0069273b870 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 31 Mar 2026 14:25:26 +0800 Subject: [PATCH 10/18] wip --- src/twinkle/infra/__init__.py | 150 ++++++++---------- src/twinkle/model/megatron/megatron.py | 27 +++- .../sampler/vllm_sampler/vllm_sampler.py | 4 +- 3 files changed, 91 insertions(+), 90 deletions(-) diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index 30999406..f25f0fa3 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -619,101 +619,81 @@ def remote_function(dispatch: Union[Literal['slice', 'all', 'slice_dp'], Callabl def decorator(func: Callable[..., T1]) -> Callable[..., T1]: - def _handle_worker_redispatch(self, args, kwargs, device_mesh): - """Handle worker-side redispatch when args contain Ray ObjectRefs.""" - from ._ray import RayHelper - if RayHelper.has_ref(args, kwargs): - args, kwargs = RayHelper.do_get_and_collect(args, kwargs) - world_size = Platform.get_world_size() - rank = Platform.get_rank() - _workers_and_args = _dispatch_args( - _get_workers([None] * world_size, execute), dispatch, execute, device_mesh, args, kwargs) - _, args, kwargs = _workers_and_args[rank] - return args, kwargs - - def _handle_driver_dispatch(self, func_name, args, kwargs, device_mesh): - """Handle driver-side dispatch and result collection.""" - from ._ray import RayHelper - execute_method = RayHelper.execute_all_async if not sync else RayHelper.execute_all_sync - if RayHelper.has_ref(args, kwargs): - _workers_and_args = _dispatch_args( - _get_workers(self._actors, execute), 'all', execute, device_mesh, args, kwargs) - else: - _workers_and_args = _dispatch_args( - _get_workers(self._actors, execute), dispatch, execute, device_mesh, args, kwargs) - - result = execute_method(func_name, _workers_and_args) - result_func = RayHelper.do_get_and_collect_func(_collect_func, collect, result, device_mesh) - _local_lazy_collect = _lazy_collect - - if func_name == '__iter__': - return self - - if func_name == '__len__': - import ray - return ray.get(result[0]) - - if func_name == '__next__': - import ray - for _res in result: - stop = ray.get(_res[1]) - if stop: - raise StopIteration() - result = [_res[0] for _res in result] - result_func._futures = result - - if lazy_collect is not None: - _local_lazy_collect = lazy_collect - if hasattr(self, '_lazy_collect'): - _local_lazy_collect = self._lazy_collect - return result_func if _local_lazy_collect else result_func() - - def _set_wrapper_attrs(wrapper_func): - """Set common attributes on wrapper function.""" - wrapper_func._execute = execute - wrapper_func._collect = collect - wrapper_func._dispatch = dispatch - wrapper_func._lazy_collect = _lazy_collect - wrapper_func._sync = sync - return wrapper_func - - def _run_func(func, self, args, kwargs): - """Run function, handling both sync and async functions.""" - result = func(self, *args, **kwargs) - if inspect.iscoroutine(result): - import asyncio - # Run coroutine in event loop - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = None - if loop is not None: - # Already in an event loop, create task - import concurrent.futures - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(asyncio.run, result) - return future.result() - else: - return asyncio.run(result) - return result - @functools.wraps(func) def wrapper(self, *args, **kwargs) -> T1: device_mesh = getattr(self, 'device_mesh', None) if _mode == 'local': - return _run_func(func, self, args, kwargs) + return func(self, *args, **kwargs) elif _mode == 'ray': check_unsafe(*args, **kwargs) if not hasattr(self, '_actors'): - # Worker side - args, kwargs = _handle_worker_redispatch(self, args, kwargs, device_mesh) - return _run_func(func, self, args, kwargs) + # This is the worker + from ._ray import RayHelper + if RayHelper.has_ref(args, kwargs): + # In this case, driver dispatch is all, redispatch here + args, kwargs = RayHelper.do_get_and_collect(args, kwargs) + world_size = Platform.get_world_size() + rank = Platform.get_rank() + # Redispatch here + _workers_and_args = _dispatch_args( + _get_workers([None] * world_size, execute), dispatch, execute, device_mesh, args, kwargs) + _, args, kwargs = _workers_and_args[rank] + return func(self, *args, **kwargs) else: - # Driver side - Ray handles remote execution - return _handle_driver_dispatch(self, func.__name__, args, kwargs, device_mesh) + # This is the driver + from ._ray import RayHelper + execute_method = RayHelper.execute_all_async if not sync else RayHelper.execute_all_sync + if RayHelper.has_ref(args, kwargs): + # If has any object-ref, dispatch in worker, because we don't know the structure in the ref. + # for example, dataloader returns any data list. + _workers_and_args = _dispatch_args( + _get_workers(self._actors, execute), 'all', execute, device_mesh, args, kwargs) + else: + # dispatch now + _workers_and_args = _dispatch_args( + _get_workers(self._actors, execute), dispatch, execute, device_mesh, args, kwargs) + + result = execute_method(func.__name__, _workers_and_args) + # This is a result future, call it to get the actual result + result_func = RayHelper.do_get_and_collect_func(_collect_func, collect, result, device_mesh) + _local_lazy_collect = _lazy_collect + if func.__name__ == '__iter__': + # return self + return self + + if func.__name__ == '__len__': + # Get the first result and ignore the `lazy_collect` + import ray + return ray.get(result[0]) + + if func.__name__ == '__next__': + import ray + for _res in result: + # raise when any worker raises StopIteration + stop = ray.get(_res[1]) + if stop: + raise StopIteration() + result = [_res[0] for _res in result] + result_func._futures = result + + if lazy_collect is not None: + # Maybe this function returns a small object + _local_lazy_collect = lazy_collect + if hasattr(self, '_lazy_collect'): + # _lazy_collect in class has the highest priority + # This is the unique case that an object ref contains another + # And this is user independent, only decided by the code. + _local_lazy_collect = self._lazy_collect + result = result_func if _local_lazy_collect else result_func() + return result else: raise NotImplementedError(f'Unsupported mode {_mode}') - return _set_wrapper_attrs(wrapper) + wrapper._execute = execute + wrapper._collect = collect + wrapper._dispatch = dispatch + wrapper._lazy_collect = _lazy_collect + wrapper._sync = sync + return wrapper return decorator diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index bbf817d2..4f612920 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -325,7 +325,7 @@ def forward_backward(self, if self.variable_seq_lengths: seq_length = 4096 else: - original_seq_length = inputs[0]['input_ids'].shape[1] * cp_size + original_seq_length = inputs[0]['input_ids'].shape[1] * (cp_size or 1) if cp_size > 1: divisor = 2 * cp_size elif self.strategy.sequence_parallel and self.device_mesh.tp_world_size > 1: @@ -1403,7 +1403,7 @@ def get_train_configs(self, **kwargs): # via NCCL; others consume the generator silently (rank=-1). @remote_function(dispatch='all', lazy_collect=True) - async def send_weights( + def send_weights( self, adapter_name: str = None, base_sync_done: bool = False, @@ -1474,6 +1474,7 @@ def weight_generator(): else: yield from _raw_weights() + is_sender = (engine.rank is not None and engine.rank == 0) if not is_sender: @@ -1481,7 +1482,27 @@ def weight_generator(): pass return - await engine.send_weights(weight_generator()) + async def _send(): + await engine.send_weights(weight_generator()) + + result_container = {'error': None} + + def _run(): + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(_send()) + finally: + loop.close() + except Exception as e: + result_container['error'] = e + + thread = threading.Thread(target=_run) + thread.start() + thread.join() + if result_container['error'] is not None: + raise result_container['error'] @remote_function(collect='first') def get_peft_config_dict(self, adapter_name: str = None) -> Optional[Dict[str, Any]]: diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index 04b8d99f..915e012f 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -376,7 +376,7 @@ def reset_prefix_cache(self): self._run_in_loop(self.engine.reset_prefix_cache()) @remote_function(dispatch='all', lazy_collect=True) - async def receive_weights( + def receive_weights( self, base_sync_done: bool = False, peft_config: dict = None, @@ -423,7 +423,7 @@ async def _receive_and_load(): # Base-model sync invalidates any previously synced LoRA. self.engine.invalidate_synced_lora() - await _receive_and_load() + self._run_in_loop(_receive_and_load()) def shutdown(self): """Gracefully shutdown the vLLM engine and background event loop. From 43ef29e58818423117acaeba4c9a9e9d0fb978fc Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 31 Mar 2026 15:26:13 +0800 Subject: [PATCH 11/18] wip --- cookbook/megatron/tp.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cookbook/megatron/tp.py b/cookbook/megatron/tp.py index b09d1a60..4ea2e13e 100644 --- a/cookbook/megatron/tp.py +++ b/cookbook/megatron/tp.py @@ -8,8 +8,8 @@ from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import MegatronModel from twinkle.preprocessor import SelfCognitionProcessor -# Construct a device_mesh, tp=pp=cp=2, dp=1 -device_mesh = DeviceMesh.from_sizes(dp_size=1, tp_size=2, pp_size=2, cp_size=2) +# Construct a device_mesh, tp=pp=dp=2 +device_mesh = DeviceMesh.from_sizes(dp_size=2, tp_size=2, pp_size=2) # use torchrun mode twinkle.initialize(mode='local', global_device_mesh=device_mesh) @@ -19,7 +19,7 @@ def eval(model): # 100 Samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100))) - dataset.set_template('Template', model_id='ms://Qwen/Qwen3-4B') + dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) dataset.encode() dataloader = DataLoader(dataset=dataset, batch_size=16) @@ -33,7 +33,7 @@ def train(): # 1000 samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) # Set template to prepare encoding - dataset.set_template('Template', model_id='ms://Qwen/Qwen3-4B') + dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') # Preprocess the dataset to standard format dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) # Encode dataset @@ -41,7 +41,7 @@ def train(): # Global batch size = 1, dp_size = 1 dataloader = DataLoader(dataset=dataset, batch_size=16) # Use a MegatronModel - model = MegatronModel(model_id='ms://Qwen/Qwen3-4B') + model = MegatronModel(model_id='ms://Qwen/Qwen3.5-4B') lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') From c3c762011773f006b8a8f2be5478eb7b77493a24 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 31 Mar 2026 17:53:49 +0800 Subject: [PATCH 12/18] fix --- src/twinkle/model/megatron/megatron.py | 21 ++++- .../model/megatron/multi_lora_megatron.py | 83 ++++++++----------- .../model/megatron/strategy/megatron.py | 52 ++++++++++-- .../transformers/multi_lora_transformers.py | 1 + 4 files changed, 96 insertions(+), 61 deletions(-) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 4f612920..639a818e 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -83,6 +83,7 @@ def __init__( self, model_id: str, config: Optional[PretrainedConfig] = None, + ddp_config: Optional[Dict[str, Any]] = None, device_mesh: Optional[DeviceMesh] = None, mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16', load_weights: bool = True, @@ -95,6 +96,7 @@ def __init__( requires('megatron_core') requires('mcore_bridge') os.environ['TOKENIZERS_PARALLELISM'] = 'true' + os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' nn.Module.__init__(self) from twinkle.patch.megatron_peft import MegatronPeft @@ -117,11 +119,16 @@ def __init__( 'variable_seq_lengths': self.variable_seq_lengths, }) seed = kwargs.pop('seed', None) or int(os.environ.get('TWINKLE_SEED', 42)) - self.strategy = MegatronStrategy(self._model_path, self.device_mesh, mixed_precision=mixed_precision, + self.strategy = MegatronStrategy(self._model_path, + self.device_mesh, + mixed_precision=mixed_precision, + config=config, + ddp_config=ddp_config or {}, seed=seed, **kwargs) self.model: List[nn.Module] = self.strategy.create_megatron_model(load_weights) self._model_wrapped = False + self._finish_config = False # This correctly handles vocab sharding in Tensor Parallelism self.optimizer_group: Dict[str, MegatronOptimizerGroup] = { _default_adapter_name: self._construct_default_optimizer_group() @@ -143,6 +150,13 @@ def _lazy_wrap_model(self): self.model = self.strategy.wrap_model(self.model) self._model_wrapped = True + def _lazy_finish_param_config(self): + if self._finish_config: + return + self._finish_config = True + optimizer = self.optimizer_group[self._get_default_group()].optimizer + self.strategy.finish_param_config(self.model, optimizer) + def _get_default_group(self): """Get the only group has optimizer, else return the default one""" if len(self.optimizer_group) == 1: @@ -289,6 +303,7 @@ def forward_backward(self, Average loss value across all microbatches. """ self._lazy_wrap_model() + self._lazy_finish_param_config() from functools import partial from megatron.core import parallel_state as mpu from megatron.core.pipeline_parallel import get_forward_backward_func @@ -1142,7 +1157,7 @@ def _read_iteration(tracker_path: str) -> int: def _merge_lora_adapters(self, adapter_name: str = 'default'): """Merge LoRA adapters into base model weights.""" - from mcoreself.strategy.bridge import LoraParallelLinear + from mcore_bridge import LoraParallelLinear with torch.no_grad(): for model in self.strategy.unwrap_model(self.model): for module in model.modules(): @@ -1151,7 +1166,7 @@ def _merge_lora_adapters(self, adapter_name: str = 'default'): def _unmerge_lora_adapters(self): """Unmerge LoRA adapters to restore training state.""" - from mcoreself.strategy.bridge import LoraParallelLinear + from mcore_bridge import LoraParallelLinear with torch.no_grad(): for model in self.strategy.unwrap_model(self.model): for module in model.modules(): diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index fc35b2b0..0f454412 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -25,25 +25,26 @@ class MultiLoraMegatronModel(MegatronModel): def __init__( - self, - model_id: str, - config: Optional[PretrainedConfig] = None, - device_mesh: Optional[DeviceMesh] = None, - mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16', - load_weights: bool = True, - recompute_granularity: Optional[str] = 'full', # Activation checkpointing - recompute_method: Optional[str] = 'uniform', - recompute_num_layers: Optional[int] = 1, - recompute_modules: Optional[list] = None, # Modules to recompute - max_loras: int = 5, - max_r: int = 32, - max_length: int = 8192, - **kwargs, + self, + model_id: str, + config: Optional[PretrainedConfig] = None, + ddp_config: Optional[Dict[str, Any]] = None, + device_mesh: Optional[DeviceMesh] = None, + mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16', + load_weights: bool = True, + recompute_granularity: Optional[str] = 'full', # Activation checkpointing + recompute_method: Optional[str] = 'uniform', + recompute_num_layers: Optional[int] = 1, + recompute_modules: Optional[list] = None, # Modules to recompute + max_loras: int = 5, + max_r: int = 32, + max_length: int = 8192, + **kwargs, ): requires('megatron_core') + requires('mcore_bridge') os.environ['TOKENIZERS_PARALLELISM'] = 'true' os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' - from .args import TwinkleMegatronArgs, set_args nn.Module.__init__(self) from twinkle.patch.megatron_peft import MegatronPeft @@ -52,56 +53,37 @@ def __init__( self.mixed_precision = mixed_precision self._model_path = HubOperation.download_model(model_id) - self.hf_config = config or AutoConfig.from_pretrained(self._model_path) self.tokenizer_id = kwargs.get('tokenizer_id', self.model_id) - - self._seed = kwargs.pop('seed', None) or int(os.environ.get('TWINKLE_SEED', 42)) self._default_tokenizer = None self.use_distributed_optimizer = kwargs.get('use_distributed_optimizer', True) self.variable_seq_lengths = kwargs.get('variable_seq_lengths', False) self.optimizer_group = {} torch_util.set_device() + self._try_init_process_group() - self.strategy = MegatronStrategy( - self.device_mesh, - sequence_parallel=self.device_mesh.sequence_parallel, - mixed_precision=mixed_precision, - **kwargs) - - # Determine params_dtype and activation checkpointing kwargs - params_dtype = torch.bfloat16 - if self.mixed_precision == 'fp16': - params_dtype = torch.float16 - elif self.mixed_precision == 'no': - params_dtype = torch.float32 - - ac_kwargs = { + kwargs.update({ 'recompute_granularity': recompute_granularity, 'recompute_modules': recompute_modules, 'recompute_method': recompute_method, 'recompute_num_layers': recompute_num_layers, - } - - # Initialize TwinkleMegatronArgs BEFORE creating the model - args = TwinkleMegatronArgs.from_hf_config( - self.hf_config, - model_dir=self._model_path, - device_mesh=self.device_mesh, - params_dtype=params_dtype, - sequence_parallel=self.strategy.sequence_parallel, - **ac_kwargs, - ) - - set_args(args) - self._initialized = False - self.model: List[nn.Module] = self._create_megatron_model(load_weights, **kwargs) - + 'variable_seq_lengths': self.variable_seq_lengths, + }) + seed = kwargs.pop('seed', None) or int(os.environ.get('TWINKLE_SEED', 42)) + self.strategy = MegatronStrategy(self._model_path, + self.device_mesh, + mixed_precision=mixed_precision, + config=config, + ddp_config=ddp_config or {}, + seed=seed, **kwargs) + self.model: List[nn.Module] = self.strategy.create_megatron_model(load_weights) MegatronPeft().__call__() self.multi_adapter = MultiLora(max_loras=max_loras, max_r=max_r, max_length=max_length) self.model = self.multi_adapter.patch(self.model) self.model = self.strategy.wrap_model(self.model) - self._model_wrapped = True + self.strategy.finish_param_config(self.model, None) self.multi_adapter.save_initial_weights() + self._model_wrapped = True + self._finish_config = True # Active group for compatibility with single adapter self.active_group = None @@ -112,6 +94,9 @@ def _check_adapter_valid(self, adapter_name: str): def _lazy_wrap_model(self): pass + def _lazy_finish_param_config(self): + pass + @remote_function(dispatch='slice_dp', collect=collect_tensor_dict, sync=True) def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajectory]], **kwargs): """Forward pass without gradient computation. diff --git a/src/twinkle/model/megatron/strategy/megatron.py b/src/twinkle/model/megatron/strategy/megatron.py index 595d47e0..ea16d073 100644 --- a/src/twinkle/model/megatron/strategy/megatron.py +++ b/src/twinkle/model/megatron/strategy/megatron.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn +from transformers import PreTrainedConfig from twinkle import DeviceMesh, Platform, torch_util @@ -17,6 +18,8 @@ def __init__( mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16', seed: int = 42, variable_seq_lengths: bool = False, + config: PreTrainedConfig = None, + ddp_config: Dict[str, Any] = None, **kwargs, ): from megatron.core import mpu @@ -26,6 +29,17 @@ def __init__( self.model_dir = model_dir self.seed = seed self.variable_seq_lengths = variable_seq_lengths + self.ddp_config = ddp_config or {} + + if 'overlap_grad_reduce' not in self.ddp_config: + self.ddp_config['overlap_grad_reduce'] = False + if 'overlap_param_gather' not in self.ddp_config: + self.ddp_config['overlap_param_gather'] = False + if 'align_param_gather' not in self.ddp_config: + self.ddp_config['align_param_gather'] = False + if 'grad_reduce_in_fp32' not in self.ddp_config: + self.ddp_config['grad_reduce_in_fp32'] = True + # Determine params_dtype and activation checkpointing kwargs params_dtype = torch.bfloat16 if self.mixed_precision == 'fp16': @@ -58,7 +72,10 @@ def __init__( ) from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed model_parallel_cuda_manual_seed(self.seed) - self.config = self.get_model_config(model_dir, parallel_kwargs, **kwargs) + if config is None: + self.config = self.get_model_config(model_dir, parallel_kwargs, **kwargs) + else: + self.config = config @property def sequence_parallel(self) -> bool: @@ -126,7 +143,7 @@ def wrap_model( return model self._check_device_mesh() - return self._wrap_with_megatron_ddp(model, use_distributed_optimizer) + return self._wrap_with_megatron_ddp(model, use_distributed_optimizer, self.ddp_config) def unwrap_model(self, model: List[nn.Module]) -> List[nn.Module]: from megatron.core.distributed import DistributedDataParallel as MegatronDDP @@ -140,10 +157,30 @@ def unwrap_model(self, model: List[nn.Module]) -> List[nn.Module]: _models.append(_model) return _models + def finish_param_config(self, model: List[nn.Module], optimizer: Any): + self.config.grad_scale_func = getattr(optimizer, 'scale_loss') if optimizer is not None else None + ddp_config = self.ddp_config + if ddp_config['overlap_grad_reduce']: + assert self.config.no_sync_func is None, ( + 'When overlap_grad_reduce is True, config.no_sync_func must be None; ' + 'a custom no_sync_func is not supported when overlapping grad-reduce' + ) + self.config.no_sync_func = [model_chunk.no_sync for model_chunk in model] # noqa + if len(model) == 1: + self.config.no_sync_func = self.config.no_sync_func[0] # noqa + self.config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in model] # noqa + if len(model) == 1: + self.config.grad_sync_func = self.config.grad_sync_func[0] # noqa + if ddp_config['overlap_param_gather'] and ddp_config['align_param_gather']: + self.config.param_sync_func = [model_chunk.start_param_sync for model_chunk in model] # noqa + if len(model) == 1: + self.config.param_sync_func = self.config.param_sync_func[0] # noqa + @staticmethod def _wrap_with_megatron_ddp( model: List[nn.Module], use_distributed_optimizer: bool, + ddp_config: Dict[str, Any], ) -> List[nn.Module]: from megatron.core.distributed import DistributedDataParallel as MegatronDDP from megatron.core.distributed import DistributedDataParallelConfig @@ -157,15 +194,13 @@ def _wrap_with_megatron_ddp( if not isinstance(model, Float16Module) and (config.fp16 or config.bf16): _model = Float16Module(config, _model) - ddp_config = DistributedDataParallelConfig( - grad_reduce_in_fp32=True, - overlap_grad_reduce=False, + ddp_config_cls = DistributedDataParallelConfig( + **ddp_config, use_distributed_optimizer=use_distributed_optimizer, ) - wrapped_model = MegatronDDP( config=config, - ddp_config=ddp_config, + ddp_config=ddp_config_cls, module=_model, ) @@ -219,7 +254,7 @@ def _get_base_model(m): # Use native implementation for DDP models return _native_finalize_model_grads(model, *args, **kwargs) - config = ModelConfig( + return ModelConfig( use_cpu_initialization=True, params_dtype=self.params_type, sequence_parallel=self.sequence_parallel, @@ -228,7 +263,6 @@ def _get_base_model(m): **parallel_kwargs, **config_kwargs, ) - return config def create_megatron_model( self, diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index 77a25b18..f7573f41 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -36,6 +36,7 @@ def __init__( **kwargs): assert device_mesh.fsdp_world_size <= 0, f'MultiLora does not support FSDP, current is: {str(device_mesh)}' os.environ['TOKENIZERS_PARALLELISM'] = 'true' + self._try_init_process_group() super(PreTrainedModel, self).__init__() model_id = HubOperation.download_model(model_id) if isinstance(model_cls, str): From dec91c9ec0543b0e4a8e48bf491987ad898d974c Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 31 Mar 2026 18:23:51 +0800 Subject: [PATCH 13/18] lint code --- src/twinkle/checkpoint_engine/base.py | 2 +- src/twinkle/model/base.py | 47 +------ src/twinkle/model/megatron/megatron.py | 129 +++++++++--------- .../model/megatron/multi_lora_megatron.py | 49 ++++--- .../model/megatron/strategy/megatron.py | 37 +++-- src/twinkle/model/multi_lora.py | 2 +- src/twinkle/processor/base.py | 45 ++++++ src/twinkle/utils/__init__.py | 4 +- src/twinkle/utils/torch_utils.py | 4 +- 9 files changed, 163 insertions(+), 156 deletions(-) diff --git a/src/twinkle/checkpoint_engine/base.py b/src/twinkle/checkpoint_engine/base.py index 0760f2ae..f3a1d891 100644 --- a/src/twinkle/checkpoint_engine/base.py +++ b/src/twinkle/checkpoint_engine/base.py @@ -1,7 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. # Adapted from https://github.com/volcengine/verl/blob/main/verl/checkpoint_engine/base.py from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, AsyncGenerator, Generator, TypedDict, Optional +from typing import TYPE_CHECKING, Any, AsyncGenerator, Generator, Optional, TypedDict if TYPE_CHECKING: import torch diff --git a/src/twinkle/model/base.py b/src/twinkle/model/base.py index 19a05958..596f3c32 100644 --- a/src/twinkle/model/base.py +++ b/src/twinkle/model/base.py @@ -1,7 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union, List +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union from twinkle import Platform, torch_util from twinkle.data_format import InputFeature, ModelOutput @@ -157,48 +157,3 @@ def _try_init_process_group(self): if backend in ('nccl', 'hccl'): init_kwargs['device_id'] = torch.device(Platform.get_local_device()) dist.init_process_group(**init_kwargs) - - @staticmethod - def get_target_modules(model: 'torch.nn.Module', target_modules: List[str]) -> List[str]: - import torch - - def find_layers(model: torch.nn.Module, cond_fn) -> List[str]: - result = [] - for name, module in model.named_modules(): - if cond_fn(name, module): - result.append(name) - return result - - def find_all_linears(model: torch.nn.Module) -> List[str]: - from megatron.core.extensions.transformer_engine import TEGroupedLinear, TELayerNormColumnParallelLinear, \ - TELinear - - def _cond(name: str, module: torch.nn.Module) -> bool: - if name == 'output_layer' or 'lora' in name: - return False - if isinstance(module, (TELinear, TELayerNormColumnParallelLinear, TEGroupedLinear, torch.nn.Linear)): - return True - return False - - return find_layers(model, _cond) - - def find_router(model: torch.nn.Module) -> List[str]: - from megatron.core.transformer.moe.router import TopKRouter - return find_layers(model, lambda name, module: isinstance(module, TopKRouter) and 'lora' not in name) - - def find_embedding(model: torch.nn.Module) -> List[str]: - from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding - return find_layers(model, - lambda name, module: isinstance(module, LanguageModelEmbedding) and 'lora' not in name) - - result = target_modules.copy() - if 'all-linear' in result: - result.remove('all-linear') - result += find_all_linears(model) - if 'all-embedding' in result: - result.remove('all-embedding') - result += find_embedding(model) - if 'all-router' in result: - result.remove('all-router') - result += find_router(model) - return list(set(result)) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 639a818e..c96005ee 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -2,23 +2,22 @@ import asyncio import json import logging +import numpy as np import os import random import re import threading -from dataclasses import dataclass -from typing import Any, Callable, Dict, Generator, List, Literal, Optional, Tuple, Type, Union - -import numpy as np import torch import torch.distributed as dist import torch.nn as nn from contextlib import contextmanager +from dataclasses import dataclass from peft import LoraConfig, PeftConfig, PeftModel, get_peft_model from peft.tuners.lora import Linear as LoraLinear from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from transformers import PretrainedConfig +from typing import Any, Callable, Dict, Generator, List, Literal, Optional, Tuple, Type, Union import twinkle import twinkle.metric @@ -119,12 +118,19 @@ def __init__( 'variable_seq_lengths': self.variable_seq_lengths, }) seed = kwargs.pop('seed', None) or int(os.environ.get('TWINKLE_SEED', 42)) - self.strategy = MegatronStrategy(self._model_path, - self.device_mesh, - mixed_precision=mixed_precision, - config=config, - ddp_config=ddp_config or {}, - seed=seed, **kwargs) + if config is None: + from transformers import AutoConfig + self.hf_config = AutoConfig.from_pretrained(self._model_path, trust_remote_code=True) + else: + self.hf_config = config + self.strategy = MegatronStrategy( + self._model_path, + self.device_mesh, + mixed_precision=mixed_precision, + config=self.hf_config, + ddp_config=ddp_config or {}, + seed=seed, + **kwargs) self.model: List[nn.Module] = self.strategy.create_megatron_model(load_weights) self._model_wrapped = False @@ -200,51 +206,6 @@ def _slice_value_for_microbatch(value, mb_start: int, mb_end: int, micro_batch_s # Scalars, small tensors, or non-sliceable values pass through as-is return value - def _postprocess_tensor_cp(self, tensor): - """All-gather and reconstruct full sequence from CP-split tensor. - - Uses load-balanced split pattern: each CP rank holds chunks [rank] and - [2*cp_size - rank - 1] from the original 2*cp_size chunks. - - Only the current rank's slice retains the original tensor (and its - gradient graph); other ranks' slices are plain copies. This means - backward through the reconstructed tensor only produces gradients for - the local chunk, naturally distributing the gradient across CP ranks - without extra scaling. - - Args: - tensor: [batch_size, seq_len/cp_size] CP-split tensor - - Returns: - [batch_size, full_seq_len] reconstructed full tensor - """ - from megatron.core import parallel_state as mpu - cp_size = mpu.get_context_parallel_world_size() - if cp_size <= 1: - return tensor - - cp_rank = mpu.get_context_parallel_rank() - cp_group = mpu.get_context_parallel_group() - - gathered = [torch.empty_like(tensor) for _ in range(cp_size)] - torch.distributed.all_gather(gathered, tensor.contiguous(), group=cp_group) - gathered[cp_rank] = tensor - - batch_size = tensor.shape[0] - seq_len_per_cp = tensor.shape[1] - full_seq_len = seq_len_per_cp * cp_size - chunk_len = full_seq_len // (2 * cp_size) - half_len = seq_len_per_cp // 2 - - output = tensor.new_zeros(batch_size, full_seq_len) - for j in range(cp_size): - o = gathered[j] - output[:, j * chunk_len:(j + 1) * chunk_len] = o[:, :half_len] - reverse_idx = 2 * cp_size - j - 1 - output[:, reverse_idx * chunk_len:(reverse_idx + 1) * chunk_len] = o[:, half_len:] - - return output - @remote_function() def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): raise NotImplementedError('Megatron only supports `forward_backward` and `forward_only`') @@ -409,9 +370,8 @@ def forward_step_func(data_iterator, model): masked_labels[~loss_mask] = 0 output_tensor.div_(temperature) logps = selective_log_softmax(output_tensor, masked_labels) - if cp_size > 1: - logps = self._postprocess_tensor_cp(logps) - batch['labels'] = self._postprocess_tensor_cp(labels) + logps = processor.postprocess_tensor_cp(logps) + batch['labels'] = processor.postprocess_tensor_cp(labels) return output_tensor, partial(post_loss_function, inputs=batch, logps=logps) # Get Megatron's forward-backward function @@ -1279,7 +1239,6 @@ def get_hf_state_dict(self, adapter_name: str = '') -> Generator[Tuple[str, torc target_device=None, # Keep on current device for IPC transfer only_master_rank=False, # All ranks participate in weight sync peft_format=bool(adapter_name), - # adapter_name=adapter_name if adapter_name else None, tqdm_desc='Weight sync: ', ) @@ -1310,14 +1269,14 @@ def _patch_adapter(self, adapter_name: str, config_or_dir: Union[PeftConfig, str expanded_modules = self.get_target_modules(_model, target_modules) config.target_modules = expanded_modules - _model = get_peft_model(_model, config, adapter_name=adapter_name) # noqa + _model = get_peft_model(_model, config, adapter_name=adapter_name) # noqa _models.append(_model) self.model = _models # Create optimizer group for adapter self.optimizer_group[adapter_name] = self._construct_default_optimizer_group() self.optimizer_group[adapter_name].adapter_name = adapter_name - self.optimizer_group[adapter_name].adapter_config = config # noqa + self.optimizer_group[adapter_name].adapter_config = config # noqa self.optimizer_group[adapter_name].gradient_accumulation_steps = kwargs.get('gradient_accumulation_steps', 1) # Fix: use .processor instead of .tokenizer - Template class uses self.processor self._default_tokenizer = self.optimizer_group[adapter_name].template.processor @@ -1453,6 +1412,7 @@ def _add_base_layer_suffix(params): if base_sync_done and adapter_name: # The first base model synchronization finished, and is lora training if merge_and_sync: + def weight_generator(): with merge_lora(): for name, tensor in self.get_hf_state_dict(adapter_name=''): @@ -1464,6 +1424,7 @@ def weight_generator(): yield name, tensor else: + def weight_generator(): for name, tensor in self.get_hf_state_dict(adapter_name=adapter_name): if name is None or tensor is None: @@ -1489,7 +1450,6 @@ def weight_generator(): else: yield from _raw_weights() - is_sender = (engine.rank is not None and engine.rank == 0) if not is_sender: @@ -1537,3 +1497,48 @@ def get_peft_config_dict(self, adapter_name: str = None) -> Optional[Dict[str, A if isinstance(config, dict): config = config.get(adapter_name, next(iter(config.values()))) return config.to_dict() if hasattr(config, 'to_dict') else dict(config) + + @staticmethod + def get_target_modules(model: 'torch.nn.Module', target_modules: List[str]) -> List[str]: + import torch + + def find_layers(model: torch.nn.Module, cond_fn) -> List[str]: + result = [] + for name, module in model.named_modules(): + if cond_fn(name, module): + result.append(name) + return result + + def find_all_linears(model: torch.nn.Module) -> List[str]: + from megatron.core.extensions.transformer_engine import (TEGroupedLinear, TELayerNormColumnParallelLinear, + TELinear) + + def _cond(name: str, module: torch.nn.Module) -> bool: + if name == 'output_layer' or 'lora' in name: + return False + if isinstance(module, (TELinear, TELayerNormColumnParallelLinear, TEGroupedLinear, torch.nn.Linear)): + return True + return False + + return find_layers(model, _cond) + + def find_router(model: torch.nn.Module) -> List[str]: + from megatron.core.transformer.moe.router import TopKRouter + return find_layers(model, lambda name, module: isinstance(module, TopKRouter) and 'lora' not in name) + + def find_embedding(model: torch.nn.Module) -> List[str]: + from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding + return find_layers(model, + lambda name, module: isinstance(module, LanguageModelEmbedding) and 'lora' not in name) + + result = target_modules.copy() + if 'all-linear' in result: + result.remove('all-linear') + result += find_all_linears(model) + if 'all-embedding' in result: + result.remove('all-embedding') + result += find_embedding(model) + if 'all-router' in result: + result.remove('all-router') + result += find_router(model) + return list(set(result)) diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index 0f454412..eab301a3 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -1,6 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os -import torch import torch.distributed as dist import torch.nn as nn from peft import LoraConfig @@ -25,21 +24,21 @@ class MultiLoraMegatronModel(MegatronModel): def __init__( - self, - model_id: str, - config: Optional[PretrainedConfig] = None, - ddp_config: Optional[Dict[str, Any]] = None, - device_mesh: Optional[DeviceMesh] = None, - mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16', - load_weights: bool = True, - recompute_granularity: Optional[str] = 'full', # Activation checkpointing - recompute_method: Optional[str] = 'uniform', - recompute_num_layers: Optional[int] = 1, - recompute_modules: Optional[list] = None, # Modules to recompute - max_loras: int = 5, - max_r: int = 32, - max_length: int = 8192, - **kwargs, + self, + model_id: str, + config: Optional[PretrainedConfig] = None, + ddp_config: Optional[Dict[str, Any]] = None, + device_mesh: Optional[DeviceMesh] = None, + mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16', + load_weights: bool = True, + recompute_granularity: Optional[str] = 'full', # Activation checkpointing + recompute_method: Optional[str] = 'uniform', + recompute_num_layers: Optional[int] = 1, + recompute_modules: Optional[list] = None, # Modules to recompute + max_loras: int = 5, + max_r: int = 32, + max_length: int = 8192, + **kwargs, ): requires('megatron_core') requires('mcore_bridge') @@ -69,12 +68,18 @@ def __init__( 'variable_seq_lengths': self.variable_seq_lengths, }) seed = kwargs.pop('seed', None) or int(os.environ.get('TWINKLE_SEED', 42)) - self.strategy = MegatronStrategy(self._model_path, - self.device_mesh, - mixed_precision=mixed_precision, - config=config, - ddp_config=ddp_config or {}, - seed=seed, **kwargs) + if config is None: + self.hf_config = AutoConfig.from_pretrained(self._model_path, trust_remote_code=True) + else: + self.hf_config = config + self.strategy = MegatronStrategy( + self._model_path, + self.device_mesh, + mixed_precision=mixed_precision, + config=self.hf_config, + ddp_config=ddp_config or {}, + seed=seed, + **kwargs) self.model: List[nn.Module] = self.strategy.create_megatron_model(load_weights) MegatronPeft().__call__() self.multi_adapter = MultiLora(max_loras=max_loras, max_r=max_r, max_length=max_length) diff --git a/src/twinkle/model/megatron/strategy/megatron.py b/src/twinkle/model/megatron/strategy/megatron.py index ea16d073..9341ac0b 100644 --- a/src/twinkle/model/megatron/strategy/megatron.py +++ b/src/twinkle/model/megatron/strategy/megatron.py @@ -1,9 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from typing import List, Literal, Optional, Dict, Any - import torch import torch.nn as nn from transformers import PreTrainedConfig +from typing import Any, Dict, List, Literal, Optional from twinkle import DeviceMesh, Platform, torch_util @@ -30,7 +29,11 @@ def __init__( self.seed = seed self.variable_seq_lengths = variable_seq_lengths self.ddp_config = ddp_config or {} - + if config is None: + from transformers import AutoConfig + self.hf_config = AutoConfig.from_pretrained(self.model_dir, trust_remote_code=True) + else: + self.hf_config = config if 'overlap_grad_reduce' not in self.ddp_config: self.ddp_config['overlap_grad_reduce'] = False if 'overlap_param_gather' not in self.ddp_config: @@ -72,10 +75,7 @@ def __init__( ) from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed model_parallel_cuda_manual_seed(self.seed) - if config is None: - self.config = self.get_model_config(model_dir, parallel_kwargs, **kwargs) - else: - self.config = config + self.config = self.get_model_config(self.hf_config, parallel_kwargs, **kwargs) @property def sequence_parallel(self) -> bool: @@ -163,18 +163,17 @@ def finish_param_config(self, model: List[nn.Module], optimizer: Any): if ddp_config['overlap_grad_reduce']: assert self.config.no_sync_func is None, ( 'When overlap_grad_reduce is True, config.no_sync_func must be None; ' - 'a custom no_sync_func is not supported when overlapping grad-reduce' - ) - self.config.no_sync_func = [model_chunk.no_sync for model_chunk in model] # noqa + 'a custom no_sync_func is not supported when overlapping grad-reduce') + self.config.no_sync_func = [model_chunk.no_sync for model_chunk in model] # noqa if len(model) == 1: - self.config.no_sync_func = self.config.no_sync_func[0] # noqa - self.config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in model] # noqa + self.config.no_sync_func = self.config.no_sync_func[0] # noqa + self.config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in model] # noqa if len(model) == 1: - self.config.grad_sync_func = self.config.grad_sync_func[0] # noqa + self.config.grad_sync_func = self.config.grad_sync_func[0] # noqa if ddp_config['overlap_param_gather'] and ddp_config['align_param_gather']: - self.config.param_sync_func = [model_chunk.start_param_sync for model_chunk in model] # noqa + self.config.param_sync_func = [model_chunk.start_param_sync for model_chunk in model] # noqa if len(model) == 1: - self.config.param_sync_func = self.config.param_sync_func[0] # noqa + self.config.param_sync_func = self.config.param_sync_func[0] # noqa @staticmethod def _wrap_with_megatron_ddp( @@ -222,14 +221,12 @@ def reduce_loss(self, local_loss, local_count, logits, logps): def get_model_config( self, - model_dir: str, + hf_config: PreTrainedConfig, parallel_kwargs: Dict[str, Any], **kwargs, ): from mcore_bridge import ModelConfig, hf_to_mcore_config from megatron.core.distributed import finalize_model_grads as _native_finalize_model_grads - from transformers import AutoConfig - hf_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) config_kwargs = hf_to_mcore_config(hf_config) config_kwargs.update(kwargs) if 'calculate_per_token_loss' not in config_kwargs: @@ -268,8 +265,8 @@ def create_megatron_model( self, load_weights: bool = True, ) -> List[nn.Module]: - from mcore_bridge import get_mcore_model import torch.distributed as dist + from mcore_bridge import get_mcore_model mg_models = get_mcore_model(self.config) if dist.is_initialized(): @@ -290,4 +287,4 @@ def create_megatron_model( def _move_model_to_gpu(model: nn.Module) -> nn.Module: model = model.to(Platform.get_local_device()) torch_util.synchronize() - return model \ No newline at end of file + return model diff --git a/src/twinkle/model/multi_lora.py b/src/twinkle/model/multi_lora.py index 37b647e6..3459c70f 100644 --- a/src/twinkle/model/multi_lora.py +++ b/src/twinkle/model/multi_lora.py @@ -6,7 +6,7 @@ from peft import LoraConfig, PeftModel, get_peft_model from peft.tuners.lora import Embedding, Linear, LoraLayer from types import MethodType -from typing import Any, Dict, List, Optional, Union, Callable +from typing import Any, Callable, Dict, List, Optional, Union from twinkle import torch_util from twinkle.data_format import InputFeature diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index 576db8cd..0d9f582e 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -408,3 +408,48 @@ def collate_fn(self, output[key] = res[key][i:i + micro_batch_size] outputs.append(output) return outputs + + def postprocess_tensor_cp(self, tensor): + """All-gather and reconstruct full sequence from CP-split tensor. + + Uses load-balanced split pattern: each CP rank holds chunks [rank] and + [2*cp_size - rank - 1] from the original 2*cp_size chunks. + + Only the current rank's slice retains the original tensor (and its + gradient graph); other ranks' slices are plain copies. This means + backward through the reconstructed tensor only produces gradients for + the local chunk, naturally distributing the gradient across CP ranks + without extra scaling. + + Args: + tensor: [batch_size, seq_len/cp_size] CP-split tensor + + Returns: + [batch_size, full_seq_len] reconstructed full tensor + """ + if self.device_mesh.cp_world_size <= 1: + return tensor + + from megatron.core import parallel_state as mpu + cp_size = mpu.get_context_parallel_world_size() + cp_rank = mpu.get_context_parallel_rank() + cp_group = mpu.get_context_parallel_group() + + gathered = [torch.empty_like(tensor) for _ in range(cp_size)] + torch.distributed.all_gather(gathered, tensor.contiguous(), group=cp_group) + gathered[cp_rank] = tensor + + batch_size = tensor.shape[0] + seq_len_per_cp = tensor.shape[1] + full_seq_len = seq_len_per_cp * cp_size + chunk_len = full_seq_len // (2 * cp_size) + half_len = seq_len_per_cp // 2 + + output = tensor.new_zeros(batch_size, full_seq_len) + for j in range(cp_size): + o = gathered[j] + output[:, j * chunk_len:(j + 1) * chunk_len] = o[:, :half_len] + reverse_idx = 2 * cp_size - j - 1 + output[:, reverse_idx * chunk_len:(reverse_idx + 1) * chunk_len] = o[:, half_len:] + + return output diff --git a/src/twinkle/utils/__init__.py b/src/twinkle/utils/__init__.py index 982707d9..cca7e63b 100644 --- a/src/twinkle/utils/__init__.py +++ b/src/twinkle/utils/__init__.py @@ -10,8 +10,8 @@ from .parallel import processing_lock from .platforms import GPU, NPU, Platform, ensure_hccl_socket_env, ensure_npu_backend from .safetensors import LazyTensor, SafetensorLazyLoader, StreamingSafetensorSaver -from .torch_utils import (pad_and_stack_tensors, pad_sequence_to_length, selective_log_softmax, - stateless_init_process_group, to_device, split_cp_inputs) +from .torch_utils import (pad_and_stack_tensors, pad_sequence_to_length, selective_log_softmax, split_cp_inputs, + stateless_init_process_group, to_device) from .transformers_utils import find_all_linears, find_layers, get_modules_to_not_convert from .unsafe import check_unsafe, trust_remote_code from .utils import copy_files_by_pattern, deep_getattr diff --git a/src/twinkle/utils/torch_utils.py b/src/twinkle/utils/torch_utils.py index 07d1730f..27020a23 100644 --- a/src/twinkle/utils/torch_utils.py +++ b/src/twinkle/utils/torch_utils.py @@ -1,7 +1,7 @@ import socket from datetime import timedelta -from typing import Optional -from typing import TYPE_CHECKING, Any, List, Mapping, Union +from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Union + from .network import is_valid_ipv6_address if TYPE_CHECKING: From 0a1c34cdd9217395d5f2cb92585bf843a4972468 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 31 Mar 2026 21:56:47 +0800 Subject: [PATCH 14/18] fix --- src/twinkle/model/multi_lora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/twinkle/model/multi_lora.py b/src/twinkle/model/multi_lora.py index 3459c70f..ff300d3d 100644 --- a/src/twinkle/model/multi_lora.py +++ b/src/twinkle/model/multi_lora.py @@ -406,8 +406,8 @@ def _patch_megatron(_module): else: target_modules = list(_config.target_modules) - from .base import TwinkleModel - _config.target_modules = TwinkleModel.get_target_modules(_module, target_modules) + from .megatron import MegatronModel + _config.target_modules = MegatronModel.get_target_modules(_module, target_modules) _module = get_peft_model(_module, _config, lora_tenant.adapter_name) for name, submodule in _module.named_modules(): From ab0b161edf99e1a8dd18b1b7de2a284f119ae6fa Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 31 Mar 2026 23:35:58 +0800 Subject: [PATCH 15/18] fix --- src/twinkle/model/megatron/megatron.py | 5 +++-- src/twinkle/model/megatron/multi_lora_megatron.py | 6 ++++++ src/twinkle/model/megatron/strategy/megatron.py | 3 +-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index c96005ee..b6d83597 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -130,6 +130,7 @@ def __init__( config=self.hf_config, ddp_config=ddp_config or {}, seed=seed, + use_distributed_optimizer=self.use_distributed_optimizer, **kwargs) self.model: List[nn.Module] = self.strategy.create_megatron_model(load_weights) @@ -667,7 +668,7 @@ def _create_megatron_optimizer(self, **kwargs): # Build optimizer config lr = kwargs.pop('lr', 1e-4) - use_distributed_optimizer: bool = kwargs.pop('use_distributed_optimizer', False) + self.use_distributed_optimizer: bool = kwargs.pop('use_distributed_optimizer', self.use_distributed_optimizer) opt_config = OptimizerConfig( optimizer='adam', @@ -679,7 +680,7 @@ def _create_megatron_optimizer(self, **kwargs): adam_eps=kwargs.pop('adam_eps', 1e-8), clip_grad=kwargs.pop('clip_grad', 1.0), bf16=kwargs.pop('bf16', True), - use_distributed_optimizer=use_distributed_optimizer, + use_distributed_optimizer=self.use_distributed_optimizer, overlap_param_gather=kwargs.pop('overlap_param_gather', False), log_num_zeros_in_grad=kwargs.pop('log_num_zeros_in_grad', False), **kwargs, diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index eab301a3..a97a9676 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -79,6 +79,7 @@ def __init__( config=self.hf_config, ddp_config=ddp_config or {}, seed=seed, + use_distributed_optimizer=self.use_distributed_optimizer, **kwargs) self.model: List[nn.Module] = self.strategy.create_megatron_model(load_weights) MegatronPeft().__call__() @@ -254,6 +255,11 @@ def set_processor(self, processor_cls: Union[Type[InputProcessor], str, Callable self._check_adapter_valid(kwargs.get('adapter_name')) super().set_processor(processor_cls, **kwargs) + @remote_function(dispatch='all') + def set_optimizer(self, optimizer_cls: Union[Optimizer, Type[Optimizer], str], **kwargs): + kwargs.pop('use_distributed_optimizer', None) + super().set_optimizer(optimizer_cls, **kwargs) + @remote_function() def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool] = None, **kwargs): self._check_adapter_valid(kwargs.get('adapter_name')) diff --git a/src/twinkle/model/megatron/strategy/megatron.py b/src/twinkle/model/megatron/strategy/megatron.py index 9341ac0b..b9e66505 100644 --- a/src/twinkle/model/megatron/strategy/megatron.py +++ b/src/twinkle/model/megatron/strategy/megatron.py @@ -129,7 +129,6 @@ def _check_device_mesh(self): def wrap_model( self, model: List[nn.Module], - use_distributed_optimizer: bool = True, ) -> List[nn.Module]: if self.device_mesh.world_size <= 1: from megatron.core.distributed import DistributedDataParallelConfig @@ -143,7 +142,7 @@ def wrap_model( return model self._check_device_mesh() - return self._wrap_with_megatron_ddp(model, use_distributed_optimizer, self.ddp_config) + return self._wrap_with_megatron_ddp(model, self.use_distributed_optimizer, self.ddp_config) def unwrap_model(self, model: List[nn.Module]) -> List[nn.Module]: from megatron.core.distributed import DistributedDataParallel as MegatronDDP From d50465f59b0a35e8ed61e80ef8988fc52511bdb4 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 31 Mar 2026 23:47:17 +0800 Subject: [PATCH 16/18] fix --- src/twinkle/model/megatron/megatron.py | 2 +- src/twinkle/model/megatron/multi_lora_megatron.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index b6d83597..e1f2c392 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -630,7 +630,7 @@ def set_optimizer(self, optimizer_cls: Union[Optimizer, Type[Optimizer], str], * self._model_wrapped = True # Check if requesting Megatron distributed optimizer - if not optimizer_cls or optimizer_cls in ('MegatronDistributedOptimizer', 'default', 'Adam'): + if not optimizer_cls or optimizer_cls in ('MegatronOptimizer', 'default', 'Adam'): optimizer_config.optimizer = self._create_megatron_optimizer(**kwargs) # noqa else: raise NotImplementedError( diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index a97a9676..49126a62 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -257,7 +257,11 @@ def set_processor(self, processor_cls: Union[Type[InputProcessor], str, Callable @remote_function(dispatch='all') def set_optimizer(self, optimizer_cls: Union[Optimizer, Type[Optimizer], str], **kwargs): + # Multi lora cannot config use_distributed_optimizer/loss_scale/mix_precision kwargs.pop('use_distributed_optimizer', None) + kwargs.pop('loss_scale', None) + kwargs['fp16'] = False + kwargs['bf16'] = True super().set_optimizer(optimizer_cls, **kwargs) @remote_function() From 08d2dafe40beae02c662a71091fcc377ae803cde Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 1 Apr 2026 13:45:27 +0800 Subject: [PATCH 17/18] wip --- src/twinkle/checkpoint_engine/manager.py | 25 ++++++++++++++++--- src/twinkle/model/megatron/megatron.py | 19 ++++++++------ .../model/megatron/multi_lora_megatron.py | 2 +- .../model/transformers/transformers.py | 1 + .../sampler/vllm_sampler/vllm_engine.py | 11 ++++++++ .../sampler/vllm_sampler/vllm_sampler.py | 4 +++ .../vllm_sampler/vllm_worker_extension.py | 3 +++ 7 files changed, 53 insertions(+), 12 deletions(-) diff --git a/src/twinkle/checkpoint_engine/manager.py b/src/twinkle/checkpoint_engine/manager.py index 882a7b5b..bcc16933 100644 --- a/src/twinkle/checkpoint_engine/manager.py +++ b/src/twinkle/checkpoint_engine/manager.py @@ -1,7 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. # Adapted from https://github.com/volcengine/verl/blob/main/verl/checkpoint_engine/base.py import time -from typing import Optional +from typing import Optional, List from twinkle import Platform, get_logger from .base import CheckpointEngine @@ -64,6 +64,7 @@ def __init__( # Cached peft_config dict for LoRA-only sync. # Fetched lazily from the model on first LoRA sync. self._peft_config: dict | None = None + self._model_keys: Optional[List[str]] = None @staticmethod def decide_backend_engine(platform: Optional[str] = None) -> 'CheckpointEngine': @@ -117,8 +118,26 @@ def sync_weights(self, merge_and_sync=True): if self._peft_config is None: self._peft_config = self.model.get_peft_config_dict() peft_config = self._peft_config - - model_result = self.model.send_weights(base_sync_done=self.base_sync_done, merge_and_sync=merge_and_sync) + + if self._model_keys is None: + if hasattr(self.sampler, 'get_state_keys'): + self._model_keys = self.sampler.get_state_keys() + + if self._model_keys is None: + self._model_keys = [] + + # vLLM may have grouped params + _STACKED_MAPPINGS = { + 'qkv_proj': ('q_proj', 'k_proj', 'v_proj'), + 'gate_up_proj': ('gate_proj', 'up_proj'), + } + for key in self._model_keys: + for merged, individuals in _STACKED_MAPPINGS.items(): + if merged in key: + for ind in individuals: + self._model_keys.append(key.replace(merged, ind)) + + model_result = self.model.send_weights(base_sync_done=self.base_sync_done, merge_and_sync=merge_and_sync, model_keys=self._model_keys) sampler_result = self.sampler.receive_weights(base_sync_done=self.base_sync_done, peft_config=peft_config) model_result() sampler_result() diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index e1f2c392..0e980c7a 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -858,7 +858,7 @@ def load(self, name: str, output_dir: Optional[str] = None, **kwargs): bridge.load_weights( _model, checkpoint_dir, - is_peft_format=(adapter_name != _default_adapter_name), + peft_format=(adapter_name != _default_adapter_name), ) if dist.is_initialized(): @@ -1163,7 +1163,7 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non model = self.strategy.unwrap_model(self.model) self.strategy.bridge.save_weights( - model, output_dir, is_peft_format=is_peft_format, adapter_name=adapter_name, lora_converter=lora_converter) + model, output_dir, peft_format=is_peft_format, lora_converter=lora_converter) # Save config on rank 0 only if dp_rank == 0: @@ -1384,6 +1384,7 @@ def send_weights( base_sync_done: bool = False, merge_and_sync: bool = False, add_base_layer_path: bool = True, + model_keys: List[str] = None, ): if adapter_name is None: adapter_name = self._get_default_group() @@ -1400,13 +1401,15 @@ def merge_lora(): _model.unmerge_adapter() def _add_base_layer_suffix(params): - _BASE_LAYER_SUFFIXES = ['weight', 'bias'] for name, param in params: - for suffix in _BASE_LAYER_SUFFIXES: - if name.endswith(suffix): - attr = suffix.rsplit('.', 1)[-1] # 'weight' or 'bias' - name = f'{name[:-len(attr)]}base_layer.{attr}' - break + if name.endswith('.weight'): + base_layer_name = f'{name[:-7]}.base_layer.weight' + if base_layer_name in model_keys or not model_keys: + name = base_layer_name + elif name.endswith('.bias'): + base_layer_name = f'{name[:-5]}.base_layer.bias' + if base_layer_name in model_keys or not model_keys: + name = base_layer_name yield name, param is_peft_format = (adapter_name != _default_adapter_name) diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index 49126a62..35c98bfe 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -207,7 +207,7 @@ def load(self, name: str, output_dir: Optional[str] = None, **kwargs): bridge.load_weights( _model, checkpoint_dir, - True, + peft_format=True, adapter_name=adapter_name, lora_converter=self.multi_adapter.load_lora_converter) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 6097ffe2..88e08a2e 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -1151,6 +1151,7 @@ def send_weights( adapter_name: str = None, base_sync_done: bool = False, merge_and_sync: bool = False, + **kwargs, ): if adapter_name is None: adapter_name = self._get_default_group() diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index 911b11c3..b5e20430 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -449,6 +449,13 @@ async def wake_up(self, tags: Optional[List[str]] = None) -> None: async def reset_prefix_cache(self) -> None: await self.engine.reset_prefix_cache() + + async def get_state_keys(self) -> List[str]: + results = await self.engine.collective_rpc('get_state_keys') + all_keys = set() + for r in results: + all_keys.update(r) + return list(all_keys) async def update_weights( self, @@ -488,6 +495,7 @@ async def update_weights( async def _dict_iter(): for item in weights.items(): + breakpoint() yield item weight_aiter = _dict_iter() @@ -497,6 +505,7 @@ async def _dict_iter(): # sync generator / iterable async def _sync_iter(): for item in weights: + breakpoint() yield item weight_aiter = _sync_iter() @@ -575,6 +584,8 @@ async def _chain_first(): """Re-inject the peeked first tensor, then yield the rest.""" yield first_name, first_tensor async for item in weight_aiter: + if 'qkv_proj' in item[0]: + breakpoint() yield item offset = 0 diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index 62f15630..960a521b 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -373,6 +373,10 @@ def wake_up(self, tags: List[str] = None) -> None: @remote_function(dispatch='all', collect='first') def reset_prefix_cache(self): self._run_in_loop(self.engine.reset_prefix_cache()) + + @remote_function(dispatch='all', collect='first') + def get_state_keys(self): + return self._run_in_loop(self.engine.get_state_keys()) @remote_function(dispatch='all', lazy_collect=True) def receive_weights( diff --git a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py index 61920cd9..c18b2293 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py @@ -427,6 +427,9 @@ def _load_weights( self.model_runner.model.load_weights(converted) logger.info(f'Loaded {len(converted)} base weights') + + def get_state_keys(self): + return list(self.model_runner.model.state_dict().keys()) def _get_zmq_handle(self) -> str: """Get ZMQ handle for IPC communication.""" From fa6b46359e1e7774bdf4af77634b8588924b777a Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 1 Apr 2026 14:47:13 +0800 Subject: [PATCH 18/18] wip --- src/twinkle/checkpoint_engine/manager.py | 2 +- src/twinkle/sampler/vllm_sampler/vllm_engine.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/twinkle/checkpoint_engine/manager.py b/src/twinkle/checkpoint_engine/manager.py index bcc16933..93a74b69 100644 --- a/src/twinkle/checkpoint_engine/manager.py +++ b/src/twinkle/checkpoint_engine/manager.py @@ -128,7 +128,7 @@ def sync_weights(self, merge_and_sync=True): # vLLM may have grouped params _STACKED_MAPPINGS = { - 'qkv_proj': ('q_proj', 'k_proj', 'v_proj'), + 'qkv_proj': ('q_proj', 'k_proj', 'v_proj', 'q', 'k', 'v'), 'gate_up_proj': ('gate_proj', 'up_proj'), } for key in self._model_keys: diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index b5e20430..eb7fa0a1 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -495,7 +495,6 @@ async def update_weights( async def _dict_iter(): for item in weights.items(): - breakpoint() yield item weight_aiter = _dict_iter() @@ -505,7 +504,6 @@ async def _dict_iter(): # sync generator / iterable async def _sync_iter(): for item in weights: - breakpoint() yield item weight_aiter = _sync_iter() @@ -584,8 +582,6 @@ async def _chain_first(): """Re-inject the peeked first tensor, then yield the rest.""" yield first_name, first_tensor async for item in weight_aiter: - if 'qkv_proj' in item[0]: - breakpoint() yield item offset = 0