Skip to content

Conversation

@Oleg-Goncharov
Copy link
Collaborator

@Oleg-Goncharov Oleg-Goncharov commented Jan 12, 2026

Description

This PR adds a new kernel that supports MXFP8 quantization of grouped tensors.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added MXFP8 cast kernel for grouped tensors
  • Added the test suite

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_mxfp8_grouped_kernel branch 4 times, most recently from e6bf02a to fc2a53f Compare January 15, 2026 16:15
@Oleg-Goncharov Oleg-Goncharov added enhancement New feature or request MoE labels Jan 15, 2026
@ptrendx ptrendx linked an issue Jan 16, 2026 that may be closed by this pull request
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_mxfp8_grouped_kernel branch from 74a7917 to 88cf1b2 Compare January 21, 2026 17:00
pre-commit-ci bot and others added 6 commits January 21, 2026 17:00
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_mxfp8_grouped_kernel branch from 7c4fda7 to 39bb24f Compare January 22, 2026 18:12
@Oleg-Goncharov Oleg-Goncharov marked this pull request as ready for review January 24, 2026 00:53
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 24, 2026

Greptile Overview

Greptile Summary

Added MXFP8 quantization kernel for grouped tensors with support for various shape configurations and activation fusion.

The PR implements a new CUDA kernel that performs MXFP8 quantization on grouped tensors, supporting:

  • Four shape representations: same dimensions, varying first/last dimensions, or varying both
  • Rowwise and/or columnwise scaling with 32-element MXFP8 blocks
  • Fused activation functions (GeLU, ReLU, SiLU, etc.) and bias gradients
  • TMA (Tensor Memory Accelerator) for efficient memory transfers on Blackwell architecture
  • Binary search for tensor ID lookup in variable-dimension cases

Previous review comments have identified several issues that should be addressed:

  • Uninitialized variables (scaling_type at line 743, shape_rep at line 752) if validation checks fail
  • Commented code at line 104-105 creating ambiguity about fallthrough behavior
  • Potential binary search underflow if current_offset < offsets_ptr[0]
  • Multiple typos ("gropued" instead of "grouped") in header documentation

The implementation is architecturally sound with comprehensive test coverage, but the noted edge cases should be resolved to ensure robustness.

Confidence Score: 3/5

  • This PR introduces a complex new kernel with several edge cases and previously identified issues that need verification
  • The implementation is sophisticated and includes comprehensive tests, but previous review threads identified legitimate concerns about uninitialized variables, commented code ambiguity, and potential binary search edge cases that should be verified before merge
  • Pay close attention to transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh for the uninitialized variable issues and binary search logic

Important Files Changed

Filename Overview
transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh New MXFP8 grouped tensor quantization kernel with binary search logic, potential uninitialized variables, and commented code
transformer_engine/common/include/transformer_engine/cast.h API declarations for grouped tensor MXFP8 quantization, typos already noted in previous reviews
tests/cpp/operator/test_cast_mxfp8_grouped.cu Comprehensive test suite for grouped MXFP8 quantization covering various shape configurations and activation functions

Sequence Diagram

sequenceDiagram
    participant User
    participant API as nvte_group_quantize
    participant Dispatch as group_quantize_fwd_helper
    participant Kernel as group_quantize_mxfp8_kernel
    participant TMA as TMA Descriptors

    User->>API: nvte_group_quantize(input, output, stream)
    API->>Dispatch: group_quantize_fwd_helper<IS_ACT, Empty, nullptr>
    Dispatch->>Dispatch: Convert GroupedTensor pointers
    Dispatch->>Dispatch: Check scaling_mode == NVTE_MXFP8_1D_SCALING
    
    alt Single tensor (SAME_BOTH_DIMS or VARYING_FIRST_DIM)
        Dispatch->>Kernel: Launch with static tensor maps
        Kernel->>Kernel: Process chunks directly
    else Multiple tensors (VARYING_LAST_DIM or VARYING_BOTH_DIMS)
        Dispatch->>TMA: update_tma_descriptors<<<num_tensors>>>
        TMA->>TMA: modify_base_tensor_map for each tensor
        TMA->>TMA: Store descriptors in g_tensor_maps_*
        Dispatch->>Kernel: Launch with dynamic tensor maps
        Kernel->>Kernel: Binary search to find current tensor_id
        Kernel->>Kernel: fence_acquire_tensormap
    end
    
    Kernel->>Kernel: Load data via TMA (2D bulk copy)
    Kernel->>Kernel: Compute MXFP8 scales (rowwise/colwise)
    Kernel->>Kernel: Quantize to FP8 with E8M0 scaling
    Kernel->>Kernel: Store via TMA bulk write
    Kernel-->>User: Quantized grouped tensor
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Oleg-Goncharov and others added 2 commits January 26, 2026 17:27
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +85 to +95
while (low < hi) {
const size_t mid = low + (hi - low) / 2;
const size_t mid_offset = static_cast<size_t>(offsets_ptr[mid]);

if (mid_offset <= current_offset) {
low = mid + 1;
} else {
hi = mid;
}
}
return low - 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Binary search can return low - 1 = -1 (underflow to SIZE_MAX) when current_offset < offsets_ptr[0]. Verify that current_offset is always >= offsets_ptr[0] or add bounds check to prevent underflow.

Oleg-Goncharov and others added 2 commits January 26, 2026 18:42
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request MoE

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Quantization support for GroupedTensor: MXFP8

2 participants