Fused Adam Support for MXFP8 + FSDP2 integration#2780
Draft
vthumbe1503 wants to merge 8 commits intoNVIDIA:mainfrom
Draft
Fused Adam Support for MXFP8 + FSDP2 integration#2780vthumbe1503 wants to merge 8 commits intoNVIDIA:mainfrom
vthumbe1503 wants to merge 8 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Collaborator
Author
|
/te-ci L1 pytorch |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Contributor
Greptile SummaryThis PR adds a fused Adam optimizer kernel for MXFP8 (MX block-scaling FP8) model weights, integrating it into the existing Key changes:
Issues found:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant PY as FusedAdam.step() (Python)
participant EXT as adam.cpp (PyTorch ext)
participant CU as adam.cu (CUDA)
participant APPLY as multi_tensor_apply_mxfp8
participant KERNEL as adam_mxfp8_fused_kernel
PY->>PY: per-param loop — classify params<br/>(Float8/MXFP8/F16/F32)
PY->>PY: accumulate into p_mxfp8_rowwise,<br/>p_mxfp8_colwise, moments, master_param
PY->>EXT: multi_tensor_adam_mxfp8(chunk_size,<br/>noop_flag, 8 tensor lists, …, fp8_dtype)
EXT->>EXT: makeTransformerEngineTensorList()<br/>validate num_lists == 8
EXT->>CU: nvte_multi_tensor_adam_mxfp8_cuda(…)
CU->>CU: compute_bias_correction()<br/>check_tensor_list_sizes()<br/>dtype validation
CU->>APPLY: multi_tensor_apply_mxfp8<kernel>(…)
loop For each tensor (batched ≤ MXFP8_MAX_TENSORS=24 tensors,<br/>≤ MXFP8_MAX_BLOCKS=320 blocks per launch)
APPLY->>APPLY: build MXFP8TensorListMetadata<br/>(block_to_tensor, block_to_tile, rows, cols)
APPLY->>KERNEL: Kernel<<<blocks, 256>>>(tl, β1, β2, ε, lr, …)
KERNEL->>KERNEL: Stage 4: Adam update → p/m/v (FP32)
KERNEL->>KERNEL: Stage 5: atomicMaxFloat → row/col amax (shared mem)
KERNEL->>KERNEL: Stage 6: write rowwise & colwise scale-inv (e8m0)
KERNEL->>KERNEL: Stage 7: quantise p → MXFP8 rowwise + colwise data
end
CU-->>PY: return (master params, moments, MXFP8 data, scales updated in-place)
Last reviewed commit: "address review comme..." |
…rmerEngine into fused_adam_for_mxfp8
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…rmerEngine into fused_adam_for_mxfp8
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Collaborator
Author
|
Need more perf tuning for mxfp8, Will make the PR active after the desired perf is achieved. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: