-
Notifications
You must be signed in to change notification settings - Fork 614
CPU Optimizations for FP8 #2559
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci L1 pytorch |
Greptile OverviewGreptile SummaryThis PR implements CPU-side optimizations for FP8 quantized tensor operations. The main changes include:
The optimizations are well-targeted at reducing Python/C++ boundary overhead. However, several critical memory safety and correctness issues from previous review threads remain unaddressed and must be fixed before merging. Confidence Score: 1/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User Code
participant Linear as Linear Module
participant Quantizer as Float8Quantizer (C++)
participant PyAPI as Python C API
participant Tensor as QuantizedTensor
User->>Linear: forward(input, weight)
Note over Linear: Cache requires_grad checks
Linear->>Linear: inp_requires_grad = inp.requires_grad
Linear->>Linear: weight_requires_grad = weight.requires_grad
Linear->>Quantizer: create_tensor(shape, dtype)
Note over Quantizer: Cache GEMM support check
Quantizer->>Quantizer: is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported()
Note over Quantizer: Direct C API (bypassing pybind11)
Quantizer->>Quantizer: stride = stride_from_shape(shape)
Quantizer->>PyAPI: PyDict_New(), PyTuple_New(0)
Quantizer->>PyAPI: PyDict_SetItemString(kwargs, ...)
Quantizer->>PyAPI: PyObject_Call(Float8TensorClass, args, kwargs)
Quantizer->>PyAPI: Py_DECREF(kwargs), Py_DECREF(args)
PyAPI-->>Tensor: Float8Tensor instance
Note over Tensor: Cached properties initialized
Tensor->>Tensor: _dtype = dtype
Tensor->>Tensor: _requires_grad = requires_grad
Tensor-->>User: Quantized tensor
User->>Tensor: tensor.dtype
Note over Tensor: Return cached _dtype (no PyObject lookup)
Tensor-->>User: dtype
User->>Tensor: tensor.shape
Note over Tensor: Return _data.shape (cached property)
Tensor-->>User: shape
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (3)
-
transformer_engine/pytorch/csrc/util.cpp, line 18-20 (link)logic: Critical logical error:
||should be&&. This condition will always betruesince a value cannot simultaneously be both scaling modes, causing the function to always return nullopt for valid inputs. -
transformer_engine/pytorch/quantized_tensor.py, line 373-393 (link)style: commented-out code for
requires_gradcaching optimization - consider removing dead code entirely. Is this code planned to be implemented later or should it be removed?Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
-
transformer_engine/pytorch/module/linear.py, line 484 (link)logic: Logical error: this condition should use OR (||) not AND (&&). The original logic was checking if ANY tensor requires gradients for FP8 handling, but this now only activates when ALL three require gradients, including bias which may be None.
Should the FP8 condition check if any tensor requires gradients (OR logic) rather than all tensors (AND logic)?
10 files reviewed, 3 comments
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR implements CPU-side performance optimizations for FP8 operations by caching frequently accessed attributes and reducing redundant function calls. The optimizations target expensive PyObject attribute lookups on custom tensor types and repeated C++ function calls.
Key Changes:
- Caches
requires_grad,dtype,shape, andis_cudaattribute accesses to avoid expensive PyObject lookups on custom tensors - Reorders attribute checks in
get_tensor_device()to prioritize internal quantized tensor attributes - Makes
num_devicesstatic innvte_is_non_tn_fp8_gemm_supported()to cache device count - Stores GEMM support check results in local variables to avoid redundant function calls
Critical Issues Found:
- Variable redeclaration error in
cublaslt_gemm.cu(line 224) will prevent compilation - Logic bug in
linear.py(line 484) changes FP8 state management from OR logic to AND logic, breaking functionality when bias is None or doesn't require grad
Confidence Score: 0/5
- This PR cannot be merged due to compilation error and critical logic bug
- Two critical issues prevent merging: (1) C++ compilation will fail due to variable redeclaration at line 224 of cublaslt_gemm.cu, and (2) logic bug at line 484 of linear.py breaks FP8 state management by requiring all three tensors to have requires_grad=True instead of any one of them
- Pay close attention to
transformer_engine/common/gemm/cublaslt_gemm.cu(compilation error) andtransformer_engine/pytorch/module/linear.py(logic bug)
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/gemm/cublaslt_gemm.cu | 1/5 | Caches function call result to reduce overhead, but contains variable redeclaration error that will cause compilation failure |
| transformer_engine/common/transformer_engine.cpp | 5/5 | Makes num_devices static to avoid redundant calls to cuda::num_devices() - valid optimization |
| transformer_engine/pytorch/module/linear.py | 0/5 | Caches requires_grad checks for performance, but contains critical logic bug at line 484 that changes FP8 state management behavior |
Sequence Diagram
sequenceDiagram
participant User as User Code
participant Linear as Linear Module
participant Quantizer as Quantizer/QuantizedTensor
participant GEMM as GEMM Operations
participant CPP as C++ Extensions
Note over Linear,CPP: Performance Optimization Flow
User->>Linear: forward(input, weight, bias)
Note over Linear: Cache requires_grad checks
Linear->>Linear: inp_requires_grad = inp.requires_grad<br/>weight_requires_grad = weight.requires_grad<br/>bias_requires_grad = bias.requires_grad
Linear->>Quantizer: Check if quantized tensor
alt QuantizedTensor
Note over Quantizer: Use cached dtype property
Quantizer->>Quantizer: return self._dtype
Note over Quantizer: Use cached shape/is_cuda
Quantizer->>Quantizer: return self._data.shape
else Regular Tensor
Quantizer->>Linear: Standard attribute access
end
Linear->>CPP: get_tensor_device(tensor)
Note over CPP: Reordered attribute checks
CPP->>CPP: Check _rowwise_data first<br/>Check _columnwise_data<br/>Check device last
CPP-->>Linear: device_index
Linear->>GEMM: Configure GEMM parameters
Note over GEMM: Cache nvte_is_non_tn_fp8_gemm_supported
GEMM->>CPP: nvte_is_non_tn_fp8_gemm_supported()
Note over CPP: Static num_devices cached
CPP-->>GEMM: support_flag
GEMM->>GEMM: Store in local variable
GEMM->>GEMM: Execute optimized GEMM
GEMM-->>Linear: output
Note over Linear: FP8 State Management
alt FP8 enabled and requires_grad check
Linear->>Linear: Update FP8 tensors<br/>based on cached flags
end
Linear-->>User: output
Additional Comments (2)
|
|
/te-ci L1 pytorch |
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 8 comments
| ret.A_scale_inv = A.scale_inv.dptr; | ||
| ret.lda = is_A_transposed ? k : m; | ||
| if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) { | ||
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable is_nvte_non_tn_fp8_gemm_supported is redeclared in the same scope - it was already declared at line 224 for the B matrix configuration. This causes a compilation error or shadowing issue.
The variable should only be declared once at the beginning of the function scope and reused for both A and B matrix configurations.
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); | |
| // Move this declaration before line 125 (Configure A matrix) | |
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | ||
| PyTuple_New(0), kwargs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Memory leak: PyTuple_New(0) creates a new Python tuple object with a reference count of 1, but this reference is never decremented. The tuple is passed directly to PyObject_Call which doesn't steal the reference, so it must be explicitly released.
This same issue occurs at lines 186, 398, 417, 697, 718, 1039, 1058, 1346, and 1367.
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | |
| PyTuple_New(0), kwargs); | |
| PyObject* args = PyTuple_New(0); | |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | |
| args, kwargs); | |
| Py_DECREF(args); | |
| Py_DECREF(kwargs); |
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Invalid parameter: The Float8TensorStorage.__new__ method does not accept a device parameter. Looking at the Python class definition in float8_tensor_storage.py, the __new__ signature only expects: data, fp8_scale_inv, fp8_dtype, data_transpose, and quantizer.
This will cause a runtime error when trying to create Float8TensorStorage instances. This issue also occurs at lines 184, 396, 415, and in similar patterns for MXFP8 and other quantizers.
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); | |
| // Remove this line - device parameter is not accepted | |
| // PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); |
| PyTypeObject *NVFP4TensorPythonClass = nullptr; | ||
| PyTypeObject *NVFP4TensorStoragePythonClass = nullptr; | ||
| PyTypeObject *NVFP4QuantizerClass = nullptr; | ||
| bool is_extension_initialized = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Race condition: The is_extension_initialized flag is not thread-safe. If init_extension() is called concurrently from multiple threads, both threads could pass the check at line 105 before either sets the flag at line 110, causing the initialization code to run multiple times.
The original implementation using null-checks on the individual class pointers was safer because Python's module import system provides thread safety. Consider using std::once_flag and std::call_once for proper thread-safe initialization:
| bool is_extension_initialized = false; | |
| bool is_extension_initialized = false; | |
| std::once_flag extension_init_flag; |
And in init_extension():
void init_extension() {
std::call_once(extension_init_flag, []() {
init_float8_extension();
init_mxfp8_extension();
init_float8blockwise_extension();
init_nvfp4_extensions();
});
}| static int num_devices = transformer_engine::cuda::num_devices(); | ||
| static std::vector<int> cache(num_devices, -1); | ||
| static std::vector<std::once_flag> flags(num_devices); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential edge case: The num_devices is captured once as a static variable, so if the number of devices changes at runtime (e.g., with CUDA MPS or device enumeration changes), the cache and flags vectors could be undersized, leading to out-of-bounds access at line 968.
While dynamic device changes are rare, consider adding a bounds check:
| static int num_devices = transformer_engine::cuda::num_devices(); | |
| static std::vector<int> cache(num_devices, -1); | |
| static std::vector<std::once_flag> flags(num_devices); | |
| static int num_devices = transformer_engine::cuda::num_devices(); | |
| static std::vector<int> cache(num_devices, -1); | |
| static std::vector<std::once_flag> flags(num_devices); | |
| int device_id = transformer_engine::cuda::current_device(); | |
| NVTE_CHECK(device_id < num_devices, "Device ID exceeds cached device count"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 4 comments
| ret.B_scale_inv = B.scale_inv.dptr; | ||
| ret.ldb = is_B_transposed ? n : k; | ||
| if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) { | ||
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable is_nvte_non_tn_fp8_gemm_supported is redeclared in the same function scope (already declared at line 132). In C++, this is a compilation error. The variable should be declared once at the function scope level before both the A matrix and B matrix configuration blocks.
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); | |
| // Move declaration to line 104 (after GemmParam ret;) to avoid redeclaration |
| py::object scale_inv_py = py::cast(scale_inv); | ||
| // Initialize scale-inverse tensor | ||
| if (!scale_inv) { | ||
| scale_inv = at::reciprocal(scale); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical bug: scale_inv_py is created on line 146 by dereferencing scale_inv (which is std::optional<at::Tensor>), but scale_inv is only initialized on lines 148-150 if it's empty. When scale_inv is not provided to this function, line 146 will dereference an empty optional, causing undefined behavior or a crash.
The fix is to move line 146 after the initialization block:
| py::object scale_inv_py = py::cast(scale_inv); | |
| // Initialize scale-inverse tensor | |
| if (!scale_inv) { | |
| scale_inv = at::reciprocal(scale); | |
| } | |
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); | |
| // Initialize scale-inverse tensor | |
| if (!scale_inv) { | |
| scale_inv = at::reciprocal(scale); | |
| } | |
| py::object scale_inv_py = py::cast(*scale_inv); |
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | ||
| PyTuple_New(0), kwargs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Memory leak: PyTuple_New(0) creates a new tuple with reference count 1, but PyObject_Call does not steal the reference. The tuple must be stored and explicitly DECREF'd after the call. This pattern is repeated 12 times throughout this file (lines 166, 186, 398, 417, 697, 718, 1039, 1058, 1346, 1367).
| PyTuple_New(0), kwargs); | |
| PyObject* args = PyTuple_New(0); | |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | |
| args, kwargs); | |
| Py_DECREF(args); |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
13 files reviewed, 13 comments
| ret.B_scale_inv = B.scale_inv.dptr; | ||
| ret.ldb = is_B_transposed ? n : k; | ||
| if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) { | ||
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Variable is_nvte_non_tn_fp8_gemm_supported is already declared at line 132. Declaring it again here causes a compilation error. Remove this declaration and reuse the variable from line 132.
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); | |
| // Reuse is_nvte_non_tn_fp8_gemm_supported from line 132 |
| } | ||
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); | ||
|
|
||
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Use-after-free bug: scale_inv_py is created from scale_inv before scale_inv is initialized. When scale_inv parameter is empty (lines 148-150), this dereferences an empty std::optional, causing undefined behavior. Move this line after the scale_inv initialization.
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); | |
| // Move scale_inv_py creation after initialization (line 150) |
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) creates a new Python object with refcount 1, but it's never decremented. This leaks memory on every tensor creation. Add Py_DECREF for the tuple:
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | |
| PyObject* args = PyTuple_New(0); | |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | |
| Py_DECREF(args); |
| PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(MXFP8TensorStoragePythonClass), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).
| PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(NVFP4TensorStoragePythonClass), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
12 files reviewed, 1 comment
| return self._data.shape | ||
| if self._transpose is not None: | ||
| transpose_shape = self._transpose.shape | ||
| return tuple(transpose_shape[1:]) + (transpose_shape[0],) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Returns tuple instead of torch.Size when computed from transpose, while self._data.shape returns torch.Size. Consider wrapping with torch.Size() for type consistency:
| return tuple(transpose_shape[1:]) + (transpose_shape[0],) | |
| return torch.Size(tuple(transpose_shape[1:]) + (transpose_shape[0],)) |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations
Greptile's behavior is changing!From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section. This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR". |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
3 similar comments
|
/te-ci L1 pytorch |
|
/te-ci L1 pytorch |
|
/te-ci L1 pytorch |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations
|
/te-ci L1 pytorch |
|
/te-ci pytorch |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci pytorch |
|
/te-ci pytorch |
… at::empty Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci pytorch |
|
/te-ci L1 pytorch |
|
/te-ci pytorch |
|
/te-ci L1 pytorch |
| Return the high precision data type of the tensor | ||
| Attribute access of custom tensors goes through an | ||
| expensive Pyobject lookup. Since dtype for a tensor is never | ||
| change after creation, we cache it in a member variable and return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix this comment. dtype can get changed when custom_tensor.data is modified using set
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
15 files reviewed, 15 comments
| # Calculate stride from shape if not provided. When creating this object from | ||
| # C++ code, we provide the stride computed from shape in C++ to avoid the | ||
| # PyobjectVectorCall overhead of calling _stride_from_shape from C++ to Python. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cached _dtype and _requires_grad are only initialized here in __new__, but QuantizedTensor can be created through other paths (unpickling, tensor ops, FSDP state loading) that may bypass this initialization. The property getters include fallback logic with hasattr() checks, but this adds overhead to the optimization.
ensure all creation paths properly initialize these cached attributes, or document that the fallback path is intentional
| @dtype.setter | ||
| def dtype(self, value: torch.dtype) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the dtype setter allows changing dtype after creation with only a warning. For quantized tensors, changing _dtype without re-quantizing the underlying FP8/MXFP8/NVFP4 data creates a dangerous mismatch between the cached dtype and actual data representation.
consider raising an error instead of warning, or trigger re-quantization
| PyObject* kwargs = PyDict_New(); | ||
| PyObject* args = PyTuple_New(0); | ||
| PyDict_SetItemString(kwargs, "data", data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | ||
| if (result == nullptr) { | ||
| PyErr_Print(); | ||
| } | ||
| Py_DECREF(kwargs); | ||
| Py_DECREF(args); | ||
| NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); | ||
| out_py = py::reinterpret_steal<py::object>(result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
critical memory leaks and exception safety issues:
PyTuple_New(0)creates a tuple with refcount=1, butPy_DECREFonly happens after the call. IfPyObject_Callfails and returns NULL, the program terminates viaNVTE_CHECKwith the refs properly released, but this is a fatal path- No NULL checks after
PyDict_New()orPyTuple_New()- if allocation fails, subsequent operations will crash - Exception safety: if any
py::cast()throws between allocation and cleanup,kwargsandargsleak py::cast(this->dtype).ptr()creates a temporary object that's destroyed after the statement, butPyDict_SetItemStringshould increment the refcount. This pattern is fragile
use pybind11's py::dict and py::tuple for RAII-based memory management:
py::dict kwargs;
kwargs["data"] = data_py;
kwargs["fp8_scale_inv"] = scale_inv_py;
kwargs["fp8_dtype"] = py::cast(this->dtype);
kwargs["data_transpose"] = transpose_py;
kwargs["quantizer"] = this->quantizer;
py::object result = py::reinterpret_steal<py::object>(
PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
py::tuple().ptr(), kwargs.ptr()));
NVTE_CHECK(result.ptr() != nullptr, "Failed to create Float8TensorStorage instance");this same pattern repeats throughout the file at lines 214-234, 382-404, and others
| at::Tensor empty_cuda(std::vector<int64_t> sizes_vector, const at::TensorOptions& opts) { | ||
| // at::empty under the hood calls at::native::empty_cuda, however it introduces a device guard | ||
| // which is called again in at::native::empty_cuda anyway. So directly calling that to avoid the | ||
| // double device guard overhead. | ||
| auto sizes = at::ArrayRef<int64_t>(sizes_vector); | ||
| auto device_opt = opts.device_opt(); | ||
| c10::optional<at::ScalarType> dtype_opt = c10::nullopt; | ||
| if (opts.dtype_opt().has_value()) { | ||
| dtype_opt = opts.dtype_opt()->toScalarType(); | ||
| } | ||
| return at::native::empty_cuda(sizes, dtype_opt, opts.layout_opt(), device_opt, | ||
| opts.pinned_memory_opt(), opts.memory_format_opt()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the empty_cuda wrapper bypasses the standard at::empty dispatch by calling at::native::empty_cuda directly to avoid double device guard overhead. However, this skips the dispatch mechanism that handles tensor subclasses, hooks, and custom allocators.
verify this doesn't break compatibility with PyTorch features
| if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None: | ||
| return tensor._rowwise_data.device.index | ||
| if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None: | ||
| return tensor._columnwise_data.device.index | ||
| if hasattr(tensor, "_data") and tensor._data is not None: | ||
| return tensor._data.device.index | ||
| if hasattr(tensor, "_transpose") and tensor._transpose is not None: | ||
| return tensor._transpose.device.index | ||
| if hasattr(tensor, "device"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reordering optimizes for quantized tensors but creates performance regression for regular tensors. Before, regular torch.Tensor objects checked device immediately (1 attribute check). Now they check 4 non-existent quantized attributes via hasattr() first.
if get_tensor_device() is frequently called with regular tensors, consider adding an early type check:
if type(tensor).__name__ == 'Tensor':
return tensor.device.index if hasattr(tensor, 'device') else torch.cuda.current_device()|
|
||
| static std::unordered_map<std::string, void *> symbol_cache; | ||
| static std::mutex cache_mutex; | ||
| FuncT *func; | ||
|
|
||
| { | ||
| std::lock_guard<std::mutex> lock(cache_mutex); | ||
| auto it = symbol_cache.find(symbol); | ||
| if (it == symbol_cache.end()) { | ||
| void *ptr = get_symbol(symbol); | ||
| symbol_cache[symbol] = ptr; | ||
| func = reinterpret_cast<FuncT *>(ptr); | ||
| } else { | ||
| func = reinterpret_cast<FuncT *>(it->second); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
excellent optimization - caching CUDA driver symbols with proper mutex synchronization avoids repeated get_symbol() calls. The mutex correctly protects both reads and writes to symbol_cache from concurrent threads
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| int is_nvte_non_tn_fp8_gemm_supported = 0; // needed only for per tensor scaling | ||
| if (is_tensor_scaling(A.scaling_mode) || is_tensor_scaling(B.scaling_mode)) { | ||
| is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
caching nvte_is_non_tn_fp8_gemm_supported() result avoids redundant calls throughout the GEMM configuration for both A and B matrices. Clean optimization with proper scoping
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| inp_requires_grad = inp.requires_grad | ||
| weight_requires_grad = weight.requires_grad | ||
| bias_requires_grad = bias.requires_grad if bias is not None else False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
caching requires_grad checks early avoids redundant attribute lookups throughout the forward pass (used at lines 147, 260, 385, 454, 480, 482, 485). The OR logic at line 485 is semantically equivalent to the original requires_grad(inp, weight, bias) helper
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
|
||
| @property | ||
| def shape(self): | ||
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | ||
| if self._data is not None: | ||
| return self._data.shape | ||
| if self._transpose is not None: | ||
| transpose_shape = self._transpose.shape | ||
| return torch.Size(tuple(transpose_shape[1:]) + (transpose_shape[0],)) | ||
| raise RuntimeError("Both data and transpose are None") | ||
|
|
||
| @property | ||
| def is_cuda(self): | ||
| """Return whether the tensor is on a CUDA device.""" | ||
| if self._data is not None: | ||
| return self._data.is_cuda | ||
| if self._transpose is not None: | ||
| return self._transpose.is_cuda |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the cached shape and is_cuda properties include proper error handling with RuntimeError when both _data and _transpose are None. This prevents AttributeError in edge cases and provides clear error messages
| dtype = inp.dtype | ||
| # All checks after this have already been performed once, thus skip |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
caching inp.dtype early avoids a second attribute lookup at line 933 (originally line 936). Clean micro-optimization with no functional change
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Additional Comments (1)
|
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: