diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000000..697b12bbde8 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,112 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Build Commands + +```bash +# Development build (from build directory) +mkdir build && cd build +../script/cmake-ck-dev.sh .. # Uses gfx908;gfx90a;gfx942 by default +../script/cmake-ck-dev.sh .. gfx90a # Specific GPU target +make -j32 # Limit threads; ~2GB RAM per thread + +# Manual cmake configuration +cmake -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -D CMAKE_BUILD_TYPE=Release \ + -D GPU_TARGETS="gfx90a" \ + .. + +# Common build flags +-D DTYPES="fp16;fp32" # Build only specific data types (speeds up build) +-D CK_USE_FP8_ON_UNSUPPORTED_ARCH=ON # Enable FP8 on MI100/MI200 +-D BUILD_DEV=ON # Development mode with verbose errors +``` + +## Testing + +```bash +make -j check # Build and run all tests +make -j smoke # Quick tests only (<30s each) +make -j regression # Long-running tests (>=30s each) + +# CTest direct commands +ctest --output-on-failure -L "SMOKE_TEST" +ctest --output-on-failure -L "REGRESSION_TEST" + +# Run single test executable +./test/gemm/test_gemm_fp16 +``` + +## Architecture Overview + +### Two Programming Models + +1. **CK (Legacy)** - `/include/ck/`: Traditional template-based approach +2. **CK Tile (Modern)** - `/include/ck_tile/`: Newer unified model with simpler component structure + +CK Tile is independently maintained and preferred for new development. Include single headers like `#include "ck_tile/core.hpp"` or `#include "ck_tile/ops/fmha.hpp"`. + +### Four-Layer Architecture + +1. **Templated Tile Operators** - High-level tile abstractions +2. **Templated Kernel and Invoker** - Generic kernel templates +3. **Instantiated Kernel and Invoker** - Concrete implementations per data type/hardware +4. **Client API** - User-facing interfaces + +### Core Concepts + +- **Tile-Based Programming**: Operations work on tiles (sub-regions) of tensors +- **Tensor Coordinate Transformation**: Maps ND tensor indices to 1D memory offsets through transform primitives (merge/unmerge/embed) +- **Distributed Tensor**: Describes how threads collaboratively process a tensor tile + +### CK Tile Components + +- `core/` - Basic structures: array, tuple, sequence, numeric types, coordinate transforms +- `host/` - Kernel launch utilities, device buffers, reference implementations +- `ops/` - Operation implementations (gemm, fmha, reduce) +- `ref/` - CPU/GPU reference implementations for validation + +### Instance Organization + +`/library/src/tensor_operation_instance/gpu/` contains 100+ operation variants organized by: +- Operation type (gemm, batched_gemm, conv, etc.) +- Data type (fp16, fp32, fp8, bf16, int8) +- Instruction set (xdl, wmma, dl, dpp) + +## Key Directories + +- `/example/ck_tile/01_fmha/` - Flash Multi-Head Attention (main FMHA implementation) +- `/example/01_gemm/` - Foundational GEMM example +- `/profiler/` - Performance benchmarking tool (`make -j ckProfiler`) +- `/test/` - 68+ test directories with smoke/regression classification + +## Supported Hardware + +- **MI Series**: gfx908 (MI100), gfx90a (MI200), gfx942/gfx950 (MI300) +- **NAVI Series**: gfx1030-1032 (NAVI2x), gfx1100-1102 (NAVI3x), gfx1200-1201 (RDNA4) + +## Code Style + +Pre-commit hooks enforce formatting. Install with: +```bash +sudo script/install_precommit.sh +``` + +Bypass temporarily with `git commit --no-verify`. + +## Profiling + +```bash +make -j ckProfiler +./profiler/ckProfiler gemm_xdl -M 4096 -N 4096 -K 4096 -A fp16 -B fp16 -C fp16 +``` + +## sccache (Faster Rebuilds) + +```bash +sccache --start-server +cmake ... -DCMAKE_HIP_COMPILER_LAUNCHER=sccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=sccache +``` diff --git a/GEMINI.md b/GEMINI.md new file mode 100644 index 00000000000..e69de29bb2d diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index dd65c0298b3..4257b84fbfc 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -210,10 +210,12 @@ ((0 < args.window_size_left) or (0 < args.window_size_right)); const bool can_dispatch_v3 = (device_name.compare(0, 6, "gfx950") == 0) and - (traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and + (((traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and + (traits.qscale_type == quant_scale_enum::no_scale)) or + ((traits.data_type.compare("fp8bf16") == 0) and + (traits.qscale_type == quant_scale_enum::pertensor))) and traits.is_v_rowmajor and (traits.bias_type == bias_enum::no_bias) and - (not traits.has_lse) and (not traits.has_dropout) and - (traits.qscale_type == quant_scale_enum::no_scale) and (not is_swa) and + (not traits.has_lse) and (not traits.has_dropout) and (not is_swa) and (args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and (args.hdim_v == 128); if ({F_is_v3_enabled} and can_dispatch_v3) {{ return fmha_fwd_v3(traits, args, config); @@ -1048,6 +1050,10 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: if (128, 128) in result.keys(): result[(128, 128)].append( FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip + elif dtype in cls._DT_FP8BF16: + if (128, 128) in result.keys(): + result[(128, 128)].append( + FmhaFwdTileSize(256, 64, 128, 128, 64, 128, 8, 1, 1, 8, 1, 1, 32, 32, 32, 32, 32, 32, -1)) # fmt: skip return result @classmethod @@ -1085,6 +1091,15 @@ def get_pipelines( pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip + elif dtype in cls._DT_FP8BF16: + # no need lse/dropout kernels + # qr_async_trload_v3 only supports (generic) causal mask + for logits, qscale, mask in itertools.product( + ["t", "f"], + ["no", "pertensor"], + ["no", "causal"], + ): + pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip return pipelines diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index fdd720fd75b..d5b06d6e5a3 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -738,6 +738,9 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args) return FmhaKernel::MakeKargs(args.q_ptr, args.k_ptr, args.v_ptr, + args.q_descale_ptr, + args.k_descale_ptr, + args.v_descale_ptr, nullptr, // lse_ptr args.o_ptr, args.seqstart_q_ptr, @@ -771,6 +774,9 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args) return FmhaKernel::MakeKargs(args.q_ptr, args.k_ptr, args.v_ptr, + args.q_descale_ptr, + args.k_descale_ptr, + args.v_descale_ptr, nullptr, // lse_ptr args.o_ptr, args.seqlen_q, diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index 6fe1de634d9..c2e0fe0d4cc 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -27,6 +27,7 @@ struct FmhaFwdV3Kernel using QDataType = ck_tile::remove_cvref_t; using KDataType = ck_tile::remove_cvref_t; using VDataType = ck_tile::remove_cvref_t; + using PDataType = ck_tile::remove_cvref_t; using LSEDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; using SaccDataType = ck_tile::remove_cvref_t; @@ -38,6 +39,7 @@ struct FmhaFwdV3Kernel static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap; static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum; using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; @@ -118,11 +120,21 @@ struct FmhaFwdV3Kernel float logits_soft_cap_rcp; }; + struct FmhaFwdCommonQScaleKargs + { + const void* q_descale_ptr = nullptr; + const void* k_descale_ptr = nullptr; + const void* v_descale_ptr = nullptr; + }; + struct FmhaFwdBatchModeKargs : FmhaFwdCommonKargs, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -139,7 +151,10 @@ struct FmhaFwdV3Kernel : FmhaFwdCommonKargs, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -166,6 +181,9 @@ struct FmhaFwdV3Kernel MakeKargs(const void* q_ptr, const void* k_ptr, const void* v_ptr, + const void* q_descale_ptr, + const void* k_descale_ptr, + const void* v_descale_ptr, void* lse_ptr, void* o_ptr, ck_tile::index_t seqlen_q, @@ -218,6 +236,7 @@ struct FmhaFwdV3Kernel nhead_stride_o}, // args for common karg {}, // placeholder for mask {}, // placeholder for lse + {}, // placeholder for qscale {}, // placeholder for logits_soft_cap batch_stride_q, batch_stride_k, @@ -237,6 +256,12 @@ struct FmhaFwdV3Kernel kargs.nhead_stride_lse = nhead_stride_lse; kargs.batch_stride_lse = batch_stride_lse; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + } if constexpr(kHasLogitsSoftCap) { kargs.init_logits_soft_cap(logits_soft_cap); @@ -252,6 +277,9 @@ struct FmhaFwdV3Kernel MakeKargs(const void* q_ptr, const void* k_ptr, const void* v_ptr, + const void* q_descale_ptr, + const void* k_descale_ptr, + const void* v_descale_ptr, void* lse_ptr, void* o_ptr, const void* seqstart_q_ptr, @@ -301,6 +329,7 @@ struct FmhaFwdV3Kernel nhead_stride_o}, // args for common karg {}, // placeholder for mask {}, // placeholder for lse + {}, // placeholder for qscale {}, // placeholder for logits_soft_cap reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), @@ -319,6 +348,12 @@ struct FmhaFwdV3Kernel kargs.lse_ptr = lse_ptr; kargs.nhead_stride_lse = nhead_stride_lse; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + } if constexpr(kHasLogitsSoftCap) { kargs.init_logits_soft_cap(logits_soft_cap); @@ -640,32 +675,82 @@ struct FmhaFwdV3Kernel return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; }(); + const float scale_s = [&] { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) + { + float q_descale = *(reinterpret_cast(kargs.q_descale_ptr)); + float k_descale = *(reinterpret_cast(kargs.k_descale_ptr)); + return kargs.scale_s * q_descale * k_descale; + } + else + { + return kargs.scale_s; + } + }(); + AttentionVariant variant; const auto variant_params = [&] { if constexpr(kHasLogitsSoftCap) { return ck_tile::LogitsSoftCapParams{ - mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; + mask, scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; } else { - return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; + return ck_tile::StandardAttentionParams{mask, scale_s}; } }(); BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; auto o_acc_tile = [&]() { - return FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - lse_dram_window, - mask, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr); + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) + { + float v_descale = *(reinterpret_cast(kargs.v_descale_ptr)); + float scale_p = ck_tile::type_convert(ck_tile::numeric::max()); + float scale_o = v_descale / scale_p; + + auto o_acc_element_func = [&]() { + if constexpr(std::is_same_v) + return make_composes( + ck_tile::saturates{}, + ck_tile::scales>{scale_o}); + else + return ck_tile::scales>{scale_o}; + }(); + + return FmhaPipeline{}( + q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + lse_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + scales>{scale_p}, // p_compute_element_func + o_acc_element_func, + mask, + scale_s, + variant, + variant_params, + block_indices, + smem_ptr); + } + else + { + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + lse_dram_window, + mask, + scale_s, + variant, + variant_params, + block_indices, + smem_ptr); + } }(); // O DRAM and O DRAM window diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index c25f57632fa..09ca59be420 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -300,7 +300,6 @@ struct BlockFmhaFwdV3Pipeline static constexpr auto QScaleEnum = Problem::QScaleEnum; static constexpr bool kSkipMinSeqlenQ = Problem::kSkipMinSeqlenQ; static_assert((BiasEnum == BlockAttentionBiasEnum::NO_BIAS && !kStoreLSE && !kHasDropout && - (QScaleEnum == ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE) && !kSkipMinSeqlenQ), "enable unsupported features"); @@ -437,7 +436,7 @@ struct BlockFmhaFwdV3Pipeline kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); - static_assert(sizeof(SaccDataType) * kM0 * kN0 <= GetSmemSize()); + // static_assert(sizeof(SaccDataType) * kM0 * kN0 <= GetSmemSize()); auto s_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr)), MakeSimpleLdsDesc()); @@ -854,12 +853,21 @@ struct BlockFmhaFwdV3Pipeline sp(sp_reg_idx).p.thread_buf_[idx] = casted.x; sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y; } - else + else if constexpr(std::is_same_v) { auto casted = detail::cvt_pk_bf16_f32(x, y); sp(sp_reg_idx).p.thread_buf_[idx] = casted.x; sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y; } + else if constexpr(std::is_same_v) + { + sp(sp_reg_idx).p.thread_buf_[idx] = type_convert(x); + sp(sp_reg_idx).p.thread_buf_[idx + 1] = type_convert(y); + } + else + { + static_assert(false, "unsupported data type for P"); + } }); /// Note: Place fmha_alu1() at the end of the phase. The surrounding inline assembly diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp index ce097b6741b..957e404b35b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp @@ -239,10 +239,18 @@ struct BlockFmhaV3PipelineDefaultPolicy typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0WarpTile>>; - constexpr auto warp_gemm = []() { - if constexpr(std::is_same_v && - std::is_same_v && + constexpr auto warp_gemm = [] { + if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) + { + constexpr index_t swizzle_factor = 4; + return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution< + swizzle_factor>{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) { /// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use /// WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution here diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 00512424752..ff724f41d68 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -288,6 +288,13 @@ using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl, 2>>; +template +using WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed = + WarpGemmImpl, + 2, + AttrNumAccess>>; + using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl, 2>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index a7d71d4fa3c..67b6586ea81 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -100,6 +100,8 @@ template<> struct Dispatcher { u // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed; };