Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from transformer_engine.pytorch.quantized_tensor import (
Quantizer,
prepare_for_saving,
restore_from_saved,
restore_from_func_ctx,
)

_current_file = pathlib.Path(__file__).resolve()
Expand Down Expand Up @@ -2701,10 +2701,7 @@ def forward(
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
with torch.cuda.nvtx.range("_DPA"):
saved_tensors = ctx.saved_tensors
(q, k, v, inp_fp8, qkv_weight_fp8, out) = restore_from_saved(
ctx.tensor_objects, saved_tensors
)
(q, k, v, inp_fp8, qkv_weight_fp8, out) = restore_from_func_ctx(ctx)

proj_dgrad = ctx.dO_quantizer(grad_output)
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from transformer_engine.pytorch.quantized_tensor import Quantizer
from transformer_engine.pytorch.quantized_tensor import prepare_for_saving
from transformer_engine.pytorch.quantized_tensor import restore_from_saved
from transformer_engine.pytorch.quantized_tensor import restore_from_func_ctx
from transformer_engine.pytorch.tensor import Float8Quantizer
from transformer_engine.pytorch.tensor import Float8CurrentScalingQuantizer
from transformer_engine.pytorch.tensor import MXFP8Quantizer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from transformer_engine.pytorch.quantized_tensor import (
QuantizedTensorStorage,
prepare_for_saving,
restore_from_saved,
restore_from_func_ctx,
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.constants import (
Expand Down Expand Up @@ -1477,7 +1477,7 @@ def backward(ctx, d_out, *_args):
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
*other_tensors,
) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
) = restore_from_func_ctx(ctx)

aux_ctx_tensors = other_tensors

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

from transformer_engine.pytorch.quantized_tensor import (
prepare_for_saving,
restore_from_saved,
restore_from_func_ctx,
)

# Import attention utils
Expand Down Expand Up @@ -2085,7 +2085,7 @@ def backward(ctx, dout, *_args):
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
*other_tensors,
) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
) = restore_from_func_ctx(ctx)
cu_seqlens_q_per_step = other_tensors[:cp_size]
cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2]
rng_states = other_tensors[cp_size * 2 : cp_size * 3]
Expand Down Expand Up @@ -3675,7 +3675,7 @@ def backward(ctx, dout, *_args):
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
*aux_ctx_tensors,
) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
) = restore_from_func_ctx(ctx)

qkv_format = ctx.qkv_format
qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
QuantizedTensorStorage,
Quantizer,
prepare_for_saving,
restore_from_saved,
restore_from_func_ctx,
)
from ...debug.pytorch.debug_quantization import DebugQuantizer
from ...debug.pytorch.debug_state import TEDebugState
Expand Down Expand Up @@ -316,7 +316,7 @@ def forward(
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
with get_nvtx_range_context("_GroupedLinear_backward"):
saved_tensors = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
saved_tensors = restore_from_func_ctx(ctx)
N = ctx.num_gemms
inputmats = saved_tensors[:N]
weights = saved_tensors[N : 2 * N]
Expand Down
9 changes: 2 additions & 7 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
QuantizedTensorStorage,
Quantizer,
prepare_for_saving,
restore_from_saved,
restore_from_func_ctx,
)
from ...debug.pytorch.debug_state import TEDebugState
from ..tensor.mxfp8_tensor import MXFP8Quantizer
Expand Down Expand Up @@ -546,7 +546,6 @@ def backward(
nvtx_label = f"{nvtx_label}.{ctx.ub_name}"

with get_nvtx_range_context("_LayerNormLinear_backward"):
saved_tensors = ctx.saved_tensors
( # pylint: disable=unbalanced-tuple-unpacking
inputmat,
weight,
Expand All @@ -556,11 +555,7 @@ def backward(
ln_out,
mu,
rsigma,
) = restore_from_saved(ctx.tensor_objects, saved_tensors)

# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None
) = restore_from_func_ctx(ctx)

# Since main_grad can be modified inplace, it should not be a part of saved_tensors
main_grad = (
Expand Down
8 changes: 2 additions & 6 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
QuantizedTensorStorage,
Quantizer,
prepare_for_saving,
restore_from_saved,
restore_from_func_ctx,
)
from ..cpp_extensions import (
general_gemm,
Expand Down Expand Up @@ -898,11 +898,7 @@ def forward(
def _recompute(ctx):
# pylint: disable=missing-function-docstring

saved_tensors = ctx.saved_tensors
tensors = restore_from_saved(ctx.tensor_objects, saved_tensors)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None
tensors = restore_from_func_ctx(ctx)

if ctx.checkpoint: # do recomputation from the original args

Expand Down
9 changes: 2 additions & 7 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
QuantizedTensorStorage,
Quantizer,
prepare_for_saving,
restore_from_saved,
restore_from_func_ctx,
)
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
Expand Down Expand Up @@ -501,15 +501,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
nvtx_label = f"{nvtx_label}.{ctx.ub_name}"

with get_nvtx_range_context("_Linear_backward"):
saved_tensors = ctx.saved_tensors
inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking
restore_from_saved(ctx.tensor_objects, saved_tensors)
restore_from_func_ctx(ctx)
)

# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None

# Since main_grad can be modified inplace, it should not be a part of saved_tensors
main_grad = (
ctx.main_grad_func()
Expand Down
5 changes: 2 additions & 3 deletions transformer_engine/pytorch/ops/fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch

from ..quantization import FP8GlobalStateManager, Recipe, DelayedScaling
from ..quantized_tensor import prepare_for_saving, restore_from_saved
from ..quantized_tensor import prepare_for_saving, restore_from_func_ctx
from .op import (
BasicOperation,
FusibleOperation,
Expand Down Expand Up @@ -212,8 +212,7 @@ def backward(
basic_op_ctxs = func_ctx.basic_op_ctxs

# Restore saved tensors
saved_tensors = restore_from_saved(func_ctx.tensor_objects, func_ctx.saved_tensors)
func_ctx.tensor_objects = None
saved_tensors = restore_from_func_ctx(func_ctx)

# Unflatten list of saved tensors
for ctx in basic_op_ctxs:
Expand Down
22 changes: 21 additions & 1 deletion transformer_engine/pytorch/quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ def restore_from_saved(
list[Optional[torch.Tensor]],
]
):
"""Recombine the tensor data and metadata during backward pass."""
"""Recombine the tensor data and metadata during backward pass.
Note: please use `restore_from_func_ctx` instead if you are restoring tensors from a function context to make sure tensor_objects is detached and its memory can be freed
"""
tensor_objects = []
for tensor in tensors:
if tensor is None or isinstance(tensor, torch.Tensor):
Expand All @@ -174,6 +176,24 @@ def restore_from_saved(
return tensor_objects


def restore_from_func_ctx(ctx: torch.autograd.function.FunctionCtx, return_saved_tensors=False) -> (
list[Optional[torch.Tensor | QuantizedTensorStorage]]
| tuple[
list[Optional[torch.Tensor | QuantizedTensorStorage]],
list[Optional[torch.Tensor]],
]
):
"""Recombine the tensor data and metadata during backward pass and delete tensor objects attached to function context."""
if not hasattr(ctx, "tensor_objects") or ctx.tensor_objects is None:
raise AttributeError("ctx must have .tensor_objects to restore saved tensors")
out = restore_from_saved(
ctx.tensor_objects, ctx.saved_tensors, return_saved_tensors=return_saved_tensors
)
# Delete the references to tensor objects once they've been consumed by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None
return out


class Quantizer(abc.ABC):
"""Builder class for quantized tensors.

Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/pytorch/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Quantizer,
prepare_for_saving,
restore_from_saved,
restore_from_func_ctx,
)
from .storage.float8_tensor_storage import Float8TensorStorage
from .storage.mxfp8_tensor_storage import MXFP8TensorStorage
Expand Down Expand Up @@ -46,6 +47,7 @@
"GroupedTensor",
"prepare_for_saving",
"restore_from_saved",
"restore_from_func_ctx",
]


Expand Down
Loading