Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions transformer_engine/pytorch/cpu_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def mark_activation_offload(*tensors):
def mark_not_offload(*tensors: torch.Tensor):
"""Marks tensors to prevent them from being offloaded."""
if NVTE_CPU_OFFLOAD_V1:
v1_code_path.mark_not_offload(*tensors)
return

tensors, tensor_obj = prepare_for_saving(*tensors)
Expand Down
17 changes: 16 additions & 1 deletion transformer_engine/pytorch/cpu_offload_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
import torch

from transformer_engine.debug.pytorch.debug_state import TEDebugState
from .quantized_tensor import QuantizedTensorStorage
from .quantized_tensor import (
QuantizedTensorStorage,
prepare_for_saving,
restore_from_saved,
)
from .tensor.float8_tensor import Float8Tensor

__all__ = ["get_cpu_offload_context"]
Expand Down Expand Up @@ -45,6 +49,17 @@ def is_cpu_offload_enabled() -> bool:
return CPUOffloadEnabled


def mark_not_offload(*tensors: torch.Tensor):
"""Marks tensors to prevent them from being offloaded."""
tensors, tensor_obj = prepare_for_saving(*tensors)

for tensor in tensors:
if tensor is not None:
setattr(tensor, "_TE_do_not_offload", True)

restore_from_saved(tensor_obj, tensors)
Comment on lines +52 to +60
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 _TE_do_not_offload is never checked in V1 path — function is a no-op

The V1 code path decides whether to offload a tensor using tensor_need_offloading_checker_activations, which checks only for hasattr(tensor, "activation_offloading") (see cpu_offload_v1.py line 738–739). The attribute _TE_do_not_offload set here is never consulted in the V1 path, so calling this function after mark_activation_offload will not prevent offloading.

To confirm: the only place _TE_do_not_offload is read in the whole codebase is cpu_offload.py line 463 (OffloadableLayerState._check_if_offload), which belongs to the non-V1 code path.

For mark_not_offload to actually work in V1, the implementation should mirror how mark_activation_offload operates — i.e., remove (or otherwise gate on) the activation_offloading attribute:

def mark_not_offload(*tensors: torch.Tensor):
    """Marks tensors to prevent them from being offloaded."""
    for tensor in tensors:
        if tensor is None:
            continue
        if type(tensor) in [torch.Tensor, torch.nn.Parameter]:
            if hasattr(tensor, "activation_offloading"):
                del tensor.activation_offloading
        else:
            data_tensors = tensor.get_data_tensors()
            for t in data_tensors:
                if t is not None and hasattr(t, "activation_offloading"):
                    del t.activation_offloading

Alternatively, tensor_need_offloading_checker_activations could be updated to also gate on _TE_do_not_offload, but that would be a more invasive change.



def is_current_layer_offloaded() -> bool:
"""Check if current layers is being offloaded."""
return CPUOffloadedLayer
Expand Down