diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py new file mode 100644 index 0000000000..07d27f4010 --- /dev/null +++ b/tests/pytorch/test_torch_compile.py @@ -0,0 +1,303 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.common import recipe +from transformer_engine.pytorch.module.base import TransformerEngineBaseModule +from transformer_engine.pytorch.ops.basic.basic_linear import BasicLinear + + +# --------------------------------------------------------------------------- +# ToyLinear – minimal TE module backed by BasicLinear functional ops +# --------------------------------------------------------------------------- + +# Global list of ToyLinear instances. Each module registers itself here on +# construction; the custom op identifies which module to use via an integer +# index so that no Python object ever enters the compiled graph. +_toy_modules: list["ToyLinear"] = [] + + +class ToyLinear(TransformerEngineBaseModule): + """Minimal TE-compatible linear module used for torch.compile tests. + + Inherits TransformerEngineBaseModule so that prepare_forward / end_forward + and the FP8 metadata machinery work exactly as in production modules. + The actual compute is delegated to BasicLinear._functional_forward / + _functional_backward via the opaque custom op below. + """ + + def __init__( + self, + in_features: int, + out_features: int, + device: str = "cuda", + dtype: torch.dtype = torch.bfloat16, + ) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = torch.nn.Parameter( + torch.empty(out_features, in_features, dtype=dtype, device=device) + ) + torch.nn.init.normal_(self.weight) + + # Register in the global list and remember the index. + self._module_idx = len(_toy_modules) + _toy_modules.append(self) + + # -- required abstract overrides ----------------------------------------- + + def _get_weight_tensors(self): + return [self.weight] + + def _get_weight_quantizers(self): + # Weight quantizer: use FP8 scaling when FP8 is enabled. + if not self.fp8 and not self.fp8_calibration: + return [None] + weight_q = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + weight_q.internal = True + return [weight_q] + + # -- quantizer helpers (mirrors what Linear._get_quantizers does) --------- + + def get_forward_quantizers(self): + """Return (input_q, weight_q) for the forward GEMM.""" + if not self.fp8: + return None, None + input_q = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + input_q.internal = True + input_q.optimize_for_gemm = True + (weight_q,) = self._get_weight_quantizers() + return input_q, weight_q + + def get_backward_quantizers(self): + """Return (grad_output_q, grad_input_q) for the backward GEMMs.""" + if not self.fp8: + return None, None + grad_output_q = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] + grad_output_q.internal = True + grad_output_q.optimize_for_gemm = True + return grad_output_q, None + + # -- forward ------------------------------------------------------------- + + def forward(self, inp: torch.Tensor) -> torch.Tensor: + inp = self.prepare_forward(inp, num_gemms=1) + try: + return _toy_linear_fwd_op(inp, self.weight, self._module_idx) + finally: + self.end_forward() + + +# --------------------------------------------------------------------------- +# Opaque custom op (torch.library) +# --------------------------------------------------------------------------- + + +@torch.library.custom_op("test_te::toy_linear", mutates_args=[]) +def _toy_linear_fwd_op(inp: torch.Tensor, weight: torch.Tensor, module_idx: int) -> torch.Tensor: + """Forward GEMM wrapped as an opaque custom op.""" + module = _toy_modules[module_idx] + input_q, weight_q = module.get_forward_quantizers() + out, _, _ = BasicLinear._functional_forward( + input=inp, + weight=weight, + dtype=inp.dtype, + input_quantizer=input_q, + weight_quantizer=weight_q, + ) + return out + + +@_toy_linear_fwd_op.register_fake +def _(inp: torch.Tensor, weight: torch.Tensor, module_idx: int) -> torch.Tensor: + """Abstract implementation for shape inference under torch.compile.""" + return inp @ weight.T + + +def _toy_linear_setup_context(ctx, inputs, output): + inp, weight, module_idx = inputs + ctx.save_for_backward(inp, weight) + ctx.module_idx = module_idx + + +@torch.library.custom_op("test_te::toy_linear_backward", mutates_args=[]) +def _toy_linear_bwd_op( + grad_output: torch.Tensor, inp: torch.Tensor, weight: torch.Tensor, module_idx: int +) -> tuple[torch.Tensor, torch.Tensor]: + """Opaque backward op wrapping BasicLinear._functional_backward.""" + module = _toy_modules[module_idx] + grad_output_q, grad_input_q = module.get_backward_quantizers() + dx, dw = BasicLinear._functional_backward( + grad_output=grad_output, + input=inp, + weight=weight, + grad_output_quantizer=grad_output_q, + grad_input_quantizer=grad_input_q, + ) + return dx, dw + + +@_toy_linear_bwd_op.register_fake +def _(grad_output: torch.Tensor, inp: torch.Tensor, weight: torch.Tensor, module_idx: int): + """Abstract backward implementation for shape inference under torch.compile.""" + return torch.empty_like(inp), torch.empty_like(weight) + + +def _toy_linear_backward(ctx, grad_output: torch.Tensor): + """Backward: dispatch to opaque custom op so TE backward is not traced.""" + inp, weight = ctx.saved_tensors + dx, dw = _toy_linear_bwd_op(grad_output, inp, weight, ctx.module_idx) + return dx, dw, None # None for module_idx gradient + + +torch.library.register_autograd( + "test_te::toy_linear", + _toy_linear_backward, + setup_context=_toy_linear_setup_context, +) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +_fp8_available = te.is_fp8_available() +_mxfp8_available = te.is_mxfp8_available() +_fp8_block_scaling_available = te.is_fp8_block_scaling_available() + +# Each entry is (fp8_recipe, fp8_enabled). +# For the "no_fp8" variant, enabled=False so autocast is a no-op, but we still +# pass a real pre-created recipe object so that get_default_fp8_recipe() is +# never called during compilation (which would assert-fail inside torch.compile). +_recipes = [ + pytest.param(recipe.DelayedScaling(), False, id="no_fp8"), + pytest.param( + recipe.Float8CurrentScaling(), + True, + id="float8_current_scaling", + marks=pytest.mark.skipif(not _fp8_available, reason="FP8 not supported"), + ), + pytest.param( + recipe.MXFP8BlockScaling(), + True, + id="mxfp8_block_scaling", + marks=pytest.mark.skipif(not _mxfp8_available, reason="MXFP8 not supported"), + ), + pytest.param( + recipe.Float8BlockScaling(), + True, + id="float8_block_scaling", + marks=pytest.mark.skipif( + not _fp8_block_scaling_available, reason="FP8 block scaling not supported" + ), + ), +] + + +@pytest.mark.parametrize("fp8_recipe,fp8_enabled", _recipes) +def test_autocast(fp8_recipe, fp8_enabled): + """Test that ToyLinear inside te.autocast compiles without graph breaks. + + fullgraph=True makes torch.compile raise an error if any graph break occurs. + Parametrized over all supported recipes. The no_fp8 variant uses + enabled=False so the autocast is a no-op, but still passes a real + pre-created recipe object to avoid calling get_default_fp8_recipe() during + compilation. + """ + global _toy_modules + _toy_modules = [] + + dtype = torch.bfloat16 + device = "cuda" + + model = ToyLinear(32, 64, device=device, dtype=dtype) + inp = torch.randn(8, 32, dtype=dtype, device=device, requires_grad=True) + + def fn(inp): + with te.autocast(recipe=fp8_recipe, enabled=fp8_enabled): + return model(inp) + + torch._dynamo.reset() + compiled = torch.compile(fn, fullgraph=True) + + out = compiled(inp) + out.sum().backward() + + +@pytest.mark.skipif(not _fp8_available, reason="FP8 not supported") +def test_autocast_delayed_scaling_unsupported(): + """DelayedScaling should fail with a clear error under torch.compile.""" + global _toy_modules + _toy_modules = [] + + dtype = torch.bfloat16 + device = "cuda" + + model = ToyLinear(32, 64, device=device, dtype=dtype) + inp = torch.randn(8, 32, dtype=dtype, device=device, requires_grad=True) + fp8_recipe = recipe.DelayedScaling() + + def fn(inp): + with te.autocast(recipe=fp8_recipe, enabled=True): + return model(inp) + + torch._dynamo.reset() + compiled = torch.compile(fn, fullgraph=True) + + with pytest.raises(RuntimeError, match="DelayedScaling is not supported under torch.compile"): + compiled(inp) + + +@pytest.mark.skipif(not te.is_fp8_available(), reason="FP8 not supported on this GPU") +def test_autocast_nested(): + """Test sequential model with different FP8 recipes and nested te.autocast. + + Layout: + with autocast(Float8CurrentScaling): # outer + out = m0(inp) + with autocast(Float8CurrentScaling): # nested inside outer + out = m1(out) + with autocast(Float8CurrentScaling): # separate, after the nested pair + out = m2(out) + + fullgraph=True makes torch.compile raise an error if any graph break occurs. + """ + global _toy_modules + _toy_modules = [] + + dtype = torch.bfloat16 + device = "cuda" + + m0 = ToyLinear(32, 32, device=device, dtype=dtype) + m1 = ToyLinear(32, 32, device=device, dtype=dtype) + m2 = ToyLinear(32, 32, device=device, dtype=dtype) + + # Use distinct recipe objects so nested/separate autocast contexts use + # different identities under torch.compile. + recipe_current0 = recipe.Float8CurrentScaling() + recipe_current1 = recipe.Float8CurrentScaling() + recipe_current2 = recipe.Float8CurrentScaling() + + inp = torch.randn(8, 32, dtype=dtype, device=device, requires_grad=True) + + def fn(inp): + with te.autocast(recipe=recipe_current0): # outer + out = m0(inp) + with te.autocast(recipe=recipe_current1): # nested inside outer + out = m1(out) + + with te.autocast(recipe=recipe_current2): # separate, after nested pair + out = m2(out) + return out + + torch._dynamo.reset() + compiled = torch.compile(fn, fullgraph=True) + + out = compiled(inp) + out.sum().backward() diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 2dc42be18a..588c708e10 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -704,7 +704,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: autocast_key = FP8GlobalStateManager.get_unique_autocast_key( fp8_recipe_dpa, fp8_group ) - FP8GlobalStateManager.autocast_arguments[autocast_key] = ( + FP8GlobalStateManager.quantization_state.autocast_arguments[autocast_key] = ( fp8_recipe_dpa, fp8_group, ) @@ -736,7 +736,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: autocast_key = FP8GlobalStateManager.get_unique_autocast_key( fp8_recipe_dpa, fp8_group ) - FP8GlobalStateManager.autocast_arguments[autocast_key] = ( + FP8GlobalStateManager.quantization_state.autocast_arguments[autocast_key] = ( fp8_recipe_dpa, fp8_group, ) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index b80e58fe20..d3b460e948 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -263,10 +263,10 @@ def __enter__(self): if self.activation_recompute and not self.recompute_phase: activation_recompute_forward._is_first_fp8_module.append( - FP8GlobalStateManager.IS_FIRST_FP8_MODULE + FP8GlobalStateManager.quantization_state.is_first_fp8_module ) if self.activation_recompute and self.recompute_phase: - FP8GlobalStateManager.IS_FIRST_FP8_MODULE = ( + FP8GlobalStateManager.quantization_state.is_first_fp8_module = ( activation_recompute_forward._is_first_fp8_module.pop(0) ) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 86b8a4acf4..6e8675076c 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -324,7 +324,11 @@ def _make_graphed_callables( if cache_quantized_params: # Initialize flag that controls FP8 weight updates - FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False) + if FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor is None: + FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor = torch.empty( + 1, dtype=torch.float32, device="cuda" + ) + FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor.fill_(False) # Check callables for c in callables: @@ -836,7 +840,9 @@ def forward(ctx, skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *i # Set flag for whether to update FP8 weight updates ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() if ctx.is_first_module and skip_fp8_weight_update is not None: - FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(skip_fp8_weight_update) + FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor.fill_( + skip_fp8_weight_update + ) ctx.cuda_graph_stream = cuda_graph_stream ctx.cuda_graph_event = cuda_graph_event # Copy values from new tensors into static tensors diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 28da4873f0..542c96565c 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -726,19 +726,22 @@ def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> FP8GlobalStateManager.get_buffer_info() ] for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)): - if buffer_key in FP8GlobalStateManager.global_amax_buffer: - if buffer_key not in FP8GlobalStateManager.global_amax_history_buffer: + if buffer_key in FP8GlobalStateManager.quantization_state.global_amax_buffer: + if ( + buffer_key + not in FP8GlobalStateManager.quantization_state.global_amax_history_buffer + ): raise RuntimeError( "TE internal error during amax history change: " f"buffer_key '{buffer_key}' found in global_amax_buffer " "but missing from global_amax_history_buffer" ) - FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = self.fp8_meta[ - meta_key - ].amax_history[0] - FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = ( - self.fp8_meta[meta_key].amax_history - ) + FP8GlobalStateManager.quantization_state.global_amax_history_buffer[ + buffer_key + ][pos] = self.fp8_meta[meta_key].amax_history + FP8GlobalStateManager.quantization_state.global_amax_buffer[buffer_key][ + pos + ] = self.fp8_meta[meta_key].amax_history[0] def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index d775dc3e8e..b9e838b347 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -515,10 +515,10 @@ def forward( ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias): - _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE + _first_fp8_module = FP8GlobalStateManager.quantization_state.is_first_fp8_module ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): - FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module + FP8GlobalStateManager.quantization_state.is_first_fp8_module = _first_fp8_module ctx.wgrad_store = wgrad_store ctx.debug = debug @@ -1493,7 +1493,9 @@ def forward( debug = self.is_debug_iter() if FP8GlobalStateManager.fp8_graph_capturing(): - skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + skip_fp8_weight_update = ( + FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor + ) else: skip_fp8_weight_update = None if skip_fp8_weight_update is not None: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 037fb6c858..54f6b90682 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -241,7 +241,12 @@ def _forward( if checkpoint: # save the state of autocast and quantizers for recomputation ctx.autocast_state = ( - FP8GlobalStateManager.get_autocast_state() + FP8GlobalStateManager.quantization_state.fp8_enabled, + FP8GlobalStateManager.quantization_state.fp8_calibration, + FP8GlobalStateManager.quantization_state.fp8_recipe, + FP8GlobalStateManager.quantization_state.fp8_distributed_group, + FP8GlobalStateManager.quantization_state.is_first_fp8_module, + FP8GlobalStateManager.quantization_state.fp8_graph_capturing, ) # to restore autocast state during recomputation if ( fp8 @@ -831,10 +836,10 @@ def _forward( if ctx.fp8 and requires_grad( inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias ): - _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE + _first_fp8_module = FP8GlobalStateManager.quantization_state.is_first_fp8_module ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase() or is_recomputation: - FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module + FP8GlobalStateManager.quantization_state.is_first_fp8_module = _first_fp8_module ctx.wgrad_store = wgrad_store if is_recomputation: # return the recomputed tensors @@ -909,9 +914,21 @@ def _recompute(ctx): # backward is not in autocast context, so we set the state here # we also have to set the quantizer states to what they were before the forward pass (only relevant for DelayedScaling recipe) final_autocast_state = ( - FP8GlobalStateManager.get_autocast_state() + FP8GlobalStateManager.quantization_state.fp8_enabled, + FP8GlobalStateManager.quantization_state.fp8_calibration, + FP8GlobalStateManager.quantization_state.fp8_recipe, + FP8GlobalStateManager.quantization_state.fp8_distributed_group, + FP8GlobalStateManager.quantization_state.is_first_fp8_module, + FP8GlobalStateManager.quantization_state.fp8_graph_capturing, ) # get current autocast state - FP8GlobalStateManager.set_autocast_state(ctx.autocast_state) # set old autocast state + ( + FP8GlobalStateManager.quantization_state.fp8_enabled, + FP8GlobalStateManager.quantization_state.fp8_calibration, + FP8GlobalStateManager.quantization_state.fp8_recipe, + FP8GlobalStateManager.quantization_state.fp8_distributed_group, + FP8GlobalStateManager.quantization_state.is_first_fp8_module, + FP8GlobalStateManager.quantization_state.fp8_graph_capturing, + ) = ctx.autocast_state # set old autocast state if ( ctx.other_args["fp8"] and FP8GlobalStateManager.get_fp8_recipe().__class__.__name__ == "DelayedScaling" @@ -934,7 +951,14 @@ def _recompute(ctx): tuple(ctx.other_args.values()), ) - FP8GlobalStateManager.set_autocast_state(final_autocast_state) # restore autocast state + ( + FP8GlobalStateManager.quantization_state.fp8_enabled, + FP8GlobalStateManager.quantization_state.fp8_calibration, + FP8GlobalStateManager.quantization_state.fp8_recipe, + FP8GlobalStateManager.quantization_state.fp8_distributed_group, + FP8GlobalStateManager.quantization_state.is_first_fp8_module, + FP8GlobalStateManager.quantization_state.fp8_graph_capturing, + ) = final_autocast_state # restore autocast state if ( ctx.other_args["fp8"] and FP8GlobalStateManager.get_fp8_recipe().__class__.__name__ == "DelayedScaling" @@ -2045,7 +2069,9 @@ def forward( debug = self.is_debug_iter() if FP8GlobalStateManager.fp8_graph_capturing(): - skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + skip_fp8_weight_update = ( + FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor + ) else: skip_fp8_weight_update = None if skip_fp8_weight_update is not None: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 1e3eadc405..5221a31469 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -479,10 +479,10 @@ def forward( ctx.owns_input = saved_inputmat is not inp if ctx.fp8 and requires_grad(inp, weight, bias): - _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE + _first_fp8_module = FP8GlobalStateManager.quantization_state.is_first_fp8_module ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): - FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module + FP8GlobalStateManager.quantization_state.is_first_fp8_module = _first_fp8_module ctx.wgrad_store = wgrad_store # ------------------------------------------------------ @@ -1376,7 +1376,9 @@ def forward( debug = self.is_debug_iter() if FP8GlobalStateManager.fp8_graph_capturing(): - skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + skip_fp8_weight_update = ( + FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor + ) else: skip_fp8_weight_update = None if skip_fp8_weight_update is not None: diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 54b3f00117..330666ba3a 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -322,16 +322,20 @@ def reset_recipe_state( pos, buffer_key = self._fp8_metas[mode][ FP8GlobalStateManager.get_buffer_info() ] - if buffer_key in FP8GlobalStateManager.global_amax_buffer: + if ( + buffer_key + in FP8GlobalStateManager.quantization_state.global_amax_buffer + ): assert ( - buffer_key in FP8GlobalStateManager.global_amax_history_buffer + buffer_key + in FP8GlobalStateManager.quantization_state.global_amax_history_buffer ), "TE internal error during amax history change." - FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = ( - recipe_state.amax_history[0] - ) - FP8GlobalStateManager.global_amax_history_buffer[buffer_key][ + FP8GlobalStateManager.quantization_state.global_amax_buffer[buffer_key][ pos - ] = recipe_state.amax_history + ] = recipe_state.amax_history[0] + FP8GlobalStateManager.quantization_state.global_amax_history_buffer[ + buffer_key + ][pos] = recipe_state.amax_history # Add meta tensors to global buffer to participate in reduction for mode in ("forward", "backward"): diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 47e6d5c8dc..b769577b22 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -7,9 +7,9 @@ import abc import itertools -import functools import warnings import os +from dataclasses import dataclass, field from contextlib import contextmanager from collections import deque from typing import Callable, List, Optional, Dict, Any, Tuple, Union @@ -44,8 +44,13 @@ ] -@functools.lru_cache(maxsize=None) -def check_fp8_support() -> Tuple[bool, str]: +_FP8_SUPPORT: Optional[Tuple[bool, str]] = None +_MXFP8_SUPPORT: Optional[Tuple[bool, str]] = None +_NVFP4_SUPPORT: Optional[Tuple[bool, str]] = None +_FP8_BLOCK_SCALING_SUPPORT: Optional[Tuple[bool, str]] = None + + +def _compute_fp8_support() -> Tuple[bool, str]: """Return if fp8 support is available""" if get_device_compute_capability() >= (9, 0): # hopper and above return True, "" @@ -58,8 +63,7 @@ def check_fp8_support() -> Tuple[bool, str]: return True, "" -@functools.lru_cache(maxsize=None) -def check_mxfp8_support() -> Tuple[bool, str]: +def _compute_mxfp8_support() -> Tuple[bool, str]: """Return if fp8 support is available""" if get_device_compute_capability() >= (12, 0): return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet." @@ -68,16 +72,14 @@ def check_mxfp8_support() -> Tuple[bool, str]: return False, "Device compute capability 10.0 or higher required for MXFP8 execution." -@functools.lru_cache(maxsize=None) -def check_nvfp4_support() -> Tuple[bool, str]: +def _compute_nvfp4_support() -> Tuple[bool, str]: """Return if nvfp4 support is available""" if get_device_compute_capability() >= (10, 0): # blackwell and above return True, "" return False, "Device compute capability 10.0 or higher required for NVFP4 execution." -@functools.lru_cache(maxsize=None) -def check_fp8_block_scaling_support() -> Tuple[bool, str]: +def _compute_fp8_block_scaling_support() -> Tuple[bool, str]: """Return if fp8 block scaling support is available""" if get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.9: return True, "" @@ -87,8 +89,49 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]: ) +@torch.compiler.assume_constant_result +def check_fp8_support() -> Tuple[bool, str]: + """Return if fp8 support is available.""" + global _FP8_SUPPORT + if _FP8_SUPPORT is None: + _FP8_SUPPORT = _compute_fp8_support() + return _FP8_SUPPORT + + +@torch.compiler.assume_constant_result +def check_mxfp8_support() -> Tuple[bool, str]: + """Return if MXFP8 support is available.""" + global _MXFP8_SUPPORT + if _MXFP8_SUPPORT is None: + _MXFP8_SUPPORT = _compute_mxfp8_support() + return _MXFP8_SUPPORT + + +@torch.compiler.assume_constant_result +def check_nvfp4_support() -> Tuple[bool, str]: + """Return if NVFP4 support is available.""" + global _NVFP4_SUPPORT + if _NVFP4_SUPPORT is None: + _NVFP4_SUPPORT = _compute_nvfp4_support() + return _NVFP4_SUPPORT + + +@torch.compiler.assume_constant_result +def check_fp8_block_scaling_support() -> Tuple[bool, str]: + """Return if fp8 block scaling support is available.""" + global _FP8_BLOCK_SCALING_SUPPORT + if _FP8_BLOCK_SCALING_SUPPORT is None: + _FP8_BLOCK_SCALING_SUPPORT = _compute_fp8_block_scaling_support() + return _FP8_BLOCK_SCALING_SUPPORT + + def check_recipe_support(recipe: Recipe) -> None: """Check if the given recipe is supported.""" + if torch.compiler.is_compiling() and isinstance(recipe, DelayedScaling): + raise RuntimeError( + "DelayedScaling is not supported under torch.compile yet. " + "Use Float8CurrentScaling, MXFP8BlockScaling, or Float8BlockScaling instead." + ) recipe_supported = True unsupported_reason = "" if isinstance(recipe, (DelayedScaling, Float8CurrentScaling)): @@ -103,6 +146,10 @@ def check_recipe_support(recipe: Recipe) -> None: def get_default_fp8_recipe() -> Recipe: """FP8 recipe with default args.""" + assert not torch.compiler.is_compiling(), ( + "get_default_fp8_recipe() must not be called during torch.compile tracing. " + "Pass an explicit recipe to te.autocast() instead of relying on the default." + ) if check_mxfp8_support()[0]: return MXFP8BlockScaling() if get_device_compute_capability() >= (12, 0): @@ -232,71 +279,44 @@ def is_nvfp4_available(return_reason: bool = False) -> Union[bool, Tuple[bool, s return check_nvfp4_support()[0] +@dataclass(slots=True) +class FP8GlobalState: + """Mutable process-global FP8 state stored on an instance. + + Using an instance avoids class-level `setattr(type, ...)` writes, which + `torch.compile` cannot trace in fullgraph mode. + """ + + fp8_enabled: bool = False + fp8_calibration: bool = False + fp8_recipe: Optional[Recipe] = None + fp8_distributed_group: Optional[dist_group_type] = None + fp8_parameters: bool = False + high_precision_init_val: bool = False + is_first_fp8_module: bool = False + fp8_graph_capturing: bool = False + autocast_depth: int = 0 + global_amax_buffer: Dict[str, list] = field(default_factory=dict) + global_amax_history_buffer: Dict[str, list] = field(default_factory=dict) + global_scale_buffer: Dict[str, list] = field(default_factory=dict) + fp8_tensors_recompute_buffer: list = field(default_factory=list) + autocast_arguments: Dict[Any, Tuple[Recipe, Optional[dist_group_type]]] = field( + default_factory=dict + ) + skip_fp8_weight_update_tensor: Optional[torch.Tensor] = None + + class FP8GlobalStateManager: """Class to keep track of and manipulate the global FP8 state at different stages of execution. """ - FP8_ENABLED = False - FP8_CALIBRATION = False - FP8_RECIPE = None - FP8_DISTRIBUTED_GROUP = None - FP8_PARAMETERS = False - HIGH_PRECISION_INIT_VAL = False - IS_FIRST_FP8_MODULE = False - FP8_GRAPH_CAPTURING = False - AUTOCAST_DEPTH = 0 - global_amax_buffer = {} - global_amax_history_buffer = {} - global_scale_buffer = {} - fp8_tensors_recompute_buffer = [] - fp8_available = None - reason_for_no_fp8 = "" - autocast_arguments = {} - skip_fp8_weight_update_tensor = None - mxfp8_available = None - reason_for_no_mxfp8 = "" - fp8_block_scaling_available = None - reason_for_no_fp8_block_scaling = None - nvfp4_available = None - reason_for_no_nvfp4 = "" + quantization_state = FP8GlobalState() @classmethod def reset(cls) -> None: """Reset the global state""" - cls.FP8_ENABLED = False - cls.FP8_CALIBRATION = False - cls.FP8_RECIPE = None - cls.FP8_DISTRIBUTED_GROUP = None - cls.FP8_PARAMETERS = False - cls.HIGH_PRECISION_INIT_VAL = False - cls.IS_FIRST_FP8_MODULE = False - cls.FP8_GRAPH_CAPTURING = False - cls.AUTOCAST_DEPTH = 0 - cls.global_amax_buffer = {} - cls.global_amax_history_buffer = {} - cls.global_scale_buffer = {} - cls.fp8_tensors_recompute_buffer = [] - cls.fp8_available = None - cls.reason_for_no_fp8 = "" - cls.autocast_arguments = {} - cls.skip_fp8_weight_update_tensor = None - cls.mxfp8_available = None - cls.reason_for_no_mxfp8 = "" - cls.fp8_block_scaling_available = None - cls.reason_for_no_fp8_block_scaling = "" - - @classmethod - def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: - """`skip_fp8_weight_update_tensor` inplace setter.""" - if cls.skip_fp8_weight_update_tensor is None: - cls.skip_fp8_weight_update_tensor = torch.empty(1, dtype=torch.float32, device="cuda") - cls.skip_fp8_weight_update_tensor.fill_(skip) - - @classmethod - def get_skip_fp8_weight_update_tensor(cls) -> None: - """`skip_fp8_weight_update_tensor` getter.""" - return cls.skip_fp8_weight_update_tensor + cls.quantization_state = FP8GlobalState() @classmethod def is_fp8_available(cls) -> Tuple[bool, str]: @@ -398,90 +418,81 @@ def add_fp8_tensors_to_global_buffer( key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"]) - if key not in cls.global_amax_buffer: - cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] - cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history] - cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] + if key not in FP8GlobalStateManager.quantization_state.global_amax_buffer: + FP8GlobalStateManager.quantization_state.global_amax_buffer[key] = [ + fp8_meta[fp8_meta_tensor_key].amax_history[0] + ] + FP8GlobalStateManager.quantization_state.global_amax_history_buffer[key] = [ + fp8_meta[fp8_meta_tensor_key].amax_history + ] + FP8GlobalStateManager.quantization_state.global_scale_buffer[key] = [ + fp8_meta[fp8_meta_tensor_key].scale + ] else: - cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) - cls.global_amax_history_buffer[key].append( + FP8GlobalStateManager.quantization_state.global_amax_buffer[key].append( + fp8_meta[fp8_meta_tensor_key].amax_history[0] + ) + FP8GlobalStateManager.quantization_state.global_amax_history_buffer[key].append( fp8_meta[fp8_meta_tensor_key].amax_history ) - cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) - fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1) + FP8GlobalStateManager.quantization_state.global_scale_buffer[key].append( + fp8_meta[fp8_meta_tensor_key].scale + ) + fp8_meta[index_in_buffer].append( + len(FP8GlobalStateManager.quantization_state.global_amax_buffer[key]) - 1 + ) fp8_meta[index_in_buffer].append(key) @classmethod def is_fp8_enabled(cls) -> bool: """Is FP8 enabled""" - return cls.FP8_ENABLED + return cls.quantization_state.fp8_enabled @classmethod def is_fp8_calibration(cls) -> bool: """Is FP8 calibration""" - return cls.FP8_CALIBRATION + return cls.quantization_state.fp8_calibration @classmethod def with_fp8_parameters(cls) -> bool: """Should the parameters be stored as FP8""" - return cls.FP8_PARAMETERS + return cls.quantization_state.fp8_parameters @classmethod def with_high_precision_init_val(cls) -> bool: """Should the high precision initial values be stored with FP8 parameters""" - return cls.HIGH_PRECISION_INIT_VAL + return cls.quantization_state.high_precision_init_val @classmethod def fp8_graph_capturing(cls) -> bool: """Is CUDA graph capture under way?""" - return cls.FP8_GRAPH_CAPTURING or torch.cuda.is_current_stream_capturing() + if torch.compiler.is_compiling(): + assert not cls.quantization_state.fp8_graph_capturing + return False + return ( + cls.quantization_state.fp8_graph_capturing or torch.cuda.is_current_stream_capturing() + ) @classmethod def is_first_fp8_module(cls): """Returns `True` only the first time when called multiple times from within the same `autocast` context. """ - tmp = cls.IS_FIRST_FP8_MODULE - cls.IS_FIRST_FP8_MODULE = False + tmp = cls.quantization_state.is_first_fp8_module + cls.quantization_state.is_first_fp8_module = False return tmp @classmethod def get_fp8_recipe(cls) -> Recipe: """Return the fp8 recipe""" - if cls.FP8_RECIPE is not None: - return cls.FP8_RECIPE + if cls.quantization_state.fp8_recipe is not None: + return cls.quantization_state.fp8_recipe return get_default_fp8_recipe() @classmethod def get_fp8_group(cls) -> Union[dist_group_type, None]: """Return the fp8 group for scale/amax comm""" - return cls.FP8_DISTRIBUTED_GROUP - - @classmethod - def get_autocast_state(cls) -> Tuple[bool, bool, Recipe, dist_group_type, bool]: - """FP8 autocast state getter""" - return ( - cls.FP8_ENABLED, - cls.FP8_CALIBRATION, - cls.FP8_RECIPE, - cls.FP8_DISTRIBUTED_GROUP, - cls.IS_FIRST_FP8_MODULE, - cls.FP8_GRAPH_CAPTURING, - ) - - @classmethod - def set_autocast_state( - cls, fp8_state: Tuple[bool, bool, DelayedScaling, dist_group_type, bool] - ) -> None: - """FP8 autocast state setter""" - ( - cls.FP8_ENABLED, - cls.FP8_CALIBRATION, - cls.FP8_RECIPE, - cls.FP8_DISTRIBUTED_GROUP, - cls.IS_FIRST_FP8_MODULE, - cls.FP8_GRAPH_CAPTURING, - ) = fp8_state + return cls.quantization_state.fp8_distributed_group @staticmethod def reduce_tensor_across_group_op_max(tensor: torch.Tensor, group: dist_group_type) -> None: @@ -501,7 +512,10 @@ def reduce_and_update_fp8_tensors( ) -> None: """Delayed scaling only. Concatenate, reduce, and split amaxes in the global buffer.""" # global_amax_buffer should only be non-empty for fp8 delayed scaling - for buffer_key, amax_buffer in cls.global_amax_buffer.items(): + for ( + buffer_key, + amax_buffer, + ) in FP8GlobalStateManager.quantization_state.global_amax_buffer.items(): # Check for forward or backward reduction. fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key) if fwd_update != forward: @@ -510,7 +524,9 @@ def reduce_and_update_fp8_tensors( continue # Retrieve autocast specific args and concat amaxes. - recipe, group = cls.autocast_arguments[autocast_key] + recipe, group = FP8GlobalStateManager.quantization_state.autocast_arguments[ + autocast_key + ] contiguous_amax = torch.cat(amax_buffer) # Reduction. @@ -531,8 +547,8 @@ def reduce_and_update_fp8_tensors( if not unfused_update: tex.fused_amax_and_scale_update_after_reduction( contiguous_amax, - cls.global_amax_history_buffer[buffer_key], - cls.global_scale_buffer[buffer_key], + FP8GlobalStateManager.quantization_state.global_amax_history_buffer[buffer_key], + FP8GlobalStateManager.quantization_state.global_scale_buffer[buffer_key], recipe.amax_compute_algo, get_fp8_te_dtype(recipe, forward), recipe.margin, @@ -541,8 +557,8 @@ def reduce_and_update_fp8_tensors( split_and_copy(contiguous_amax, amax_buffer, [x.numel() for x in amax_buffer]) for amax_history, scale in zip( - cls.global_amax_history_buffer[buffer_key], - cls.global_scale_buffer[buffer_key], + FP8GlobalStateManager.quantization_state.global_amax_history_buffer[buffer_key], + FP8GlobalStateManager.quantization_state.global_scale_buffer[buffer_key], ): _amax_and_scale_update( amax_history, scale, get_fp8_max(recipe, forward), recipe @@ -556,9 +572,12 @@ def get_unique_autocast_key( ): """ For FP8, each autocast can be uniquely identified by the recipe and fp8 group. - Safely using `hash` as we never cross checkpoint boundaries. + Object identity is sufficient since autocast contexts never outlive a single + training session. """ - return f"{str(recipe)}:{hash(group)}" + return str( + (id(recipe) if recipe is not None else None, id(group) if group is not None else None) + ) @classmethod def autocast_enter( @@ -573,17 +592,20 @@ def autocast_enter( fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) - cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group) + FP8GlobalStateManager.quantization_state.autocast_arguments[autocast_key] = ( + fp8_recipe, + fp8_group, + ) - cls.FP8_ENABLED = enabled - cls.FP8_CALIBRATION = calibrating - cls.FP8_RECIPE = fp8_recipe - cls.FP8_DISTRIBUTED_GROUP = fp8_group - cls.FP8_GRAPH_CAPTURING = _graph + FP8GlobalStateManager.quantization_state.fp8_enabled = enabled + FP8GlobalStateManager.quantization_state.fp8_calibration = calibrating + FP8GlobalStateManager.quantization_state.fp8_recipe = fp8_recipe + FP8GlobalStateManager.quantization_state.fp8_distributed_group = fp8_group + FP8GlobalStateManager.quantization_state.fp8_graph_capturing = _graph - if cls.AUTOCAST_DEPTH == 0: - cls.IS_FIRST_FP8_MODULE = True - cls.AUTOCAST_DEPTH += 1 + if FP8GlobalStateManager.quantization_state.autocast_depth == 0: + FP8GlobalStateManager.quantization_state.is_first_fp8_module = True + FP8GlobalStateManager.quantization_state.autocast_depth += 1 if enabled: fp8_available, reason_for_no_fp8 = cls.is_fp8_available() @@ -601,11 +623,16 @@ def autocast_enter( @classmethod def autocast_exit(cls, enabled: bool, _graph: bool) -> None: """Set state and tracking variables for exit from FP8 region.""" - cls.AUTOCAST_DEPTH -= 1 + FP8GlobalStateManager.quantization_state.autocast_depth -= 1 # Reduce only the non-FP8 weight modules here. # FP8 weight modules are reduced at the end of the optimizer # step after the weight amax is populated. - if enabled and cls.AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled(): + if ( + enabled + and FP8GlobalStateManager.quantization_state.autocast_depth == 0 + and not _graph + and torch.is_grad_enabled() + ): # delayed scaling only function, for other recipes (current scaling with any granularity), # this is noop for other recipes because cls.global_amax_buffer is empty list cls.reduce_and_update_fp8_tensors(forward=True) @@ -628,14 +655,22 @@ def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) - ] if buffer_position_key in fp8_meta: - cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy) + FP8GlobalStateManager.quantization_state.fp8_tensors_recompute_buffer[ + fp8_meta[buffer_position_key] + ].append(to_copy) else: - if len(cls.fp8_tensors_recompute_buffer) == 0: - cls.fp8_tensors_recompute_buffer = [deque()] + if len(FP8GlobalStateManager.quantization_state.fp8_tensors_recompute_buffer) == 0: + FP8GlobalStateManager.quantization_state.fp8_tensors_recompute_buffer = [deque()] else: - cls.fp8_tensors_recompute_buffer.append(deque()) - cls.fp8_tensors_recompute_buffer[-1].append(to_copy) - fp8_meta[buffer_position_key] = len(cls.fp8_tensors_recompute_buffer) - 1 + FP8GlobalStateManager.quantization_state.fp8_tensors_recompute_buffer.append( + deque() + ) + FP8GlobalStateManager.quantization_state.fp8_tensors_recompute_buffer[-1].append( + to_copy + ) + fp8_meta[buffer_position_key] = ( + len(FP8GlobalStateManager.quantization_state.fp8_tensors_recompute_buffer) - 1 + ) @classmethod def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: @@ -652,7 +687,9 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non # Retrieve stashed amaxes and scales from phase 1 pre forward. buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" - stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].popleft() + stashed_fp8_meta = FP8GlobalStateManager.quantization_state.fp8_tensors_recompute_buffer[ + fp8_meta[buffer_position_key] + ].popleft() # Replace amaxes and scales with stashed values for phase 2 forward fp8_meta["scaling_fwd"].amax_history.copy_(stashed_fp8_meta[0]) @@ -749,18 +786,22 @@ def quantized_model_init( This functionality is *EXPERIMENTAL*. """ - _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS - _fp8_recipe = FP8GlobalStateManager.FP8_RECIPE - _high_precision_init_val = FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL - FP8GlobalStateManager.FP8_PARAMETERS = enabled - FP8GlobalStateManager.FP8_RECIPE = get_default_fp8_recipe() if recipe is None else recipe - FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = preserve_high_precision_init_val + _fp8_parameters = FP8GlobalStateManager.quantization_state.fp8_parameters + _fp8_recipe = FP8GlobalStateManager.quantization_state.fp8_recipe + _high_precision_init_val = FP8GlobalStateManager.quantization_state.high_precision_init_val + FP8GlobalStateManager.quantization_state.fp8_parameters = enabled + FP8GlobalStateManager.quantization_state.fp8_recipe = ( + get_default_fp8_recipe() if recipe is None else recipe + ) + FP8GlobalStateManager.quantization_state.high_precision_init_val = ( + preserve_high_precision_init_val + ) try: yield finally: - FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters - FP8GlobalStateManager.FP8_RECIPE = _fp8_recipe - FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = _high_precision_init_val + FP8GlobalStateManager.quantization_state.fp8_parameters = _fp8_parameters + FP8GlobalStateManager.quantization_state.fp8_recipe = _fp8_recipe + FP8GlobalStateManager.quantization_state.high_precision_init_val = _high_precision_init_val @contextmanager @@ -847,7 +888,14 @@ def autocast( check_recipe_support(recipe) # Save current state so we always restore it on exit. - fp8_state = FP8GlobalStateManager.get_autocast_state() + fp8_state = ( + FP8GlobalStateManager.quantization_state.fp8_enabled, + FP8GlobalStateManager.quantization_state.fp8_calibration, + FP8GlobalStateManager.quantization_state.fp8_recipe, + FP8GlobalStateManager.quantization_state.fp8_distributed_group, + FP8GlobalStateManager.quantization_state.is_first_fp8_module, + FP8GlobalStateManager.quantization_state.fp8_graph_capturing, + ) FP8GlobalStateManager.autocast_enter( enabled=enabled, @@ -859,7 +907,14 @@ def autocast( try: yield finally: - FP8GlobalStateManager.set_autocast_state(fp8_state) + ( + FP8GlobalStateManager.quantization_state.fp8_enabled, + FP8GlobalStateManager.quantization_state.fp8_calibration, + FP8GlobalStateManager.quantization_state.fp8_recipe, + FP8GlobalStateManager.quantization_state.fp8_distributed_group, + FP8GlobalStateManager.quantization_state.is_first_fp8_module, + FP8GlobalStateManager.quantization_state.fp8_graph_capturing, + ) = fp8_state FP8GlobalStateManager.autocast_exit(enabled, _graph=_graph)