diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index fade2957d5..828bec5bd5 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -476,7 +476,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_split_accumulator=wgrad_gemm_use_split_accumulator, accumulate=( accumulate_wgrad_into_param_main_grad - if not getattr(weights[0], "overwrite_main_grad", False) + if not getattr(origin_weights[0], "overwrite_main_grad", False) else False ), ) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index d775dc3e8e..245dfa9928 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -864,7 +864,7 @@ def backward( "quantization_params": ctx.grad_weight_quantizer, "accumulate": ( accumulate_wgrad_into_param_main_grad - if not getattr(weight, "overwrite_main_grad", False) + if not getattr(origin_weight, "overwrite_main_grad", False) else False ), "layout": "NT", diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 037fb6c858..1d44d48512 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1224,7 +1224,7 @@ def backward( "quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision "accumulate": ( accumulate_wgrad_into_param_main_grad - if not getattr(fc1_weight, "overwrite_main_grad", False) + if not getattr(origin_fc2_weight, "overwrite_main_grad", False) else False ), "layout": "NT", @@ -1471,7 +1471,7 @@ def fc2_wgrad_gemm( "quantization_params": ctx.fc1_grad_weight_quantizer, "accumulate": ( accumulate_wgrad_into_param_main_grad - if not getattr(fc2_weight, "overwrite_main_grad", False) + if not getattr(origin_fc1_weight, "overwrite_main_grad", False) else False ), "layout": "NT",