Skip to content

[Question] kimi k2.5 sft int4训练报错 #1783

@yuexiuyawh

Description

@yuexiuyawh

Your Question

我在尝试sft训练,使用了https://github.com/NVIDIA-NeMo/Megatron-Bridge/pull/2743这个pr中的Megatron-Bridge代码加载hf格式的kimi k2.5,训练在fused_attn部分报错,请问是什么问题?

训练参数如下:
python3 train_async.py --actor-num-nodes 32 --actor-num-gpus-per-node 8 --hf-checkpoint /mnt/kimi-k2.5 --load /mnt/kimi-k2.5 --save /mnt/260322-145650 --save-interval 1000 --prompt-data /mnt/Chinese-Qwen3-235B-2507-Distill-data-110k-SFT/qwen3_235b_2507_distill_110k.jsonl --input-key messages --rollout-batch-size 64 --n-samples-per-prompt 1 --global-batch-size 64 --tensor-model-parallel-size 8 --pipeline-model-parallel-size 8 --context-parallel-size 4 --expert-model-parallel-size 32 --expert-tensor-parallel-size 1 --recompute-granularity full --recompute-method uniform --recompute-num-layers 1 --max-tokens-per-gpu 16384 --optimizer adam --lr 1e-5 --lr-decay-style constant --min-lr 1e-6 --lr-warmup-fraction 0.1 --weight-decay 0.1 --adam-beta1 0.9 --adam-beta2 0.98 --attention-dropout 0.0 --hidden-dropout 0.0 --attention-backend flash --decoder-last-pipeline-num-layers 5 --rollout-function-path slime.rollout.sft_rollout.generate_rollout --num-epoch 1 --loss-type sft_loss --rollout-shuffle --sequence-parallel --use-dynamic-batch-size --optimizer-cpu-offload --overlap-cpu-optimizer-d2h-h2d --use-precision-aware-optimizer --accumulate-allreduce-grads-in-fp32 --attention-softmax-in-fp32 --calculate-per-token-loss --disable-compute-advantages-and-returns --debug-train-only --freeze-vision-model --freeze-vision-projection --megatron-to-hf-mode bridge --micro-batch-size 1 --disable-bias-linear --num-layers 61 --hidden-size 7168 --ffn-hidden-size 18432 --num-attention-heads 64 --kv-channels 64 --normalization RMSNorm --position-embedding-type rope --norm-epsilon 1e-5 --swiglu --untie-embeddings-and-output-weights --vocab-size 163840 --multi-latent-attention --q-lora-rank 1536 --kv-lora-rank 512 --qk-head-dim 128 --qk-pos-emb-head-dim 64 --v-head-dim 128 --qk-layernorm --rotary-scaling-factor 64.0 --rotary-base 50000 --mscale 1.0 --mscale-all-dim 1.0 --attention-softmax-in-fp32 --no-rope-fusion --num-experts 384 --moe-layer-freq [0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1] --moe-ffn-hidden-size 2048 --moe-router-topk 8 --moe-shared-expert-intermediate-size 2048 --moe-router-pre-softmax --moe-router-score-function sigmoid --moe-router-enable-expert-bias --moe-router-load-balancing-type seq_aux_loss --moe-token-dispatcher-type alltoall --moe-aux-loss-coeff 0 --moe-router-bias-update-rate 0 --moe-router-group-topk 1 --moe-router-num-groups 1 --moe-grouped-gemm --moe-router-topk-scaling-factor 2.827 --moe-router-dtype fp32 --moe-permute-fusion

报错日志如下:
Traceback (most recent call last):
File "/tmp/ray/session_2026-03-22_14-56-34_952654_8784/runtime_resources/working_dir_files/_ray_pkg_32318a9a3e40baf3/src/slime-0.2.3/train_async.py", line 80, in
train(args)
File "/tmp/ray/session_2026-03-22_14-56-34_952654_8784/runtime_resources/working_dir_files/_ray_pkg_32318a9a3e40baf3/src/slime-0.2.3/train_async.py", line 48, in train
ray.get(actor_model.async_train(rollout_id, rollout_data_curr_ref))
File "/usr/local/lib/python3.12/dist-packages/ray/_private/auto_init_hook.py", line 22, in auto_init_wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/ray/_private/client_mode_hook.py", line 104, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/ray/_private/worker.py", line 2981, in get
values, debugger_breakpoint = worker.get_objects(
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/ray/_private/worker.py", line 1012, in get_objects
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(AcceleratorError): �[36mray::MegatronTrainRayActor.train()�[39m (pid=15345, ip=10.217.25.171, actor_id=1da0d4093917dd238330b9a602000000, repr=<slime.backends.megatron_utils.actor.MegatronTrainRayActor object at 0x7f338a54f440>)
File "/root/Megatron-LM/megatron/core/transformer/transformer_block.py", line 735, in forward
hidden_states, context = layer(
^^^^^^
File "/root/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 1044, in call
return super().call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/Megatron-LM/megatron/core/transformer/module.py", line 319, in call
return super().call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 475, in forward
hidden_states, context = self._forward_attention(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 549, in _forward_attention
attention_output_with_bias = self.self_attention(
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/Megatron-LM/megatron/core/transformer/multi_latent_attention.py", line 301, in forward
core_attn_out = self.core_attention(
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/Megatron-LM/megatron/core/extensions/transformer_engine.py", line 1111, in forward
core_attn_out = super().forward(
^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/external_utils.py", line 196, in nonrecursive_disable_wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py", line 1545, in forward
attn_out = self.fused_attention(
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/attention/dot_product_attention/backends.py", line 1916, in forward
output = FusedAttnFunc.apply(
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 581, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/attention/dot_product_attention/backends.py", line 1267, in forward
out
, aux_ctx_tensors, *max_logit = fused_attn_fwd(
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/cpp_extensions/fused_attn.py", line 297, in fused_attn_fwd
output_tensors = tex.fused_attn_fwd(
^^^^^^^^^^^^^^^^^^^
RuntimeError: /TransformerEngine/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu:534 in function fused_attn_arbitrary_seqlen_fwd_impl: cuDNN Error: execute(handle, plan->get_raw_desc(), variant_pack_descriptor.get_ptr()) failed with message: err 700 != CUDA_SUCCESS :: shimCuLaunchKernelEx(&config, func, kernelParams, nullptr)
at: err != CUDA_SUCCESS, and code: CUDNN_STATUS_EXECUTION_FAILED_CUDA_DRIVER. For more information, enable cuDNN error logging by setting CUDNN_LOGERR_DBG=1 and CUDNN_LOGDEST_DBG=stderr in the environment.

What I've Tried

我尝试过使用flash attention、fused attention都会在attn计算报错

Environment (if relevant)

  • slime version: 0.2.3
  • Python version: 3.12
  • PyTorch version:
  • CUDA/ROCm version:
  • GPU type and count: H100
  • OS:

Additional Context

No response

Pre-submission Checklist

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions