From e355c38c111175a7d46b7dd27c38523135e9429b Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 2 Mar 2026 03:10:24 +0000 Subject: [PATCH 1/6] fix merge conflicts, now things working Signed-off-by: Varun Thumbe --- 3rdparty/cudnn-frontend | 2 +- tests/cpp/operator/CMakeLists.txt | 1 + .../operator/test_multi_tensor_adam_mxfp8.cu | 266 ++++++++++++ tests/cpp/test_common.h | 10 + .../distributed/run_fsdp2_fused_adam.py | 8 +- tests/pytorch/distributed/test_torch_fsdp2.py | 5 - .../include/transformer_engine/multi_tensor.h | 35 ++ .../common/multi_tensor/adam.cu | 384 +++++++++++++++--- .../multi_tensor/multi_tensor_apply.cuh | 94 +++++ transformer_engine/pytorch/csrc/extensions.h | 6 + .../csrc/extensions/multi_tensor/adam.cpp | 20 + .../pytorch/csrc/extensions/pybind.cpp | 2 + .../pytorch/optimizers/fused_adam.py | 45 +- 13 files changed, 812 insertions(+), 66 deletions(-) create mode 100644 tests/cpp/operator/test_multi_tensor_adam_mxfp8.cu diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index d33027a41a..8d19d3182b 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit d33027a41a93af9c85f089c6364ab415fce98982 +Subproject commit 8d19d3182bfbc304046a15e9236bec9ff31511fc diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 5e73675f4f..4241ada3ba 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -27,6 +27,7 @@ add_executable(test_operator test_memset.cu test_splits_to_offsets.cu test_multi_cast_transpose.cu + test_multi_tensor_adam_mxfp8.cu test_multi_padding.cu test_multi_unpadding.cu test_causal_softmax.cu diff --git a/tests/cpp/operator/test_multi_tensor_adam_mxfp8.cu b/tests/cpp/operator/test_multi_tensor_adam_mxfp8.cu new file mode 100644 index 0000000000..470917580f --- /dev/null +++ b/tests/cpp/operator/test_multi_tensor_adam_mxfp8.cu @@ -0,0 +1,266 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +#include "../test_common.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +uint8_t fp8_to_u8(fp8e4m3 v) { + uint8_t out = 0; + std::memcpy(&out, &v, sizeof(uint8_t)); + return out; +} + +uint8_t fp8_to_u8(fp8e5m2 v) { + uint8_t out = 0; + std::memcpy(&out, &v, sizeof(uint8_t)); + return out; +} + +void run_mxfp8_adam_test(DType fp8_dtype) { + const std::vector shape1{64, 128}; + const std::vector shape2{32, 64}; + const float lr = 1e-3f; + const float beta1 = 0.9f; + const float beta2 = 0.999f; + const float eps = 1e-8f; + const int step = 1; + const int mode = 1; + const int bias_correction = 1; + const float weight_decay = 0.0f; + + // Run with 25 tensors > 24[MXFP8_MAX_TENSORS] to check + // the chunking logic + const size_t tensor_count = 25; + std::vector> shapes; + shapes.reserve(tensor_count); + for (size_t i = 0; i < tensor_count; ++i) { + shapes.push_back((i % 2 == 0) ? shape1 : shape2); + } + + std::vector names; + names.reserve(tensor_count * 11); + std::vector g; + std::vector p; + std::vector m; + std::vector v; + std::vector p_ref_t; + std::vector m_ref_t; + std::vector v_ref_t; + std::vector q_ref; + std::vector dq; + std::vector dq_ref; + std::vector q; + g.reserve(tensor_count); + p.reserve(tensor_count); + m.reserve(tensor_count); + v.reserve(tensor_count); + p_ref_t.reserve(tensor_count); + m_ref_t.reserve(tensor_count); + v_ref_t.reserve(tensor_count); + q_ref.reserve(tensor_count); + dq.reserve(tensor_count); + dq_ref.reserve(tensor_count); + q.reserve(tensor_count); + + for (size_t i = 0; i < tensor_count; ++i) { + const std::vector &shape = shapes[i]; + names.push_back("g" + std::to_string(i)); + g.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false); + names.push_back("p" + std::to_string(i)); + p.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false); + names.push_back("m" + std::to_string(i)); + m.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false); + names.push_back("v" + std::to_string(i)); + v.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false); + + fillUniform(&g.back()); + fillUniform(&p.back()); + std::fill_n(m.back().rowwise_cpu_dptr(), product(m.back().rowwise_shape()), 0.0f); + std::fill_n(v.back().rowwise_cpu_dptr(), product(v.back().rowwise_shape()), 0.0f); + m.back().from_cpu(); + v.back().from_cpu(); + + names.push_back("p_ref_" + std::to_string(i)); + p_ref_t.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false); + names.push_back("m_ref_" + std::to_string(i)); + m_ref_t.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false); + names.push_back("v_ref_" + std::to_string(i)); + v_ref_t.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false); + const size_t n = shape[0] * shape[1]; + std::memcpy(p_ref_t.back().rowwise_cpu_dptr(), p.back().rowwise_cpu_dptr(), + n * sizeof(float)); + std::memcpy(m_ref_t.back().rowwise_cpu_dptr(), m.back().rowwise_cpu_dptr(), + n * sizeof(float)); + std::memcpy(v_ref_t.back().rowwise_cpu_dptr(), v.back().rowwise_cpu_dptr(), + n * sizeof(float)); + p_ref_t.back().from_cpu(); + m_ref_t.back().from_cpu(); + v_ref_t.back().from_cpu(); + + names.push_back("q_ref_" + std::to_string(i)); + q_ref.emplace_back(names.back().c_str(), shape, fp8_dtype, true, true, NVTE_MXFP8_1D_SCALING); + q_ref.back().set_with_gemm_swizzled_scales(false); + + names.push_back("dq" + std::to_string(i)); + dq.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false); + names.push_back("dq_ref_" + std::to_string(i)); + dq_ref.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false); + + names.push_back("q" + std::to_string(i)); + q.emplace_back(names.back().c_str(), shape, fp8_dtype, true, true, NVTE_MXFP8_1D_SCALING); + q.back().set_with_gemm_swizzled_scales(false); + } + + Tensor noop("noop", std::vector{1}, DType::kInt32, true, false); + int zero = 0; + std::memcpy(noop.rowwise_cpu_dptr(), &zero, sizeof(int)); + noop.from_cpu(); + + std::vector> lists(8); + std::vector extra_wrappers; + extra_wrappers.reserve(tensor_count * 4); + + auto add_tensor = [&](Tensor &g, Tensor &p, Tensor &m, Tensor &v, Tensor &q) { + lists[0].push_back(g.data()); + lists[1].push_back(p.data()); + lists[2].push_back(m.data()); + lists[3].push_back(v.data()); + + extra_wrappers.emplace_back(q.rowwise_dptr(), q.rowwise_shape(), fp8_dtype); + lists[4].push_back(extra_wrappers.back().data()); + extra_wrappers.emplace_back(q.columnwise_dptr(), q.columnwise_shape(), fp8_dtype); + lists[5].push_back(extra_wrappers.back().data()); + extra_wrappers.emplace_back(q.rowwise_scale_inv_dptr(), q.rowwise_scale_inv_shape(), + DType::kByte); + lists[6].push_back(extra_wrappers.back().data()); + extra_wrappers.emplace_back(q.columnwise_scale_inv_dptr(), q.columnwise_scale_inv_shape(), + DType::kByte); + lists[7].push_back(extra_wrappers.back().data()); + }; + + for (size_t i = 0; i < tensor_count; ++i) { + add_tensor(g[i], p[i], m[i], v[i], q[i]); + } + + std::vector list_ptrs; + list_ptrs.reserve(lists.size()); + for (auto &l : lists) { + list_ptrs.push_back(l.data()); + } + + nvte_multi_tensor_adam_mxfp8_cuda(65536, noop.data(), list_ptrs.data(), lists.size(), + lists[0].size(), static_cast(fp8_dtype), lr, beta1, + beta2, eps, step, mode, bias_correction, weight_decay, 0); + + std::vector> ref_lists(4); + for (size_t i = 0; i < tensor_count; ++i) { + ref_lists[0].push_back(g[i].data()); + ref_lists[1].push_back(p_ref_t[i].data()); + ref_lists[2].push_back(m_ref_t[i].data()); + ref_lists[3].push_back(v_ref_t[i].data()); + } + std::vector ref_list_ptrs; + ref_list_ptrs.reserve(ref_lists.size()); + for (auto &l : ref_lists) { + ref_list_ptrs.push_back(l.data()); + } + + nvte_multi_tensor_adam_cuda(65536, noop.data(), ref_list_ptrs.data(), ref_lists.size(), + ref_lists[0].size(), lr, beta1, beta2, eps, step, mode, + bias_correction, weight_decay, 0); + + for (size_t i = 0; i < tensor_count; ++i) { + nvte_quantize(p_ref_t[i].data(), q_ref[i].data(), 0); + nvte_dequantize(q[i].data(), dq[i].data(), 0); + nvte_dequantize(q_ref[i].data(), dq_ref[i].data(), 0); + } + + cudaDeviceSynchronize(); + + for (size_t i = 0; i < tensor_count; ++i) { + q[i].to_cpu(); + p[i].to_cpu(); + m[i].to_cpu(); + v[i].to_cpu(); + q_ref[i].to_cpu(); + dq[i].to_cpu(); + dq_ref[i].to_cpu(); + p_ref_t[i].to_cpu(); + m_ref_t[i].to_cpu(); + v_ref_t[i].to_cpu(); + } + + for (size_t i = 0; i < lists[0].size(); ++i) { + const Tensor &g_i = g[i]; + const Tensor &p_i = p[i]; + const Tensor &m_i = m[i]; + const Tensor &v_i = v[i]; + Tensor &q_i = q[i]; + const Tensor &p_ref_t_i = p_ref_t[i]; + const Tensor &m_ref_t_i = m_ref_t[i]; + const Tensor &v_ref_t_i = v_ref_t[i]; + Tensor &q_ref_i = q_ref[i]; + + compareResults("p", p_i, p_ref_t_i.rowwise_cpu_dptr(), true, 0.0, 0.0, true, 0); + compareResults("m", m_i, m_ref_t_i.rowwise_cpu_dptr(), true, 0.0, 0.0, true, 0); + compareResults("v", v_i, v_ref_t_i.rowwise_cpu_dptr(), true, 0.0, 0.0, true, 0); + + const Tensor &dq_i = dq[i]; + const Tensor &dq_ref_i = dq_ref[i]; + compareResults("dequantized", dq_i, dq_ref_i.rowwise_cpu_dptr(), true, 0.0, 0.0, true, + 0); + + const size_t rs = q_i.rowwise_scale_inv_shape().data[1]; + const size_t cs = q_i.columnwise_scale_inv_shape().data[1]; + const size_t rowwise_scale_size = q_i.rowwise_scale_inv_shape().data[0] * rs; + const size_t colwise_scale_size = q_i.columnwise_scale_inv_shape().data[0] * cs; + compareResults("rowwise_scale", q_i.rowwise_cpu_scale_inv_ptr(), + q_ref_i.rowwise_cpu_scale_inv_ptr(), rowwise_scale_size, 0.0f); + compareResults("colwise_scale", q_i.columnwise_cpu_scale_inv_ptr(), + q_ref_i.columnwise_cpu_scale_inv_ptr(), colwise_scale_size, 0.0f); + + uint8_t *row_data = nullptr; + uint8_t *col_data = nullptr; + uint8_t *row_data_ref = nullptr; + uint8_t *col_data_ref = nullptr; + if (fp8_dtype == DType::kFloat8E4M3) { + row_data = reinterpret_cast(q_i.rowwise_cpu_dptr()); + col_data = reinterpret_cast(q_i.columnwise_cpu_dptr()); + row_data_ref = reinterpret_cast(q_ref_i.rowwise_cpu_dptr()); + col_data_ref = reinterpret_cast(q_ref_i.columnwise_cpu_dptr()); + } else { + row_data = reinterpret_cast(q_i.rowwise_cpu_dptr()); + col_data = reinterpret_cast(q_i.columnwise_cpu_dptr()); + row_data_ref = reinterpret_cast(q_ref_i.rowwise_cpu_dptr()); + col_data_ref = reinterpret_cast(q_ref_i.columnwise_cpu_dptr()); + } + const size_t data_size = q_i.rowwise_shape().data[0] * q_i.rowwise_shape().data[1]; + compareResults("rowwise_data", row_data, row_data_ref, data_size, 0.0f); + compareResults("colwise_data", col_data, col_data_ref, data_size, 0.0f); + } +} + +} // namespace + +TEST(MultiTensorAdamMXFP8, E4M3) { run_mxfp8_adam_test(DType::kFloat8E4M3); } + +TEST(MultiTensorAdamMXFP8, E5M2) { run_mxfp8_adam_test(DType::kFloat8E5M2); } diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 927407f478..eab181fa82 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -200,6 +200,16 @@ class Tensor { return tensor_.get_columnwise_data().data_ptr; } + void *rowwise_scale_inv_dptr() const { + NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); + return tensor_.get_rowwise_scale_inv().data_ptr; + } + + void *columnwise_scale_inv_dptr() const { + NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); + return tensor_.get_columnwise_scale_inv().data_ptr; + } + template T *rowwise_cpu_dptr() const { NVTE_CHECK(TypeInfo::dtype == tensor_.dtype(), "Invalid type!"); diff --git a/tests/pytorch/distributed/run_fsdp2_fused_adam.py b/tests/pytorch/distributed/run_fsdp2_fused_adam.py index c39957cf13..34764d4e0a 100644 --- a/tests/pytorch/distributed/run_fsdp2_fused_adam.py +++ b/tests/pytorch/distributed/run_fsdp2_fused_adam.py @@ -36,7 +36,11 @@ def get_recipe_from_string(recipe): SEQ_LEN = 32 BATCH_PER_RANK = 2 NUM_STEPS = 3 +LOCAL_RANK = None +def dist_print(msg): + if LOCAL_RANK == 0: + print(msg) def save_custom_attrs(module): custom_attrs = {} @@ -151,6 +155,8 @@ def test_fused_adam_fp8_master_weights(recipe=None): - Training loop completes without error - DTensor wrapping and QuantizedTensor local tensors are preserved """ + global LOCAL_RANK + LOCAL_RANK = int(os.environ["LOCAL_RANK"]) world_size, _, device = _setup() model = _build_model(fp8_init=True, recipe=recipe) @@ -183,7 +189,7 @@ def test_fused_adam_fp8_master_weights(recipe=None): loss = F.mse_loss(output, target) loss.backward() optimizer.step() - + dist_print(f"Step {step} completed with loss {loss.item()}") # Verify optimizer states for param in model.parameters(): state = optimizer.state[param] diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py index 02e45d99cb..6d7ae4d7bb 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2.py +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -224,11 +224,6 @@ def test_fsdp2_dcp_output_parity_async(fp_recipe): @pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs") def test_fsdp2_safetensors_fp32_export(fp_recipe): """Export FP32 model from optimizer master weights to safetensors.""" - if fp_recipe == "MXFP8BlockScaling": - pytest.xfail( - "MXFP8BlockScaling: FusedAdam CUDA kernel does not support " - "MXFP8 quantized tensors, causing illegal memory access" - ) _run_fused_adam_test("safetensors_fp32_export", fp_recipe) diff --git a/transformer_engine/common/include/transformer_engine/multi_tensor.h b/transformer_engine/common/include/transformer_engine/multi_tensor.h index 09ab260f15..90c87b166e 100644 --- a/transformer_engine/common/include/transformer_engine/multi_tensor.h +++ b/transformer_engine/common/include/transformer_engine/multi_tensor.h @@ -149,6 +149,41 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, const float weight_decay, const NVTEDType fp8_dtype, cudaStream_t stream); +/*! \brief Compute and apply gradient update to parameters for Adam optimizer + * when model parameters are in MXFP8 precision. + * + * The update is applied to FP32 master parameters, then the master + * parameters are quantized to MXFP8 rowwise and columnwise data + * (both are always required). + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in,out] tensor_lists 2D array of input tensors with 8 lists in order: + * (0) gradients, (1) FP32 master params, (2) first moment, + * (3) second moment, (4) rowwise MXFP8 data, + * (5) columnwise MXFP8 data, (6) rowwise scale-inv, + * (7) columnwise scale-inv. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. Must be 8. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] fp8_dtype MXFP8 element type for quantization (E4M3/E5M2). + * \param[in] lr Learning rate. + * \param[in] beta1 Coefficient for first moment of gradient. + * \param[in] beta2 Coefficient for second moment of gradient. + * \param[in] epsilon Term added to the denominator for numerical stability. + * \param[in] step Iteration counter. + * \param[in] mode Whether to use AdamW (L2 penalty applied to params). + * \param[in] bias_correction Whether to apply correction factor for moment estimates. + * \param[in] weight_decay L2 penalty for weight decay. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_multi_tensor_adam_mxfp8_cuda( + int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, + const size_t num_tensor_lists, const size_t num_tensors_per_list, const NVTEDType fp8_dtype, + const float lr, const float beta1, const float beta2, const float epsilon, const int step, + const int mode, const int bias_correction, const float weight_decay, cudaStream_t stream); + /*! \brief Compute and apply gradient update to parameters for Adam optimizer * with CUDA graph support and LR scheduling. * diff --git a/transformer_engine/common/multi_tensor/adam.cu b/transformer_engine/common/multi_tensor/adam.cu index 29a073be84..fa75c645f3 100644 --- a/transformer_engine/common/multi_tensor/adam.cu +++ b/transformer_engine/common/multi_tensor/adam.cu @@ -4,12 +4,16 @@ * See LICENSE for license information. ************************************************************************/ +#include #include #include #include #include +#include "../common.h" +#include "../util/math.h" #include "../utils.cuh" +#include "../util/ptx.cuh" #include "multi_tensor_apply.cuh" namespace transformer_engine { @@ -27,6 +31,7 @@ typedef enum { using MATH_T = float; using fp8e4m3 = __nv_fp8_e4m3; using fp8e5m2 = __nv_fp8_e5m2; +using e8m0_t = transformer_engine::e8m0_t; template struct is_fp8 : std::false_type {}; @@ -49,6 +54,31 @@ struct FP8Data { template <> struct FP8Data {}; +template +__device__ __forceinline__ void adam_update(T &r_g, T &r_p, T &r_m, T &r_v, const float beta1, + const float beta2, const float beta1_correction, + const float beta2_correction, const float epsilon, + const float lr, adamMode_t mode, const float decay) { + if (mode == ADAM_MODE_0) { // L2 + r_g = r_g + (decay * r_p); + r_m = beta1 * r_m + (1 - beta1) * r_g; + r_v = beta2 * r_v + (1 - beta2) * r_g * r_g; + T next_m_unbiased = r_m / beta1_correction; + T next_v_unbiased = r_v / beta2_correction; + T denom = sqrtf(next_v_unbiased) + epsilon; + T update = next_m_unbiased / denom; + r_p = r_p - (lr * update); + } else { // weight decay + r_m = beta1 * r_m + (1 - beta1) * r_g; + r_v = beta2 * r_v + (1 - beta2) * r_g * r_g; + T next_m_unbiased = r_m / beta1_correction; + T next_v_unbiased = r_v / beta2_correction; + T denom = sqrtf(next_v_unbiased) + epsilon; + T update = (next_m_unbiased / denom) + (decay * r_p); + r_p = r_p - (lr * update); + } +} + template struct AdamFunctorMaster { static constexpr bool is_fp8_type = is_fp8::value; @@ -122,24 +152,8 @@ struct AdamFunctorMaster { } #pragma unroll for (int ii = 0; ii < ILP; ii++) { - if (mode == ADAM_MODE_0) { // L2 - r_g[ii] = r_g[ii] + (decay * r_p[ii]); - r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; - r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; - MATH_T next_m_unbiased = r_m[ii] / beta1_correction; - MATH_T next_v_unbiased = r_v[ii] / beta2_correction; - MATH_T denom = sqrtf(next_v_unbiased) + epsilon; - MATH_T update = next_m_unbiased / denom; - r_p[ii] = r_p[ii] - (lr * update); - } else { // weight decay - r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; - r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; - MATH_T next_m_unbiased = r_m[ii] / beta1_correction; - MATH_T next_v_unbiased = r_v[ii] / beta2_correction; - MATH_T denom = sqrtf(next_v_unbiased) + epsilon; - MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]); - r_p[ii] = r_p[ii] - (lr * update); - } + adam_update(r_g[ii], r_p[ii], r_m[ii], r_v[ii], beta1, beta2, beta1_correction, + beta2_correction, epsilon, lr, mode, decay); } #pragma unroll @@ -572,6 +586,188 @@ struct AdamCapturableMasterFunctor { } }; +template +__device__ __forceinline__ FP8_T cast_to_fp8(float x) { + return static_cast(x); +} + +__device__ __forceinline__ float fp8_max_norm_rcp(uint8_t fp8_dtype) { + if (fp8_dtype == static_cast(transformer_engine::DType::kFloat8E4M3)) { + return transformer_engine::Quantized_Limits::max_norm_rcp; + } + return transformer_engine::Quantized_Limits::max_norm_rcp; +} + +template +__global__ void adam_mxfp8_fused_kernel( + int64_t chunk_size, volatile int *noop_gmem, MXFP8TensorListMetadata tl, float beta1, + float beta2, float beta1_correction, float beta2_correction, float epsilon, float lr, int mode, + float weight_decay) { + // Stage 0: optional early-exit if a noop flag is set. + if (noop_gmem != nullptr && *noop_gmem == 1) { + return; + } + (void)chunk_size; + + // Stage 1: map this block to a specific tensor tile. + const int block_idx = blockIdx.x; + const int tensor_idx = tl.block_to_tensor[block_idx]; + const int tile_idx = tl.block_to_tile[block_idx]; + const int64_t rows_val = tl.rows[tensor_idx]; + const int64_t cols_val = tl.cols[tensor_idx]; + if (rows_val == 0 || cols_val == 0) { + return; + } + + const int64_t tiles_per_row = (cols_val + MXFP8_TILE - 1) / MXFP8_TILE; + const int64_t tile_row = tile_idx / tiles_per_row; + const int64_t tile_col = tile_idx % tiles_per_row; + const int64_t row_base = tile_row * MXFP8_TILE; + const int64_t col_base = tile_col * MXFP8_TILE; + + // Stage 2: load pointers for grads/params/moments and MXFP8 outputs/scales. + GRAD_T *g = reinterpret_cast(tl.addresses[0][tensor_idx]); + PARAM_T *p = reinterpret_cast(tl.addresses[1][tensor_idx]); + MOMENT_T *m = reinterpret_cast(tl.addresses[2][tensor_idx]); + MOMENT_T *v = reinterpret_cast(tl.addresses[3][tensor_idx]); + + auto *rowwise_data = reinterpret_cast(tl.addresses[4][tensor_idx]); + auto *colwise_data = reinterpret_cast(tl.addresses[5][tensor_idx]); + auto *rowwise_scale_inv = reinterpret_cast(tl.addresses[6][tensor_idx]); + auto *colwise_scale_inv = reinterpret_cast(tl.addresses[7][tensor_idx]); + + const int64_t unpadded_scales_X_rowwise = (cols_val + MXFP8_TILE - 1) / MXFP8_TILE; + constexpr int64_t kRowwiseScaleAlign = 4; + const int64_t row_stride = + DIVUP_TO_MULTIPLE(unpadded_scales_X_rowwise, kRowwiseScaleAlign); + constexpr int64_t kColwiseScaleAlign = 128; + const int64_t col_stride = DIVUP_TO_MULTIPLE(cols_val, kColwiseScaleAlign); + const uint8_t dtype = tl.fp8_dtype[tensor_idx]; + const auto adam_mode = static_cast(mode); + + // Stage 3: initialize shared amax accumulators per row/col within the tile. + __shared__ float row_max_vals[MXFP8_TILE]; + __shared__ float col_max_vals[MXFP8_TILE]; + if (threadIdx.x < MXFP8_TILE) { + row_max_vals[threadIdx.x] = 0.0f; + col_max_vals[threadIdx.x] = 0.0f; + } + __syncthreads(); + + for (int t = threadIdx.x; t < MXFP8_TILE_ELEMS; t += blockDim.x) { + const int local_r = t / MXFP8_TILE; + const int local_c = t % MXFP8_TILE; + const int64_t r = row_base + local_r; + const int64_t c = col_base + local_c; + if (r >= rows_val || c >= cols_val) { + continue; + } + const index_t idx = static_cast(r * cols_val + c); + + float r_g = static_cast(g[idx]); + float r_p = static_cast(p[idx]); + float r_m = static_cast(m[idx]); + float r_v = static_cast(v[idx]); + + // Stage 4: apply Adam update in FP32 and write back updated p/m/v. + transformer_engine::multi_tensor_adam::adam_update( + r_g, r_p, r_m, r_v, beta1, beta2, beta1_correction, beta2_correction, epsilon, lr, + adam_mode, weight_decay); + + p[idx] = static_cast(r_p); + m[idx] = static_cast(r_m); + v[idx] = static_cast(r_v); + + // Stage 5: accumulate per-row/col absmax for MXFP8 scaling. + const float abs_p = fabsf(r_p); + transformer_engine::atomicMaxFloat(&row_max_vals[local_r], abs_p); + transformer_engine::atomicMaxFloat(&col_max_vals[local_c], abs_p); + } + + __syncthreads(); + + // Stage 6: write rowwise/colwise scale-inverse exponents for the tile. + const float max_norm_rcp = fp8_max_norm_rcp(dtype); + + for (int r = threadIdx.x; r < MXFP8_TILE; r += blockDim.x) { + const int64_t row = row_base + r; + if (row >= rows_val) { + continue; + } + const float amax = row_max_vals[r]; + const ::transformer_engine::e8m0_t biased_exponent = + transformer_engine::ptx::float_to_e8m0(amax * max_norm_rcp); + const size_t scale_idx = static_cast(row * row_stride + tile_col); + rowwise_scale_inv[scale_idx] = reinterpret_cast(biased_exponent); + } + + for (int c = threadIdx.x; c < MXFP8_TILE; c += blockDim.x) { + const int64_t col = col_base + c; + if (col >= cols_val) { + continue; + } + const float amax = col_max_vals[c]; + const ::transformer_engine::e8m0_t biased_exponent = + transformer_engine::ptx::float_to_e8m0(amax * max_norm_rcp); + const size_t scale_idx = static_cast(tile_row * col_stride + col); + colwise_scale_inv[scale_idx] = reinterpret_cast(biased_exponent); + } + + __syncthreads(); + + // Stage 7: quantize updated params to MXFP8 using rowwise and colwise scales. + for (int t = threadIdx.x; t < MXFP8_TILE_ELEMS; t += blockDim.x) { + const int local_r = t / MXFP8_TILE; + const int local_c = t % MXFP8_TILE; + const int64_t r = row_base + local_r; + const int64_t c = col_base + local_c; + if (r >= rows_val || c >= cols_val) { + continue; + } + const index_t idx = static_cast(r * cols_val + c); + const float r_p = static_cast(p[idx]); + + const size_t row_scale_idx = static_cast(r * row_stride + tile_col); + const uint8_t row_raw = rowwise_scale_inv[row_scale_idx]; + const ::transformer_engine::e8m0_t row_biased = + reinterpret_cast(row_raw); + const float row_scale_inv = transformer_engine::ptx::exp2f_rcp(row_biased); + if (dtype == static_cast(transformer_engine::DType::kFloat8E4M3)) { + auto *out = reinterpret_cast(rowwise_data); + out[idx] = cast_to_fp8(r_p * row_scale_inv); + } else { + auto *out = reinterpret_cast(rowwise_data); + out[idx] = cast_to_fp8(r_p * row_scale_inv); + } + + const size_t col_scale_idx = static_cast(tile_row * col_stride + c); + const uint8_t col_raw = colwise_scale_inv[col_scale_idx]; + const ::transformer_engine::e8m0_t col_biased = + reinterpret_cast(col_raw); + const float col_scale_inv = transformer_engine::ptx::exp2f_rcp(col_biased); + if (dtype == static_cast(transformer_engine::DType::kFloat8E4M3)) { + auto *out = reinterpret_cast(colwise_data); + out[idx] = cast_to_fp8(r_p * col_scale_inv); + } else { + auto *out = reinterpret_cast(colwise_data); + out[idx] = cast_to_fp8(r_p * col_scale_inv); + } + } +} + +inline bool requires_64bit_indexing(const std::vector> &tensor_lists) { + const size_t num_tensor_lists = tensor_lists.size(); + const size_t num_tensors_per_list = tensor_lists[0].size(); + for (size_t i = 0; i < num_tensor_lists; ++i) { + for (size_t j = 0; j < num_tensors_per_list; ++j) { + if (tensor_lists[i][j]->numel() >= INT_MAX) { + return true; + } + } + } + return false; +} + void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, std::vector> tensor_lists, const float lr, const float beta1, const float beta2, const float epsilon, @@ -624,25 +820,13 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, } } - // Check if 64-bit indices are required - bool requires_64bit_indexing = false; - for (size_t i = 0; i < num_tensor_lists; i++) { - for (size_t j = 0; j < num_tensors_per_list; j++) { - if (tensor_lists[i][j]->numel() >= INT_MAX) { - requires_64bit_indexing = true; - break; - } - } - if (requires_64bit_indexing) { - break; - } - } + const bool use_64bit_indexing = requires_64bit_indexing(tensor_lists); // Get moment dtype (m and v have the same dtype, already validated above) const auto moment_type_te = tensor_lists[2][0]->dtype(); // Launch kernel - if (requires_64bit_indexing) { + if (use_64bit_indexing) { if (num_tensor_lists == 4) { // g, p, m, v TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( @@ -766,28 +950,41 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag, NVTE_CHECK_CUDA(cudaGetLastError()); } -void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, - std::vector> tensor_lists, const float lr, - const float beta1, const float beta2, const float epsilon, - const int step, const int mode, const int bias_correction, - const float weight_decay, const DType fp8_dtype, - cudaStream_t stream) { - // Handle bias correction mode - float bias_correction1 = 1.0f, bias_correction2 = 1.0f; +inline std::pair compute_bias_correction(int bias_correction, float beta1, + float beta2, int step) { + float bias_correction1 = 1.0f; + float bias_correction2 = 1.0f; if (bias_correction == 1) { bias_correction1 = 1 - std::pow(beta1, step); bias_correction2 = 1 - std::pow(beta2, step); } + return {bias_correction1, bias_correction2}; +} - // Check tensor list sizes - // 8 tensor lists: g, p_fp8, m, v, p_master, scale, amax, scale_inv +inline void check_tensor_list_sizes(const std::vector> &tensor_lists, + size_t expected_lists) { const size_t num_tensor_lists = tensor_lists.size(); - NVTE_CHECK(num_tensor_lists == 8, "Expected 8 tensor lists, but found ", num_tensor_lists); + NVTE_CHECK(num_tensor_lists == expected_lists, "Expected ", expected_lists, + " tensor lists, but found ", num_tensor_lists); const size_t num_tensors_per_list = tensor_lists[0].size(); - for (size_t i = 1; i < num_tensor_lists; i++) { + for (size_t i = 1; i < num_tensor_lists; ++i) { NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i, " has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list); } +} + + +void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, + std::vector> tensor_lists, const float lr, + const float beta1, const float beta2, const float epsilon, + const int step, const int mode, const int bias_correction, + const float weight_decay, const DType fp8_dtype, + cudaStream_t stream) { + auto [bias_correction1, bias_correction2] = + compute_bias_correction(bias_correction, beta1, beta2, step); + check_tensor_list_sizes(tensor_lists, 8); + const size_t num_tensor_lists = tensor_lists.size(); + const size_t num_tensors_per_list = tensor_lists[0].size(); // Check tensor dtypes const auto g_in_type_te = tensor_lists[0][0]->dtype(); @@ -819,22 +1016,10 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, ", but expected dtype=", to_string(DType::kFloat32)); } - // Check if 64-bit indices are required - bool requires_64bit_indexing = false; - for (size_t i = 0; i < num_tensor_lists; i++) { - for (size_t j = 0; j < num_tensors_per_list; j++) { - if (tensor_lists[i][j]->numel() >= INT_MAX) { - requires_64bit_indexing = true; - break; - } - } - if (requires_64bit_indexing) { - break; - } - } + const bool use_64bit_indexing = requires_64bit_indexing(tensor_lists); // Launch kernel - if (requires_64bit_indexing) { + if (use_64bit_indexing) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( fp8_dtype, FP8_T, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( @@ -856,6 +1041,76 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, NVTE_CHECK_CUDA(cudaGetLastError()); } +void multi_tensor_adam_mxfp8_cuda(int chunk_size, Tensor noop_flag, + std::vector> tensor_lists, const float lr, + const float beta1, const float beta2, const float epsilon, + const int step, const int mode, const int bias_correction, + const float weight_decay, const DType fp8_dtype, + cudaStream_t stream) { + auto [bias_correction1, bias_correction2] = + compute_bias_correction(bias_correction, beta1, beta2, step); + check_tensor_list_sizes(tensor_lists, 8); + const size_t num_tensor_lists = tensor_lists.size(); + const size_t num_tensors_per_list = tensor_lists[0].size(); + + NVTE_CHECK(fp8_dtype == DType::kFloat8E4M3 || fp8_dtype == DType::kFloat8E5M2, + "fp8_dtype must be E4M3 or E5M2 for MXFP8 fused Adam."); + + // Check tensor dtypes + const auto g_in_type_te = tensor_lists[0][0]->dtype(); + const auto p_in_type_te = tensor_lists[1][0]->dtype(); + const auto moment_type_te = tensor_lists[2][0]->dtype(); + for (size_t j = 0; j < num_tensors_per_list; ++j) { + NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j, + " has dtype=", to_string(tensor_lists[0][j]->dtype()), + ", but expected dtype=", to_string(g_in_type_te)); + NVTE_CHECK(tensor_lists[1][j]->dtype() == p_in_type_te, "Param tensor ", j, + " has dtype=", to_string(tensor_lists[1][j]->dtype()), + ", but expected dtype=", to_string(p_in_type_te)); + { + const bool m_is_fp32 = tensor_lists[2][j]->dtype() == DType::kFloat32; + const bool m_is_bf16 = tensor_lists[2][j]->dtype() == DType::kBFloat16; + const bool v_is_fp32 = tensor_lists[3][j]->dtype() == DType::kFloat32; + const bool v_is_bf16 = tensor_lists[3][j]->dtype() == DType::kBFloat16; + NVTE_CHECK((m_is_fp32 && v_is_fp32) || (m_is_bf16 && v_is_bf16), + "First and second moment tensors must both be Float32 or both be BFloat16, but " + "tensor ", + j, " has first moment dtype=", to_string(tensor_lists[2][j]->dtype()), + " and second moment dtype=", to_string(tensor_lists[3][j]->dtype())); + } + } + + const bool use_64bit_indexing = requires_64bit_indexing(tensor_lists); + + if (use_64bit_indexing) { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + p_in_type_te, p_in_type, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + g_in_type_te, g_in_type, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP32_BF16( + moment_type_te, moment_type, + multi_tensor_apply_mxfp8< + transformer_engine::multi_tensor_adam::adam_mxfp8_fused_kernel< + p_in_type, g_in_type, moment_type, int64_t>>( + chunk_size, noop_flag, tensor_lists, static_cast(fp8_dtype), stream, + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, mode, + weight_decay);))); + } else { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + p_in_type_te, p_in_type, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + g_in_type_te, g_in_type, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP32_BF16( + moment_type_te, moment_type, + multi_tensor_apply_mxfp8< + transformer_engine::multi_tensor_adam::adam_mxfp8_fused_kernel< + p_in_type, g_in_type, moment_type, int32_t>>( + chunk_size, noop_flag, tensor_lists, static_cast(fp8_dtype), stream, + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, mode, + weight_decay);))); + } +} + void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag, std::vector> tensor_lists, Tensor lr, const float beta1, const float beta2, const float epsilon, @@ -1018,6 +1273,19 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, epsilon, step, mode, bias_correction, weight_decay, static_cast(fp8_dtype), stream); } +void nvte_multi_tensor_adam_mxfp8_cuda( + int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, + const size_t num_tensor_lists, const size_t num_tensors_per_list, const NVTEDType fp8_dtype, + const float lr, const float beta1, const float beta2, const float epsilon, const int step, + const int mode, const int bias_correction, const float weight_decay, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_tensor_adam_mxfp8_cuda); + using namespace transformer_engine; + multi_tensor_adam::multi_tensor_adam_mxfp8_cuda( + chunk_size, *convertNVTETensorCheck(noop_flag), + convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2, + epsilon, step, mode, bias_correction, weight_decay, static_cast(fp8_dtype), stream); +} + void nvte_multi_tensor_adam_capturable_cuda( int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2, diff --git a/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh b/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh index 3062ead551..c334f3908e 100644 --- a/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh +++ b/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh @@ -35,6 +35,23 @@ struct TensorListMetadata : public TensorListMetadataBase { void *fp8_meta_addresses[3][depth_to_max_tensors[n - 1]]; }; +constexpr int MXFP8_TILE = 32; +constexpr int MXFP8_TILE_ELEMS = MXFP8_TILE * MXFP8_TILE; +constexpr int MXFP8_BLOCK_THREADS = 256; +constexpr int MXFP8_MAX_TENSORS = 24; +constexpr int MXFP8_MAX_BLOCKS = 320; + +struct MXFP8TensorListMetadata { + void *addresses[8][MXFP8_MAX_TENSORS]; + int sizes[MXFP8_MAX_TENSORS]; + int rows[MXFP8_MAX_TENSORS]; + int cols[MXFP8_MAX_TENSORS]; + uint8_t fp8_dtype[MXFP8_MAX_TENSORS]; + unsigned char block_to_tensor[MXFP8_MAX_BLOCKS]; + int block_to_tile[MXFP8_MAX_BLOCKS]; + int start_tensor_this_launch; +}; + template __global__ void multi_tensor_apply_kernel(int64_t chunk_size, volatile int *noop_flag, T tl, U callable, ArgTypes... args) { @@ -113,3 +130,80 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, } } } + +template +void multi_tensor_apply_mxfp8(int64_t chunk_size, const transformer_engine::Tensor &noop_flag, + std::vector> tensor_lists, + uint8_t fp8_dtype, cudaStream_t stream, ArgTypes... args) { + constexpr size_t kNumTensorLists = 8; + NVTE_CHECK(tensor_lists.size() == kNumTensorLists, + "Expected 8 tensor lists for MXFP8, but found ", tensor_lists.size()); + + const size_t num_tensors_per_list = tensor_lists[0].size(); + if (num_tensors_per_list == 0) { + return; + } + for (size_t i = 1; i < tensor_lists.size(); ++i) { + NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i, + " has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list); + } + + MXFP8TensorListMetadata tl; + tl.start_tensor_this_launch = 0; + int loc_block_info = 0; + int loc_tensor_info = 0; + + for (size_t t = 0; t < num_tensors_per_list; ++t) { + + const auto &g = tensor_lists[0][t]; + const auto &rowwise_data = tensor_lists[4][t]; + const auto &colwise_data = tensor_lists[5][t]; + + const int rows_val = static_cast(rowwise_data->data.shape[0]); + const int cols_val = static_cast(rowwise_data->data.shape[1]); + + tl.sizes[loc_tensor_info] = g->numel(); + tl.rows[loc_tensor_info] = rows_val; + tl.cols[loc_tensor_info] = cols_val; + tl.fp8_dtype[loc_tensor_info] = fp8_dtype; + + for (int d = 0; d < kNumTensorLists; ++d) { + tl.addresses[d][loc_tensor_info] = tensor_lists[d][t]->data.dptr; + } + loc_tensor_info++; + + const int tiles_y = (rows_val + MXFP8_TILE - 1) / MXFP8_TILE; + const int tiles_x = (cols_val + MXFP8_TILE - 1) / MXFP8_TILE; + const int tiles_this_tensor = tiles_y * tiles_x; + + for (int tile = 0; tile < tiles_this_tensor; ++tile) { + tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tl.block_to_tile[loc_block_info] = tile; + loc_block_info++; + + const bool blocks_full = (loc_block_info == MXFP8_MAX_BLOCKS); + const bool tensors_full = + (loc_tensor_info == MXFP8_MAX_TENSORS && tile == tiles_this_tensor - 1); + const bool last_tile = (t == num_tensors_per_list - 1 && tile == tiles_this_tensor - 1); + if (blocks_full || tensors_full || last_tile) { + Kernel<<>>( + chunk_size, reinterpret_cast(noop_flag.data.dptr), tl, args...); + NVTE_CHECK_CUDA(cudaGetLastError()); + loc_block_info = 0; + if (tile == tiles_this_tensor - 1) { + loc_tensor_info = 0; + tl.start_tensor_this_launch = t + 1; + } else { + tl.rows[0] = tl.rows[loc_tensor_info - 1]; + tl.cols[0] = tl.cols[loc_tensor_info - 1]; + tl.fp8_dtype[0] = tl.fp8_dtype[loc_tensor_info - 1]; + for (int d = 0; d < kNumTensorLists; ++d) { + tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1]; + } + loc_tensor_info = 1; + tl.start_tensor_this_launch = t; + } + } + } + } +} diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 1c5116a8da..65e2c54d67 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -517,6 +517,12 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, const int step, const int mode, const int bias_correction, const float weight_decay, DType fp8_dtype); +void multi_tensor_adam_mxfp8_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, const float lr, + const float beta1, const float beta2, const float epsilon, + const int step, const int mode, const int bias_correction, + const float weight_decay, DType fp8_dtype); + void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::Tensor lr, const float beta1, const float beta2, diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp index 145e1d4b40..01a21d44bb 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp @@ -5,6 +5,7 @@ ************************************************************************/ #include "../../extensions.h" +#include "pybind.h" namespace transformer_engine::pytorch { @@ -51,6 +52,25 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, at::cuda::getCurrentCUDAStream()); } +void multi_tensor_adam_mxfp8_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, const float lr, + const float beta1, const float beta2, const float epsilon, + const int step, const int mode, const int bias_correction, + const float weight_decay, DType fp8_dtype) { + auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); + auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = + makeTransformerEngineTensorList(tensor_lists); + + NVTE_CHECK(num_lists == 8, + "Expected 8 tensor lists (g, p_master, m, v, rowwise_data, colwise_data, " + "rowwise_scale_inv, colwise_scale_inv), but found ", + num_lists); + nvte_multi_tensor_adam_mxfp8_cuda( + chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, + static_cast(fp8_dtype), lr, beta1, beta2, epsilon, step, mode, bias_correction, + weight_decay, at::cuda::getCurrentCUDAStream()); +} + void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::Tensor lr, const float beta1, const float beta2, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index c590a3c9e2..6def07b08e 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -525,6 +525,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("multi_tensor_adam_fp8", &transformer_engine::pytorch::multi_tensor_adam_fp8_cuda, "Compute and apply gradient update to parameters for Adam optimizer", py::call_guard()); + m.def("multi_tensor_adam_mxfp8", &transformer_engine::pytorch::multi_tensor_adam_mxfp8_cuda, + "Compute and apply gradient update to parameters for Adam optimizer"); m.def("multi_tensor_adam_capturable", &transformer_engine::pytorch::multi_tensor_adam_capturable_cuda, "Compute and apply gradient update to parameters for Adam optimizer with CUDA graph " diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index bcfd2bef19..f4ab2e7c37 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -14,6 +14,7 @@ from torch.distributed._tensor import DTensor import transformer_engine_torch as tex from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.quantized_tensor import QuantizedTensor from .multi_tensor_apply import multi_tensor_applier @@ -189,6 +190,7 @@ def __init__( self.multi_tensor_adam = tex.multi_tensor_adam self.multi_tensor_adam_param_remainder = tex.multi_tensor_adam_param_remainder self.multi_tensor_adam_fp8 = tex.multi_tensor_adam_fp8 + self.multi_tensor_adam_mxfp8 = tex.multi_tensor_adam_mxfp8 self.multi_tensor_adam_capturable = tex.multi_tensor_adam_capturable self.multi_tensor_adam_capturable_master = tex.multi_tensor_adam_capturable_master @@ -544,18 +546,27 @@ def step(self, closure=None, grad_scaler=None): # create lists for multi-tensor apply p_main_of_fp8_model = [] p_main_of_f16_model = [] + p_main_of_mxfp8_model = [] g_of_fp8_model = [] g_of_f16_model = [] g_of_f32_model = [] + g_of_mxfp8_model = [] m_of_fp8_model = [] m_of_f16_model = [] m_of_f32_model = [] + m_of_mxfp8_model = [] v_of_fp8_model = [] v_of_f16_model = [] v_of_f32_model = [] + v_of_mxfp8_model = [] p_fp8_model = [] p_f16_model = [] p_f32_model = [] + # mxfp8 meta + p_mxfp8_rowwise = [] + p_mxfp8_colwise = [] + p_mxfp8_rowwise_scale_inv = [] + p_mxfp8_colwise_scale_inv = [] # fp8 meta scales = [] amaxes = [] @@ -623,10 +634,30 @@ def step(self, closure=None, grad_scaler=None): g_of_fp8_model.append(p_grad.data) m_of_fp8_model.append(unscaled_state["exp_avg"]) v_of_fp8_model.append(unscaled_state["exp_avg_sq"]) + elif isinstance(p, MXFP8Tensor) or ( + isinstance(p, DTensor) and isinstance(p._local_tensor, MXFP8Tensor) + ): + p = p._local_tensor if isinstance(p, DTensor) else p + if p._rowwise_data is None or p._columnwise_data is None: + raise RuntimeError("MXFP8Tensor does not have one of rowwise/columnwise data.") + if self.capturable: + raise RuntimeError( + "FusedAdam does not support MXFP8 model weights with capturable=True." + ) + if self.master_weights: + p_main_of_mxfp8_model.append(unscaled_state["master_param"].data) + g_of_mxfp8_model.append(p_grad.data) + m_of_mxfp8_model.append(unscaled_state["exp_avg"]) + v_of_mxfp8_model.append(unscaled_state["exp_avg_sq"]) + p_mxfp8_rowwise.append(p._rowwise_data) + p_mxfp8_colwise.append(p._columnwise_data) + p_mxfp8_rowwise_scale_inv.append(p._rowwise_scale_inv) + p_mxfp8_colwise_scale_inv.append(p._columnwise_scale_inv) + out_dtype = p._fp8_dtype elif isinstance(p, QuantizedTensor) or ( isinstance(p, DTensor) and isinstance(p._local_tensor, QuantizedTensor) ): - # Block-scaling quantized params (MXFP8Tensor, Float8BlockwiseQTensor, + # Block-scaling quantized params (Float8BlockwiseQTensor, # NVFP4Tensor). Operate on FP32 master weights, requantize back after # Adam update. # Note: a fused Adam+requantize kernel (like multi_tensor_adam_fp8 @@ -797,6 +828,18 @@ def apply_multi_tensor_adam(adam_func, tensor_lists, inv_scale=None, out_dtype=N scale_invs, ] apply_multi_tensor_adam(self.multi_tensor_adam_fp8, tensor_lists, out_dtype) + if len(p_mxfp8_rowwise) > 0 and len(p_mxfp8_colwise) > 0: + tensor_lists = [ + g_of_mxfp8_model, + p_main_of_mxfp8_model, + m_of_mxfp8_model, + v_of_mxfp8_model, + p_mxfp8_rowwise, + p_mxfp8_colwise, + p_mxfp8_rowwise_scale_inv, + p_mxfp8_colwise_scale_inv, + ] + apply_multi_tensor_adam(self.multi_tensor_adam_mxfp8, tensor_lists, out_dtype) if len(p_f32_model) > 0: tensor_lists = [ g_of_f32_model, From 5b733fc8ea8ece8b2906bbaf22cbc9da897010b3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Mar 2026 16:35:22 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../operator/test_multi_tensor_adam_mxfp8.cu | 2 +- .../distributed/run_fsdp2_fused_adam.py | 2 ++ .../include/transformer_engine/multi_tensor.h | 12 ++++--- .../common/multi_tensor/adam.cu | 35 ++++++++++--------- .../multi_tensor/multi_tensor_apply.cuh | 1 - transformer_engine/pytorch/csrc/extensions.h | 8 ++--- .../csrc/extensions/multi_tensor/adam.cpp | 16 ++++----- .../pytorch/optimizers/fused_adam.py | 4 ++- 8 files changed, 43 insertions(+), 37 deletions(-) diff --git a/tests/cpp/operator/test_multi_tensor_adam_mxfp8.cu b/tests/cpp/operator/test_multi_tensor_adam_mxfp8.cu index 470917580f..d0eed33781 100644 --- a/tests/cpp/operator/test_multi_tensor_adam_mxfp8.cu +++ b/tests/cpp/operator/test_multi_tensor_adam_mxfp8.cu @@ -170,7 +170,7 @@ void run_mxfp8_adam_test(DType fp8_dtype) { nvte_multi_tensor_adam_mxfp8_cuda(65536, noop.data(), list_ptrs.data(), lists.size(), lists[0].size(), static_cast(fp8_dtype), lr, beta1, beta2, eps, step, mode, bias_correction, weight_decay, 0); - + std::vector> ref_lists(4); for (size_t i = 0; i < tensor_count; ++i) { ref_lists[0].push_back(g[i].data()); diff --git a/tests/pytorch/distributed/run_fsdp2_fused_adam.py b/tests/pytorch/distributed/run_fsdp2_fused_adam.py index 34764d4e0a..44884934a4 100644 --- a/tests/pytorch/distributed/run_fsdp2_fused_adam.py +++ b/tests/pytorch/distributed/run_fsdp2_fused_adam.py @@ -38,10 +38,12 @@ def get_recipe_from_string(recipe): NUM_STEPS = 3 LOCAL_RANK = None + def dist_print(msg): if LOCAL_RANK == 0: print(msg) + def save_custom_attrs(module): custom_attrs = {} for name, param in module.named_parameters(): diff --git a/transformer_engine/common/include/transformer_engine/multi_tensor.h b/transformer_engine/common/include/transformer_engine/multi_tensor.h index 90c87b166e..a50be47b95 100644 --- a/transformer_engine/common/include/transformer_engine/multi_tensor.h +++ b/transformer_engine/common/include/transformer_engine/multi_tensor.h @@ -178,11 +178,13 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, * \param[in] weight_decay L2 penalty for weight decay. * \param[in] stream CUDA stream used for this operation. */ -void nvte_multi_tensor_adam_mxfp8_cuda( - int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, - const size_t num_tensor_lists, const size_t num_tensors_per_list, const NVTEDType fp8_dtype, - const float lr, const float beta1, const float beta2, const float epsilon, const int step, - const int mode, const int bias_correction, const float weight_decay, cudaStream_t stream); +void nvte_multi_tensor_adam_mxfp8_cuda(int chunk_size, NVTETensor noop_flag, + NVTETensor **tensor_lists, const size_t num_tensor_lists, + const size_t num_tensors_per_list, const NVTEDType fp8_dtype, + const float lr, const float beta1, const float beta2, + const float epsilon, const int step, const int mode, + const int bias_correction, const float weight_decay, + cudaStream_t stream); /*! \brief Compute and apply gradient update to parameters for Adam optimizer * with CUDA graph support and LR scheduling. diff --git a/transformer_engine/common/multi_tensor/adam.cu b/transformer_engine/common/multi_tensor/adam.cu index fa75c645f3..36ff7a84ab 100644 --- a/transformer_engine/common/multi_tensor/adam.cu +++ b/transformer_engine/common/multi_tensor/adam.cu @@ -4,16 +4,17 @@ * See LICENSE for license information. ************************************************************************/ -#include #include #include #include #include +#include + #include "../common.h" #include "../util/math.h" -#include "../utils.cuh" #include "../util/ptx.cuh" +#include "../utils.cuh" #include "multi_tensor_apply.cuh" namespace transformer_engine { @@ -599,10 +600,10 @@ __device__ __forceinline__ float fp8_max_norm_rcp(uint8_t fp8_dtype) { } template -__global__ void adam_mxfp8_fused_kernel( - int64_t chunk_size, volatile int *noop_gmem, MXFP8TensorListMetadata tl, float beta1, - float beta2, float beta1_correction, float beta2_correction, float epsilon, float lr, int mode, - float weight_decay) { +__global__ void adam_mxfp8_fused_kernel(int64_t chunk_size, volatile int *noop_gmem, + MXFP8TensorListMetadata tl, float beta1, float beta2, + float beta1_correction, float beta2_correction, + float epsilon, float lr, int mode, float weight_decay) { // Stage 0: optional early-exit if a noop flag is set. if (noop_gmem != nullptr && *noop_gmem == 1) { return; @@ -638,8 +639,7 @@ __global__ void adam_mxfp8_fused_kernel( const int64_t unpadded_scales_X_rowwise = (cols_val + MXFP8_TILE - 1) / MXFP8_TILE; constexpr int64_t kRowwiseScaleAlign = 4; - const int64_t row_stride = - DIVUP_TO_MULTIPLE(unpadded_scales_X_rowwise, kRowwiseScaleAlign); + const int64_t row_stride = DIVUP_TO_MULTIPLE(unpadded_scales_X_rowwise, kRowwiseScaleAlign); constexpr int64_t kColwiseScaleAlign = 128; const int64_t col_stride = DIVUP_TO_MULTIPLE(cols_val, kColwiseScaleAlign); const uint8_t dtype = tl.fp8_dtype[tensor_idx]; @@ -670,9 +670,9 @@ __global__ void adam_mxfp8_fused_kernel( float r_v = static_cast(v[idx]); // Stage 4: apply Adam update in FP32 and write back updated p/m/v. - transformer_engine::multi_tensor_adam::adam_update( - r_g, r_p, r_m, r_v, beta1, beta2, beta1_correction, beta2_correction, epsilon, lr, - adam_mode, weight_decay); + transformer_engine::multi_tensor_adam::adam_update(r_g, r_p, r_m, r_v, beta1, beta2, + beta1_correction, beta2_correction, epsilon, + lr, adam_mode, weight_decay); p[idx] = static_cast(r_p); m[idx] = static_cast(r_m); @@ -973,7 +973,6 @@ inline void check_tensor_list_sizes(const std::vector> &te } } - void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, std::vector> tensor_lists, const float lr, const float beta1, const float beta2, const float epsilon, @@ -1273,11 +1272,13 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, epsilon, step, mode, bias_correction, weight_decay, static_cast(fp8_dtype), stream); } -void nvte_multi_tensor_adam_mxfp8_cuda( - int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, - const size_t num_tensor_lists, const size_t num_tensors_per_list, const NVTEDType fp8_dtype, - const float lr, const float beta1, const float beta2, const float epsilon, const int step, - const int mode, const int bias_correction, const float weight_decay, cudaStream_t stream) { +void nvte_multi_tensor_adam_mxfp8_cuda(int chunk_size, NVTETensor noop_flag, + NVTETensor **tensor_lists, const size_t num_tensor_lists, + const size_t num_tensors_per_list, const NVTEDType fp8_dtype, + const float lr, const float beta1, const float beta2, + const float epsilon, const int step, const int mode, + const int bias_correction, const float weight_decay, + cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_adam_mxfp8_cuda); using namespace transformer_engine; multi_tensor_adam::multi_tensor_adam_mxfp8_cuda( diff --git a/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh b/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh index c334f3908e..fdef340a0e 100644 --- a/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh +++ b/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh @@ -154,7 +154,6 @@ void multi_tensor_apply_mxfp8(int64_t chunk_size, const transformer_engine::Tens int loc_tensor_info = 0; for (size_t t = 0; t < num_tensors_per_list; ++t) { - const auto &g = tensor_lists[0][t]; const auto &rowwise_data = tensor_lists[4][t]; const auto &colwise_data = tensor_lists[5][t]; diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 65e2c54d67..5096cfb252 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -518,10 +518,10 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, const float weight_decay, DType fp8_dtype); void multi_tensor_adam_mxfp8_cuda(int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, const float lr, - const float beta1, const float beta2, const float epsilon, - const int step, const int mode, const int bias_correction, - const float weight_decay, DType fp8_dtype); + std::vector> tensor_lists, const float lr, + const float beta1, const float beta2, const float epsilon, + const int step, const int mode, const int bias_correction, + const float weight_decay, DType fp8_dtype); void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp index 01a21d44bb..677165b063 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp @@ -53,10 +53,10 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, } void multi_tensor_adam_mxfp8_cuda(int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, const float lr, - const float beta1, const float beta2, const float epsilon, - const int step, const int mode, const int bias_correction, - const float weight_decay, DType fp8_dtype) { + std::vector> tensor_lists, const float lr, + const float beta1, const float beta2, const float epsilon, + const int step, const int mode, const int bias_correction, + const float weight_decay, DType fp8_dtype) { auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = makeTransformerEngineTensorList(tensor_lists); @@ -65,10 +65,10 @@ void multi_tensor_adam_mxfp8_cuda(int chunk_size, at::Tensor noop_flag, "Expected 8 tensor lists (g, p_master, m, v, rowwise_data, colwise_data, " "rowwise_scale_inv, colwise_scale_inv), but found ", num_lists); - nvte_multi_tensor_adam_mxfp8_cuda( - chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, - static_cast(fp8_dtype), lr, beta1, beta2, epsilon, step, mode, bias_correction, - weight_decay, at::cuda::getCurrentCUDAStream()); + nvte_multi_tensor_adam_mxfp8_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), + num_lists, num_tensors, static_cast(fp8_dtype), lr, + beta1, beta2, epsilon, step, mode, bias_correction, + weight_decay, at::cuda::getCurrentCUDAStream()); } void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index f4ab2e7c37..967d3846dd 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -639,7 +639,9 @@ def step(self, closure=None, grad_scaler=None): ): p = p._local_tensor if isinstance(p, DTensor) else p if p._rowwise_data is None or p._columnwise_data is None: - raise RuntimeError("MXFP8Tensor does not have one of rowwise/columnwise data.") + raise RuntimeError( + "MXFP8Tensor does not have one of rowwise/columnwise data." + ) if self.capturable: raise RuntimeError( "FusedAdam does not support MXFP8 model weights with capturable=True." From 175e43edaa3f80cda4a7e83890a2178ce804c0b1 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Wed, 18 Mar 2026 16:40:42 +0000 Subject: [PATCH 3/6] xfail isnt fixe yet Signed-off-by: Varun Thumbe --- tests/pytorch/distributed/test_torch_fsdp2.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py index 6d7ae4d7bb..02e45d99cb 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2.py +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -224,6 +224,11 @@ def test_fsdp2_dcp_output_parity_async(fp_recipe): @pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs") def test_fsdp2_safetensors_fp32_export(fp_recipe): """Export FP32 model from optimizer master weights to safetensors.""" + if fp_recipe == "MXFP8BlockScaling": + pytest.xfail( + "MXFP8BlockScaling: FusedAdam CUDA kernel does not support " + "MXFP8 quantized tensors, causing illegal memory access" + ) _run_fused_adam_test("safetensors_fp32_export", fp_recipe) From 2ac16293beb9039d4af97365ab1f4959de4877fd Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Wed, 18 Mar 2026 09:43:33 -0700 Subject: [PATCH 4/6] Apply suggestion from @greptile-apps[bot] Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 --- transformer_engine/pytorch/csrc/extensions/pybind.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 6def07b08e..bb2a8b6227 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -526,7 +526,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Compute and apply gradient update to parameters for Adam optimizer", py::call_guard()); m.def("multi_tensor_adam_mxfp8", &transformer_engine::pytorch::multi_tensor_adam_mxfp8_cuda, - "Compute and apply gradient update to parameters for Adam optimizer"); + "Compute and apply gradient update to parameters for Adam optimizer", + py::call_guard()); m.def("multi_tensor_adam_capturable", &transformer_engine::pytorch::multi_tensor_adam_capturable_cuda, "Compute and apply gradient update to parameters for Adam optimizer with CUDA graph " From 6b916f56efdc6972d3c27ba1ba50f2af4159e406 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Wed, 18 Mar 2026 16:46:41 +0000 Subject: [PATCH 5/6] address review comment Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/optimizers/fused_adam.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 967d3846dd..3aed89309b 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -646,8 +646,12 @@ def step(self, closure=None, grad_scaler=None): raise RuntimeError( "FusedAdam does not support MXFP8 model weights with capturable=True." ) - if self.master_weights: - p_main_of_mxfp8_model.append(unscaled_state["master_param"].data) + if not self.master_weights: + raise RuntimeError( + "FusedAdam without master_weights does not support " + "MXFP8 model weights. Use master_weights=True." + ) + p_main_of_mxfp8_model.append(unscaled_state["master_param"].data) g_of_mxfp8_model.append(p_grad.data) m_of_mxfp8_model.append(unscaled_state["exp_avg"]) v_of_mxfp8_model.append(unscaled_state["exp_avg_sq"]) From 31d0aa5933da9d6d59ea35cf32837273c1a41b01 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Wed, 18 Mar 2026 16:52:44 +0000 Subject: [PATCH 6/6] address review comments Signed-off-by: Varun Thumbe --- 3rdparty/cudnn-frontend | 2 +- transformer_engine/common/multi_tensor/multi_tensor_apply.cuh | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 8d19d3182b..d33027a41a 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 8d19d3182bfbc304046a15e9236bec9ff31511fc +Subproject commit d33027a41a93af9c85f089c6364ab415fce98982 diff --git a/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh b/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh index fdef340a0e..7a836c270b 100644 --- a/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh +++ b/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh @@ -43,7 +43,6 @@ constexpr int MXFP8_MAX_BLOCKS = 320; struct MXFP8TensorListMetadata { void *addresses[8][MXFP8_MAX_TENSORS]; - int sizes[MXFP8_MAX_TENSORS]; int rows[MXFP8_MAX_TENSORS]; int cols[MXFP8_MAX_TENSORS]; uint8_t fp8_dtype[MXFP8_MAX_TENSORS]; @@ -154,14 +153,11 @@ void multi_tensor_apply_mxfp8(int64_t chunk_size, const transformer_engine::Tens int loc_tensor_info = 0; for (size_t t = 0; t < num_tensors_per_list; ++t) { - const auto &g = tensor_lists[0][t]; const auto &rowwise_data = tensor_lists[4][t]; - const auto &colwise_data = tensor_lists[5][t]; const int rows_val = static_cast(rowwise_data->data.shape[0]); const int cols_val = static_cast(rowwise_data->data.shape[1]); - tl.sizes[loc_tensor_info] = g->numel(); tl.rows[loc_tensor_info] = rows_val; tl.cols[loc_tensor_info] = cols_val; tl.fp8_dtype[loc_tensor_info] = fp8_dtype;