-
Notifications
You must be signed in to change notification settings - Fork 614
[Common] MXFP8 kernel for grouped tensors #2586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Common] MXFP8 kernel for grouped tensors #2586
Conversation
e6bf02a to
fc2a53f
Compare
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
74a7917 to
88cf1b2
Compare
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
7c4fda7 to
39bb24f
Compare
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryAdded MXFP8 quantization kernel for grouped tensors with support for various shape configurations and activation fusion. The PR implements a new CUDA kernel that performs MXFP8 quantization on grouped tensors, supporting:
Previous review comments have identified several issues that should be addressed:
The implementation is architecturally sound with comprehensive test coverage, but the noted edge cases should be resolved to ensure robustness. Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant API as nvte_group_quantize
participant Dispatch as group_quantize_fwd_helper
participant Kernel as group_quantize_mxfp8_kernel
participant TMA as TMA Descriptors
User->>API: nvte_group_quantize(input, output, stream)
API->>Dispatch: group_quantize_fwd_helper<IS_ACT, Empty, nullptr>
Dispatch->>Dispatch: Convert GroupedTensor pointers
Dispatch->>Dispatch: Check scaling_mode == NVTE_MXFP8_1D_SCALING
alt Single tensor (SAME_BOTH_DIMS or VARYING_FIRST_DIM)
Dispatch->>Kernel: Launch with static tensor maps
Kernel->>Kernel: Process chunks directly
else Multiple tensors (VARYING_LAST_DIM or VARYING_BOTH_DIMS)
Dispatch->>TMA: update_tma_descriptors<<<num_tensors>>>
TMA->>TMA: modify_base_tensor_map for each tensor
TMA->>TMA: Store descriptors in g_tensor_maps_*
Dispatch->>Kernel: Launch with dynamic tensor maps
Kernel->>Kernel: Binary search to find current tensor_id
Kernel->>Kernel: fence_acquire_tensormap
end
Kernel->>Kernel: Load data via TMA (2D bulk copy)
Kernel->>Kernel: Compute MXFP8 scales (rowwise/colwise)
Kernel->>Kernel: Quantize to FP8 with E8M0 scaling
Kernel->>Kernel: Store via TMA bulk write
Kernel-->>User: Quantized grouped tensor
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
10 files reviewed, 6 comments
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
10 files reviewed, 1 comment
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 3 comments
| while (low < hi) { | ||
| const size_t mid = low + (hi - low) / 2; | ||
| const size_t mid_offset = static_cast<size_t>(offsets_ptr[mid]); | ||
|
|
||
| if (mid_offset <= current_offset) { | ||
| low = mid + 1; | ||
| } else { | ||
| hi = mid; | ||
| } | ||
| } | ||
| return low - 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Binary search can return low - 1 = -1 (underflow to SIZE_MAX) when current_offset < offsets_ptr[0]. Verify that current_offset is always >= offsets_ptr[0] or add bounds check to prevent underflow.
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No files reviewed, no comments
Description
This PR adds a new kernel that supports MXFP8 quantization of grouped tensors.
Fixes # (issue)
Type of change
Changes
Checklist: