From 22bac5d39f4b770a0568c5c4af18baf5c002ef79 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 13 Mar 2026 12:48:02 +0100 Subject: [PATCH 1/6] Improve torch.compile behavior around FP8 autocast. Move FP8 global state onto an instance so Dynamo can trace autocast state updates, explicitly reject DelayedScaling under torch.compile, and add toy compile tests that keep TE forward/backward opaque while covering supported recipes. Signed-off-by: Pawel Gadzinski --- .../pytorch/test_global_dataclass_mutation.py | 155 +++++++++ tests/pytorch/test_global_dict_mutation.py | 140 ++++++++ tests/pytorch/test_torch_compile.py | 304 +++++++++++++++++ .../dot_product_attention.py | 4 +- transformer_engine/pytorch/distributed.py | 4 +- transformer_engine/pytorch/module/base.py | 12 +- .../pytorch/module/layernorm_linear.py | 4 +- .../pytorch/module/layernorm_mlp.py | 4 +- transformer_engine/pytorch/module/linear.py | 4 +- transformer_engine/pytorch/ops/op.py | 8 +- transformer_engine/pytorch/quantization.py | 322 +++++++++++------- 11 files changed, 818 insertions(+), 143 deletions(-) create mode 100644 tests/pytorch/test_global_dataclass_mutation.py create mode 100644 tests/pytorch/test_global_dict_mutation.py create mode 100644 tests/pytorch/test_torch_compile.py diff --git a/tests/pytorch/test_global_dataclass_mutation.py b/tests/pytorch/test_global_dataclass_mutation.py new file mode 100644 index 0000000000..fb10a6d4a0 --- /dev/null +++ b/tests/pytorch/test_global_dataclass_mutation.py @@ -0,0 +1,155 @@ +""" +Experiment: can torch.compile handle mutation of a global dataclass? + +Analogous to the global-dict experiment, but uses a dataclass instance +stored as a module-level global instead of a plain dict. + +Parts: + 1. Read a field from the global dataclass, check if recompilation happens + when the field value changes. + 2. Write a Python scalar to a dataclass field inside a compiled function. + 3. Write a Tensor to a dataclass field inside a compiled function. +""" + +from dataclasses import dataclass, field +from typing import Optional + +import torch + + +# --------------------------------------------------------------------------- +# Global dataclass +# --------------------------------------------------------------------------- + +@dataclass +class State: + scale: float = 1.0 + result: Optional[int] = None + tensor_val: Optional[torch.Tensor] = None + + +GLOBAL_STATE = State() + + +# --------------------------------------------------------------------------- +# Functions that access / mutate the global dataclass +# --------------------------------------------------------------------------- + + +def fn_read_dataclass(x: torch.Tensor) -> torch.Tensor: + """Read scale from the global dataclass and multiply x by it.""" + return x * GLOBAL_STATE.scale + + +def fn_write_scalar(x: torch.Tensor, value: int) -> torch.Tensor: + """Write a Python scalar to the global dataclass, return x unchanged.""" + GLOBAL_STATE.result = value + return x + 0 + + +def fn_write_tensor(x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """Write a Tensor to the global dataclass, return x unchanged.""" + GLOBAL_STATE.tensor_val = t + return x + 0 + + +# --------------------------------------------------------------------------- +# Compiled versions +# --------------------------------------------------------------------------- +compiled_read = torch.compile(fn_read_dataclass, fullgraph=False) +compiled_write_scalar = torch.compile(fn_write_scalar, fullgraph=False) +compiled_write_tensor = torch.compile(fn_write_tensor, fullgraph=False) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def reset(): + global GLOBAL_STATE + GLOBAL_STATE = State() + + +def unique_graphs(): + return torch._dynamo.utils.counters["stats"].get("unique_graphs", "?") + + +# --------------------------------------------------------------------------- +# Experiment +# --------------------------------------------------------------------------- + +def run(): + print("=" * 60) + print("Experiment: torch.compile + global dataclass mutation") + print("=" * 60) + + x = torch.tensor([2.0], device="cpu") + + # ----------------------------------------------------------------------- + # Part 1 – reading a field from the global dataclass + # ----------------------------------------------------------------------- + print("\n--- Part 1: reading a field from a global dataclass ---") + reset() + torch._dynamo.reset() + torch._dynamo.utils.counters.clear() + + GLOBAL_STATE.scale = 3.0 + y1 = compiled_read(x) + g1 = unique_graphs() + print(f" GLOBAL_STATE.scale = 3.0 → compiled_read(x) = {y1.item()} (expected {x.item() * 3.0})") + print(f" unique_graphs after 1st call: {g1}") + + GLOBAL_STATE.scale = 5.0 + y2 = compiled_read(x) + g2 = unique_graphs() + print(f" GLOBAL_STATE.scale = 5.0 → compiled_read(x) = {y2.item()} (expected {x.item() * 5.0})") + print(f" unique_graphs after 2nd call: {g2}") + + if g2 != g1: + print(f" NOTE: Dynamo recompiled (graphs: {g1} -> {g2})") + else: + print(f" NOTE: Dynamo did NOT recompile") + + if abs(y2.item() - x.item() * 5.0) < 1e-6: + print(" PASS: result reflects updated dataclass field") + else: + print(" FAIL: result does NOT reflect updated field (guard baked-in old value)") + + # ----------------------------------------------------------------------- + # Part 2 – writing a Python scalar to the dataclass + # ----------------------------------------------------------------------- + print("\n--- Part 2: writing a Python scalar to a dataclass field ---") + reset() + torch._dynamo.reset() + + print(f" GLOBAL_STATE.result before call: {GLOBAL_STATE.result}") + compiled_write_scalar(x, 42) + print(f" GLOBAL_STATE.result after call: {GLOBAL_STATE.result}") + + if GLOBAL_STATE.result == 42: + print(" PASS: dataclass field mutation (scalar) is visible after compiled call") + else: + print(" FAIL: dataclass field mutation (scalar) NOT visible") + + # ----------------------------------------------------------------------- + # Part 3 – writing a Tensor to the dataclass + # ----------------------------------------------------------------------- + print("\n--- Part 3: writing a Tensor to a dataclass field ---") + reset() + torch._dynamo.reset() + + t = torch.tensor(99.0) + print(f" GLOBAL_STATE.tensor_val before call: {GLOBAL_STATE.tensor_val}") + compiled_write_tensor(x, t) + print(f" GLOBAL_STATE.tensor_val after call: {GLOBAL_STATE.tensor_val}") + + if GLOBAL_STATE.tensor_val is not None: + print(" PASS: dataclass field mutation (Tensor) is visible after compiled call") + else: + print(" FAIL: dataclass field mutation (Tensor) NOT visible") + + print("\nDone.") + + +if __name__ == "__main__": + run() diff --git a/tests/pytorch/test_global_dict_mutation.py b/tests/pytorch/test_global_dict_mutation.py new file mode 100644 index 0000000000..27c8e9d702 --- /dev/null +++ b/tests/pytorch/test_global_dict_mutation.py @@ -0,0 +1,140 @@ +""" +Experiment: can torch.compile handle mutation of a global dictionary? + +We test two scenarios: + 1. A compiled function that reads from a global dict. + 2. A compiled function that writes (mutates) a global dict. + +In both cases we check whether recompilation or graph breaks occur, and +whether the results are numerically correct. +""" + +import torch + +# --------------------------------------------------------------------------- +# Global state +# --------------------------------------------------------------------------- +GLOBAL_DICT: dict = {} + +# --------------------------------------------------------------------------- +# Functions that access / mutate the global dict +# --------------------------------------------------------------------------- + + +def fn_read_global(x: torch.Tensor) -> torch.Tensor: + """Read a scale factor stored in a global dict and multiply x by it.""" + scale = GLOBAL_DICT.get("scale", 1.0) + return x * scale + + +def fn_write_global(x: torch.Tensor, key: str, value) -> torch.Tensor: + """Write a value into the global dict, then return x unchanged.""" + GLOBAL_DICT[key] = value + return x + 0 # trivial op so there is a tensor output + + +# --------------------------------------------------------------------------- +# Compiled versions +# --------------------------------------------------------------------------- +compiled_read = torch.compile(fn_read_global, fullgraph=False) +compiled_write = torch.compile(fn_write_global, fullgraph=False) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def reset(): + global GLOBAL_DICT + GLOBAL_DICT = {} + + +def count_recompilations(fn): + """Return the number of frames that have been compiled so far.""" + # torch._dynamo.explain() gives per-call stats; we use the simpler + # guard cache size as a proxy. + try: + return torch._dynamo.utils.counters["stats"]["unique_graphs"] + except Exception: + return None + + +# --------------------------------------------------------------------------- +# Experiment +# --------------------------------------------------------------------------- + +def run(): + print("=" * 60) + print("Experiment: torch.compile + global dict mutation") + print("=" * 60) + + x = torch.tensor([2.0], device="cpu") + + # ----------------------------------------------------------------------- + # Part 1 – reading from the global dict + # ----------------------------------------------------------------------- + print("\n--- Part 1: reading from a global dict ---") + reset() + torch._dynamo.reset() + torch._dynamo.utils.counters.clear() + + GLOBAL_DICT["scale"] = 3.0 + y1 = compiled_read(x) + graphs_after_first = torch._dynamo.utils.counters["stats"].get("unique_graphs", "?") + print(f" GLOBAL_DICT = {GLOBAL_DICT}") + print(f" compiled_read(x) = {y1.item()} (expected {x.item() * 3.0})") + print(f" unique_graphs after 1st call: {graphs_after_first}") + + # Change the dict value and call again – should Dynamo pick up the change? + GLOBAL_DICT["scale"] = 5.0 + y2 = compiled_read(x) + graphs_after_second = torch._dynamo.utils.counters["stats"].get("unique_graphs", "?") + print(f" After mutating scale to 5.0:") + print(f" compiled_read(x) = {y2.item()} (expected {x.item() * 5.0})") + print(f" unique_graphs after 2nd call: {graphs_after_second}") + + if graphs_after_second != graphs_after_first: + print(f" NOTE: Dynamo recompiled (graphs: {graphs_after_first} -> {graphs_after_second})") + else: + print(f" NOTE: Dynamo did NOT recompile (same graph count)") + + if abs(y2.item() - x.item() * 5.0) < 1e-6: + print(" PASS: result reflects updated dict value") + else: + print(" FAIL: result does NOT reflect updated dict value (guard baked-in old value)") + + # ----------------------------------------------------------------------- + # Part 2 – writing / mutating the global dict inside the compiled fn + # ----------------------------------------------------------------------- + print("\n--- Part 2: writing into a global dict ---") + reset() + torch._dynamo.reset() + + print(f" GLOBAL_DICT before call: {GLOBAL_DICT}") + compiled_write(x, "result", 42) + print(f" GLOBAL_DICT after call: {GLOBAL_DICT}") + + if GLOBAL_DICT.get("result") == 42: + print(" PASS: global dict mutation is visible after compiled call") + else: + print(" FAIL: global dict mutation is NOT visible (side-effect was dropped)") + + # ----------------------------------------------------------------------- + # Part 3 – mutation of value that is a Tensor + # ----------------------------------------------------------------------- + print("\n--- Part 3: storing a Tensor into the global dict ---") + reset() + torch._dynamo.reset() + + compiled_write(x, "tensor_val", torch.tensor(99.0)) + print(f" GLOBAL_DICT after call: {GLOBAL_DICT}") + if "tensor_val" in GLOBAL_DICT: + print(" PASS: tensor stored in global dict is visible") + else: + print(" FAIL: tensor NOT stored") + + print("\nDone.") + + +if __name__ == "__main__": + run() diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py new file mode 100644 index 0000000000..20140c0a0c --- /dev/null +++ b/tests/pytorch/test_torch_compile.py @@ -0,0 +1,304 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.common import recipe +from transformer_engine.pytorch.module.base import TransformerEngineBaseModule +from transformer_engine.pytorch.ops.basic.basic_linear import BasicLinear + + +# --------------------------------------------------------------------------- +# ToyLinear – minimal TE module backed by BasicLinear functional ops +# --------------------------------------------------------------------------- + +# Global list of ToyLinear instances. Each module registers itself here on +# construction; the custom op identifies which module to use via an integer +# index so that no Python object ever enters the compiled graph. +_toy_modules: list["ToyLinear"] = [] + + +class ToyLinear(TransformerEngineBaseModule): + """Minimal TE-compatible linear module used for torch.compile tests. + + Inherits TransformerEngineBaseModule so that prepare_forward / end_forward + and the FP8 metadata machinery work exactly as in production modules. + The actual compute is delegated to BasicLinear._functional_forward / + _functional_backward via the opaque custom op below. + """ + + def __init__( + self, + in_features: int, + out_features: int, + device: str = "cuda", + dtype: torch.dtype = torch.bfloat16, + ) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = torch.nn.Parameter( + torch.empty(out_features, in_features, dtype=dtype, device=device) + ) + torch.nn.init.normal_(self.weight) + + # Register in the global list and remember the index. + self._module_idx = len(_toy_modules) + _toy_modules.append(self) + + # -- required abstract overrides ----------------------------------------- + + def _get_weight_tensors(self): + return [self.weight] + + def _get_weight_quantizers(self): + # Weight quantizer: use FP8 scaling when FP8 is enabled. + if not self.fp8 and not self.fp8_calibration: + return [None] + weight_q = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + weight_q.internal = True + return [weight_q] + + # -- quantizer helpers (mirrors what Linear._get_quantizers does) --------- + + def get_forward_quantizers(self): + """Return (input_q, weight_q) for the forward GEMM.""" + if not self.fp8: + return None, None + input_q = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + input_q.internal = True + input_q.optimize_for_gemm = True + (weight_q,) = self._get_weight_quantizers() + return input_q, weight_q + + def get_backward_quantizers(self): + """Return (grad_output_q, grad_input_q) for the backward GEMMs.""" + if not self.fp8: + return None, None + grad_output_q = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] + grad_output_q.internal = True + grad_output_q.optimize_for_gemm = True + return grad_output_q, None + + # -- forward ------------------------------------------------------------- + + def forward(self, inp: torch.Tensor) -> torch.Tensor: + inp = self.prepare_forward(inp, num_gemms=1) + try: + return _toy_linear_fwd_op(inp, self.weight, self._module_idx) + finally: + self.end_forward() + + +# --------------------------------------------------------------------------- +# Opaque custom op (torch.library) +# --------------------------------------------------------------------------- + +@torch.library.custom_op("test_te::toy_linear", mutates_args=[]) +def _toy_linear_fwd_op( + inp: torch.Tensor, weight: torch.Tensor, module_idx: int +) -> torch.Tensor: + """Forward GEMM wrapped as an opaque custom op.""" + module = _toy_modules[module_idx] + input_q, weight_q = module.get_forward_quantizers() + out, _, _ = BasicLinear._functional_forward( + input=inp, + weight=weight, + dtype=inp.dtype, + input_quantizer=input_q, + weight_quantizer=weight_q, + ) + return out + + +@_toy_linear_fwd_op.register_fake +def _(inp: torch.Tensor, weight: torch.Tensor, module_idx: int) -> torch.Tensor: + """Abstract implementation for shape inference under torch.compile.""" + return inp @ weight.T + + +def _toy_linear_setup_context(ctx, inputs, output): + inp, weight, module_idx = inputs + ctx.save_for_backward(inp, weight) + ctx.module_idx = module_idx + + +@torch.library.custom_op("test_te::toy_linear_backward", mutates_args=[]) +def _toy_linear_bwd_op( + grad_output: torch.Tensor, inp: torch.Tensor, weight: torch.Tensor, module_idx: int +) -> tuple[torch.Tensor, torch.Tensor]: + """Opaque backward op wrapping BasicLinear._functional_backward.""" + module = _toy_modules[module_idx] + grad_output_q, grad_input_q = module.get_backward_quantizers() + dx, dw = BasicLinear._functional_backward( + grad_output=grad_output, + input=inp, + weight=weight, + grad_output_quantizer=grad_output_q, + grad_input_quantizer=grad_input_q, + ) + return dx, dw + + +@_toy_linear_bwd_op.register_fake +def _(grad_output: torch.Tensor, inp: torch.Tensor, weight: torch.Tensor, module_idx: int): + """Abstract backward implementation for shape inference under torch.compile.""" + return torch.empty_like(inp), torch.empty_like(weight) + + +def _toy_linear_backward(ctx, grad_output: torch.Tensor): + """Backward: dispatch to opaque custom op so TE backward is not traced.""" + inp, weight = ctx.saved_tensors + dx, dw = _toy_linear_bwd_op(grad_output, inp, weight, ctx.module_idx) + return dx, dw, None # None for module_idx gradient + + +torch.library.register_autograd( + "test_te::toy_linear", + _toy_linear_backward, + setup_context=_toy_linear_setup_context, +) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +_fp8_available = te.is_fp8_available() +_mxfp8_available = te.is_mxfp8_available() +_fp8_block_scaling_available = te.is_fp8_block_scaling_available() + +# Each entry is (fp8_recipe, fp8_enabled). +# For the "no_fp8" variant, enabled=False so autocast is a no-op, but we still +# pass a real pre-created recipe object so that get_default_fp8_recipe() is +# never called during compilation (which would assert-fail inside torch.compile). +_recipes = [ + pytest.param(recipe.DelayedScaling(), False, id="no_fp8"), + pytest.param( + recipe.Float8CurrentScaling(), + True, + id="float8_current_scaling", + marks=pytest.mark.skipif(not _fp8_available, reason="FP8 not supported"), + ), + pytest.param( + recipe.MXFP8BlockScaling(), + True, + id="mxfp8_block_scaling", + marks=pytest.mark.skipif(not _mxfp8_available, reason="MXFP8 not supported"), + ), + pytest.param( + recipe.Float8BlockScaling(), + True, + id="float8_block_scaling", + marks=pytest.mark.skipif( + not _fp8_block_scaling_available, reason="FP8 block scaling not supported" + ), + ), +] + + +@pytest.mark.parametrize("fp8_recipe,fp8_enabled", _recipes) +def test_autocast(fp8_recipe, fp8_enabled): + """Test that ToyLinear inside te.autocast compiles without graph breaks. + + fullgraph=True makes torch.compile raise an error if any graph break occurs. + Parametrized over all supported recipes. The no_fp8 variant uses + enabled=False so the autocast is a no-op, but still passes a real + pre-created recipe object to avoid calling get_default_fp8_recipe() during + compilation. + """ + global _toy_modules + _toy_modules = [] + + dtype = torch.bfloat16 + device = "cuda" + + model = ToyLinear(32, 64, device=device, dtype=dtype) + inp = torch.randn(8, 32, dtype=dtype, device=device, requires_grad=True) + + def fn(inp): + with te.autocast(recipe=fp8_recipe, enabled=fp8_enabled): + return model(inp) + + torch._dynamo.reset() + compiled = torch.compile(fn, fullgraph=True) + + out = compiled(inp) + out.sum().backward() + + +@pytest.mark.skipif(not _fp8_available, reason="FP8 not supported") +def test_autocast_delayed_scaling_unsupported(): + """DelayedScaling should fail with a clear error under torch.compile.""" + global _toy_modules + _toy_modules = [] + + dtype = torch.bfloat16 + device = "cuda" + + model = ToyLinear(32, 64, device=device, dtype=dtype) + inp = torch.randn(8, 32, dtype=dtype, device=device, requires_grad=True) + fp8_recipe = recipe.DelayedScaling() + + def fn(inp): + with te.autocast(recipe=fp8_recipe, enabled=True): + return model(inp) + + torch._dynamo.reset() + compiled = torch.compile(fn, fullgraph=True) + + with pytest.raises(RuntimeError, match="DelayedScaling is not supported under torch.compile"): + compiled(inp) + + +@pytest.mark.skipif(not te.is_fp8_available(), reason="FP8 not supported on this GPU") +def test_autocast_nested(): + """Test sequential model with different FP8 recipes and nested te.autocast. + + Layout: + with autocast(Float8CurrentScaling): # outer + out = m0(inp) + with autocast(Float8CurrentScaling): # nested inside outer + out = m1(out) + with autocast(Float8CurrentScaling): # separate, after the nested pair + out = m2(out) + + fullgraph=True makes torch.compile raise an error if any graph break occurs. + """ + global _toy_modules + _toy_modules = [] + + dtype = torch.bfloat16 + device = "cuda" + + m0 = ToyLinear(32, 32, device=device, dtype=dtype) + m1 = ToyLinear(32, 32, device=device, dtype=dtype) + m2 = ToyLinear(32, 32, device=device, dtype=dtype) + + # Use distinct recipe objects so nested/separate autocast contexts use + # different identities under torch.compile. + recipe_current0 = recipe.Float8CurrentScaling() + recipe_current1 = recipe.Float8CurrentScaling() + recipe_current2 = recipe.Float8CurrentScaling() + + inp = torch.randn(8, 32, dtype=dtype, device=device, requires_grad=True) + + def fn(inp): + with te.autocast(recipe=recipe_current0): # outer + out = m0(inp) + with te.autocast(recipe=recipe_current1): # nested inside outer + out = m1(out) + + with te.autocast(recipe=recipe_current2): # separate, after nested pair + out = m2(out) + return out + + torch._dynamo.reset() + compiled = torch.compile(fn, fullgraph=True) + + out = compiled(inp) + out.sum().backward() diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 64db4646f6..c65c047510 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -704,7 +704,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: autocast_key = FP8GlobalStateManager.get_unique_autocast_key( fp8_recipe_dpa, fp8_group ) - FP8GlobalStateManager.autocast_arguments[autocast_key] = ( + FP8GlobalStateManager.get_autocast_arguments()[autocast_key] = ( fp8_recipe_dpa, fp8_group, ) @@ -736,7 +736,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: autocast_key = FP8GlobalStateManager.get_unique_autocast_key( fp8_recipe_dpa, fp8_group ) - FP8GlobalStateManager.autocast_arguments[autocast_key] = ( + FP8GlobalStateManager.get_autocast_arguments()[autocast_key] = ( fp8_recipe_dpa, fp8_group, ) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index f269e21b8c..c194b66fe1 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -255,10 +255,10 @@ def __enter__(self): if self.activation_recompute and not self.recompute_phase: activation_recompute_forward._is_first_fp8_module.append( - FP8GlobalStateManager.IS_FIRST_FP8_MODULE + FP8GlobalStateManager.peek_is_first_fp8_module() ) if self.activation_recompute and self.recompute_phase: - FP8GlobalStateManager.IS_FIRST_FP8_MODULE = ( + FP8GlobalStateManager.set_is_first_fp8_module( activation_recompute_forward._is_first_fp8_module.pop(0) ) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 09b12afa21..c7266ad23b 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -693,14 +693,14 @@ def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> FP8GlobalStateManager.get_buffer_info() ] for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)): - if buffer_key in FP8GlobalStateManager.global_amax_buffer: + if buffer_key in FP8GlobalStateManager.get_global_amax_buffer(): assert ( - buffer_key in FP8GlobalStateManager.global_amax_history_buffer + buffer_key in FP8GlobalStateManager.get_global_amax_history_buffer() ), "TE internal error during amax history change." - FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = self.fp8_meta[ - meta_key - ].amax_history[0] - FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = ( + FP8GlobalStateManager.get_global_amax_buffer()[buffer_key][pos] = ( + self.fp8_meta[meta_key].amax_history[0] + ) + FP8GlobalStateManager.get_global_amax_history_buffer()[buffer_key][pos] = ( self.fp8_meta[meta_key].amax_history ) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 702916696b..20615aa5f6 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -516,10 +516,10 @@ def forward( ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias): - _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE + _first_fp8_module = FP8GlobalStateManager.peek_is_first_fp8_module() ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): - FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module + FP8GlobalStateManager.set_is_first_fp8_module(_first_fp8_module) ctx.wgrad_store = wgrad_store ctx.debug = debug diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 4532ea60e7..8129e9b2e2 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -832,10 +832,10 @@ def _forward( if ctx.fp8 and requires_grad( inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias ): - _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE + _first_fp8_module = FP8GlobalStateManager.peek_is_first_fp8_module() ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase() or is_recomputation: - FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module + FP8GlobalStateManager.set_is_first_fp8_module(_first_fp8_module) ctx.wgrad_store = wgrad_store if is_recomputation: # return the recomputed tensors diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 23ad8cacb0..28710eee14 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -480,10 +480,10 @@ def forward( ctx.owns_input = saved_inputmat is not inp if ctx.fp8 and requires_grad(inp, weight, bias): - _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE + _first_fp8_module = FP8GlobalStateManager.peek_is_first_fp8_module() ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): - FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module + FP8GlobalStateManager.set_is_first_fp8_module(_first_fp8_module) ctx.wgrad_store = wgrad_store # ------------------------------------------------------ diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 54b3f00117..f79eab35bf 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -322,14 +322,14 @@ def reset_recipe_state( pos, buffer_key = self._fp8_metas[mode][ FP8GlobalStateManager.get_buffer_info() ] - if buffer_key in FP8GlobalStateManager.global_amax_buffer: + if buffer_key in FP8GlobalStateManager.get_global_amax_buffer(): assert ( - buffer_key in FP8GlobalStateManager.global_amax_history_buffer + buffer_key in FP8GlobalStateManager.get_global_amax_history_buffer() ), "TE internal error during amax history change." - FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = ( + FP8GlobalStateManager.get_global_amax_buffer()[buffer_key][pos] = ( recipe_state.amax_history[0] ) - FP8GlobalStateManager.global_amax_history_buffer[buffer_key][ + FP8GlobalStateManager.get_global_amax_history_buffer()[buffer_key][ pos ] = recipe_state.amax_history diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index eba547afb0..7e6bdcd5b6 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -7,9 +7,9 @@ import abc import itertools -import functools import warnings import os +from dataclasses import dataclass, field from contextlib import contextmanager from collections import deque from typing import Callable, List, Optional, Dict, Any, Tuple, Union @@ -44,8 +44,13 @@ ] -@functools.lru_cache(maxsize=None) -def check_fp8_support() -> Tuple[bool, str]: +_FP8_SUPPORT: Optional[Tuple[bool, str]] = None +_MXFP8_SUPPORT: Optional[Tuple[bool, str]] = None +_NVFP4_SUPPORT: Optional[Tuple[bool, str]] = None +_FP8_BLOCK_SCALING_SUPPORT: Optional[Tuple[bool, str]] = None + + +def _compute_fp8_support() -> Tuple[bool, str]: """Return if fp8 support is available""" if get_device_compute_capability() >= (9, 0): # hopper and above return True, "" @@ -58,8 +63,7 @@ def check_fp8_support() -> Tuple[bool, str]: return True, "" -@functools.lru_cache(maxsize=None) -def check_mxfp8_support() -> Tuple[bool, str]: +def _compute_mxfp8_support() -> Tuple[bool, str]: """Return if fp8 support is available""" if get_device_compute_capability() >= (12, 0): return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet." @@ -68,16 +72,14 @@ def check_mxfp8_support() -> Tuple[bool, str]: return False, "Device compute capability 10.0 or higher required for MXFP8 execution." -@functools.lru_cache(maxsize=None) -def check_nvfp4_support() -> Tuple[bool, str]: +def _compute_nvfp4_support() -> Tuple[bool, str]: """Return if nvfp4 support is available""" if get_device_compute_capability() >= (10, 0): # blackwell and above return True, "" return False, "Device compute capability 10.0 or higher required for NVFP4 execution." -@functools.lru_cache(maxsize=None) -def check_fp8_block_scaling_support() -> Tuple[bool, str]: +def _compute_fp8_block_scaling_support() -> Tuple[bool, str]: """Return if fp8 block scaling support is available""" if get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.9: return True, "" @@ -87,8 +89,45 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]: ) +def check_fp8_support() -> Tuple[bool, str]: + """Return if fp8 support is available.""" + global _FP8_SUPPORT + if _FP8_SUPPORT is None: + _FP8_SUPPORT = _compute_fp8_support() + return _FP8_SUPPORT + + +def check_mxfp8_support() -> Tuple[bool, str]: + """Return if MXFP8 support is available.""" + global _MXFP8_SUPPORT + if _MXFP8_SUPPORT is None: + _MXFP8_SUPPORT = _compute_mxfp8_support() + return _MXFP8_SUPPORT + + +def check_nvfp4_support() -> Tuple[bool, str]: + """Return if NVFP4 support is available.""" + global _NVFP4_SUPPORT + if _NVFP4_SUPPORT is None: + _NVFP4_SUPPORT = _compute_nvfp4_support() + return _NVFP4_SUPPORT + + +def check_fp8_block_scaling_support() -> Tuple[bool, str]: + """Return if fp8 block scaling support is available.""" + global _FP8_BLOCK_SCALING_SUPPORT + if _FP8_BLOCK_SCALING_SUPPORT is None: + _FP8_BLOCK_SCALING_SUPPORT = _compute_fp8_block_scaling_support() + return _FP8_BLOCK_SCALING_SUPPORT + + def check_recipe_support(recipe: Recipe) -> None: """Check if the given recipe is supported.""" + if torch.compiler.is_compiling() and isinstance(recipe, DelayedScaling): + raise RuntimeError( + "DelayedScaling is not supported under torch.compile yet. " + "Use Float8CurrentScaling, MXFP8BlockScaling, or Float8BlockScaling instead." + ) recipe_supported = True unsupported_reason = "" if isinstance(recipe, (DelayedScaling, Float8CurrentScaling)): @@ -102,6 +141,10 @@ def check_recipe_support(recipe: Recipe) -> None: def get_default_fp8_recipe() -> Recipe: """FP8 recipe with default args.""" + assert not torch.compiler.is_compiling(), ( + "get_default_fp8_recipe() must not be called during torch.compile tracing. " + "Pass an explicit recipe to te.autocast() instead of relying on the default." + ) if check_mxfp8_support()[0]: return MXFP8BlockScaling() if get_device_compute_capability() >= (12, 0): @@ -231,71 +274,68 @@ def is_nvfp4_available(return_reason: bool = False) -> Union[bool, Tuple[bool, s return check_nvfp4_support()[0] +@dataclass(slots=True) +class FP8GlobalState: + """Mutable process-global FP8 state stored on an instance. + + Using an instance avoids class-level `setattr(type, ...)` writes, which + `torch.compile` cannot trace in fullgraph mode. + """ + + fp8_enabled: bool = False + fp8_calibration: bool = False + fp8_recipe: Optional[Recipe] = None + fp8_distributed_group: Optional[dist_group_type] = None + fp8_parameters: bool = False + high_precision_init_val: bool = False + is_first_fp8_module: bool = False + fp8_graph_capturing: bool = False + autocast_depth: int = 0 + global_amax_buffer: Dict[str, list] = field(default_factory=dict) + global_amax_history_buffer: Dict[str, list] = field(default_factory=dict) + global_scale_buffer: Dict[str, list] = field(default_factory=dict) + fp8_tensors_recompute_buffer: list = field(default_factory=list) + fp8_available: Optional[bool] = None + reason_for_no_fp8: str = "" + autocast_arguments: Dict[Any, Tuple[Recipe, Optional[dist_group_type]]] = field( + default_factory=dict + ) + skip_fp8_weight_update_tensor: Optional[torch.Tensor] = None + mxfp8_available: Optional[bool] = None + reason_for_no_mxfp8: str = "" + fp8_block_scaling_available: Optional[bool] = None + reason_for_no_fp8_block_scaling: str = "" + nvfp4_available: Optional[bool] = None + reason_for_no_nvfp4: str = "" + + +_fp8_state = FP8GlobalState() + + class FP8GlobalStateManager: """Class to keep track of and manipulate the global FP8 state at different stages of execution. """ - FP8_ENABLED = False - FP8_CALIBRATION = False - FP8_RECIPE = None - FP8_DISTRIBUTED_GROUP = None - FP8_PARAMETERS = False - HIGH_PRECISION_INIT_VAL = False - IS_FIRST_FP8_MODULE = False - FP8_GRAPH_CAPTURING = False - AUTOCAST_DEPTH = 0 - global_amax_buffer = {} - global_amax_history_buffer = {} - global_scale_buffer = {} - fp8_tensors_recompute_buffer = [] - fp8_available = None - reason_for_no_fp8 = "" - autocast_arguments = {} - skip_fp8_weight_update_tensor = None - mxfp8_available = None - reason_for_no_mxfp8 = "" - fp8_block_scaling_available = None - reason_for_no_fp8_block_scaling = None - nvfp4_available = None - reason_for_no_nvfp4 = "" - @classmethod def reset(cls) -> None: """Reset the global state""" - cls.FP8_ENABLED = False - cls.FP8_CALIBRATION = False - cls.FP8_RECIPE = None - cls.FP8_DISTRIBUTED_GROUP = None - cls.FP8_PARAMETERS = False - cls.HIGH_PRECISION_INIT_VAL = False - cls.IS_FIRST_FP8_MODULE = False - cls.FP8_GRAPH_CAPTURING = False - cls.AUTOCAST_DEPTH = 0 - cls.global_amax_buffer = {} - cls.global_amax_history_buffer = {} - cls.global_scale_buffer = {} - cls.fp8_tensors_recompute_buffer = [] - cls.fp8_available = None - cls.reason_for_no_fp8 = "" - cls.autocast_arguments = {} - cls.skip_fp8_weight_update_tensor = None - cls.mxfp8_available = None - cls.reason_for_no_mxfp8 = "" - cls.fp8_block_scaling_available = None - cls.reason_for_no_fp8_block_scaling = "" + global _fp8_state + _fp8_state = FP8GlobalState() @classmethod def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: """`skip_fp8_weight_update_tensor` inplace setter.""" - if cls.skip_fp8_weight_update_tensor is None: - cls.skip_fp8_weight_update_tensor = torch.empty(1, dtype=torch.float32, device="cuda") - cls.skip_fp8_weight_update_tensor.fill_(skip) + if _fp8_state.skip_fp8_weight_update_tensor is None: + _fp8_state.skip_fp8_weight_update_tensor = torch.empty( + 1, dtype=torch.float32, device="cuda" + ) + _fp8_state.skip_fp8_weight_update_tensor.fill_(skip) @classmethod def get_skip_fp8_weight_update_tensor(cls) -> None: """`skip_fp8_weight_update_tensor` getter.""" - return cls.skip_fp8_weight_update_tensor + return _fp8_state.skip_fp8_weight_update_tensor @classmethod def is_fp8_available(cls) -> Tuple[bool, str]: @@ -397,75 +437,92 @@ def add_fp8_tensors_to_global_buffer( key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"]) - if key not in cls.global_amax_buffer: - cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] - cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history] - cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] + if key not in _fp8_state.global_amax_buffer: + _fp8_state.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] + _fp8_state.global_amax_history_buffer[key] = [ + fp8_meta[fp8_meta_tensor_key].amax_history + ] + _fp8_state.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] else: - cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) - cls.global_amax_history_buffer[key].append( + _fp8_state.global_amax_buffer[key].append( + fp8_meta[fp8_meta_tensor_key].amax_history[0] + ) + _fp8_state.global_amax_history_buffer[key].append( fp8_meta[fp8_meta_tensor_key].amax_history ) - cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) - fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1) + _fp8_state.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) + fp8_meta[index_in_buffer].append(len(_fp8_state.global_amax_buffer[key]) - 1) fp8_meta[index_in_buffer].append(key) @classmethod def is_fp8_enabled(cls) -> bool: """Is FP8 enabled""" - return cls.FP8_ENABLED + return _fp8_state.fp8_enabled @classmethod def is_fp8_calibration(cls) -> bool: """Is FP8 calibration""" - return cls.FP8_CALIBRATION + return _fp8_state.fp8_calibration @classmethod def with_fp8_parameters(cls) -> bool: """Should the parameters be stored as FP8""" - return cls.FP8_PARAMETERS + return _fp8_state.fp8_parameters @classmethod def with_high_precision_init_val(cls) -> bool: """Should the high precision initial values be stored with FP8 parameters""" - return cls.HIGH_PRECISION_INIT_VAL + return _fp8_state.high_precision_init_val @classmethod def fp8_graph_capturing(cls) -> bool: """Is CUDA graph capture under way?""" - return cls.FP8_GRAPH_CAPTURING or torch.cuda.is_current_stream_capturing() + if torch.compiler.is_compiling(): + assert not _fp8_state.fp8_graph_capturing + return False + return _fp8_state.fp8_graph_capturing or torch.cuda.is_current_stream_capturing() @classmethod def is_first_fp8_module(cls): """Returns `True` only the first time when called multiple times from within the same `autocast` context. """ - tmp = cls.IS_FIRST_FP8_MODULE - cls.IS_FIRST_FP8_MODULE = False + tmp = _fp8_state.is_first_fp8_module + _fp8_state.is_first_fp8_module = False return tmp + @classmethod + def peek_is_first_fp8_module(cls) -> bool: + """Return the current first-module flag without consuming it.""" + return _fp8_state.is_first_fp8_module + + @classmethod + def set_is_first_fp8_module(cls, value: bool) -> None: + """Set the first-module flag.""" + _fp8_state.is_first_fp8_module = value + @classmethod def get_fp8_recipe(cls) -> Recipe: """Return the fp8 recipe""" - if cls.FP8_RECIPE is not None: - return cls.FP8_RECIPE + if _fp8_state.fp8_recipe is not None: + return _fp8_state.fp8_recipe return get_default_fp8_recipe() @classmethod def get_fp8_group(cls) -> Union[dist_group_type, None]: """Return the fp8 group for scale/amax comm""" - return cls.FP8_DISTRIBUTED_GROUP + return _fp8_state.fp8_distributed_group @classmethod def get_autocast_state(cls) -> Tuple[bool, bool, Recipe, dist_group_type, bool]: """FP8 autocast state getter""" return ( - cls.FP8_ENABLED, - cls.FP8_CALIBRATION, - cls.FP8_RECIPE, - cls.FP8_DISTRIBUTED_GROUP, - cls.IS_FIRST_FP8_MODULE, - cls.FP8_GRAPH_CAPTURING, + _fp8_state.fp8_enabled, + _fp8_state.fp8_calibration, + _fp8_state.fp8_recipe, + _fp8_state.fp8_distributed_group, + _fp8_state.is_first_fp8_module, + _fp8_state.fp8_graph_capturing, ) @classmethod @@ -474,14 +531,29 @@ def set_autocast_state( ) -> None: """FP8 autocast state setter""" ( - cls.FP8_ENABLED, - cls.FP8_CALIBRATION, - cls.FP8_RECIPE, - cls.FP8_DISTRIBUTED_GROUP, - cls.IS_FIRST_FP8_MODULE, - cls.FP8_GRAPH_CAPTURING, + _fp8_state.fp8_enabled, + _fp8_state.fp8_calibration, + _fp8_state.fp8_recipe, + _fp8_state.fp8_distributed_group, + _fp8_state.is_first_fp8_module, + _fp8_state.fp8_graph_capturing, ) = fp8_state + @classmethod + def get_global_amax_buffer(cls) -> Dict[str, list]: + """Return the global amax buffer.""" + return _fp8_state.global_amax_buffer + + @classmethod + def get_global_amax_history_buffer(cls) -> Dict[str, list]: + """Return the global amax history buffer.""" + return _fp8_state.global_amax_history_buffer + + @classmethod + def get_autocast_arguments(cls) -> Dict[Any, Tuple[Recipe, Optional[dist_group_type]]]: + """Return autocast arguments keyed by autocast identity.""" + return _fp8_state.autocast_arguments + @staticmethod def reduce_tensor_across_group_op_max(tensor: torch.Tensor, group: dist_group_type) -> None: """Reduce tensor across given group.""" @@ -500,7 +572,7 @@ def reduce_and_update_fp8_tensors( ) -> None: """Delayed scaling only. Concatenate, reduce, and split amaxes in the global buffer.""" # global_amax_buffer should only be non-empty for fp8 delayed scaling - for buffer_key, amax_buffer in cls.global_amax_buffer.items(): + for buffer_key, amax_buffer in _fp8_state.global_amax_buffer.items(): # Check for forward or backward reduction. fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key) if fwd_update != forward: @@ -509,7 +581,7 @@ def reduce_and_update_fp8_tensors( continue # Retrieve autocast specific args and concat amaxes. - recipe, group = cls.autocast_arguments[autocast_key] + recipe, group = _fp8_state.autocast_arguments[autocast_key] contiguous_amax = torch.cat(amax_buffer) # Reduction. @@ -530,8 +602,8 @@ def reduce_and_update_fp8_tensors( if not unfused_update: tex.fused_amax_and_scale_update_after_reduction( contiguous_amax, - cls.global_amax_history_buffer[buffer_key], - cls.global_scale_buffer[buffer_key], + _fp8_state.global_amax_history_buffer[buffer_key], + _fp8_state.global_scale_buffer[buffer_key], recipe.amax_compute_algo, get_fp8_te_dtype(recipe, forward), recipe.margin, @@ -540,8 +612,8 @@ def reduce_and_update_fp8_tensors( split_and_copy(contiguous_amax, amax_buffer, [x.numel() for x in amax_buffer]) for amax_history, scale in zip( - cls.global_amax_history_buffer[buffer_key], - cls.global_scale_buffer[buffer_key], + _fp8_state.global_amax_history_buffer[buffer_key], + _fp8_state.global_scale_buffer[buffer_key], ): _amax_and_scale_update( amax_history, scale, get_fp8_max(recipe, forward), recipe @@ -555,9 +627,11 @@ def get_unique_autocast_key( ): """ For FP8, each autocast can be uniquely identified by the recipe and fp8 group. - Safely using `hash` as we never cross checkpoint boundaries. + Object identity is sufficient since autocast contexts never outlive a single + training session. """ - return f"{str(recipe)}:{hash(group)}" + return str((id(recipe) if recipe is not None else None, + id(group) if group is not None else None)) @classmethod def autocast_enter( @@ -572,17 +646,17 @@ def autocast_enter( fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) - cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group) + _fp8_state.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group) - cls.FP8_ENABLED = enabled - cls.FP8_CALIBRATION = calibrating - cls.FP8_RECIPE = fp8_recipe - cls.FP8_DISTRIBUTED_GROUP = fp8_group - cls.FP8_GRAPH_CAPTURING = _graph + _fp8_state.fp8_enabled = enabled + _fp8_state.fp8_calibration = calibrating + _fp8_state.fp8_recipe = fp8_recipe + _fp8_state.fp8_distributed_group = fp8_group + _fp8_state.fp8_graph_capturing = _graph - if cls.AUTOCAST_DEPTH == 0: - cls.IS_FIRST_FP8_MODULE = True - cls.AUTOCAST_DEPTH += 1 + if _fp8_state.autocast_depth == 0: + _fp8_state.is_first_fp8_module = True + _fp8_state.autocast_depth += 1 if enabled: fp8_available, reason_for_no_fp8 = cls.is_fp8_available() @@ -600,11 +674,11 @@ def autocast_enter( @classmethod def autocast_exit(cls, enabled: bool, _graph: bool) -> None: """Set state and tracking variables for exit from FP8 region.""" - cls.AUTOCAST_DEPTH -= 1 + _fp8_state.autocast_depth -= 1 # Reduce only the non-FP8 weight modules here. # FP8 weight modules are reduced at the end of the optimizer # step after the weight amax is populated. - if enabled and cls.AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled(): + if enabled and _fp8_state.autocast_depth == 0 and not _graph and torch.is_grad_enabled(): # delayed scaling only function, for other recipes (current scaling with any granularity), # this is noop for other recipes because cls.global_amax_buffer is empty list cls.reduce_and_update_fp8_tensors(forward=True) @@ -627,14 +701,14 @@ def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) - ] if buffer_position_key in fp8_meta: - cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy) + _fp8_state.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy) else: - if len(cls.fp8_tensors_recompute_buffer) == 0: - cls.fp8_tensors_recompute_buffer = [deque()] + if len(_fp8_state.fp8_tensors_recompute_buffer) == 0: + _fp8_state.fp8_tensors_recompute_buffer = [deque()] else: - cls.fp8_tensors_recompute_buffer.append(deque()) - cls.fp8_tensors_recompute_buffer[-1].append(to_copy) - fp8_meta[buffer_position_key] = len(cls.fp8_tensors_recompute_buffer) - 1 + _fp8_state.fp8_tensors_recompute_buffer.append(deque()) + _fp8_state.fp8_tensors_recompute_buffer[-1].append(to_copy) + fp8_meta[buffer_position_key] = len(_fp8_state.fp8_tensors_recompute_buffer) - 1 @classmethod def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: @@ -651,7 +725,9 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non # Retrieve stashed amaxes and scales from phase 1 pre forward. buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" - stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].popleft() + stashed_fp8_meta = _fp8_state.fp8_tensors_recompute_buffer[ + fp8_meta[buffer_position_key] + ].popleft() # Replace amaxes and scales with stashed values for phase 2 forward fp8_meta["scaling_fwd"].amax_history.copy_(stashed_fp8_meta[0]) @@ -748,18 +824,18 @@ def quantized_model_init( This functionality is *EXPERIMENTAL*. """ - _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS - _fp8_recipe = FP8GlobalStateManager.FP8_RECIPE - _high_precision_init_val = FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL - FP8GlobalStateManager.FP8_PARAMETERS = enabled - FP8GlobalStateManager.FP8_RECIPE = get_default_fp8_recipe() if recipe is None else recipe - FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = preserve_high_precision_init_val + _fp8_parameters = _fp8_state.fp8_parameters + _fp8_recipe = _fp8_state.fp8_recipe + _high_precision_init_val = _fp8_state.high_precision_init_val + _fp8_state.fp8_parameters = enabled + _fp8_state.fp8_recipe = get_default_fp8_recipe() if recipe is None else recipe + _fp8_state.high_precision_init_val = preserve_high_precision_init_val try: yield finally: - FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters - FP8GlobalStateManager.FP8_RECIPE = _fp8_recipe - FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = _high_precision_init_val + _fp8_state.fp8_parameters = _fp8_parameters + _fp8_state.fp8_recipe = _fp8_recipe + _fp8_state.high_precision_init_val = _high_precision_init_val @contextmanager From 31a40762466ade2fe4339b64ff1a8399c989379a Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 13 Mar 2026 13:44:02 +0100 Subject: [PATCH 2/6] Remove temporary global state experiment tests. Drop the standalone global dict and dataclass mutation experiments now that the torch.compile regression coverage lives in the focused autocast test file. Signed-off-by: Pawel Gadzinski --- .../pytorch/test_global_dataclass_mutation.py | 155 ------------------ tests/pytorch/test_global_dict_mutation.py | 140 ---------------- 2 files changed, 295 deletions(-) delete mode 100644 tests/pytorch/test_global_dataclass_mutation.py delete mode 100644 tests/pytorch/test_global_dict_mutation.py diff --git a/tests/pytorch/test_global_dataclass_mutation.py b/tests/pytorch/test_global_dataclass_mutation.py deleted file mode 100644 index fb10a6d4a0..0000000000 --- a/tests/pytorch/test_global_dataclass_mutation.py +++ /dev/null @@ -1,155 +0,0 @@ -""" -Experiment: can torch.compile handle mutation of a global dataclass? - -Analogous to the global-dict experiment, but uses a dataclass instance -stored as a module-level global instead of a plain dict. - -Parts: - 1. Read a field from the global dataclass, check if recompilation happens - when the field value changes. - 2. Write a Python scalar to a dataclass field inside a compiled function. - 3. Write a Tensor to a dataclass field inside a compiled function. -""" - -from dataclasses import dataclass, field -from typing import Optional - -import torch - - -# --------------------------------------------------------------------------- -# Global dataclass -# --------------------------------------------------------------------------- - -@dataclass -class State: - scale: float = 1.0 - result: Optional[int] = None - tensor_val: Optional[torch.Tensor] = None - - -GLOBAL_STATE = State() - - -# --------------------------------------------------------------------------- -# Functions that access / mutate the global dataclass -# --------------------------------------------------------------------------- - - -def fn_read_dataclass(x: torch.Tensor) -> torch.Tensor: - """Read scale from the global dataclass and multiply x by it.""" - return x * GLOBAL_STATE.scale - - -def fn_write_scalar(x: torch.Tensor, value: int) -> torch.Tensor: - """Write a Python scalar to the global dataclass, return x unchanged.""" - GLOBAL_STATE.result = value - return x + 0 - - -def fn_write_tensor(x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - """Write a Tensor to the global dataclass, return x unchanged.""" - GLOBAL_STATE.tensor_val = t - return x + 0 - - -# --------------------------------------------------------------------------- -# Compiled versions -# --------------------------------------------------------------------------- -compiled_read = torch.compile(fn_read_dataclass, fullgraph=False) -compiled_write_scalar = torch.compile(fn_write_scalar, fullgraph=False) -compiled_write_tensor = torch.compile(fn_write_tensor, fullgraph=False) - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -def reset(): - global GLOBAL_STATE - GLOBAL_STATE = State() - - -def unique_graphs(): - return torch._dynamo.utils.counters["stats"].get("unique_graphs", "?") - - -# --------------------------------------------------------------------------- -# Experiment -# --------------------------------------------------------------------------- - -def run(): - print("=" * 60) - print("Experiment: torch.compile + global dataclass mutation") - print("=" * 60) - - x = torch.tensor([2.0], device="cpu") - - # ----------------------------------------------------------------------- - # Part 1 – reading a field from the global dataclass - # ----------------------------------------------------------------------- - print("\n--- Part 1: reading a field from a global dataclass ---") - reset() - torch._dynamo.reset() - torch._dynamo.utils.counters.clear() - - GLOBAL_STATE.scale = 3.0 - y1 = compiled_read(x) - g1 = unique_graphs() - print(f" GLOBAL_STATE.scale = 3.0 → compiled_read(x) = {y1.item()} (expected {x.item() * 3.0})") - print(f" unique_graphs after 1st call: {g1}") - - GLOBAL_STATE.scale = 5.0 - y2 = compiled_read(x) - g2 = unique_graphs() - print(f" GLOBAL_STATE.scale = 5.0 → compiled_read(x) = {y2.item()} (expected {x.item() * 5.0})") - print(f" unique_graphs after 2nd call: {g2}") - - if g2 != g1: - print(f" NOTE: Dynamo recompiled (graphs: {g1} -> {g2})") - else: - print(f" NOTE: Dynamo did NOT recompile") - - if abs(y2.item() - x.item() * 5.0) < 1e-6: - print(" PASS: result reflects updated dataclass field") - else: - print(" FAIL: result does NOT reflect updated field (guard baked-in old value)") - - # ----------------------------------------------------------------------- - # Part 2 – writing a Python scalar to the dataclass - # ----------------------------------------------------------------------- - print("\n--- Part 2: writing a Python scalar to a dataclass field ---") - reset() - torch._dynamo.reset() - - print(f" GLOBAL_STATE.result before call: {GLOBAL_STATE.result}") - compiled_write_scalar(x, 42) - print(f" GLOBAL_STATE.result after call: {GLOBAL_STATE.result}") - - if GLOBAL_STATE.result == 42: - print(" PASS: dataclass field mutation (scalar) is visible after compiled call") - else: - print(" FAIL: dataclass field mutation (scalar) NOT visible") - - # ----------------------------------------------------------------------- - # Part 3 – writing a Tensor to the dataclass - # ----------------------------------------------------------------------- - print("\n--- Part 3: writing a Tensor to a dataclass field ---") - reset() - torch._dynamo.reset() - - t = torch.tensor(99.0) - print(f" GLOBAL_STATE.tensor_val before call: {GLOBAL_STATE.tensor_val}") - compiled_write_tensor(x, t) - print(f" GLOBAL_STATE.tensor_val after call: {GLOBAL_STATE.tensor_val}") - - if GLOBAL_STATE.tensor_val is not None: - print(" PASS: dataclass field mutation (Tensor) is visible after compiled call") - else: - print(" FAIL: dataclass field mutation (Tensor) NOT visible") - - print("\nDone.") - - -if __name__ == "__main__": - run() diff --git a/tests/pytorch/test_global_dict_mutation.py b/tests/pytorch/test_global_dict_mutation.py deleted file mode 100644 index 27c8e9d702..0000000000 --- a/tests/pytorch/test_global_dict_mutation.py +++ /dev/null @@ -1,140 +0,0 @@ -""" -Experiment: can torch.compile handle mutation of a global dictionary? - -We test two scenarios: - 1. A compiled function that reads from a global dict. - 2. A compiled function that writes (mutates) a global dict. - -In both cases we check whether recompilation or graph breaks occur, and -whether the results are numerically correct. -""" - -import torch - -# --------------------------------------------------------------------------- -# Global state -# --------------------------------------------------------------------------- -GLOBAL_DICT: dict = {} - -# --------------------------------------------------------------------------- -# Functions that access / mutate the global dict -# --------------------------------------------------------------------------- - - -def fn_read_global(x: torch.Tensor) -> torch.Tensor: - """Read a scale factor stored in a global dict and multiply x by it.""" - scale = GLOBAL_DICT.get("scale", 1.0) - return x * scale - - -def fn_write_global(x: torch.Tensor, key: str, value) -> torch.Tensor: - """Write a value into the global dict, then return x unchanged.""" - GLOBAL_DICT[key] = value - return x + 0 # trivial op so there is a tensor output - - -# --------------------------------------------------------------------------- -# Compiled versions -# --------------------------------------------------------------------------- -compiled_read = torch.compile(fn_read_global, fullgraph=False) -compiled_write = torch.compile(fn_write_global, fullgraph=False) - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -def reset(): - global GLOBAL_DICT - GLOBAL_DICT = {} - - -def count_recompilations(fn): - """Return the number of frames that have been compiled so far.""" - # torch._dynamo.explain() gives per-call stats; we use the simpler - # guard cache size as a proxy. - try: - return torch._dynamo.utils.counters["stats"]["unique_graphs"] - except Exception: - return None - - -# --------------------------------------------------------------------------- -# Experiment -# --------------------------------------------------------------------------- - -def run(): - print("=" * 60) - print("Experiment: torch.compile + global dict mutation") - print("=" * 60) - - x = torch.tensor([2.0], device="cpu") - - # ----------------------------------------------------------------------- - # Part 1 – reading from the global dict - # ----------------------------------------------------------------------- - print("\n--- Part 1: reading from a global dict ---") - reset() - torch._dynamo.reset() - torch._dynamo.utils.counters.clear() - - GLOBAL_DICT["scale"] = 3.0 - y1 = compiled_read(x) - graphs_after_first = torch._dynamo.utils.counters["stats"].get("unique_graphs", "?") - print(f" GLOBAL_DICT = {GLOBAL_DICT}") - print(f" compiled_read(x) = {y1.item()} (expected {x.item() * 3.0})") - print(f" unique_graphs after 1st call: {graphs_after_first}") - - # Change the dict value and call again – should Dynamo pick up the change? - GLOBAL_DICT["scale"] = 5.0 - y2 = compiled_read(x) - graphs_after_second = torch._dynamo.utils.counters["stats"].get("unique_graphs", "?") - print(f" After mutating scale to 5.0:") - print(f" compiled_read(x) = {y2.item()} (expected {x.item() * 5.0})") - print(f" unique_graphs after 2nd call: {graphs_after_second}") - - if graphs_after_second != graphs_after_first: - print(f" NOTE: Dynamo recompiled (graphs: {graphs_after_first} -> {graphs_after_second})") - else: - print(f" NOTE: Dynamo did NOT recompile (same graph count)") - - if abs(y2.item() - x.item() * 5.0) < 1e-6: - print(" PASS: result reflects updated dict value") - else: - print(" FAIL: result does NOT reflect updated dict value (guard baked-in old value)") - - # ----------------------------------------------------------------------- - # Part 2 – writing / mutating the global dict inside the compiled fn - # ----------------------------------------------------------------------- - print("\n--- Part 2: writing into a global dict ---") - reset() - torch._dynamo.reset() - - print(f" GLOBAL_DICT before call: {GLOBAL_DICT}") - compiled_write(x, "result", 42) - print(f" GLOBAL_DICT after call: {GLOBAL_DICT}") - - if GLOBAL_DICT.get("result") == 42: - print(" PASS: global dict mutation is visible after compiled call") - else: - print(" FAIL: global dict mutation is NOT visible (side-effect was dropped)") - - # ----------------------------------------------------------------------- - # Part 3 – mutation of value that is a Tensor - # ----------------------------------------------------------------------- - print("\n--- Part 3: storing a Tensor into the global dict ---") - reset() - torch._dynamo.reset() - - compiled_write(x, "tensor_val", torch.tensor(99.0)) - print(f" GLOBAL_DICT after call: {GLOBAL_DICT}") - if "tensor_val" in GLOBAL_DICT: - print(" PASS: tensor stored in global dict is visible") - else: - print(" FAIL: tensor NOT stored") - - print("\nDone.") - - -if __name__ == "__main__": - run() From b5d46fdf2e2bbde4609597d39fa9175345284173 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 13 Mar 2026 13:59:58 +0100 Subject: [PATCH 3/6] Clean up FP8 global state naming. Use compiler constant-result wrappers for support checks and rename the module-level FP8 singleton to `_FP8_GLOBAL_STATE` for clearer semantics. Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/quantization.py | 167 +++++++++++---------- 1 file changed, 91 insertions(+), 76 deletions(-) diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 7e6bdcd5b6..2c00460a83 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -89,6 +89,7 @@ def _compute_fp8_block_scaling_support() -> Tuple[bool, str]: ) +@torch.compiler.assume_constant_result def check_fp8_support() -> Tuple[bool, str]: """Return if fp8 support is available.""" global _FP8_SUPPORT @@ -97,6 +98,7 @@ def check_fp8_support() -> Tuple[bool, str]: return _FP8_SUPPORT +@torch.compiler.assume_constant_result def check_mxfp8_support() -> Tuple[bool, str]: """Return if MXFP8 support is available.""" global _MXFP8_SUPPORT @@ -105,6 +107,7 @@ def check_mxfp8_support() -> Tuple[bool, str]: return _MXFP8_SUPPORT +@torch.compiler.assume_constant_result def check_nvfp4_support() -> Tuple[bool, str]: """Return if NVFP4 support is available.""" global _NVFP4_SUPPORT @@ -113,6 +116,7 @@ def check_nvfp4_support() -> Tuple[bool, str]: return _NVFP4_SUPPORT +@torch.compiler.assume_constant_result def check_fp8_block_scaling_support() -> Tuple[bool, str]: """Return if fp8 block scaling support is available.""" global _FP8_BLOCK_SCALING_SUPPORT @@ -309,7 +313,7 @@ class FP8GlobalState: reason_for_no_nvfp4: str = "" -_fp8_state = FP8GlobalState() +_FP8_GLOBAL_STATE = FP8GlobalState() class FP8GlobalStateManager: @@ -320,22 +324,22 @@ class FP8GlobalStateManager: @classmethod def reset(cls) -> None: """Reset the global state""" - global _fp8_state - _fp8_state = FP8GlobalState() + global _FP8_GLOBAL_STATE + _FP8_GLOBAL_STATE = FP8GlobalState() @classmethod def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: """`skip_fp8_weight_update_tensor` inplace setter.""" - if _fp8_state.skip_fp8_weight_update_tensor is None: - _fp8_state.skip_fp8_weight_update_tensor = torch.empty( + if _FP8_GLOBAL_STATE.skip_fp8_weight_update_tensor is None: + _FP8_GLOBAL_STATE.skip_fp8_weight_update_tensor = torch.empty( 1, dtype=torch.float32, device="cuda" ) - _fp8_state.skip_fp8_weight_update_tensor.fill_(skip) + _FP8_GLOBAL_STATE.skip_fp8_weight_update_tensor.fill_(skip) @classmethod def get_skip_fp8_weight_update_tensor(cls) -> None: """`skip_fp8_weight_update_tensor` getter.""" - return _fp8_state.skip_fp8_weight_update_tensor + return _FP8_GLOBAL_STATE.skip_fp8_weight_update_tensor @classmethod def is_fp8_available(cls) -> Tuple[bool, str]: @@ -437,92 +441,96 @@ def add_fp8_tensors_to_global_buffer( key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"]) - if key not in _fp8_state.global_amax_buffer: - _fp8_state.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] - _fp8_state.global_amax_history_buffer[key] = [ + if key not in _FP8_GLOBAL_STATE.global_amax_buffer: + _FP8_GLOBAL_STATE.global_amax_buffer[key] = [ + fp8_meta[fp8_meta_tensor_key].amax_history[0] + ] + _FP8_GLOBAL_STATE.global_amax_history_buffer[key] = [ fp8_meta[fp8_meta_tensor_key].amax_history ] - _fp8_state.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] + _FP8_GLOBAL_STATE.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] else: - _fp8_state.global_amax_buffer[key].append( + _FP8_GLOBAL_STATE.global_amax_buffer[key].append( fp8_meta[fp8_meta_tensor_key].amax_history[0] ) - _fp8_state.global_amax_history_buffer[key].append( + _FP8_GLOBAL_STATE.global_amax_history_buffer[key].append( fp8_meta[fp8_meta_tensor_key].amax_history ) - _fp8_state.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) - fp8_meta[index_in_buffer].append(len(_fp8_state.global_amax_buffer[key]) - 1) + _FP8_GLOBAL_STATE.global_scale_buffer[key].append( + fp8_meta[fp8_meta_tensor_key].scale + ) + fp8_meta[index_in_buffer].append(len(_FP8_GLOBAL_STATE.global_amax_buffer[key]) - 1) fp8_meta[index_in_buffer].append(key) @classmethod def is_fp8_enabled(cls) -> bool: """Is FP8 enabled""" - return _fp8_state.fp8_enabled + return _FP8_GLOBAL_STATE.fp8_enabled @classmethod def is_fp8_calibration(cls) -> bool: """Is FP8 calibration""" - return _fp8_state.fp8_calibration + return _FP8_GLOBAL_STATE.fp8_calibration @classmethod def with_fp8_parameters(cls) -> bool: """Should the parameters be stored as FP8""" - return _fp8_state.fp8_parameters + return _FP8_GLOBAL_STATE.fp8_parameters @classmethod def with_high_precision_init_val(cls) -> bool: """Should the high precision initial values be stored with FP8 parameters""" - return _fp8_state.high_precision_init_val + return _FP8_GLOBAL_STATE.high_precision_init_val @classmethod def fp8_graph_capturing(cls) -> bool: """Is CUDA graph capture under way?""" if torch.compiler.is_compiling(): - assert not _fp8_state.fp8_graph_capturing + assert not _FP8_GLOBAL_STATE.fp8_graph_capturing return False - return _fp8_state.fp8_graph_capturing or torch.cuda.is_current_stream_capturing() + return _FP8_GLOBAL_STATE.fp8_graph_capturing or torch.cuda.is_current_stream_capturing() @classmethod def is_first_fp8_module(cls): """Returns `True` only the first time when called multiple times from within the same `autocast` context. """ - tmp = _fp8_state.is_first_fp8_module - _fp8_state.is_first_fp8_module = False + tmp = _FP8_GLOBAL_STATE.is_first_fp8_module + _FP8_GLOBAL_STATE.is_first_fp8_module = False return tmp @classmethod def peek_is_first_fp8_module(cls) -> bool: """Return the current first-module flag without consuming it.""" - return _fp8_state.is_first_fp8_module + return _FP8_GLOBAL_STATE.is_first_fp8_module @classmethod def set_is_first_fp8_module(cls, value: bool) -> None: """Set the first-module flag.""" - _fp8_state.is_first_fp8_module = value + _FP8_GLOBAL_STATE.is_first_fp8_module = value @classmethod def get_fp8_recipe(cls) -> Recipe: """Return the fp8 recipe""" - if _fp8_state.fp8_recipe is not None: - return _fp8_state.fp8_recipe + if _FP8_GLOBAL_STATE.fp8_recipe is not None: + return _FP8_GLOBAL_STATE.fp8_recipe return get_default_fp8_recipe() @classmethod def get_fp8_group(cls) -> Union[dist_group_type, None]: """Return the fp8 group for scale/amax comm""" - return _fp8_state.fp8_distributed_group + return _FP8_GLOBAL_STATE.fp8_distributed_group @classmethod def get_autocast_state(cls) -> Tuple[bool, bool, Recipe, dist_group_type, bool]: """FP8 autocast state getter""" return ( - _fp8_state.fp8_enabled, - _fp8_state.fp8_calibration, - _fp8_state.fp8_recipe, - _fp8_state.fp8_distributed_group, - _fp8_state.is_first_fp8_module, - _fp8_state.fp8_graph_capturing, + _FP8_GLOBAL_STATE.fp8_enabled, + _FP8_GLOBAL_STATE.fp8_calibration, + _FP8_GLOBAL_STATE.fp8_recipe, + _FP8_GLOBAL_STATE.fp8_distributed_group, + _FP8_GLOBAL_STATE.is_first_fp8_module, + _FP8_GLOBAL_STATE.fp8_graph_capturing, ) @classmethod @@ -531,28 +539,28 @@ def set_autocast_state( ) -> None: """FP8 autocast state setter""" ( - _fp8_state.fp8_enabled, - _fp8_state.fp8_calibration, - _fp8_state.fp8_recipe, - _fp8_state.fp8_distributed_group, - _fp8_state.is_first_fp8_module, - _fp8_state.fp8_graph_capturing, + _FP8_GLOBAL_STATE.fp8_enabled, + _FP8_GLOBAL_STATE.fp8_calibration, + _FP8_GLOBAL_STATE.fp8_recipe, + _FP8_GLOBAL_STATE.fp8_distributed_group, + _FP8_GLOBAL_STATE.is_first_fp8_module, + _FP8_GLOBAL_STATE.fp8_graph_capturing, ) = fp8_state @classmethod def get_global_amax_buffer(cls) -> Dict[str, list]: """Return the global amax buffer.""" - return _fp8_state.global_amax_buffer + return _FP8_GLOBAL_STATE.global_amax_buffer @classmethod def get_global_amax_history_buffer(cls) -> Dict[str, list]: """Return the global amax history buffer.""" - return _fp8_state.global_amax_history_buffer + return _FP8_GLOBAL_STATE.global_amax_history_buffer @classmethod def get_autocast_arguments(cls) -> Dict[Any, Tuple[Recipe, Optional[dist_group_type]]]: """Return autocast arguments keyed by autocast identity.""" - return _fp8_state.autocast_arguments + return _FP8_GLOBAL_STATE.autocast_arguments @staticmethod def reduce_tensor_across_group_op_max(tensor: torch.Tensor, group: dist_group_type) -> None: @@ -572,7 +580,7 @@ def reduce_and_update_fp8_tensors( ) -> None: """Delayed scaling only. Concatenate, reduce, and split amaxes in the global buffer.""" # global_amax_buffer should only be non-empty for fp8 delayed scaling - for buffer_key, amax_buffer in _fp8_state.global_amax_buffer.items(): + for buffer_key, amax_buffer in _FP8_GLOBAL_STATE.global_amax_buffer.items(): # Check for forward or backward reduction. fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key) if fwd_update != forward: @@ -581,7 +589,7 @@ def reduce_and_update_fp8_tensors( continue # Retrieve autocast specific args and concat amaxes. - recipe, group = _fp8_state.autocast_arguments[autocast_key] + recipe, group = _FP8_GLOBAL_STATE.autocast_arguments[autocast_key] contiguous_amax = torch.cat(amax_buffer) # Reduction. @@ -602,8 +610,8 @@ def reduce_and_update_fp8_tensors( if not unfused_update: tex.fused_amax_and_scale_update_after_reduction( contiguous_amax, - _fp8_state.global_amax_history_buffer[buffer_key], - _fp8_state.global_scale_buffer[buffer_key], + _FP8_GLOBAL_STATE.global_amax_history_buffer[buffer_key], + _FP8_GLOBAL_STATE.global_scale_buffer[buffer_key], recipe.amax_compute_algo, get_fp8_te_dtype(recipe, forward), recipe.margin, @@ -612,8 +620,8 @@ def reduce_and_update_fp8_tensors( split_and_copy(contiguous_amax, amax_buffer, [x.numel() for x in amax_buffer]) for amax_history, scale in zip( - _fp8_state.global_amax_history_buffer[buffer_key], - _fp8_state.global_scale_buffer[buffer_key], + _FP8_GLOBAL_STATE.global_amax_history_buffer[buffer_key], + _FP8_GLOBAL_STATE.global_scale_buffer[buffer_key], ): _amax_and_scale_update( amax_history, scale, get_fp8_max(recipe, forward), recipe @@ -646,17 +654,17 @@ def autocast_enter( fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) - _fp8_state.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group) + _FP8_GLOBAL_STATE.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group) - _fp8_state.fp8_enabled = enabled - _fp8_state.fp8_calibration = calibrating - _fp8_state.fp8_recipe = fp8_recipe - _fp8_state.fp8_distributed_group = fp8_group - _fp8_state.fp8_graph_capturing = _graph + _FP8_GLOBAL_STATE.fp8_enabled = enabled + _FP8_GLOBAL_STATE.fp8_calibration = calibrating + _FP8_GLOBAL_STATE.fp8_recipe = fp8_recipe + _FP8_GLOBAL_STATE.fp8_distributed_group = fp8_group + _FP8_GLOBAL_STATE.fp8_graph_capturing = _graph - if _fp8_state.autocast_depth == 0: - _fp8_state.is_first_fp8_module = True - _fp8_state.autocast_depth += 1 + if _FP8_GLOBAL_STATE.autocast_depth == 0: + _FP8_GLOBAL_STATE.is_first_fp8_module = True + _FP8_GLOBAL_STATE.autocast_depth += 1 if enabled: fp8_available, reason_for_no_fp8 = cls.is_fp8_available() @@ -674,11 +682,16 @@ def autocast_enter( @classmethod def autocast_exit(cls, enabled: bool, _graph: bool) -> None: """Set state and tracking variables for exit from FP8 region.""" - _fp8_state.autocast_depth -= 1 + _FP8_GLOBAL_STATE.autocast_depth -= 1 # Reduce only the non-FP8 weight modules here. # FP8 weight modules are reduced at the end of the optimizer # step after the weight amax is populated. - if enabled and _fp8_state.autocast_depth == 0 and not _graph and torch.is_grad_enabled(): + if ( + enabled + and _FP8_GLOBAL_STATE.autocast_depth == 0 + and not _graph + and torch.is_grad_enabled() + ): # delayed scaling only function, for other recipes (current scaling with any granularity), # this is noop for other recipes because cls.global_amax_buffer is empty list cls.reduce_and_update_fp8_tensors(forward=True) @@ -701,14 +714,16 @@ def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) - ] if buffer_position_key in fp8_meta: - _fp8_state.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy) + _FP8_GLOBAL_STATE.fp8_tensors_recompute_buffer[ + fp8_meta[buffer_position_key] + ].append(to_copy) else: - if len(_fp8_state.fp8_tensors_recompute_buffer) == 0: - _fp8_state.fp8_tensors_recompute_buffer = [deque()] + if len(_FP8_GLOBAL_STATE.fp8_tensors_recompute_buffer) == 0: + _FP8_GLOBAL_STATE.fp8_tensors_recompute_buffer = [deque()] else: - _fp8_state.fp8_tensors_recompute_buffer.append(deque()) - _fp8_state.fp8_tensors_recompute_buffer[-1].append(to_copy) - fp8_meta[buffer_position_key] = len(_fp8_state.fp8_tensors_recompute_buffer) - 1 + _FP8_GLOBAL_STATE.fp8_tensors_recompute_buffer.append(deque()) + _FP8_GLOBAL_STATE.fp8_tensors_recompute_buffer[-1].append(to_copy) + fp8_meta[buffer_position_key] = len(_FP8_GLOBAL_STATE.fp8_tensors_recompute_buffer) - 1 @classmethod def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: @@ -725,7 +740,7 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non # Retrieve stashed amaxes and scales from phase 1 pre forward. buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" - stashed_fp8_meta = _fp8_state.fp8_tensors_recompute_buffer[ + stashed_fp8_meta = _FP8_GLOBAL_STATE.fp8_tensors_recompute_buffer[ fp8_meta[buffer_position_key] ].popleft() @@ -824,18 +839,18 @@ def quantized_model_init( This functionality is *EXPERIMENTAL*. """ - _fp8_parameters = _fp8_state.fp8_parameters - _fp8_recipe = _fp8_state.fp8_recipe - _high_precision_init_val = _fp8_state.high_precision_init_val - _fp8_state.fp8_parameters = enabled - _fp8_state.fp8_recipe = get_default_fp8_recipe() if recipe is None else recipe - _fp8_state.high_precision_init_val = preserve_high_precision_init_val + _fp8_parameters = _FP8_GLOBAL_STATE.fp8_parameters + _fp8_recipe = _FP8_GLOBAL_STATE.fp8_recipe + _high_precision_init_val = _FP8_GLOBAL_STATE.high_precision_init_val + _FP8_GLOBAL_STATE.fp8_parameters = enabled + _FP8_GLOBAL_STATE.fp8_recipe = get_default_fp8_recipe() if recipe is None else recipe + _FP8_GLOBAL_STATE.high_precision_init_val = preserve_high_precision_init_val try: yield finally: - _fp8_state.fp8_parameters = _fp8_parameters - _fp8_state.fp8_recipe = _fp8_recipe - _fp8_state.high_precision_init_val = _high_precision_init_val + _FP8_GLOBAL_STATE.fp8_parameters = _fp8_parameters + _FP8_GLOBAL_STATE.fp8_recipe = _fp8_recipe + _FP8_GLOBAL_STATE.high_precision_init_val = _high_precision_init_val @contextmanager From ab3659adbde6cd3e7c9d591991341ec54398dadb Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 13 Mar 2026 14:24:11 +0100 Subject: [PATCH 4/6] Minimize FP8 global state diff. Restore the FP8 naming and remove extra state access helpers so the torch.compile changes stay focused on the instance-backed global state. Signed-off-by: Pawel Gadzinski --- .../dot_product_attention.py | 4 +- transformer_engine/pytorch/distributed.py | 4 +- transformer_engine/pytorch/graph.py | 10 +- transformer_engine/pytorch/module/base.py | 18 +- .../pytorch/module/layernorm_linear.py | 8 +- .../pytorch/module/layernorm_mlp.py | 42 +++- transformer_engine/pytorch/module/linear.py | 8 +- transformer_engine/pytorch/ops/op.py | 18 +- transformer_engine/pytorch/quantization.py | 203 +++++++----------- 9 files changed, 158 insertions(+), 157 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index c65c047510..337da758d4 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -704,7 +704,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: autocast_key = FP8GlobalStateManager.get_unique_autocast_key( fp8_recipe_dpa, fp8_group ) - FP8GlobalStateManager.get_autocast_arguments()[autocast_key] = ( + FP8GlobalStateManager.quantization_state.autocast_arguments[autocast_key] = ( fp8_recipe_dpa, fp8_group, ) @@ -736,7 +736,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: autocast_key = FP8GlobalStateManager.get_unique_autocast_key( fp8_recipe_dpa, fp8_group ) - FP8GlobalStateManager.get_autocast_arguments()[autocast_key] = ( + FP8GlobalStateManager.quantization_state.autocast_arguments[autocast_key] = ( fp8_recipe_dpa, fp8_group, ) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index c194b66fe1..f85ec3b653 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -255,10 +255,10 @@ def __enter__(self): if self.activation_recompute and not self.recompute_phase: activation_recompute_forward._is_first_fp8_module.append( - FP8GlobalStateManager.peek_is_first_fp8_module() + FP8GlobalStateManager.quantization_state.is_first_fp8_module ) if self.activation_recompute and self.recompute_phase: - FP8GlobalStateManager.set_is_first_fp8_module( + FP8GlobalStateManager.quantization_state.is_first_fp8_module = ( activation_recompute_forward._is_first_fp8_module.pop(0) ) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index f4b1fb23ae..63f7f49de8 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -293,7 +293,11 @@ def _make_graphed_callables( if cache_quantized_params: # Initialize flag that controls FP8 weight updates - FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False) + if FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor is None: + FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor = torch.empty( + 1, dtype=torch.float32, device="cuda" + ) + FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor.fill_(False) # Check callables for c in callables: @@ -788,7 +792,9 @@ def forward(ctx, skip_fp8_weight_update, *inputs): # Set flag for whether to update FP8 weight updates ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() if ctx.is_first_module and skip_fp8_weight_update is not None: - FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(skip_fp8_weight_update) + FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor.fill_( + skip_fp8_weight_update + ) # Copy values from new tensors into static tensors for i in range(len_user_args): diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index c7266ad23b..7c67bb5c93 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -693,16 +693,22 @@ def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> FP8GlobalStateManager.get_buffer_info() ] for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)): - if buffer_key in FP8GlobalStateManager.get_global_amax_buffer(): + if ( + buffer_key + in FP8GlobalStateManager.quantization_state.global_amax_buffer + ): assert ( - buffer_key in FP8GlobalStateManager.get_global_amax_history_buffer() + buffer_key + in FP8GlobalStateManager.quantization_state.global_amax_history_buffer ), "TE internal error during amax history change." - FP8GlobalStateManager.get_global_amax_buffer()[buffer_key][pos] = ( + FP8GlobalStateManager.quantization_state.global_amax_buffer[ + buffer_key + ][pos] = ( self.fp8_meta[meta_key].amax_history[0] ) - FP8GlobalStateManager.get_global_amax_history_buffer()[buffer_key][pos] = ( - self.fp8_meta[meta_key].amax_history - ) + FP8GlobalStateManager.quantization_state.global_amax_history_buffer[ + buffer_key + ][pos] = self.fp8_meta[meta_key].amax_history def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 20615aa5f6..2dd0337ea9 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -516,10 +516,10 @@ def forward( ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias): - _first_fp8_module = FP8GlobalStateManager.peek_is_first_fp8_module() + _first_fp8_module = FP8GlobalStateManager.quantization_state.is_first_fp8_module ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): - FP8GlobalStateManager.set_is_first_fp8_module(_first_fp8_module) + FP8GlobalStateManager.quantization_state.is_first_fp8_module = _first_fp8_module ctx.wgrad_store = wgrad_store ctx.debug = debug @@ -1490,7 +1490,9 @@ def forward( debug = self.is_debug_iter() if FP8GlobalStateManager.fp8_graph_capturing(): - skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + skip_fp8_weight_update = ( + FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor + ) else: skip_fp8_weight_update = None if skip_fp8_weight_update is not None: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 8129e9b2e2..156d7fd9c7 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -242,7 +242,12 @@ def _forward( if checkpoint: # save the state of autocast and quantizers for recomputation ctx.autocast_state = ( - FP8GlobalStateManager.get_autocast_state() + FP8GlobalStateManager.quantization_state.fp8_enabled, + FP8GlobalStateManager.quantization_state.fp8_calibration, + FP8GlobalStateManager.quantization_state.fp8_recipe, + FP8GlobalStateManager.quantization_state.fp8_distributed_group, + FP8GlobalStateManager.quantization_state.is_first_fp8_module, + FP8GlobalStateManager.quantization_state.fp8_graph_capturing, ) # to restore autocast state during recomputation if ( fp8 @@ -832,10 +837,12 @@ def _forward( if ctx.fp8 and requires_grad( inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias ): - _first_fp8_module = FP8GlobalStateManager.peek_is_first_fp8_module() + _first_fp8_module = FP8GlobalStateManager.quantization_state.is_first_fp8_module ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase() or is_recomputation: - FP8GlobalStateManager.set_is_first_fp8_module(_first_fp8_module) + FP8GlobalStateManager.quantization_state.is_first_fp8_module = ( + _first_fp8_module + ) ctx.wgrad_store = wgrad_store if is_recomputation: # return the recomputed tensors @@ -910,9 +917,21 @@ def _recompute(ctx): # backward is not in autocast context, so we set the state here # we also have to set the quantizer states to what they were before the forward pass (only relevant for DelayedScaling recipe) final_autocast_state = ( - FP8GlobalStateManager.get_autocast_state() + FP8GlobalStateManager.quantization_state.fp8_enabled, + FP8GlobalStateManager.quantization_state.fp8_calibration, + FP8GlobalStateManager.quantization_state.fp8_recipe, + FP8GlobalStateManager.quantization_state.fp8_distributed_group, + FP8GlobalStateManager.quantization_state.is_first_fp8_module, + FP8GlobalStateManager.quantization_state.fp8_graph_capturing, ) # get current autocast state - FP8GlobalStateManager.set_autocast_state(ctx.autocast_state) # set old autocast state + ( + FP8GlobalStateManager.quantization_state.fp8_enabled, + FP8GlobalStateManager.quantization_state.fp8_calibration, + FP8GlobalStateManager.quantization_state.fp8_recipe, + FP8GlobalStateManager.quantization_state.fp8_distributed_group, + FP8GlobalStateManager.quantization_state.is_first_fp8_module, + FP8GlobalStateManager.quantization_state.fp8_graph_capturing, + ) = ctx.autocast_state # set old autocast state if ( ctx.other_args["fp8"] and FP8GlobalStateManager.get_fp8_recipe().__class__.__name__ == "DelayedScaling" @@ -935,7 +954,14 @@ def _recompute(ctx): tuple(ctx.other_args.values()), ) - FP8GlobalStateManager.set_autocast_state(final_autocast_state) # restore autocast state + ( + FP8GlobalStateManager.quantization_state.fp8_enabled, + FP8GlobalStateManager.quantization_state.fp8_calibration, + FP8GlobalStateManager.quantization_state.fp8_recipe, + FP8GlobalStateManager.quantization_state.fp8_distributed_group, + FP8GlobalStateManager.quantization_state.is_first_fp8_module, + FP8GlobalStateManager.quantization_state.fp8_graph_capturing, + ) = final_autocast_state # restore autocast state if ( ctx.other_args["fp8"] and FP8GlobalStateManager.get_fp8_recipe().__class__.__name__ == "DelayedScaling" @@ -2046,7 +2072,9 @@ def forward( debug = self.is_debug_iter() if FP8GlobalStateManager.fp8_graph_capturing(): - skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + skip_fp8_weight_update = ( + FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor + ) else: skip_fp8_weight_update = None if skip_fp8_weight_update is not None: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 28710eee14..2e572cbc46 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -480,10 +480,10 @@ def forward( ctx.owns_input = saved_inputmat is not inp if ctx.fp8 and requires_grad(inp, weight, bias): - _first_fp8_module = FP8GlobalStateManager.peek_is_first_fp8_module() + _first_fp8_module = FP8GlobalStateManager.quantization_state.is_first_fp8_module ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): - FP8GlobalStateManager.set_is_first_fp8_module(_first_fp8_module) + FP8GlobalStateManager.quantization_state.is_first_fp8_module = _first_fp8_module ctx.wgrad_store = wgrad_store # ------------------------------------------------------ @@ -1377,7 +1377,9 @@ def forward( debug = self.is_debug_iter() if FP8GlobalStateManager.fp8_graph_capturing(): - skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + skip_fp8_weight_update = ( + FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor + ) else: skip_fp8_weight_update = None if skip_fp8_weight_update is not None: diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index f79eab35bf..c80953c069 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -322,16 +322,22 @@ def reset_recipe_state( pos, buffer_key = self._fp8_metas[mode][ FP8GlobalStateManager.get_buffer_info() ] - if buffer_key in FP8GlobalStateManager.get_global_amax_buffer(): + if ( + buffer_key + in FP8GlobalStateManager.quantization_state.global_amax_buffer + ): assert ( - buffer_key in FP8GlobalStateManager.get_global_amax_history_buffer() + buffer_key + in FP8GlobalStateManager.quantization_state.global_amax_history_buffer ), "TE internal error during amax history change." - FP8GlobalStateManager.get_global_amax_buffer()[buffer_key][pos] = ( + FP8GlobalStateManager.quantization_state.global_amax_buffer[ + buffer_key + ][pos] = ( recipe_state.amax_history[0] ) - FP8GlobalStateManager.get_global_amax_history_buffer()[buffer_key][ - pos - ] = recipe_state.amax_history + FP8GlobalStateManager.quantization_state.global_amax_history_buffer[ + buffer_key + ][pos] = recipe_state.amax_history # Add meta tensors to global buffer to participate in reduction for mode in ("forward", "backward"): diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 2c00460a83..43832717ae 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -312,34 +312,17 @@ class FP8GlobalState: nvfp4_available: Optional[bool] = None reason_for_no_nvfp4: str = "" - -_FP8_GLOBAL_STATE = FP8GlobalState() - - class FP8GlobalStateManager: """Class to keep track of and manipulate the global FP8 state at different stages of execution. """ + quantization_state = FP8GlobalState() + @classmethod def reset(cls) -> None: """Reset the global state""" - global _FP8_GLOBAL_STATE - _FP8_GLOBAL_STATE = FP8GlobalState() - - @classmethod - def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: - """`skip_fp8_weight_update_tensor` inplace setter.""" - if _FP8_GLOBAL_STATE.skip_fp8_weight_update_tensor is None: - _FP8_GLOBAL_STATE.skip_fp8_weight_update_tensor = torch.empty( - 1, dtype=torch.float32, device="cuda" - ) - _FP8_GLOBAL_STATE.skip_fp8_weight_update_tensor.fill_(skip) - - @classmethod - def get_skip_fp8_weight_update_tensor(cls) -> None: - """`skip_fp8_weight_update_tensor` getter.""" - return _FP8_GLOBAL_STATE.skip_fp8_weight_update_tensor + cls.quantization_state = FP8GlobalState() @classmethod def is_fp8_available(cls) -> Tuple[bool, str]: @@ -441,126 +424,79 @@ def add_fp8_tensors_to_global_buffer( key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"]) - if key not in _FP8_GLOBAL_STATE.global_amax_buffer: - _FP8_GLOBAL_STATE.global_amax_buffer[key] = [ + if key not in FP8GlobalStateManager.quantization_state.global_amax_buffer: + FP8GlobalStateManager.quantization_state.global_amax_buffer[key] = [ fp8_meta[fp8_meta_tensor_key].amax_history[0] ] - _FP8_GLOBAL_STATE.global_amax_history_buffer[key] = [ + FP8GlobalStateManager.quantization_state.global_amax_history_buffer[key] = [ fp8_meta[fp8_meta_tensor_key].amax_history ] - _FP8_GLOBAL_STATE.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] + FP8GlobalStateManager.quantization_state.global_scale_buffer[key] = [ + fp8_meta[fp8_meta_tensor_key].scale + ] else: - _FP8_GLOBAL_STATE.global_amax_buffer[key].append( + FP8GlobalStateManager.quantization_state.global_amax_buffer[key].append( fp8_meta[fp8_meta_tensor_key].amax_history[0] ) - _FP8_GLOBAL_STATE.global_amax_history_buffer[key].append( + FP8GlobalStateManager.quantization_state.global_amax_history_buffer[key].append( fp8_meta[fp8_meta_tensor_key].amax_history ) - _FP8_GLOBAL_STATE.global_scale_buffer[key].append( + FP8GlobalStateManager.quantization_state.global_scale_buffer[key].append( fp8_meta[fp8_meta_tensor_key].scale ) - fp8_meta[index_in_buffer].append(len(_FP8_GLOBAL_STATE.global_amax_buffer[key]) - 1) + fp8_meta[index_in_buffer].append( + len(FP8GlobalStateManager.quantization_state.global_amax_buffer[key]) - 1 + ) fp8_meta[index_in_buffer].append(key) @classmethod def is_fp8_enabled(cls) -> bool: """Is FP8 enabled""" - return _FP8_GLOBAL_STATE.fp8_enabled + return cls.quantization_state.fp8_enabled @classmethod def is_fp8_calibration(cls) -> bool: """Is FP8 calibration""" - return _FP8_GLOBAL_STATE.fp8_calibration + return cls.quantization_state.fp8_calibration @classmethod def with_fp8_parameters(cls) -> bool: """Should the parameters be stored as FP8""" - return _FP8_GLOBAL_STATE.fp8_parameters + return cls.quantization_state.fp8_parameters @classmethod def with_high_precision_init_val(cls) -> bool: """Should the high precision initial values be stored with FP8 parameters""" - return _FP8_GLOBAL_STATE.high_precision_init_val + return cls.quantization_state.high_precision_init_val @classmethod def fp8_graph_capturing(cls) -> bool: """Is CUDA graph capture under way?""" if torch.compiler.is_compiling(): - assert not _FP8_GLOBAL_STATE.fp8_graph_capturing + assert not cls.quantization_state.fp8_graph_capturing return False - return _FP8_GLOBAL_STATE.fp8_graph_capturing or torch.cuda.is_current_stream_capturing() + return cls.quantization_state.fp8_graph_capturing or torch.cuda.is_current_stream_capturing() @classmethod def is_first_fp8_module(cls): """Returns `True` only the first time when called multiple times from within the same `autocast` context. """ - tmp = _FP8_GLOBAL_STATE.is_first_fp8_module - _FP8_GLOBAL_STATE.is_first_fp8_module = False + tmp = cls.quantization_state.is_first_fp8_module + cls.quantization_state.is_first_fp8_module = False return tmp - @classmethod - def peek_is_first_fp8_module(cls) -> bool: - """Return the current first-module flag without consuming it.""" - return _FP8_GLOBAL_STATE.is_first_fp8_module - - @classmethod - def set_is_first_fp8_module(cls, value: bool) -> None: - """Set the first-module flag.""" - _FP8_GLOBAL_STATE.is_first_fp8_module = value - @classmethod def get_fp8_recipe(cls) -> Recipe: """Return the fp8 recipe""" - if _FP8_GLOBAL_STATE.fp8_recipe is not None: - return _FP8_GLOBAL_STATE.fp8_recipe + if cls.quantization_state.fp8_recipe is not None: + return cls.quantization_state.fp8_recipe return get_default_fp8_recipe() @classmethod def get_fp8_group(cls) -> Union[dist_group_type, None]: """Return the fp8 group for scale/amax comm""" - return _FP8_GLOBAL_STATE.fp8_distributed_group - - @classmethod - def get_autocast_state(cls) -> Tuple[bool, bool, Recipe, dist_group_type, bool]: - """FP8 autocast state getter""" - return ( - _FP8_GLOBAL_STATE.fp8_enabled, - _FP8_GLOBAL_STATE.fp8_calibration, - _FP8_GLOBAL_STATE.fp8_recipe, - _FP8_GLOBAL_STATE.fp8_distributed_group, - _FP8_GLOBAL_STATE.is_first_fp8_module, - _FP8_GLOBAL_STATE.fp8_graph_capturing, - ) - - @classmethod - def set_autocast_state( - cls, fp8_state: Tuple[bool, bool, DelayedScaling, dist_group_type, bool] - ) -> None: - """FP8 autocast state setter""" - ( - _FP8_GLOBAL_STATE.fp8_enabled, - _FP8_GLOBAL_STATE.fp8_calibration, - _FP8_GLOBAL_STATE.fp8_recipe, - _FP8_GLOBAL_STATE.fp8_distributed_group, - _FP8_GLOBAL_STATE.is_first_fp8_module, - _FP8_GLOBAL_STATE.fp8_graph_capturing, - ) = fp8_state - - @classmethod - def get_global_amax_buffer(cls) -> Dict[str, list]: - """Return the global amax buffer.""" - return _FP8_GLOBAL_STATE.global_amax_buffer - - @classmethod - def get_global_amax_history_buffer(cls) -> Dict[str, list]: - """Return the global amax history buffer.""" - return _FP8_GLOBAL_STATE.global_amax_history_buffer - - @classmethod - def get_autocast_arguments(cls) -> Dict[Any, Tuple[Recipe, Optional[dist_group_type]]]: - """Return autocast arguments keyed by autocast identity.""" - return _FP8_GLOBAL_STATE.autocast_arguments + return cls.quantization_state.fp8_distributed_group @staticmethod def reduce_tensor_across_group_op_max(tensor: torch.Tensor, group: dist_group_type) -> None: @@ -580,7 +516,7 @@ def reduce_and_update_fp8_tensors( ) -> None: """Delayed scaling only. Concatenate, reduce, and split amaxes in the global buffer.""" # global_amax_buffer should only be non-empty for fp8 delayed scaling - for buffer_key, amax_buffer in _FP8_GLOBAL_STATE.global_amax_buffer.items(): + for buffer_key, amax_buffer in FP8GlobalStateManager.quantization_state.global_amax_buffer.items(): # Check for forward or backward reduction. fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key) if fwd_update != forward: @@ -589,7 +525,7 @@ def reduce_and_update_fp8_tensors( continue # Retrieve autocast specific args and concat amaxes. - recipe, group = _FP8_GLOBAL_STATE.autocast_arguments[autocast_key] + recipe, group = FP8GlobalStateManager.quantization_state.autocast_arguments[autocast_key] contiguous_amax = torch.cat(amax_buffer) # Reduction. @@ -610,8 +546,8 @@ def reduce_and_update_fp8_tensors( if not unfused_update: tex.fused_amax_and_scale_update_after_reduction( contiguous_amax, - _FP8_GLOBAL_STATE.global_amax_history_buffer[buffer_key], - _FP8_GLOBAL_STATE.global_scale_buffer[buffer_key], + FP8GlobalStateManager.quantization_state.global_amax_history_buffer[buffer_key], + FP8GlobalStateManager.quantization_state.global_scale_buffer[buffer_key], recipe.amax_compute_algo, get_fp8_te_dtype(recipe, forward), recipe.margin, @@ -620,8 +556,8 @@ def reduce_and_update_fp8_tensors( split_and_copy(contiguous_amax, amax_buffer, [x.numel() for x in amax_buffer]) for amax_history, scale in zip( - _FP8_GLOBAL_STATE.global_amax_history_buffer[buffer_key], - _FP8_GLOBAL_STATE.global_scale_buffer[buffer_key], + FP8GlobalStateManager.quantization_state.global_amax_history_buffer[buffer_key], + FP8GlobalStateManager.quantization_state.global_scale_buffer[buffer_key], ): _amax_and_scale_update( amax_history, scale, get_fp8_max(recipe, forward), recipe @@ -654,17 +590,17 @@ def autocast_enter( fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) - _FP8_GLOBAL_STATE.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group) + FP8GlobalStateManager.quantization_state.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group) - _FP8_GLOBAL_STATE.fp8_enabled = enabled - _FP8_GLOBAL_STATE.fp8_calibration = calibrating - _FP8_GLOBAL_STATE.fp8_recipe = fp8_recipe - _FP8_GLOBAL_STATE.fp8_distributed_group = fp8_group - _FP8_GLOBAL_STATE.fp8_graph_capturing = _graph + FP8GlobalStateManager.quantization_state.fp8_enabled = enabled + FP8GlobalStateManager.quantization_state.fp8_calibration = calibrating + FP8GlobalStateManager.quantization_state.fp8_recipe = fp8_recipe + FP8GlobalStateManager.quantization_state.fp8_distributed_group = fp8_group + FP8GlobalStateManager.quantization_state.fp8_graph_capturing = _graph - if _FP8_GLOBAL_STATE.autocast_depth == 0: - _FP8_GLOBAL_STATE.is_first_fp8_module = True - _FP8_GLOBAL_STATE.autocast_depth += 1 + if FP8GlobalStateManager.quantization_state.autocast_depth == 0: + FP8GlobalStateManager.quantization_state.is_first_fp8_module = True + FP8GlobalStateManager.quantization_state.autocast_depth += 1 if enabled: fp8_available, reason_for_no_fp8 = cls.is_fp8_available() @@ -682,13 +618,13 @@ def autocast_enter( @classmethod def autocast_exit(cls, enabled: bool, _graph: bool) -> None: """Set state and tracking variables for exit from FP8 region.""" - _FP8_GLOBAL_STATE.autocast_depth -= 1 + FP8GlobalStateManager.quantization_state.autocast_depth -= 1 # Reduce only the non-FP8 weight modules here. # FP8 weight modules are reduced at the end of the optimizer # step after the weight amax is populated. if ( enabled - and _FP8_GLOBAL_STATE.autocast_depth == 0 + and FP8GlobalStateManager.quantization_state.autocast_depth == 0 and not _graph and torch.is_grad_enabled() ): @@ -714,16 +650,18 @@ def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) - ] if buffer_position_key in fp8_meta: - _FP8_GLOBAL_STATE.fp8_tensors_recompute_buffer[ + FP8GlobalStateManager.quantization_state.fp8_tensors_recompute_buffer[ fp8_meta[buffer_position_key] ].append(to_copy) else: - if len(_FP8_GLOBAL_STATE.fp8_tensors_recompute_buffer) == 0: - _FP8_GLOBAL_STATE.fp8_tensors_recompute_buffer = [deque()] + if len(FP8GlobalStateManager.quantization_state.fp8_tensors_recompute_buffer) == 0: + FP8GlobalStateManager.quantization_state.fp8_tensors_recompute_buffer = [deque()] else: - _FP8_GLOBAL_STATE.fp8_tensors_recompute_buffer.append(deque()) - _FP8_GLOBAL_STATE.fp8_tensors_recompute_buffer[-1].append(to_copy) - fp8_meta[buffer_position_key] = len(_FP8_GLOBAL_STATE.fp8_tensors_recompute_buffer) - 1 + FP8GlobalStateManager.quantization_state.fp8_tensors_recompute_buffer.append(deque()) + FP8GlobalStateManager.quantization_state.fp8_tensors_recompute_buffer[-1].append(to_copy) + fp8_meta[buffer_position_key] = ( + len(FP8GlobalStateManager.quantization_state.fp8_tensors_recompute_buffer) - 1 + ) @classmethod def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: @@ -740,7 +678,7 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non # Retrieve stashed amaxes and scales from phase 1 pre forward. buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" - stashed_fp8_meta = _FP8_GLOBAL_STATE.fp8_tensors_recompute_buffer[ + stashed_fp8_meta = FP8GlobalStateManager.quantization_state.fp8_tensors_recompute_buffer[ fp8_meta[buffer_position_key] ].popleft() @@ -758,7 +696,6 @@ def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"]) fp8_meta["scaling_fwd"].scale.copy_(fp8_meta["updated_scale_fwd"]) - @contextmanager def fp8_model_init( enabled: bool = True, @@ -839,18 +776,18 @@ def quantized_model_init( This functionality is *EXPERIMENTAL*. """ - _fp8_parameters = _FP8_GLOBAL_STATE.fp8_parameters - _fp8_recipe = _FP8_GLOBAL_STATE.fp8_recipe - _high_precision_init_val = _FP8_GLOBAL_STATE.high_precision_init_val - _FP8_GLOBAL_STATE.fp8_parameters = enabled - _FP8_GLOBAL_STATE.fp8_recipe = get_default_fp8_recipe() if recipe is None else recipe - _FP8_GLOBAL_STATE.high_precision_init_val = preserve_high_precision_init_val + _fp8_parameters = FP8GlobalStateManager.quantization_state.fp8_parameters + _fp8_recipe = FP8GlobalStateManager.quantization_state.fp8_recipe + _high_precision_init_val = FP8GlobalStateManager.quantization_state.high_precision_init_val + FP8GlobalStateManager.quantization_state.fp8_parameters = enabled + FP8GlobalStateManager.quantization_state.fp8_recipe = get_default_fp8_recipe() if recipe is None else recipe + FP8GlobalStateManager.quantization_state.high_precision_init_val = preserve_high_precision_init_val try: yield finally: - _FP8_GLOBAL_STATE.fp8_parameters = _fp8_parameters - _FP8_GLOBAL_STATE.fp8_recipe = _fp8_recipe - _FP8_GLOBAL_STATE.high_precision_init_val = _high_precision_init_val + FP8GlobalStateManager.quantization_state.fp8_parameters = _fp8_parameters + FP8GlobalStateManager.quantization_state.fp8_recipe = _fp8_recipe + FP8GlobalStateManager.quantization_state.high_precision_init_val = _high_precision_init_val @contextmanager @@ -937,7 +874,14 @@ def autocast( check_recipe_support(recipe) # Save current state so we always restore it on exit. - fp8_state = FP8GlobalStateManager.get_autocast_state() + fp8_state = ( + FP8GlobalStateManager.quantization_state.fp8_enabled, + FP8GlobalStateManager.quantization_state.fp8_calibration, + FP8GlobalStateManager.quantization_state.fp8_recipe, + FP8GlobalStateManager.quantization_state.fp8_distributed_group, + FP8GlobalStateManager.quantization_state.is_first_fp8_module, + FP8GlobalStateManager.quantization_state.fp8_graph_capturing, + ) FP8GlobalStateManager.autocast_enter( enabled=enabled, @@ -949,7 +893,14 @@ def autocast( try: yield finally: - FP8GlobalStateManager.set_autocast_state(fp8_state) + ( + FP8GlobalStateManager.quantization_state.fp8_enabled, + FP8GlobalStateManager.quantization_state.fp8_calibration, + FP8GlobalStateManager.quantization_state.fp8_recipe, + FP8GlobalStateManager.quantization_state.fp8_distributed_group, + FP8GlobalStateManager.quantization_state.is_first_fp8_module, + FP8GlobalStateManager.quantization_state.fp8_graph_capturing, + ) = fp8_state FP8GlobalStateManager.autocast_exit(enabled, _graph=_graph) From 75ab460662e41d22fcd6c9bd49790a8592c026a0 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 13 Mar 2026 15:04:00 +0100 Subject: [PATCH 5/6] Remove unused FP8 state fields. Drop stale availability fields from FP8GlobalState now that support checks use module-level cached results instead of manager state. Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/quantization.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 43832717ae..59fba31471 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -299,18 +299,10 @@ class FP8GlobalState: global_amax_history_buffer: Dict[str, list] = field(default_factory=dict) global_scale_buffer: Dict[str, list] = field(default_factory=dict) fp8_tensors_recompute_buffer: list = field(default_factory=list) - fp8_available: Optional[bool] = None - reason_for_no_fp8: str = "" autocast_arguments: Dict[Any, Tuple[Recipe, Optional[dist_group_type]]] = field( default_factory=dict ) skip_fp8_weight_update_tensor: Optional[torch.Tensor] = None - mxfp8_available: Optional[bool] = None - reason_for_no_mxfp8: str = "" - fp8_block_scaling_available: Optional[bool] = None - reason_for_no_fp8_block_scaling: str = "" - nvfp4_available: Optional[bool] = None - reason_for_no_nvfp4: str = "" class FP8GlobalStateManager: """Class to keep track of and manipulate the global From 16ea6e7fb9d200ccc8ca38af2a4659be72b45941 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Mar 2026 14:08:03 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_torch_compile.py | 11 +++-- transformer_engine/pytorch/module/base.py | 6 +-- .../pytorch/module/layernorm_mlp.py | 4 +- transformer_engine/pytorch/ops/op.py | 8 ++-- transformer_engine/pytorch/quantization.py | 41 ++++++++++++++----- 5 files changed, 43 insertions(+), 27 deletions(-) diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index 20140c0a0c..07d27f4010 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -98,10 +98,9 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor: # Opaque custom op (torch.library) # --------------------------------------------------------------------------- + @torch.library.custom_op("test_te::toy_linear", mutates_args=[]) -def _toy_linear_fwd_op( - inp: torch.Tensor, weight: torch.Tensor, module_idx: int -) -> torch.Tensor: +def _toy_linear_fwd_op(inp: torch.Tensor, weight: torch.Tensor, module_idx: int) -> torch.Tensor: """Forward GEMM wrapped as an opaque custom op.""" module = _toy_modules[module_idx] input_q, weight_q = module.get_forward_quantizers() @@ -288,12 +287,12 @@ def test_autocast_nested(): inp = torch.randn(8, 32, dtype=dtype, device=device, requires_grad=True) def fn(inp): - with te.autocast(recipe=recipe_current0): # outer + with te.autocast(recipe=recipe_current0): # outer out = m0(inp) - with te.autocast(recipe=recipe_current1): # nested inside outer + with te.autocast(recipe=recipe_current1): # nested inside outer out = m1(out) - with te.autocast(recipe=recipe_current2): # separate, after nested pair + with te.autocast(recipe=recipe_current2): # separate, after nested pair out = m2(out) return out diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 7e66272c94..542c96565c 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -739,9 +739,9 @@ def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> FP8GlobalStateManager.quantization_state.global_amax_history_buffer[ buffer_key ][pos] = self.fp8_meta[meta_key].amax_history - FP8GlobalStateManager.quantization_state.global_amax_buffer[ - buffer_key - ][pos] = self.fp8_meta[meta_key].amax_history[0] + FP8GlobalStateManager.quantization_state.global_amax_buffer[buffer_key][ + pos + ] = self.fp8_meta[meta_key].amax_history[0] def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index b94b74aa36..54f6b90682 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -839,9 +839,7 @@ def _forward( _first_fp8_module = FP8GlobalStateManager.quantization_state.is_first_fp8_module ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase() or is_recomputation: - FP8GlobalStateManager.quantization_state.is_first_fp8_module = ( - _first_fp8_module - ) + FP8GlobalStateManager.quantization_state.is_first_fp8_module = _first_fp8_module ctx.wgrad_store = wgrad_store if is_recomputation: # return the recomputed tensors diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index c80953c069..330666ba3a 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -330,11 +330,9 @@ def reset_recipe_state( buffer_key in FP8GlobalStateManager.quantization_state.global_amax_history_buffer ), "TE internal error during amax history change." - FP8GlobalStateManager.quantization_state.global_amax_buffer[ - buffer_key - ][pos] = ( - recipe_state.amax_history[0] - ) + FP8GlobalStateManager.quantization_state.global_amax_buffer[buffer_key][ + pos + ] = recipe_state.amax_history[0] FP8GlobalStateManager.quantization_state.global_amax_history_buffer[ buffer_key ][pos] = recipe_state.amax_history diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 99c9c28dff..b769577b22 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -305,6 +305,7 @@ class FP8GlobalState: ) skip_fp8_weight_update_tensor: Optional[torch.Tensor] = None + class FP8GlobalStateManager: """Class to keep track of and manipulate the global FP8 state at different stages of execution. @@ -468,7 +469,9 @@ def fp8_graph_capturing(cls) -> bool: if torch.compiler.is_compiling(): assert not cls.quantization_state.fp8_graph_capturing return False - return cls.quantization_state.fp8_graph_capturing or torch.cuda.is_current_stream_capturing() + return ( + cls.quantization_state.fp8_graph_capturing or torch.cuda.is_current_stream_capturing() + ) @classmethod def is_first_fp8_module(cls): @@ -509,7 +512,10 @@ def reduce_and_update_fp8_tensors( ) -> None: """Delayed scaling only. Concatenate, reduce, and split amaxes in the global buffer.""" # global_amax_buffer should only be non-empty for fp8 delayed scaling - for buffer_key, amax_buffer in FP8GlobalStateManager.quantization_state.global_amax_buffer.items(): + for ( + buffer_key, + amax_buffer, + ) in FP8GlobalStateManager.quantization_state.global_amax_buffer.items(): # Check for forward or backward reduction. fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key) if fwd_update != forward: @@ -518,7 +524,9 @@ def reduce_and_update_fp8_tensors( continue # Retrieve autocast specific args and concat amaxes. - recipe, group = FP8GlobalStateManager.quantization_state.autocast_arguments[autocast_key] + recipe, group = FP8GlobalStateManager.quantization_state.autocast_arguments[ + autocast_key + ] contiguous_amax = torch.cat(amax_buffer) # Reduction. @@ -567,8 +575,9 @@ def get_unique_autocast_key( Object identity is sufficient since autocast contexts never outlive a single training session. """ - return str((id(recipe) if recipe is not None else None, - id(group) if group is not None else None)) + return str( + (id(recipe) if recipe is not None else None, id(group) if group is not None else None) + ) @classmethod def autocast_enter( @@ -583,7 +592,10 @@ def autocast_enter( fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) - FP8GlobalStateManager.quantization_state.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group) + FP8GlobalStateManager.quantization_state.autocast_arguments[autocast_key] = ( + fp8_recipe, + fp8_group, + ) FP8GlobalStateManager.quantization_state.fp8_enabled = enabled FP8GlobalStateManager.quantization_state.fp8_calibration = calibrating @@ -650,8 +662,12 @@ def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) - if len(FP8GlobalStateManager.quantization_state.fp8_tensors_recompute_buffer) == 0: FP8GlobalStateManager.quantization_state.fp8_tensors_recompute_buffer = [deque()] else: - FP8GlobalStateManager.quantization_state.fp8_tensors_recompute_buffer.append(deque()) - FP8GlobalStateManager.quantization_state.fp8_tensors_recompute_buffer[-1].append(to_copy) + FP8GlobalStateManager.quantization_state.fp8_tensors_recompute_buffer.append( + deque() + ) + FP8GlobalStateManager.quantization_state.fp8_tensors_recompute_buffer[-1].append( + to_copy + ) fp8_meta[buffer_position_key] = ( len(FP8GlobalStateManager.quantization_state.fp8_tensors_recompute_buffer) - 1 ) @@ -689,6 +705,7 @@ def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"]) fp8_meta["scaling_fwd"].scale.copy_(fp8_meta["updated_scale_fwd"]) + @contextmanager def fp8_model_init( enabled: bool = True, @@ -773,8 +790,12 @@ def quantized_model_init( _fp8_recipe = FP8GlobalStateManager.quantization_state.fp8_recipe _high_precision_init_val = FP8GlobalStateManager.quantization_state.high_precision_init_val FP8GlobalStateManager.quantization_state.fp8_parameters = enabled - FP8GlobalStateManager.quantization_state.fp8_recipe = get_default_fp8_recipe() if recipe is None else recipe - FP8GlobalStateManager.quantization_state.high_precision_init_val = preserve_high_precision_init_val + FP8GlobalStateManager.quantization_state.fp8_recipe = ( + get_default_fp8_recipe() if recipe is None else recipe + ) + FP8GlobalStateManager.quantization_state.high_precision_init_val = ( + preserve_high_precision_init_val + ) try: yield finally: