Skip to content

[Common] Persistent Grouped MXFP8 quantization kernel#2738

Open
Oleg-Goncharov wants to merge 56 commits intoNVIDIA:mainfrom
Oleg-Goncharov:pr_persistent_grouped_mxfp8_kernel
Open

[Common] Persistent Grouped MXFP8 quantization kernel#2738
Oleg-Goncharov wants to merge 56 commits intoNVIDIA:mainfrom
Oleg-Goncharov:pr_persistent_grouped_mxfp8_kernel

Conversation

@Oleg-Goncharov
Copy link
Collaborator

@Oleg-Goncharov Oleg-Goncharov commented Mar 5, 2026

Description

This PR adds a persistent grouped MXFP8 quantization kernel with static scheduling.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added persistent kernel
  • Added TunableConfig structure to tune performance

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@Oleg-Goncharov Oleg-Goncharov added enhancement New feature or request MoE labels Mar 5, 2026
@Oleg-Goncharov Oleg-Goncharov requested a review from ptrendx March 5, 2026 16:18
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 5, 2026

Greptile Summary

This PR introduces a persistent grid-stride MXFP8 group quantization kernel that replaces the previous one-CTA-per-block strategy with a static scheduler where each physical CTA iterates over multiple work items, reducing launch overhead for large grouped tensors. It also adds a TunableConfig struct for tuning tile dimensions and persistence settings, a USE_FAST_MATH path using mixed-precision FMA PTX intrinsics for BF16/FP16 cast-only workloads, and promotes the ShapeRepresentation enum to a shared header to eliminate duplicate definitions.

Key changes:

  • Persistent kernel: Physical CTAs stride over a virtual work grid using static_next_block_id + static_block_stride, with double-buffered TMA input pipelining (BUFFS_NUM = 2) and per-job barrier reuse.
  • TunableConfig: Exposes CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK, PERSISTENT, and STATIC_PERSISTENT_BLOCKS_PER_SM as compile-time constants for future tuning.
  • USE_FAST_MATH: Uses fma.rn.f32.bf16/fma.rn.f32.f16 + cvt.rn.satfinite.e4m3x2/e5m2x2 PTX to eliminate the intermediate FP32 conversion round-trip in the BF16/FP16 cast-only path.
  • API: New nvte_group_quantize_v2 accepting a NVTEQuantizationConfig; the legacy nvte_group_quantize continues to work with a default config.
  • ShapeRepresentation consolidation: Moved to utils.cuh, removing duplicate definitions from core/common.cuh and hadamard_transform/graph_safe_group_hadamard_transform.cu.
  • Empty tensor support: The persistent work loop explicitly handles rows == 0 || cols == 0 tensors by skipping TMA operations and advancing the scheduler.
  • A missing static_assert(STAGES % BUFFS_NUM == 0) leaves a latent correctness hazard: if TunableConfig is ever changed such that STAGES is not divisible by BUFFS_NUM, the barrier parity tracking silently breaks for every job after the first one.

Confidence Score: 3/5

  • Needs the barrier parity invariant enforced before merging; the rest of the logic appears sound.
  • The persistent scheduler logic, double-buffered TMA pipeline, and fast-math PTX paths are all well-structured, and the existing static_asserts cover several key invariants. However, the critical STAGES % BUFFS_NUM == 0 invariant that underpins correct barrier-parity reuse across persistent loop iterations is not asserted. With current defaults the condition holds, but it is one TunableConfig change away from producing silent stalls or data corruption. There are also minor concerns around TMA write-group drain ordering between persistent jobs and a redundant (but harmless) offset check for VARYING_FIRST_DIM in is_job_valid. The score also reflects that the kernel operates entirely at the PTX/TMA level on SM 10.0+ hardware, where subtle ordering bugs are hard to catch without hardware testing.
  • transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh — persistent barrier parity and TMA write-group drain ordering.

Important Files Changed

Filename Overview
transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh Core of the PR: adds persistent grid-stride scheduler, TunableConfig, JobDescriptor/BlockDescriptor structs, and USE_FAST_MATH path. Missing static_assert(STAGES % BUFFS_NUM == 0) is a latent correctness bug for barrier parity reuse in persistent mode; potential missing drain of TMA write groups between persistent jobs.
transformer_engine/common/cast/dispatch/quantize.cuh Threads quant_config pointer through to group_quantize; both forward (fwd_helper) and backward (bwd_helper) paths updated. Legacy nvte_group_quantize path correctly passes a default-constructed config (safe fallback).
transformer_engine/common/util/ptx.cuh Adds alignas(4 * sizeof(T)) to FPx4 struct and four new mul_cvt_4x overloads (BF16/FP16 → FP8E4M3/FP8E5M2 via mixed-precision FMA PTX) used by the fast-math path.
transformer_engine/common/common.h Adds TRANSFORMER_ENGINE_SCALING_TYPE_SWITCH and TRANSFORMER_ENGINE_GROUP_TENSOR_SHAPE_REPRESENTATION_SWITCH macros to convert runtime enum values into compile-time template parameters, enabling the new kernel template specializations.
tests/cpp/operator/test_cast_mxfp8_grouped.cu Adds use_fast_math parameter to performTest and compute_ref, tests for empty tensors in the middle of groups, covers the new nvte_group_quantize_v2 API, and adds a name factory to produce readable test names.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["nvte_group_quantize_v2(input, output, quant_config, stream)"] --> B["group_quantize_fwd_helper"]
    B --> C{"NVTE_MXFP8_1D_SCALING?"}
    C -->|yes| D["mxfp8::group_quantize(…, quant_config, stream)"]
    C -->|no| Z["NVTE_ERROR: unsupported scaling mode"]

    D --> E["Validate quant_config\n(fast_math → BF16/FP16 only, no fused ops)"]
    E --> F{"is_single_tensor?"}
    F -->|SAME_BOTH_DIMS / VARYING_FIRST_DIM| G["work_blocks_Y = DIVUP(rows, 128)\nwork_blocks_X = DIVUP(cols, 128)"]
    F -->|VARYING_LAST_DIM / VARYING_BOTH_DIMS| H["work_blocks_Y = 1\nwork_blocks_X = DIVUP(elts_total, 128×128)"]

    G --> I
    H --> I["PERSISTENT?\nlaunch_blocks = sm_count × 24\nelse launch_blocks = work_blocks"]
    I --> J["update_tma_descriptors\n<<<num_tensors, 1>>>"]
    J --> K["group_quantize_mxfp8_kernel\n<<<launch_blocks, 128, dshmem>>>"]

    K --> L["Per-CTA: persistent work loop"]
    L --> M["decode_job → JobDescriptor"]
    M --> N{"is_job_valid?"}
    N -->|no| O["break — all work done"]
    N -->|yes| P{"job_has_work\n(rows>0 && cols>0)?"}
    P -->|no empty tensor| Q["advance_to_next_job → continue"]
    P -->|yes| R["Prime pipeline\nPREFETCH_STAGES=1 TMA loads"]
    R --> S["STAGES=4 processing loop\n(32 rows each)"]
    S --> T["colwise stage:\nload → amax → e8m0 scale → quantize"]
    T --> U["rowwise stage:\nload → amax → e8m0 scale → quantize"]
    U --> V["TMA store outputs"]
    V --> W["advance_to_next_job"]
    W --> L
Loading

Comments Outside Diff (1)

  1. transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh, line 1147-1150 (link)

    Drain of outstanding TMA write groups missing before persistent-loop job transition

    At the end of each job's stage loop, store_output_stage issues a TMA store (cp.async.bulk.tensor.2d.global.shared) and calls ptx::cp_async_bulk_commit_group(). These write groups are never explicitly drained before the next job's priming phase begins.

    In the current double-buffered setup (BUFFS_NUM = 2), the fence_proxy_async_shared_cta + __syncthreads inside the per-stage processing makes the shared memory visible to the TMA engine, but this applies to the current stage's store. When the while-loop moves to the next job and immediately starts the priming phase (prefetch_input_stage), there is no ptx::cp_async_bulk_wait_group_read<0>() (or equivalent) to ensure all outstanding write groups from the previous job have committed before returning.

    While this is unlikely to cause data corruption in practice (TMA writes and reads target independent memory ranges), it leaves write groups indefinitely pending and may interact poorly with the SM's async engine bookkeeping — especially under high persistent-kernel occupancy where many CTAs are cycling through jobs concurrently.

    Consider adding ptx::cp_async_bulk_wait_group_read<0>() (drain all outstanding groups) after the stage loop and before advance_to_next_job, or confirm via the PTX spec that incomplete write groups do not require explicit draining at a job boundary.

Last reviewed commit: "[pre-commit.ci] auto..."

@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_persistent_grouped_mxfp8_kernel branch from 924ff91 to 325181b Compare March 6, 2026 10:39
}

const float *const thread_in_base = dbias_partial + dbias_in_offset_Y * cols + thread_id * nvec;
OType *const thread_out_base = dbias_output + tensor_id * cols + thread_id * nvec;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Output stride assumes uniform cols across all tensors

The output write offset is computed as:

OType *const thread_out_base = dbias_output + tensor_id * cols + thread_id * nvec;

where cols is last_logical_dim — a single value shared across all tensors in the group. This is correct for SAME_BOTH_DIMS and VARYING_FIRST_DIM (where all tensors share the same last dimension), but the kernel receives shape_rep as a parameter and does not enforce that restriction.

For VARYING_LAST_DIM or VARYING_BOTH_DIMS where per-tensor cols differ, the fixed tensor_id * cols stride would compute wrong output offsets. Currently, tests skip dbias validation for these cases, but the kernel would produce incorrect results if actually called with varying-last-dim tensors.

Consider adding a device-side assertion to enforce the precondition:

Suggested change
OType *const thread_out_base = dbias_output + tensor_id * cols + thread_id * nvec;
if (shape_rep != ShapeRepresentation::SAME_BOTH_DIMS && shape_rep != ShapeRepresentation::VARYING_FIRST_DIM) {
NVTE_DEVICE_ERROR("group_reduce_dbias_kernel requires uniform last dimensions across tensors");
}

Oleg-Goncharov and others added 15 commits March 10, 2026 11:58
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@ptrendx ptrendx force-pushed the pr_persistent_grouped_mxfp8_kernel branch from 5815335 to aa484a3 Compare March 10, 2026 19:07
const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;

const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong units in rowwise_scale_is_within_bounds guard

scales_offset_X_rowwise is a scale index (one entry per 32-element column group), while cols is the number of data columns. Comparing them directly means the guard almost never fires.

Concretely, with cols = 96 and SCALE_DIM_X = 32:

  • scales_offset_X_rowwise for the four threads of the last (and only) X-block is {0, 1, 2, 3}
  • Valid scale positions covering real data: {0, 1, 2} (covering columns 0–31, 32–63, 64–95)
  • The current check 3 < 96 evaluates to true, so thread 3 still writes a spurious scale for the nonexistent columns 96–127

The correct comparison multiplies the scale index back to column units:

Suggested change
const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols;
const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise * SCALE_DIM_X < cols;

This correctly excludes scale index 3 because 3 * 32 = 96, which is not < 96.

Comment on lines +174 to +191
__device__ __forceinline__ JobDescriptor decode_job(
const ShapeRepresentation shape_rep, const bool is_single_tensor, const size_t num_tensors,
const size_t first_logical_dim, const size_t last_logical_dim, const size_t work_blocks_X,
const int32_t ctaid_X, const int32_t ctaid_Y, const int64_t *const __restrict__ offsets_ptr,
const int64_t *const __restrict__ first_dims_ptr,
const int64_t *const __restrict__ last_dims_ptr) {
JobDescriptor job{};
job.block_id = ctaid_Y * work_blocks_X + ctaid_X;
job.block_global_offset = is_single_tensor
? (ctaid_Y * CHUNK_DIM_Y * last_logical_dim + ctaid_X * CHUNK_DIM_X)
: (job.block_id * ELTS_PER_CHUNK);
job.tensor_id = get_current_tensor_id(shape_rep, num_tensors, job.block_global_offset, ctaid_Y,
first_logical_dim, last_logical_dim, offsets_ptr);
job.rows =
get_tensor_rows_num(job.tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors);
job.cols = get_tensor_cols_num(job.tensor_id, shape_rep, last_logical_dim, last_dims_ptr);
return job;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a constructor of the JobDescriptor struct (you can make the constructor __device__ too).

Comment on lines +218 to +232
__device__ __forceinline__ BlockDescriptor
decode_block(const JobDescriptor &job, const bool is_single_tensor,
const int64_t *const __restrict__ offsets_ptr) {
BlockDescriptor block{};
block.tensor_base = is_single_tensor ? 0 : static_cast<size_t>(offsets_ptr[job.tensor_id]);
const size_t CHUNK_DIM_X_ = CHUNK_DIM_X;
const size_t blocks_X_num_in_current_tensor = DIVUP(job.cols, CHUNK_DIM_X_);
block.block_id_in_current_tensor =
is_single_tensor ? job.block_id : (job.block_id - block.tensor_base / ELTS_PER_CHUNK);
block.block_id_Y = block.block_id_in_current_tensor / blocks_X_num_in_current_tensor;
block.block_id_X = block.block_id_in_current_tensor % blocks_X_num_in_current_tensor;
block.block_offset_Y = block.block_id_Y * CHUNK_DIM_Y;
block.block_offset_X = block.block_id_X * CHUNK_DIM_X;
return block;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly this should be a constructor too.

const size_t global_offset_Y, const size_t buff_offset, const size_t shmem_buff_size,
uint64_t *barrier, const bool leading_thread) {
if (leading_thread) {
ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 questions - why is this done before the TMA call and why is it done only by the leading_thread? In the other parts of the code (e.g. in ptx::copy_2d_to_shared) we do transfer, then arrive_expect on the leading thread and just arrive on all the other threads.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ptx::mbarrier_arrive_expect_tx is also called by a single thread in ptx::copy_2d_to_shared. I initialized the barriers using a single thread, which is sufficient for it to work. But we can also keep the previous approach, where all threads in the block participate explicitly. And since the async copy and expect_tx are in the same phase, it’s also valid to issue expect_tx first.

Comment on lines +714 to +716
if (launch_block_id >= total_work_blocks) {
return;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this possible?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, for small input tensors where total_work_blocks is less than SMs * K, with K = STATIC_PERSISTENT_BLOCKS_PER_SM

last_logical_dim, work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr,
first_dims_ptr, last_dims_ptr);
allow_next_job_prefetch =
is_job_valid(prefetch_job, shape_rep, total_work_blocks, offsets_ptr);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are prevalidating the next job here, then why do we need earlier the check if the job we are about to do is going to be valid and draining it if it is not?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We prefetch the first stage of the next CTA at the end of processing the current CTA. This check is only to avoid copying data for null blocks. The main termination check, i.e., when to stop processing the current chunk and exit the loop is at line 770.

is_job_valid(current_job, shape_rep, total_work_blocks, offsets_ptr);
if (!current_job_is_valid) {
if (has_prefetched_current_job) {
// A stage-0 prefetch may already be in flight for this CTA. Drain it before exiting.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to drain it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We destroy the barriers after exiting the loop. But this invalidation can be done once the mbarrier objects are guaranteed to have completed their current phase (drained). Otherwise, the TMA engine may finish the copy and attempt to call complete on an already invalidated mbarrier

@ptrendx ptrendx marked this pull request as draft March 12, 2026 16:26
Oleg-Goncharov and others added 2 commits March 13, 2026 17:08
Oleg-Goncharov and others added 8 commits March 17, 2026 17:41
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov Oleg-Goncharov marked this pull request as ready for review March 18, 2026 14:07
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_persistent_grouped_mxfp8_kernel branch from 7c41a6a to 6874935 Compare March 18, 2026 14:24
Oleg-Goncharov and others added 3 commits March 18, 2026 15:25
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_persistent_grouped_mxfp8_kernel branch from 7bdc696 to 5068556 Compare March 18, 2026 14:34
pre-commit-ci bot and others added 6 commits March 18, 2026 14:34
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
{VARYING_LAST_DIM, 3, 256,896, 128,256,512},
{VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256},
// Empty tensor in the middle of the group must not terminate the persistent work loop.
{VARYING_BOTH_DIMS, 3, 1,(128*128)+(128*128), 128,0,128, 128,0,128},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add a test case for only varying first dim that contains as group of size zero

Oleg-Goncharov and others added 8 commits March 18, 2026 18:29
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request MoE

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants