diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index d0b314a64f..b5c3ca083c 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -46,6 +46,7 @@ def mark_activation_offload(*tensors): def mark_not_offload(*tensors: torch.Tensor): """Marks tensors to prevent them from being offloaded.""" if NVTE_CPU_OFFLOAD_V1: + v1_code_path.mark_not_offload(*tensors) return tensors, tensor_obj = prepare_for_saving(*tensors) diff --git a/transformer_engine/pytorch/cpu_offload_v1.py b/transformer_engine/pytorch/cpu_offload_v1.py index f92c436941..9bc94bfbb6 100644 --- a/transformer_engine/pytorch/cpu_offload_v1.py +++ b/transformer_engine/pytorch/cpu_offload_v1.py @@ -10,7 +10,11 @@ import torch from transformer_engine.debug.pytorch.debug_state import TEDebugState -from .quantized_tensor import QuantizedTensorStorage +from .quantized_tensor import ( + QuantizedTensorStorage, + prepare_for_saving, + restore_from_saved, +) from .tensor.float8_tensor import Float8Tensor __all__ = ["get_cpu_offload_context"] @@ -45,6 +49,17 @@ def is_cpu_offload_enabled() -> bool: return CPUOffloadEnabled +def mark_not_offload(*tensors: torch.Tensor): + """Marks tensors to prevent them from being offloaded.""" + tensors, tensor_obj = prepare_for_saving(*tensors) + + for tensor in tensors: + if tensor is not None: + setattr(tensor, "_TE_do_not_offload", True) + + restore_from_saved(tensor_obj, tensors) + + def is_current_layer_offloaded() -> bool: """Check if current layers is being offloaded.""" return CPUOffloadedLayer