From a0d4199670270e3ee5b74dc36f20116a6cf2ff76 Mon Sep 17 00:00:00 2001 From: jianbinc Date: Sun, 15 Mar 2026 04:10:55 -0700 Subject: [PATCH 1/2] Fix the retrieval of overwrite_main_grad: fetch it from the original weight instead of the fp8 (fp4) weight, since the fp8 (fp4) weight may not inherit the required attributes during creation. Signed-off-by: jianbinc --- transformer_engine/pytorch/module/grouped_linear.py | 2 +- transformer_engine/pytorch/module/layernorm_linear.py | 2 +- transformer_engine/pytorch/module/layernorm_mlp.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index fade2957d5..828bec5bd5 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -476,7 +476,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_split_accumulator=wgrad_gemm_use_split_accumulator, accumulate=( accumulate_wgrad_into_param_main_grad - if not getattr(weights[0], "overwrite_main_grad", False) + if not getattr(origin_weights[0], "overwrite_main_grad", False) else False ), ) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index d775dc3e8e..245dfa9928 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -864,7 +864,7 @@ def backward( "quantization_params": ctx.grad_weight_quantizer, "accumulate": ( accumulate_wgrad_into_param_main_grad - if not getattr(weight, "overwrite_main_grad", False) + if not getattr(origin_weight, "overwrite_main_grad", False) else False ), "layout": "NT", diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 037fb6c858..4862b6c621 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1224,7 +1224,7 @@ def backward( "quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision "accumulate": ( accumulate_wgrad_into_param_main_grad - if not getattr(fc1_weight, "overwrite_main_grad", False) + if not getattr(origin_fc1_weight, "overwrite_main_grad", False) else False ), "layout": "NT", @@ -1471,7 +1471,7 @@ def fc2_wgrad_gemm( "quantization_params": ctx.fc1_grad_weight_quantizer, "accumulate": ( accumulate_wgrad_into_param_main_grad - if not getattr(fc2_weight, "overwrite_main_grad", False) + if not getattr(origin_fc2_weight, "overwrite_main_grad", False) else False ), "layout": "NT", From d8285e11e09713f55f97e51f06f58f07edacd491 Mon Sep 17 00:00:00 2001 From: jianbinc Date: Mon, 16 Mar 2026 10:17:16 -0700 Subject: [PATCH 2/2] Correct weight name Signed-off-by: jianbinc --- transformer_engine/pytorch/module/layernorm_mlp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 4862b6c621..1d44d48512 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1224,7 +1224,7 @@ def backward( "quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision "accumulate": ( accumulate_wgrad_into_param_main_grad - if not getattr(origin_fc1_weight, "overwrite_main_grad", False) + if not getattr(origin_fc2_weight, "overwrite_main_grad", False) else False ), "layout": "NT", @@ -1471,7 +1471,7 @@ def fc2_wgrad_gemm( "quantization_params": ctx.fc1_grad_weight_quantizer, "accumulate": ( accumulate_wgrad_into_param_main_grad - if not getattr(origin_fc2_weight, "overwrite_main_grad", False) + if not getattr(origin_fc1_weight, "overwrite_main_grad", False) else False ), "layout": "NT",