[PyTorch] Add an API restore from function context to ensure tensors are detached#2772
Conversation
Greptile SummaryThis PR introduces
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Backward as backward(ctx)
participant RRFC as restore_from_func_ctx
participant RFS as restore_from_saved
participant Ctx as ctx
Backward->>RRFC: "restore_from_func_ctx(ctx)"
RRFC->>Ctx: "read ctx.tensor_objects"
RRFC->>Ctx: "read ctx.saved_tensors"
RRFC->>RFS: "restore_from_saved(tensor_objects, saved_tensors)"
RFS-->>RRFC: "restored tensor list"
RRFC->>Ctx: "ctx.tensor_objects = None"
RRFC-->>Backward: "restored tensor list"
Last reviewed commit: "Merge branch 'main' ..." |
| if not hasattr(ctx, "tensor_objects"): | ||
| raise AttributeError("ctx must have .tensor_objects to restore saved tensors") |
There was a problem hiding this comment.
hasattr guard misses the None case after first call
After the first successful call to restore_from_func_ctx, ctx.tensor_objects is set to None. If the function is ever called a second time on the same context (e.g., due to programming error or future refactoring), hasattr(ctx, "tensor_objects") will still return True (the attribute exists, it's just None), and the code will proceed to call restore_from_saved(None, ctx.saved_tensors). This causes an unhelpful TypeError: 'NoneType' object is not iterable deep inside restore_from_saved rather than a clear AttributeError here.
The guard should also check for None:
| if not hasattr(ctx, "tensor_objects"): | |
| raise AttributeError("ctx must have .tensor_objects to restore saved tensors") | |
| if not hasattr(ctx, "tensor_objects") or ctx.tensor_objects is None: | |
| raise AttributeError("ctx must have .tensor_objects to restore saved tensors") |
| list[Optional[torch.Tensor]], | ||
| ] | ||
| ): | ||
| """Recombine the tensor data and metadata during backward pass.""" |
There was a problem hiding this comment.
Docstring omits key behavioral difference from
restore_from_saved
The docstring for restore_from_func_ctx says only "Recombine the tensor data and metadata during backward pass," which is identical in meaning to restore_from_saved. The defining — and intentional — behavior of this new function (nullifying ctx.tensor_objects after the restore so memory can be freed) is not documented here, even though it's the primary motivation for adding the new API.
| """Recombine the tensor data and metadata during backward pass.""" | |
| """Recombine the tensor data and metadata during backward pass. | |
| Unlike `restore_from_saved`, this function deletes `ctx.tensor_objects` | |
| after restoring (by setting it to None), which allows the reference to | |
| the tensor objects to be released and the underlying memory to be freed | |
| at the end of the current iteration rather than when the function context | |
| is destroyed. | |
| """ |
5a03f56 to
4471dcb
Compare
…d from ctx Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
4471dcb to
3001bbd
Compare
|
/te-ci pytorch L0 L1 |
Description
In
quantizerd_tensor.pywe only hadrestore_from_saved, and its typical usage is:However, after calling this API it's common for people to forget to detach tensor_objects from ctx, causing it to be kept by the function context and the tensor (along with the allocated memory) can only be released until the next iteration when the context is destroyed (see #2750).
By adding this new API to restore from context directly (and discourage using
restore_from_savedif you are restoring from a function context), it will delete the reference to tensor objects after restoring and ensure the memory is freed.Fixes # (issue)
For example, with FP8 quantization and
GroupedLinear:Notice the time when the selected memory section is released.
Test script:
test.py
Type of change
Changes
Please list the changes introduced in this PR:
restore_from_func_ctxwhich will delete tensor objects for users after restoring saved tensorsChecklist: