Skip to content

Fused Adam Support for MXFP8 + FSDP2 integration#2780

Draft
vthumbe1503 wants to merge 8 commits intoNVIDIA:mainfrom
vthumbe1503:fused_adam_for_mxfp8
Draft

Fused Adam Support for MXFP8 + FSDP2 integration#2780
vthumbe1503 wants to merge 8 commits intoNVIDIA:mainfrom
vthumbe1503:fused_adam_for_mxfp8

Conversation

@vthumbe1503
Copy link
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

vthumbe1503 and others added 2 commits March 18, 2026 16:29
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

@vthumbe1503 vthumbe1503 marked this pull request as ready for review March 18, 2026 16:36
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 18, 2026

Greptile Summary

This PR adds a fused Adam optimizer kernel for MXFP8 (MX block-scaling FP8) model weights, integrating it into the existing FusedAdam optimizer path and enabling FSDP2-compatible distributed training. The implementation follows the established multi-tensor-apply pattern but introduces a new tile-based dispatch (multi_tensor_apply_mxfp8) because MXFP8 scaling is inherently 2-D (32×32 tiles), unlike the 1-D chunk-based approach used for all existing Adam variants.

Key changes:

  • New CUDA kernel (adam_mxfp8_fused_kernel): fuses Adam update, per-tile absmax accumulation, scale-inverse computation (e8m0), and rowwise + colwise FP8 quantisation into a single kernel, avoiding a separate cast pass.
  • New tiling scheduler (multi_tensor_apply_mxfp8): batches up to 320 blocks / 24 tensors per launch and handles mid-tensor flush when the block budget is exhausted.
  • Python optimizer integration: adds the MXFP8 branch in FusedAdam.step() with guards for capturable=True and master_weights=False; refactors shared helpers (compute_bias_correction, check_tensor_list_sizes, requires_64bit_indexing) for reuse across the FP8 and MXFP8 paths.
  • C++ test: exercises E4M3 and E5M2 paths with 25 tensors (intentionally exceeding MXFP8_MAX_TENSORS=24) to validate the chunking logic.

Issues found:

  • out_dtype shared between FP8 and MXFP8 kernel calls (fused_adam.py lines 836/848): if a parameter group contains both Float8Tensor and MXFP8Tensor parameters, a single out_dtype variable is overwritten by whichever type is processed last in the loop, potentially passing the wrong element dtype to one of the two kernels.
  • rows/cols stored as int in MXFP8TensorListMetadata (multi_tensor_apply.cuh): the 32-bit fields can silently overflow for tensors whose individual dimensions exceed INT_MAX, despite the outer launch code correctly selecting a 64-bit index path for such tensors.

Confidence Score: 3/5

  • Safe to merge for the common single-dtype use-case, but has two correctness gaps worth fixing before wider adoption.
  • The core CUDA kernel logic, tile-scheduling, and Python guard logic are all sound and well-tested. However, two issues reduce confidence: (1) out_dtype can silently carry the wrong dtype to either kernel when a parameter group mixes Float8Tensor and MXFP8Tensor — a silent correctness bug that produces wrong quantisation with no error; (2) int rows/cols in MXFP8TensorListMetadata can overflow for very large tensors despite the 64-bit indexing path being selected upstream. The PR description is also left as a template with unchecked checklist items, and there are no added Python-level tests for the MXFP8 optimizer path.
  • transformer_engine/pytorch/optimizers/fused_adam.py (out_dtype sharing) and transformer_engine/common/multi_tensor/multi_tensor_apply.cuh (int overflow in metadata struct).

Important Files Changed

Filename Overview
transformer_engine/common/multi_tensor/adam.cu Adds the MXFP8 fused Adam kernel and launch wrapper; refactors shared helpers (bias-correction, tensor-list size check, 64-bit index detection). The new adam_mxfp8_fused_kernel is logically sound, but compute_bias_correction uses std::pow on an int step which can have precision edge-cases.
transformer_engine/common/multi_tensor/multi_tensor_apply.cuh Adds MXFP8TensorListMetadata and multi_tensor_apply_mxfp8; the tile-based chunking logic mirrors the existing chunk-based path. The rows/cols fields in the metadata struct are typed as int, which can overflow for tensors whose individual dimensions exceed INT_MAX, despite the outer code selecting a 64-bit indexing path.
transformer_engine/pytorch/optimizers/fused_adam.py Adds MXFP8 parameter handling with proper guards (capturable=False, master_weights=True, rowwise+colwise data present). However, the single shared out_dtype variable is written by both the Float8Tensor and MXFP8Tensor branches; in a mixed-type parameter group the wrong dtype could be forwarded to either kernel.
transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp Adds the multi_tensor_adam_mxfp8_cuda PyTorch extension wrapper; straightforward ATen-to-TE tensor conversion with a sensible early validation of the 8-list requirement.
transformer_engine/pytorch/csrc/extensions/pybind.cpp Registers multi_tensor_adam_mxfp8 pybind11 binding; includes py::call_guard<py::gil_scoped_release>() consistent with all other Adam bindings.
transformer_engine/common/include/transformer_engine/multi_tensor.h Adds the public C API declaration nvte_multi_tensor_adam_mxfp8_cuda; documentation accurately describes the 8-list convention and parameter semantics.
tests/cpp/operator/test_multi_tensor_adam_mxfp8.cu New C++ test covering E4M3 and E5M2 with 25 tensors (> MXFP8_MAX_TENSORS=24) to exercise the chunking path; validates updated FP32 params, moments, MXFP8 quantized data, and scale-inverses against a reference run.
tests/pytorch/distributed/run_fsdp2_fused_adam.py Adds per-step loss logging via a new dist_print helper. The helper depends on a module-level LOCAL_RANK variable that is only initialised in one test function, making it silently a no-op if called from other contexts.
tests/cpp/test_common.h Adds rowwise_scale_inv_dptr and columnwise_scale_inv_dptr accessors to the test Tensor helper class; straightforward and correct.

Sequence Diagram

sequenceDiagram
    participant PY as FusedAdam.step() (Python)
    participant EXT as adam.cpp (PyTorch ext)
    participant CU as adam.cu (CUDA)
    participant APPLY as multi_tensor_apply_mxfp8
    participant KERNEL as adam_mxfp8_fused_kernel

    PY->>PY: per-param loop — classify params<br/>(Float8/MXFP8/F16/F32)
    PY->>PY: accumulate into p_mxfp8_rowwise,<br/>p_mxfp8_colwise, moments, master_param

    PY->>EXT: multi_tensor_adam_mxfp8(chunk_size,<br/>noop_flag, 8 tensor lists, …, fp8_dtype)
    EXT->>EXT: makeTransformerEngineTensorList()<br/>validate num_lists == 8
    EXT->>CU: nvte_multi_tensor_adam_mxfp8_cuda(…)
    CU->>CU: compute_bias_correction()<br/>check_tensor_list_sizes()<br/>dtype validation
    CU->>APPLY: multi_tensor_apply_mxfp8<kernel>(…)

    loop For each tensor (batched ≤ MXFP8_MAX_TENSORS=24 tensors,<br/>≤ MXFP8_MAX_BLOCKS=320 blocks per launch)
        APPLY->>APPLY: build MXFP8TensorListMetadata<br/>(block_to_tensor, block_to_tile, rows, cols)
        APPLY->>KERNEL: Kernel<<<blocks, 256>>>(tl, β1, β2, ε, lr, …)
        KERNEL->>KERNEL: Stage 4: Adam update → p/m/v (FP32)
        KERNEL->>KERNEL: Stage 5: atomicMaxFloat → row/col amax (shared mem)
        KERNEL->>KERNEL: Stage 6: write rowwise & colwise scale-inv (e8m0)
        KERNEL->>KERNEL: Stage 7: quantise p → MXFP8 rowwise + colwise data
    end

    CU-->>PY: return (master params, moments, MXFP8 data, scales updated in-place)
Loading

Last reviewed commit: "address review comme..."

vthumbe1503 and others added 5 commits March 18, 2026 16:42
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 marked this pull request as draft March 19, 2026 18:24
@vthumbe1503
Copy link
Collaborator Author

Need more perf tuning for mxfp8, Will make the PR active after the desired perf is achieved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant