From 210887f4f75adff909b21ed876c5769e43aa0b23 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Mon, 5 Jan 2026 03:04:03 -0600 Subject: [PATCH 1/8] Allow passing descales to fmha v3 kernel --- example/ck_tile/01_fmha/fmha_fwd.hpp | 6 + .../ops/fmha/kernel/fmha_fwd_v3_kernel.hpp | 111 +++++++++++++++--- 2 files changed, 103 insertions(+), 14 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 3ff4acfc156..cfa760f6c1e 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -731,6 +731,9 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args) return FmhaKernel::MakeKargs(args.q_ptr, args.k_ptr, args.v_ptr, + args.q_descale_ptr, + args.k_descale_ptr, + args.v_descale_ptr, nullptr, // lse_ptr args.o_ptr, args.seqstart_q_ptr, @@ -764,6 +767,9 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args) return FmhaKernel::MakeKargs(args.q_ptr, args.k_ptr, args.v_ptr, + args.q_descale_ptr, + args.k_descale_ptr, + args.v_descale_ptr, nullptr, // lse_ptr args.o_ptr, args.seqlen_q, diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index 6fe1de634d9..f82c5bc6526 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -27,6 +27,7 @@ struct FmhaFwdV3Kernel using QDataType = ck_tile::remove_cvref_t; using KDataType = ck_tile::remove_cvref_t; using VDataType = ck_tile::remove_cvref_t; + using PDataType = ck_tile::remove_cvref_t; using LSEDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; using SaccDataType = ck_tile::remove_cvref_t; @@ -38,6 +39,7 @@ struct FmhaFwdV3Kernel static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap; static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum; using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; @@ -118,11 +120,21 @@ struct FmhaFwdV3Kernel float logits_soft_cap_rcp; }; + struct FmhaFwdCommonQScaleKargs + { + const void* q_descale_ptr = nullptr; + const void* k_descale_ptr = nullptr; + const void* v_descale_ptr = nullptr; + }; + struct FmhaFwdBatchModeKargs : FmhaFwdCommonKargs, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -139,7 +151,10 @@ struct FmhaFwdV3Kernel : FmhaFwdCommonKargs, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -166,6 +181,9 @@ struct FmhaFwdV3Kernel MakeKargs(const void* q_ptr, const void* k_ptr, const void* v_ptr, + const void* q_descale_ptr, + const void* k_descale_ptr, + const void* v_descale_ptr, void* lse_ptr, void* o_ptr, ck_tile::index_t seqlen_q, @@ -218,6 +236,7 @@ struct FmhaFwdV3Kernel nhead_stride_o}, // args for common karg {}, // placeholder for mask {}, // placeholder for lse + {}, // placeholder for qscale {}, // placeholder for logits_soft_cap batch_stride_q, batch_stride_k, @@ -237,6 +256,12 @@ struct FmhaFwdV3Kernel kargs.nhead_stride_lse = nhead_stride_lse; kargs.batch_stride_lse = batch_stride_lse; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + } if constexpr(kHasLogitsSoftCap) { kargs.init_logits_soft_cap(logits_soft_cap); @@ -252,6 +277,9 @@ struct FmhaFwdV3Kernel MakeKargs(const void* q_ptr, const void* k_ptr, const void* v_ptr, + const void* q_descale_ptr, + const void* k_descale_ptr, + const void* v_descale_ptr, void* lse_ptr, void* o_ptr, const void* seqstart_q_ptr, @@ -301,6 +329,7 @@ struct FmhaFwdV3Kernel nhead_stride_o}, // args for common karg {}, // placeholder for mask {}, // placeholder for lse + {}, // placeholder for qscale {}, // placeholder for logits_soft_cap reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), @@ -319,6 +348,12 @@ struct FmhaFwdV3Kernel kargs.lse_ptr = lse_ptr; kargs.nhead_stride_lse = nhead_stride_lse; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + } if constexpr(kHasLogitsSoftCap) { kargs.init_logits_soft_cap(logits_soft_cap); @@ -640,32 +675,80 @@ struct FmhaFwdV3Kernel return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; }(); + const float scale_s = [&] { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) + { + float q_descale = *(reinterpret_cast(kargs.q_descale_ptr)); + float k_descale = *(reinterpret_cast(kargs.k_descale_ptr)); + return kargs.scale_s * q_descale * k_descale; + } + else + { + return kargs.scale_s; + } + }(); + AttentionVariant variant; const auto variant_params = [&] { if constexpr(kHasLogitsSoftCap) { return ck_tile::LogitsSoftCapParams{ - mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; + mask, scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; } else { - return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; + return ck_tile::StandardAttentionParams{mask, scale_s}; } }(); BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; auto o_acc_tile = [&]() { - return FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - lse_dram_window, - mask, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr); + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) + { + float v_descale = *(reinterpret_cast(kargs.v_descale_ptr)); + float scale_p = ck_tile::type_convert(ck_tile::numeric::max()); + float scale_o = v_descale / scale_p; + + auto o_acc_element_func = [&]() { + if constexpr(std::is_same_v) + return ck_tile::composes(ck_tile::saturates{}, + ck_tile::scales{scale_o}); + else + return ck_tile::scales{scale_o}; + }(); + + return FmhaPipeline{}(q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + lse_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + scales{scale_p}, // p_compute_element_func + o_acc_element_func, + mask, + scale_s, + variant, + variant_params, + block_indices, + smem_ptr); + } + else + { + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + lse_dram_window, + mask, + scale_s, + variant, + variant_params, + block_indices, + smem_ptr); + } }(); // O DRAM and O DRAM window From 8f973efed9b24c1e9377212f5babf5a3c6c4095e Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Thu, 8 Jan 2026 00:48:51 -0600 Subject: [PATCH 2/8] Allow enabling quantization scale feature for FMHA v3 --- include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index c25f57632fa..14dd9c8db27 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -300,7 +300,6 @@ struct BlockFmhaFwdV3Pipeline static constexpr auto QScaleEnum = Problem::QScaleEnum; static constexpr bool kSkipMinSeqlenQ = Problem::kSkipMinSeqlenQ; static_assert((BiasEnum == BlockAttentionBiasEnum::NO_BIAS && !kStoreLSE && !kHasDropout && - (QScaleEnum == ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE) && !kSkipMinSeqlenQ), "enable unsupported features"); From 5876cd86bbeaa7567f6d93f72434d12400807a42 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Thu, 8 Jan 2026 00:50:35 -0600 Subject: [PATCH 3/8] Add fp8 32x32x32 warp gemm (C-transposed) --- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 7 +++++++ include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp | 2 ++ 2 files changed, 9 insertions(+) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 7bcc9107da9..e86e5d9f86a 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -276,6 +276,13 @@ using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl, 2>>; +template +using WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed = + WarpGemmImpl, + 2, + AttrNumAccess>>; + using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl, 2>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index d6c21e88b56..940447cc22f 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -98,6 +98,8 @@ template<> struct Dispatcher { u // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed; }; From 941c7e67bbbf5f6bec7b681b5179a194fa7d396f Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Thu, 8 Jan 2026 00:52:30 -0600 Subject: [PATCH 4/8] Add fp8 QK block gemm config --- .../block_fmha_fwd_v3_pipeline_default_policy.hpp | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp index ce097b6741b..957e404b35b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp @@ -239,10 +239,18 @@ struct BlockFmhaV3PipelineDefaultPolicy typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0WarpTile>>; - constexpr auto warp_gemm = []() { - if constexpr(std::is_same_v && - std::is_same_v && + constexpr auto warp_gemm = [] { + if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) + { + constexpr index_t swizzle_factor = 4; + return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution< + swizzle_factor>{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) { /// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use /// WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution here From fa404ca9703c0930242f8e6f8fff72a4036bc072 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Thu, 8 Jan 2026 00:54:57 -0600 Subject: [PATCH 5/8] Add fp8 FMHA v3 instances --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 13 +++++++++++++ .../fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index dd65c0298b3..a1919e954fb 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1048,6 +1048,10 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: if (128, 128) in result.keys(): result[(128, 128)].append( FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip + elif dtype in cls._DT_FP8BF16: + if (128, 128) in result.keys(): + result[(128, 128)].append( + FmhaFwdTileSize(256, 64, 128, 128, 64, 128, 8, 1, 1, 8, 1, 1, 32, 32, 32, 32, 32, 32, -1)) # fmt: skip return result @classmethod @@ -1085,6 +1089,15 @@ def get_pipelines( pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip + elif dtype in cls._DT_FP8BF16: + # no need lse/dropout kernels + # qr_async_trload_v3 only supports (generic) causal mask + for logits, qscale, mask in itertools.product( + ["t", "f"], + ["no", "pertensor"], + ["no", "causal"], + ): + pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip return pipelines diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index 14dd9c8db27..29c52e2fd6b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -436,7 +436,7 @@ struct BlockFmhaFwdV3Pipeline kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); - static_assert(sizeof(SaccDataType) * kM0 * kN0 <= GetSmemSize()); + // static_assert(sizeof(SaccDataType) * kM0 * kN0 <= GetSmemSize()); auto s_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr)), MakeSimpleLdsDesc()); From c5e0c500401f9ab4601a038148e5f55c0558593a Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 14 Jan 2026 11:34:39 -0600 Subject: [PATCH 6/8] Fix fmha_fwd_v3() dispatch logic --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index a1919e954fb..4257b84fbfc 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -210,10 +210,12 @@ ((0 < args.window_size_left) or (0 < args.window_size_right)); const bool can_dispatch_v3 = (device_name.compare(0, 6, "gfx950") == 0) and - (traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and + (((traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and + (traits.qscale_type == quant_scale_enum::no_scale)) or + ((traits.data_type.compare("fp8bf16") == 0) and + (traits.qscale_type == quant_scale_enum::pertensor))) and traits.is_v_rowmajor and (traits.bias_type == bias_enum::no_bias) and - (not traits.has_lse) and (not traits.has_dropout) and - (traits.qscale_type == quant_scale_enum::no_scale) and (not is_swa) and + (not traits.has_lse) and (not traits.has_dropout) and (not is_swa) and (args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and (args.hdim_v == 128); if ({F_is_v3_enabled} and can_dispatch_v3) {{ return fmha_fwd_v3(traits, args, config); From 1be789fc69c6c991030b205b32cf15c9e13c671a Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 20 Jan 2026 23:58:37 -0600 Subject: [PATCH 7/8] Add missing P tile fp32 -> fp8 converison logic --- .../ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index 29c52e2fd6b..09ca59be420 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -853,12 +853,21 @@ struct BlockFmhaFwdV3Pipeline sp(sp_reg_idx).p.thread_buf_[idx] = casted.x; sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y; } - else + else if constexpr(std::is_same_v) { auto casted = detail::cvt_pk_bf16_f32(x, y); sp(sp_reg_idx).p.thread_buf_[idx] = casted.x; sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y; } + else if constexpr(std::is_same_v) + { + sp(sp_reg_idx).p.thread_buf_[idx] = type_convert(x); + sp(sp_reg_idx).p.thread_buf_[idx + 1] = type_convert(y); + } + else + { + static_assert(false, "unsupported data type for P"); + } }); /// Note: Place fmha_alu1() at the end of the phase. The surrounding inline assembly From 35af63aeed70d7347cb8ec8970cc9bcb3c9f2d4f Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 21 Jan 2026 00:21:20 -0600 Subject: [PATCH 8/8] Update functor creation logics --- .../ops/fmha/kernel/fmha_fwd_v3_kernel.hpp | 42 ++++++++++--------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index f82c5bc6526..c2e0fe0d4cc 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -712,29 +712,31 @@ struct FmhaFwdV3Kernel auto o_acc_element_func = [&]() { if constexpr(std::is_same_v) - return ck_tile::composes(ck_tile::saturates{}, - ck_tile::scales{scale_o}); + return make_composes( + ck_tile::saturates{}, + ck_tile::scales>{scale_o}); else - return ck_tile::scales{scale_o}; + return ck_tile::scales>{scale_o}; }(); - return FmhaPipeline{}(q_dram_window, - identity{}, // q_element_func - k_dram_window, - identity{}, // k_element_func - v_dram_window, - identity{}, // v_element_func - lse_dram_window, - identity{}, // lse_element_func - identity{}, // s_acc_element_func - scales{scale_p}, // p_compute_element_func - o_acc_element_func, - mask, - scale_s, - variant, - variant_params, - block_indices, - smem_ptr); + return FmhaPipeline{}( + q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + lse_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + scales>{scale_p}, // p_compute_element_func + o_acc_element_func, + mask, + scale_s, + variant, + variant_params, + block_indices, + smem_ptr); } else {