From 0a1d9ae8b012285b0774d17d7ed354f4d4ccfe90 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 16 Mar 2026 22:53:54 -0700 Subject: [PATCH 1/2] add mark_not_offload() interface for cpu_offload_v1 Signed-off-by: Hongbin Liu --- transformer_engine/pytorch/cpu_offload.py | 1 + transformer_engine/pytorch/cpu_offload_v1.py | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index d0b314a64f..b5c3ca083c 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -46,6 +46,7 @@ def mark_activation_offload(*tensors): def mark_not_offload(*tensors: torch.Tensor): """Marks tensors to prevent them from being offloaded.""" if NVTE_CPU_OFFLOAD_V1: + v1_code_path.mark_not_offload(*tensors) return tensors, tensor_obj = prepare_for_saving(*tensors) diff --git a/transformer_engine/pytorch/cpu_offload_v1.py b/transformer_engine/pytorch/cpu_offload_v1.py index f92c436941..a92b531ac0 100644 --- a/transformer_engine/pytorch/cpu_offload_v1.py +++ b/transformer_engine/pytorch/cpu_offload_v1.py @@ -10,7 +10,11 @@ import torch from transformer_engine.debug.pytorch.debug_state import TEDebugState -from .quantized_tensor import QuantizedTensorStorage +from .quantized_tensor import ( + QuantizedTensorStorage, + prepare_for_saving, + restore_from_saved, +) from .tensor.float8_tensor import Float8Tensor __all__ = ["get_cpu_offload_context"] @@ -44,6 +48,15 @@ def is_cpu_offload_enabled() -> bool: """Check if CPU offloading is currently enabled.""" return CPUOffloadEnabled +def mark_not_offload(*tensors: torch.Tensor): + """Marks tensors to prevent them from being offloaded.""" + tensors, tensor_obj = prepare_for_saving(*tensors) + + for tensor in tensors: + if tensor is not None: + setattr(tensor, "_TE_do_not_offload", True) + + restore_from_saved(tensor_obj, tensors) def is_current_layer_offloaded() -> bool: """Check if current layers is being offloaded.""" From 730d7a1430d3e50ff13139c4241bab7cc2344e41 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Mar 2026 05:59:33 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/cpu_offload_v1.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_engine/pytorch/cpu_offload_v1.py b/transformer_engine/pytorch/cpu_offload_v1.py index a92b531ac0..9bc94bfbb6 100644 --- a/transformer_engine/pytorch/cpu_offload_v1.py +++ b/transformer_engine/pytorch/cpu_offload_v1.py @@ -48,6 +48,7 @@ def is_cpu_offload_enabled() -> bool: """Check if CPU offloading is currently enabled.""" return CPUOffloadEnabled + def mark_not_offload(*tensors: torch.Tensor): """Marks tensors to prevent them from being offloaded.""" tensors, tensor_obj = prepare_for_saving(*tensors) @@ -58,6 +59,7 @@ def mark_not_offload(*tensors: torch.Tensor): restore_from_saved(tensor_obj, tensors) + def is_current_layer_offloaded() -> bool: """Check if current layers is being offloaded.""" return CPUOffloadedLayer