diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 60ade522e3..2eb307aa48 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -49,7 +49,7 @@ from transformer_engine.pytorch.quantized_tensor import ( Quantizer, prepare_for_saving, - restore_from_saved, + restore_from_func_ctx, ) _current_file = pathlib.Path(__file__).resolve() @@ -2701,10 +2701,7 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: with torch.cuda.nvtx.range("_DPA"): - saved_tensors = ctx.saved_tensors - (q, k, v, inp_fp8, qkv_weight_fp8, out) = restore_from_saved( - ctx.tensor_objects, saved_tensors - ) + (q, k, v, inp_fp8, qkv_weight_fp8, out) = restore_from_func_ctx(ctx) proj_dgrad = ctx.dO_quantizer(grad_output) fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index cd18ca75ad..bbc1d7fab6 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -68,6 +68,7 @@ from transformer_engine.pytorch.quantized_tensor import Quantizer from transformer_engine.pytorch.quantized_tensor import prepare_for_saving from transformer_engine.pytorch.quantized_tensor import restore_from_saved +from transformer_engine.pytorch.quantized_tensor import restore_from_func_ctx from transformer_engine.pytorch.tensor import Float8Quantizer from transformer_engine.pytorch.tensor import Float8CurrentScalingQuantizer from transformer_engine.pytorch.tensor import MXFP8Quantizer diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index a6a8b0b26a..442366035a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -32,7 +32,7 @@ from transformer_engine.pytorch.quantized_tensor import ( QuantizedTensorStorage, prepare_for_saving, - restore_from_saved, + restore_from_func_ctx, ) from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.constants import ( @@ -1477,7 +1477,7 @@ def backward(ctx, d_out, *_args): cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors, - ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) + ) = restore_from_func_ctx(ctx) aux_ctx_tensors = other_tensors diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 10ba99595b..7d9eb0cb05 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -38,7 +38,7 @@ from transformer_engine.pytorch.quantized_tensor import ( prepare_for_saving, - restore_from_saved, + restore_from_func_ctx, ) # Import attention utils @@ -2085,7 +2085,7 @@ def backward(ctx, dout, *_args): cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors, - ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) + ) = restore_from_func_ctx(ctx) cu_seqlens_q_per_step = other_tensors[:cp_size] cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2] rng_states = other_tensors[cp_size * 2 : cp_size * 3] @@ -3675,7 +3675,7 @@ def backward(ctx, dout, *_args): cu_seqlens_q_padded, cu_seqlens_kv_padded, *aux_ctx_tensors, - ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) + ) = restore_from_func_ctx(ctx) qkv_format = ctx.qkv_format qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 30c1dbf408..0adda48e36 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -49,7 +49,7 @@ QuantizedTensorStorage, Quantizer, prepare_for_saving, - restore_from_saved, + restore_from_func_ctx, ) from ...debug.pytorch.debug_quantization import DebugQuantizer from ...debug.pytorch.debug_state import TEDebugState @@ -316,7 +316,7 @@ def forward( def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring with get_nvtx_range_context("_GroupedLinear_backward"): - saved_tensors = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) + saved_tensors = restore_from_func_ctx(ctx) N = ctx.num_gemms inputmats = saved_tensors[:N] weights = saved_tensors[N : 2 * N] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index d775dc3e8e..ed91bc1235 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -60,7 +60,7 @@ QuantizedTensorStorage, Quantizer, prepare_for_saving, - restore_from_saved, + restore_from_func_ctx, ) from ...debug.pytorch.debug_state import TEDebugState from ..tensor.mxfp8_tensor import MXFP8Quantizer @@ -546,7 +546,6 @@ def backward( nvtx_label = f"{nvtx_label}.{ctx.ub_name}" with get_nvtx_range_context("_LayerNormLinear_backward"): - saved_tensors = ctx.saved_tensors ( # pylint: disable=unbalanced-tuple-unpacking inputmat, weight, @@ -556,11 +555,7 @@ def backward( ln_out, mu, rsigma, - ) = restore_from_saved(ctx.tensor_objects, saved_tensors) - - # Delete the references to tensor objects once they've been consumed - # by the `restore_from_saved` method to construct back the actual tensors. - ctx.tensor_objects = None + ) = restore_from_func_ctx(ctx) # Since main_grad can be modified inplace, it should not be a part of saved_tensors main_grad = ( diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 037fb6c858..cc3dcc4064 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -80,7 +80,7 @@ QuantizedTensorStorage, Quantizer, prepare_for_saving, - restore_from_saved, + restore_from_func_ctx, ) from ..cpp_extensions import ( general_gemm, @@ -898,11 +898,7 @@ def forward( def _recompute(ctx): # pylint: disable=missing-function-docstring - saved_tensors = ctx.saved_tensors - tensors = restore_from_saved(ctx.tensor_objects, saved_tensors) - # Delete the references to tensor objects once they've been consumed - # by the `restore_from_saved` method to construct back the actual tensors. - ctx.tensor_objects = None + tensors = restore_from_func_ctx(ctx) if ctx.checkpoint: # do recomputation from the original args diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 1e3eadc405..ea921341a4 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -61,7 +61,7 @@ QuantizedTensorStorage, Quantizer, prepare_for_saving, - restore_from_saved, + restore_from_func_ctx, ) from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer @@ -501,15 +501,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], nvtx_label = f"{nvtx_label}.{ctx.ub_name}" with get_nvtx_range_context("_Linear_backward"): - saved_tensors = ctx.saved_tensors inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking - restore_from_saved(ctx.tensor_objects, saved_tensors) + restore_from_func_ctx(ctx) ) - # Delete the references to tensor objects once they've been consumed - # by the `restore_from_saved` method to construct back the actual tensors. - ctx.tensor_objects = None - # Since main_grad can be modified inplace, it should not be a part of saved_tensors main_grad = ( ctx.main_grad_func() diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 80386db2d9..76606ec799 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -12,7 +12,7 @@ import torch from ..quantization import FP8GlobalStateManager, Recipe, DelayedScaling -from ..quantized_tensor import prepare_for_saving, restore_from_saved +from ..quantized_tensor import prepare_for_saving, restore_from_func_ctx from .op import ( BasicOperation, FusibleOperation, @@ -212,8 +212,7 @@ def backward( basic_op_ctxs = func_ctx.basic_op_ctxs # Restore saved tensors - saved_tensors = restore_from_saved(func_ctx.tensor_objects, func_ctx.saved_tensors) - func_ctx.tensor_objects = None + saved_tensors = restore_from_func_ctx(func_ctx) # Unflatten list of saved tensors for ctx in basic_op_ctxs: diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 807671e863..50f1e92422 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -159,7 +159,9 @@ def restore_from_saved( list[Optional[torch.Tensor]], ] ): - """Recombine the tensor data and metadata during backward pass.""" + """Recombine the tensor data and metadata during backward pass. + Note: please use `restore_from_func_ctx` instead if you are restoring tensors from a function context to make sure tensor_objects is detached and its memory can be freed + """ tensor_objects = [] for tensor in tensors: if tensor is None or isinstance(tensor, torch.Tensor): @@ -174,6 +176,24 @@ def restore_from_saved( return tensor_objects +def restore_from_func_ctx(ctx: torch.autograd.function.FunctionCtx, return_saved_tensors=False) -> ( + list[Optional[torch.Tensor | QuantizedTensorStorage]] + | tuple[ + list[Optional[torch.Tensor | QuantizedTensorStorage]], + list[Optional[torch.Tensor]], + ] +): + """Recombine the tensor data and metadata during backward pass and delete tensor objects attached to function context.""" + if not hasattr(ctx, "tensor_objects") or ctx.tensor_objects is None: + raise AttributeError("ctx must have .tensor_objects to restore saved tensors") + out = restore_from_saved( + ctx.tensor_objects, ctx.saved_tensors, return_saved_tensors=return_saved_tensors + ) + # Delete the references to tensor objects once they've been consumed by the `restore_from_saved` method to construct back the actual tensors. + ctx.tensor_objects = None + return out + + class Quantizer(abc.ABC): """Builder class for quantized tensors. diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index 5668056700..426c656d47 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -12,6 +12,7 @@ Quantizer, prepare_for_saving, restore_from_saved, + restore_from_func_ctx, ) from .storage.float8_tensor_storage import Float8TensorStorage from .storage.mxfp8_tensor_storage import MXFP8TensorStorage @@ -46,6 +47,7 @@ "GroupedTensor", "prepare_for_saving", "restore_from_saved", + "restore_from_func_ctx", ]