diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f17a4d768c..54c8b776ddb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,6 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations. * Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines. * Added persistent async input scheduler for CK Tile universal GEMM kernels to support asynchronous input streaming. -* Added FP8 block scale quantization for FMHA forward kernel. ### Changed diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index cac6671ca5f..a3cfe2622a6 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -77,13 +77,11 @@ def get_mask_cpp_check_expr(mask: str) -> str: QSCALE_MAP = { "no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE", "pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR", - "blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE", } QSCALE_CHECK_MAP = { "no": "quant_scale_enum::no_scale", "pertensor": "quant_scale_enum::pertensor", - "blockscale": "quant_scale_enum::blockscale", } BIAS_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index b59f442663f..81c7b067d33 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1024,7 +1024,7 @@ def get_pipelines( # no need lse/dropout kernels for logits, qscale, mask, bias, sink in itertools.product( ["t", "f"], - ["no", "pertensor", "blockscale"], + ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"], ["f", "t"], @@ -1152,10 +1152,7 @@ def get_pipelines( elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels for logits, qscale, mask, bias in itertools.product( - ["f"], - ["no", "pertensor", "blockscale"], - get_mask_map(mask_impl).keys(), - ["no"], + ["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"] ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index aedbb0e17c2..fdd720fd75b 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -230,8 +230,6 @@ struct fmha_fwd_args // array [batch + 1]. (Used with padding) const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length // array [batch + 1]. (Used with padding) - const void* block_scale_seqstart_q_ptr; - const void* block_scale_seqstart_k_ptr; const void* sink_ptr; ck_tile::index_t seqlen_q; @@ -259,9 +257,6 @@ struct fmha_fwd_args ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_lse; ck_tile::index_t nhead_stride_o; - ck_tile::index_t nhead_stride_q_descale; - ck_tile::index_t nhead_stride_k_descale; - ck_tile::index_t nhead_stride_v_descale; ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; @@ -269,9 +264,6 @@ struct fmha_fwd_args ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_lse; ck_tile::index_t batch_stride_o; - ck_tile::index_t batch_stride_q_descale; - ck_tile::index_t batch_stride_k_descale; - ck_tile::index_t batch_stride_v_descale; ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; @@ -284,9 +276,6 @@ struct fmha_fwd_args std::variant, std::pair> drop_seed_offset; - - ck_tile::index_t block_scale_size_q; - ck_tile::index_t block_scale_size_kv; }; struct fmha_fwd_pagedkv_args @@ -626,8 +615,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.seqstart_k_ptr, args.seqlen_q_ptr, args.seqlen_k_ptr, - args.block_scale_seqstart_q_ptr, - args.block_scale_seqstart_k_ptr, args.hdim_q, args.hdim_v, args.nhead_q, @@ -647,9 +634,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, - args.nhead_stride_q_descale, - args.nhead_stride_k_descale, - args.nhead_stride_v_descale, args.window_size_left, args.window_size_right, args.sink_size, @@ -658,8 +642,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.p_drop, args.s_randval, args.drop_seed_offset, - args.block_scale_size_q, - args.block_scale_size_kv, args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr, args.sink_ptr); @@ -697,9 +679,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, - args.nhead_stride_q_descale, - args.nhead_stride_k_descale, - args.nhead_stride_v_descale, args.batch_stride_q, args.batch_stride_k, args.batch_stride_v, @@ -707,9 +686,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.batch_stride_randval, args.batch_stride_lse, args.batch_stride_o, - args.batch_stride_q_descale, - args.batch_stride_k_descale, - args.batch_stride_v_descale, args.window_size_left, args.window_size_right, args.sink_size, @@ -717,8 +693,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.p_drop, args.s_randval, args.drop_seed_offset, - args.block_scale_size_q, - args.block_scale_size_kv, args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr, args.sink_ptr); diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index b6287245a0d..0c988b2acce 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -210,11 +210,6 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::stream_config& stream_config, std::optional json = std::nullopt) { - // Note: block_scale_size_q_ and block_scale_size_kv_ should be greater than or equal to the - // compute block size - constexpr ck_tile::index_t block_scale_size_q_ = 128; - constexpr ck_tile::index_t block_scale_size_kv_ = 128; - const std::string data_type = []() { if constexpr(std::is_same_v) return "fp32"; @@ -476,11 +471,7 @@ fwd_result fmha_fwd_run(mode_enum mode, std::size_t flop = 0, num_byte = 0; auto max_seqlen_q = std::numeric_limits::min(); // we will use max seqlen to decide grid size - size_t i_block_scale_q = 0; - size_t i_block_scale_k = 0; - std::vector block_scale_seqstart_q_host = {0}; - std::vector block_scale_seqstart_k_host = {0}; - auto max_seqlen_k = std::numeric_limits::min(); + auto max_seqlen_k = std::numeric_limits::min(); { for(ck_tile::index_t wb = 0; wb < batch; ++wb) { @@ -496,10 +487,6 @@ fwd_result fmha_fwd_run(mode_enum mode, { max_seqlen_k = real_seqlen_k; } - i_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_size_q_); - i_block_scale_k += ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_size_kv_); - block_scale_seqstart_q_host.push_back(i_block_scale_q); - block_scale_seqstart_k_host.push_back(i_block_scale_k); flop += nhead * (static_cast(2) * mask.get_unmaskarea() * hdim_q + static_cast(2) * mask.get_unmaskarea() * hdim_v); @@ -561,15 +548,6 @@ fwd_result fmha_fwd_run(mode_enum mode, ? seqstart_k_with_padding_host.back() : seqstart_k_host.back())); - const ck_tile::index_t num_block_scale_q = - (mode == mode_enum::batch) - ? ck_tile::integer_divide_ceil(shape_seqlen_q, block_scale_size_q_) - : i_block_scale_q; - const ck_tile::index_t num_block_scale_kv = - (mode == mode_enum::batch) - ? ck_tile::integer_divide_ceil(shape_seqlen_k, block_scale_size_kv_) - : i_block_scale_k; - ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); ck_tile::HostTensor sink_host({nhead}); @@ -621,18 +599,9 @@ fwd_result fmha_fwd_run(mode_enum mode, : std::array{1, 1, 1, 1, 1}); // TODO - change the tensor length for different quant scale - ck_tile::HostTensor q_descale_host( - qscale.type == quant_scale_enum::blockscale - ? std::array{shape_batch, nhead, num_block_scale_q} - : std::array{1, 1, 1}); - ck_tile::HostTensor k_descale_host( - qscale.type == quant_scale_enum::blockscale - ? std::array{shape_batch, nhead_k, num_block_scale_kv} - : std::array{1, 1, 1}); - ck_tile::HostTensor v_descale_host( - qscale.type == quant_scale_enum::blockscale - ? std::array{shape_batch, nhead_k, num_block_scale_kv} - : std::array{1, 1, 1}); + ck_tile::HostTensor q_descale_host(get_lengths(i_perm, 1, 1, 1, 1)); + ck_tile::HostTensor k_descale_host(get_lengths(i_perm, 1, 1, 1, 1)); + ck_tile::HostTensor v_descale_host(get_lengths(i_perm, 1, 1, 1, 1)); // batch mode of lse data layout is [batch, nhead, seqlen_q] // group mode of lse data layout is [nhead, total_seqlen_q] @@ -748,12 +717,6 @@ fwd_result fmha_fwd_run(mode_enum mode, k_descale_host(0) = qkv_max / k_dtype_max; v_descale_host(0) = qkv_max / v_dtype_max; } - else if(qscale.type == quant_scale_enum::blockscale) - { - ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(q_descale_host); - ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(k_descale_host); - ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(v_descale_host); - } iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine); iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine); @@ -774,10 +737,6 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem q_descale_buf(q_descale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_descale_buf(k_descale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem v_descale_buf(v_descale_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem block_scale_seqstart_q_buf(block_scale_seqstart_q_host.size() * - sizeof(int32_t)); - ck_tile::DeviceMem block_scale_seqstart_k_buf(block_scale_seqstart_k_host.size() * - sizeof(int32_t)); ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); @@ -823,8 +782,6 @@ fwd_result fmha_fwd_run(mode_enum mode, q_descale_buf.ToDevice(q_descale_host.data()); k_descale_buf.ToDevice(k_descale_host.data()); v_descale_buf.ToDevice(v_descale_host.data()); - block_scale_seqstart_q_buf.ToDevice(block_scale_seqstart_q_host.data()); - block_scale_seqstart_k_buf.ToDevice(block_scale_seqstart_k_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); // Keep logical starts in seqstart_k; pass padded K via separate pointer seqstart_k.ToDevice(seqstart_k_host.data()); @@ -1018,14 +975,11 @@ fwd_result fmha_fwd_run(mode_enum mode, }(); const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k); - const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; - const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); - const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); - const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); - const ck_tile::index_t nhead_stride_q_descale = num_block_scale_q; - const ck_tile::index_t nhead_stride_k_descale = num_block_scale_kv; - const ck_tile::index_t nhead_stride_v_descale = num_block_scale_kv; + const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; + const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); + const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); const ck_tile::index_t batch_stride_k = @@ -1043,9 +997,6 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); - const ck_tile::index_t batch_stride_q_descale = num_block_scale_q * nhead; - const ck_tile::index_t batch_stride_k_descale = num_block_scale_kv * nhead_k; - const ck_tile::index_t batch_stride_v_descale = num_block_scale_kv * nhead_k; // setup split_stride_* arguments (only used in split-kv kernel) const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q); const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v); @@ -1133,39 +1084,9 @@ fwd_result fmha_fwd_run(mode_enum mode, if constexpr(std::is_same_v>) { - if(qscale.type == quant_scale_enum::blockscale) - { - args.q_descale_ptr = - reinterpret_cast(q_descale_buf.GetDeviceBuffer()); - args.k_descale_ptr = - reinterpret_cast(k_descale_buf.GetDeviceBuffer()); - args.v_descale_ptr = - reinterpret_cast(v_descale_buf.GetDeviceBuffer()); - - args.block_scale_seqstart_q_ptr = - (mode == mode_enum::group ? block_scale_seqstart_q_buf.GetDeviceBuffer() - : nullptr); - args.block_scale_seqstart_k_ptr = - (mode == mode_enum::group ? block_scale_seqstart_k_buf.GetDeviceBuffer() - : nullptr); - - args.nhead_stride_q_descale = nhead_stride_q_descale; - args.nhead_stride_k_descale = nhead_stride_k_descale; - args.nhead_stride_v_descale = nhead_stride_v_descale; - - args.batch_stride_q_descale = batch_stride_q_descale; - args.batch_stride_k_descale = batch_stride_k_descale; - args.batch_stride_v_descale = batch_stride_v_descale; - - args.block_scale_size_q = block_scale_size_q_; - args.block_scale_size_kv = block_scale_size_kv_; - } - else - { - args.q_descale_ptr = q_descale_buf.GetDeviceBuffer(); - args.k_descale_ptr = k_descale_buf.GetDeviceBuffer(); - args.v_descale_ptr = v_descale_buf.GetDeviceBuffer(); - } + args.q_descale_ptr = q_descale_buf.GetDeviceBuffer(); + args.k_descale_ptr = k_descale_buf.GetDeviceBuffer(); + args.v_descale_ptr = v_descale_buf.GetDeviceBuffer(); args.rand_val_ptr = randval_buf.GetDeviceBuffer(); @@ -1668,42 +1589,14 @@ fwd_result fmha_fwd_run(mode_enum mode, #endif // reference - if(qscale.type == quant_scale_enum::blockscale) - { - const ck_tile::index_t q_offset = - (mode == mode_enum::batch) ? 0 : block_scale_seqstart_q_host[wb]; - const ck_tile::index_t k_offset = - (mode == mode_enum::batch) ? 0 : block_scale_seqstart_k_host[wb]; - ck_tile::reference_batched_quant_gemm( + ck_tile:: + reference_batched_gemm( q_host_ref, k_host_ref, s_host_ref, - ck_tile::idx_identity{}, - ck_tile::idx_identity{}, - [&](auto idx, auto value) { - return value * scale_s * - q_descale_host(b_idx, - std::get<0>(idx), - q_offset + std::get<1>(idx) / block_scale_size_q_) * - k_descale_host(b_idx, - std::get<0>(idx) / nr, - k_offset + std::get<2>(idx) / block_scale_size_kv_); - }); - } - else - { - ck_tile:: - reference_batched_gemm( - q_host_ref, - k_host_ref, - s_host_ref, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales(scale_s_host)); - } + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale_s_host)); if(0.f < logits_soft_cap) { @@ -1901,35 +1794,13 @@ fwd_result fmha_fwd_run(mode_enum mode, } } - if(qscale.type == quant_scale_enum::blockscale) - { - const ck_tile::index_t v_offset = - (mode == mode_enum::batch) ? 0 : block_scale_seqstart_k_host[wb]; - ck_tile:: - reference_batched_quant_gemm( - p_host_ref, - v_host_ref, - o_host_ref, - ck_tile::idx_identity{}, - [&](auto idx, auto value) { - return ck_tile::type_convert(value) * - v_descale_host(b_idx, - std::get<0>(idx) / nr, - v_offset + - std::get<2>(idx) / block_scale_size_kv_); - }, - ck_tile::idx_identity{}); - } - else - { - ck_tile::reference_batched_gemm( - p_host_ref, - v_host_ref, - o_host_ref, - ck_tile::identity{}, - ck_tile::identity{}, - oacc_element_func); - } + ck_tile::reference_batched_gemm( + p_host_ref, + v_host_ref, + o_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + oacc_element_func); ck_tile::HostTensor o_host_result({nhead, real_seqlen_q, hdim_v}); // clang-format off @@ -1937,6 +1808,7 @@ fwd_result fmha_fwd_run(mode_enum mode, if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); }); else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); }); // clang-format on + auto [rtol, atol] = get_elimit(init_method); bool cur_pass = ck_tile::check_err(o_host_result, o_host_ref, @@ -1994,33 +1866,31 @@ fwd_result fmha_fwd_run(mode_enum mode, if(json) { - dump_fmha_fwd_json_results( - *json, - data_type, - mode == mode_enum::batch ? "batch" : "group", - io_layout(i_perm, o_perm), - batch, - nhead, - nhead_k, - seqlen_qs[0], - seqlen_ks[0], - seqlen_kpads[0], - hdim_q, - hdim_v, - scale_s, - p_drop, - lse, - qscale.type == quant_scale_enum::no_scale - ? "no_scale" - : (qscale.type == quant_scale_enum::pertensor ? "pertensor" : "blockscale"), - bias.type == bias_enum::elementwise_bias - ? "elementwise_bias" - : (bias.type == bias_enum::alibi ? "alibi" : "no_bias"), - is_v_rowmajor ? "r" : "c", - pass, - ave_time, - tflops, - gb_per_sec); + dump_fmha_fwd_json_results(*json, + data_type, + mode == mode_enum::batch ? "batch" : "group", + io_layout(i_perm, o_perm), + batch, + nhead, + nhead_k, + seqlen_qs[0], + seqlen_ks[0], + seqlen_kpads[0], + hdim_q, + hdim_v, + scale_s, + p_drop, + lse, + qscale.type == quant_scale_enum::no_scale ? "no_scale" + : "pertensor", + bias.type == bias_enum::elementwise_bias + ? "elementwise_bias" + : (bias.type == bias_enum::alibi ? "alibi" : "no_bias"), + is_v_rowmajor ? "r" : "c", + pass, + ave_time, + tflops, + gb_per_sec); } return pass ? fwd_result::success : fwd_result::failure; diff --git a/example/ck_tile/01_fmha/quant.hpp b/example/ck_tile/01_fmha/quant.hpp index feb28cba24e..59d4ac17073 100644 --- a/example/ck_tile/01_fmha/quant.hpp +++ b/example/ck_tile/01_fmha/quant.hpp @@ -13,7 +13,6 @@ enum class quant_scale_enum { no_scale = 0, pertensor = 1, - blockscale, }; struct quant_scale_info @@ -26,8 +25,6 @@ struct quant_scale_info os << "n"; else if(type == quant_scale_enum::pertensor) os << "pt"; - else if(type == quant_scale_enum::blockscale) - os << "bs"; } static quant_scale_info decode(std::string str) @@ -41,10 +38,6 @@ struct quant_scale_info { info.type = quant_scale_enum::pertensor; } - else if(str == "bs" || str == "2") - { - info.type = quant_scale_enum::blockscale; - } else { throw std::invalid_argument("invalid quant scale value: " + str); diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index 227f26c8f36..596542eb9dd 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -95,11 +95,10 @@ run_fp8bf16_tests() { for perm in 0 1 ; do for b in 1 2 ; do for hdim in 64 128 256 ; do - for scale in 1 2; do - $EXE -prec=fp8bf16 -init=3 -b=$b -h=1 -d=$hdim -s=128 -iperm=$perm -operm=$perm -vlayout=r -qscale=$scale -kname=$KNAME $COMMON_ARGS + $EXE -prec=fp8bf16 -init=3 -b=$b -h=1 -d=$hdim -s=128 -iperm=$perm -operm=$perm -vlayout=r -qscale=1 -kname=$KNAME $COMMON_ARGS - done ; done ; done ; done + done ; done ; done } run_fp8fp32_tests() { diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index a46ae509dd0..96e76f669dd 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -37,13 +37,6 @@ struct scales return lhs_ * rhs; } - template - CK_TILE_HOST_DEVICE constexpr auto operator*(OtherScale other) const - { - auto new_scale = lhs_ * other; - return scales>(new_scale); - } - private: Scale lhs_; }; diff --git a/include/ck_tile/core/utility/functional.hpp b/include/ck_tile/core/utility/functional.hpp index aa4bfa3f150..898d21574e5 100644 --- a/include/ck_tile/core/utility/functional.hpp +++ b/include/ck_tile/core/utility/functional.hpp @@ -119,18 +119,6 @@ struct identity } }; -// Similar to identity, but takes an additional index parameter as the first argument. -// The index is ignored and only the second argument (value) is forwarded. -// Useful for indexed element-wise operations where the functor signature requires an index. -struct idx_identity -{ - template - CK_TILE_HOST_DEVICE constexpr T&& operator()(I&& /*idx*/, T&& arg) const noexcept - { - return std::forward(arg); - } -}; - namespace detail { // RemainLengths: sequence<...> diff --git a/include/ck_tile/host/reference/reference_batched_gemm.hpp b/include/ck_tile/host/reference/reference_batched_gemm.hpp index d7424267407..63f13b1b161 100644 --- a/include/ck_tile/host/reference/reference_batched_gemm.hpp +++ b/include/ck_tile/host/reference/reference_batched_gemm.hpp @@ -47,44 +47,4 @@ CK_TILE_HOST void reference_batched_gemm(const HostTensor& a_b_m_k, make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])( std::thread::hardware_concurrency()); } -template -CK_TILE_HOST void reference_batched_quant_gemm(const HostTensor& a_b_m_k, - const HostTensor& b_b_n_k, - HostTensor& c_b_m_n, - const AElementOp& a_element_op = {}, - const BElementOp& b_element_op = {}, - const ACCElementOp& acc_element_op = {}) -{ - const int N = b_b_n_k.mDesc.get_lengths()[1]; - const int K = b_b_n_k.mDesc.get_lengths()[2]; - - auto f = [&](auto batch, auto m) { - for(int n = 0; n < N; ++n) - { - AccDataType v_acc = 0; - - for(int k = 0; k < K; ++k) - { - AccDataType v_a = ck_tile::type_convert( - a_element_op(std::make_tuple(batch, m, k), a_b_m_k(batch, m, k))); - AccDataType v_b = ck_tile::type_convert( - b_element_op(std::make_tuple(batch, n, k), b_b_n_k(batch, n, k))); - - v_acc += v_a * v_b; - } - - c_b_m_n(batch, m, n) = ck_tile::type_convert( - acc_element_op(std::make_tuple(batch, m, n), v_acc)); - } - }; - - make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])( - std::thread::hardware_concurrency()); -} } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp index 7e0f704bef8..3755a2bc719 100644 --- a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp +++ b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp @@ -12,7 +12,6 @@ enum class BlockAttentionQuantScaleEnum { NO_SCALE = 0, PERTENSOR = 1, - BLOCKSCALE, }; template @@ -28,10 +27,5 @@ struct BlockAttentionQuantScaleEnumToStr -struct BlockAttentionQuantScaleEnumToStr -{ - static constexpr const char* name = "blockscale"; -}; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 0039c57cfce..adbedc52599 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -168,29 +168,6 @@ struct FmhaFwdKernel const void* v_descale_ptr = nullptr; }; - struct FmhaFwdCommonBlockScaleKargs : public FmhaFwdCommonQScaleKargs - { - ck_tile::index_t nhead_stride_q_descale; - ck_tile::index_t nhead_stride_k_descale; - ck_tile::index_t nhead_stride_v_descale; - - ck_tile::index_t block_scale_size_q; - ck_tile::index_t block_scale_size_kv; - }; - - struct FmhaFwdBatchBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs - { - ck_tile::index_t batch_stride_q_descale; - ck_tile::index_t batch_stride_k_descale; - ck_tile::index_t batch_stride_v_descale; - }; - - struct FmhaFwdGroupBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs - { - const int32_t* block_scale_seqstart_q_ptr; - const int32_t* block_scale_seqstart_k_ptr; - }; - struct FmhaFwdCommonLSEKargs { void* lse_ptr = nullptr; @@ -266,12 +243,9 @@ struct FmhaFwdKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t< - QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR, - FmhaFwdCommonQScaleKargs, - std::conditional_t>>, + std::conditional_t>, std::conditional_t>, std::conditional_t> { @@ -295,12 +269,9 @@ struct FmhaFwdKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t< - QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR, - FmhaFwdCommonQScaleKargs, - std::conditional_t>>, + std::conditional_t>, std::conditional_t>, std::conditional_t>, std::conditional_t> @@ -357,9 +328,6 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, - ck_tile::index_t nhead_stride_q_descale, - ck_tile::index_t nhead_stride_k_descale, - ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, @@ -367,9 +335,6 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, - ck_tile::index_t batch_stride_q_descale, - ck_tile::index_t batch_stride_k_descale, - ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -378,8 +343,6 @@ struct FmhaFwdKernel bool s_randval, std::variant, std::pair> drop_seed_offset, - ck_tile::index_t block_scale_size_q, - ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -450,23 +413,6 @@ struct FmhaFwdKernel kargs.k_descale_ptr = k_descale_ptr; kargs.v_descale_ptr = v_descale_ptr; } - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - kargs.q_descale_ptr = q_descale_ptr; - kargs.k_descale_ptr = k_descale_ptr; - kargs.v_descale_ptr = v_descale_ptr; - - kargs.nhead_stride_q_descale = nhead_stride_q_descale; - kargs.nhead_stride_k_descale = nhead_stride_k_descale; - kargs.nhead_stride_v_descale = nhead_stride_v_descale; - - kargs.batch_stride_q_descale = batch_stride_q_descale; - kargs.batch_stride_k_descale = batch_stride_k_descale; - kargs.batch_stride_v_descale = batch_stride_v_descale; - - kargs.block_scale_size_q = block_scale_size_q; - kargs.block_scale_size_kv = block_scale_size_kv; - } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -532,9 +478,6 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, - ck_tile::index_t nhead_stride_q_descale, - ck_tile::index_t nhead_stride_k_descale, - ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, @@ -542,9 +485,6 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, - ck_tile::index_t batch_stride_q_descale, - ck_tile::index_t batch_stride_k_descale, - ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -552,8 +492,6 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, - ck_tile::index_t block_scale_size_q, - ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -590,9 +528,6 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, - nhead_stride_q_descale, - nhead_stride_k_descale, - nhead_stride_v_descale, batch_stride_q, batch_stride_k, batch_stride_v, @@ -600,9 +535,6 @@ struct FmhaFwdKernel batch_stride_randval, batch_stride_lse, batch_stride_o, - batch_stride_q_descale, - batch_stride_k_descale, - batch_stride_v_descale, window_size_left, window_size_right, sink_size, @@ -610,8 +542,6 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), - block_scale_size_q, - block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr); @@ -651,9 +581,6 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, - ck_tile::index_t nhead_stride_q_descale, - ck_tile::index_t nhead_stride_k_descale, - ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, @@ -661,9 +588,6 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, - ck_tile::index_t batch_stride_q_descale, - ck_tile::index_t batch_stride_k_descale, - ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -671,8 +595,6 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, - ck_tile::index_t block_scale_size_q, - ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -709,9 +631,6 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, - nhead_stride_q_descale, - nhead_stride_k_descale, - nhead_stride_v_descale, batch_stride_q, batch_stride_k, batch_stride_v, @@ -719,9 +638,6 @@ struct FmhaFwdKernel batch_stride_randval, batch_stride_lse, batch_stride_o, - batch_stride_q_descale, - batch_stride_k_descale, - batch_stride_v_descale, window_size_left, window_size_right, sink_size, @@ -729,8 +645,6 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), - block_scale_size_q, - block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr); @@ -752,8 +666,6 @@ struct FmhaFwdKernel const void* seqstart_k_ptr, const void* seqlen_q_ptr, const void* seqlen_k_ptr, - const void* block_scale_seqstart_q_ptr, - const void* block_scale_seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -773,9 +685,6 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, - ck_tile::index_t nhead_stride_q_descale, - ck_tile::index_t nhead_stride_k_descale, - ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -785,8 +694,6 @@ struct FmhaFwdKernel bool s_randval, std::variant, std::pair> drop_seed_offset, - ck_tile::index_t block_scale_size_q, - ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -856,24 +763,6 @@ struct FmhaFwdKernel kargs.k_descale_ptr = k_descale_ptr; kargs.v_descale_ptr = v_descale_ptr; } - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - kargs.q_descale_ptr = q_descale_ptr; - kargs.k_descale_ptr = k_descale_ptr; - kargs.v_descale_ptr = v_descale_ptr; - - kargs.nhead_stride_q_descale = nhead_stride_q_descale; - kargs.nhead_stride_k_descale = nhead_stride_k_descale; - kargs.nhead_stride_v_descale = nhead_stride_v_descale; - - kargs.block_scale_size_q = block_scale_size_q; - kargs.block_scale_size_kv = block_scale_size_kv; - - kargs.block_scale_seqstart_q_ptr = - reinterpret_cast(block_scale_seqstart_q_ptr); - kargs.block_scale_seqstart_k_ptr = - reinterpret_cast(block_scale_seqstart_k_ptr); - } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -925,8 +814,6 @@ struct FmhaFwdKernel const void* seqstart_k_ptr, const void* seqlen_q_ptr, const void* seqlen_k_ptr, - const void* block_scale_seqstart_q_ptr, - const void* block_scale_seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -946,9 +833,6 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, - ck_tile::index_t nhead_stride_q_descale, - ck_tile::index_t nhead_stride_k_descale, - ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -957,8 +841,6 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, - ck_tile::index_t block_scale_size_q, - ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -978,8 +860,6 @@ struct FmhaFwdKernel seqstart_k_ptr, seqlen_q_ptr, seqlen_k_ptr, - block_scale_seqstart_q_ptr, - block_scale_seqstart_k_ptr, hdim_q, hdim_v, num_head_q, @@ -999,9 +879,6 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, - nhead_stride_q_descale, - nhead_stride_k_descale, - nhead_stride_v_descale, window_size_left, window_size_right, sink_size, @@ -1010,8 +887,6 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), - block_scale_size_q, - block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr); @@ -1034,8 +909,6 @@ struct FmhaFwdKernel const void* seqstart_k_ptr, const void* seqlen_q_ptr, const void* seqlen_k_ptr, - const void* block_scale_seqstart_q_ptr, - const void* block_scale_seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -1055,9 +928,6 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, - ck_tile::index_t nhead_stride_q_descale, - ck_tile::index_t nhead_stride_k_descale, - ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -1066,8 +936,6 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, - ck_tile::index_t block_scale_size_q, - ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -1087,8 +955,6 @@ struct FmhaFwdKernel seqstart_k_ptr, seqlen_q_ptr, seqlen_k_ptr, - block_scale_seqstart_q_ptr, - block_scale_seqstart_k_ptr, hdim_q, hdim_v, num_head_q, @@ -1108,9 +974,6 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, - nhead_stride_q_descale, - nhead_stride_k_descale, - nhead_stride_v_descale, window_size_left, window_size_right, sink_size, @@ -1119,8 +982,6 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), - block_scale_size_q, - block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr); @@ -1250,16 +1111,13 @@ struct FmhaFwdKernel const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0); const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1); - long_index_t batch_offset_q = 0; - long_index_t batch_offset_k = 0; - long_index_t batch_offset_v = 0; - long_index_t batch_offset_bias = 0; - long_index_t batch_offset_randval = 0; - long_index_t batch_offset_lse = 0; - long_index_t batch_offset_o = 0; - long_index_t batch_offset_q_descale = 0; - long_index_t batch_offset_k_descale = 0; - long_index_t batch_offset_v_descale = 0; + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_randval = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; const float sink_value = kargs.sink_ptr != nullptr ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s @@ -1295,14 +1153,6 @@ struct FmhaFwdKernel { batch_offset_randval = query_start * kargs.stride_randval; } - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - const long_index_t bquery_start = kargs.block_scale_seqstart_q_ptr[i_batch]; - const long_index_t bkey_start = kargs.block_scale_seqstart_k_ptr[i_batch]; - batch_offset_q_descale = bquery_start; - batch_offset_k_descale = bkey_start; - batch_offset_v_descale = bkey_start; - } batch_offset_o = query_start * kargs.stride_o; // real logical lengths (exclude PAD) @@ -1370,15 +1220,6 @@ struct FmhaFwdKernel batch_offset_randval = static_cast(i_batch) * kargs.batch_stride_randval; } - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - batch_offset_q_descale = - static_cast(i_batch) * kargs.batch_stride_q_descale; - batch_offset_k_descale = - static_cast(i_batch) * kargs.batch_stride_k_descale; - batch_offset_v_descale = - static_cast(i_batch) * kargs.batch_stride_v_descale; - } batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; // If cumulative seqlen pointers are provided, override per-batch effective lengths @@ -1699,8 +1540,7 @@ struct FmhaFwdKernel }(); BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; - - auto o_acc_tile = [&, i_nhead_ = i_nhead]() { + auto o_acc_tile = [&]() { if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) { // TODO - move global load of descale to pipeline @@ -1741,62 +1581,8 @@ struct FmhaFwdKernel block_indices, smem_ptr, dropout, - nullptr, - nullptr, - 1, sink_value); } - else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - const float* q_descale_ptr = - reinterpret_cast(kargs.q_descale_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_q_descale + - batch_offset_q_descale; - const float* k_descale_ptr = - reinterpret_cast(kargs.k_descale_ptr) + - static_cast(i_nhead_ / kargs.nhead_ratio_qk) * - kargs.nhead_stride_k_descale + - batch_offset_k_descale; - const float* v_descale_ptr = - reinterpret_cast(kargs.v_descale_ptr) + - static_cast(i_nhead_ / kargs.nhead_ratio_qk) * - kargs.nhead_stride_v_descale + - batch_offset_v_descale; - - size_t idx = i_m0 / kargs.block_scale_size_q; - float q_descale = q_descale_ptr[idx]; - // BLOCKSCALE: P is scaled in exp2(x+shift) where shift=7 or 8 - // Both P and rowsum are scaled by 2^shift, canceling in normalization - // No additional scaling needed in p_compute_element_func or o_acc_element_func - - return FmhaPipeline{}( - q_dram_window, - identity{}, // q_element_func - k_dram_window, - identity{}, // k_element_func - v_dram_window, - identity{}, // v_element_func - bias_dram_window, - identity{}, // bias_element_func - randval_dram_window, - lse_dram_window, - identity{}, // lse_element_func - scales(q_descale), // s_acc_element_func - identity{}, // p_compute_element_func - No scaling (done in exp2) - identity{}, // o_acc_element_func - No dequant needed (canceled by rowsum) - mask, - position_encoding, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr, - dropout, - k_descale_ptr, - v_descale_ptr, - kargs.block_scale_size_kv, - sink_value); - } else { return FmhaPipeline{}(q_dram_window, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 2fbc9fdb545..dcccdf541cf 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -57,13 +57,8 @@ struct BlockFmhaPipelineQRKSVS static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; - static constexpr auto QScaleEnum = Problem::QScaleEnum; static constexpr bool kHasSink = Problem::kHasSink; - // For BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift] - static constexpr float OCP_FP8_SHIFT = 8.0f; - static constexpr float FNUZ_FP8_SHIFT = 7.0f; - static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate) @@ -172,9 +167,6 @@ struct BlockFmhaPipelineQRKSVS const BlockIndices& block_indices, void* smem_ptr, DropoutType& dropout, - const float* k_descale_ptr, - const float* v_descale_ptr, - const index_t block_scale_size_kv, const float sink_v) const { static_assert( @@ -366,13 +358,6 @@ struct BlockFmhaPipelineQRKSVS static_assert(1 <= k1_loops); do { - float k_descale = 1.0f; - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - // K and V share the same seqlen_k position within a block - const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv; - k_descale = k_descale_ptr[kv_idx]; - } // STAGE 1, QK gemm auto k_dram_window = make_tile_window( k_dram_block_window.get_bottom_tensor_view(), @@ -442,20 +427,11 @@ struct BlockFmhaPipelineQRKSVS k_lds_window); schedule_gemm0(); } - // dequant - auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() { - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - return s_acc_element_func * k_descale; - } - else - return s_acc_element_func; - }(); // STAGE 2, scale_s, add bias, mask, softmax if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); tile_elementwise_inout( [&](auto& x, const auto& y) { @@ -473,7 +449,7 @@ struct BlockFmhaPipelineQRKSVS { const auto k_origin = k_dram_block_window.get_window_origin(); constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); - s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { const auto tile_idx = get_x_indices_from_distributed_indices( @@ -490,7 +466,7 @@ struct BlockFmhaPipelineQRKSVS } else { - s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); if constexpr(kHasLogitsSoftCap) { auto apply_logits_transform = @@ -595,21 +571,7 @@ struct BlockFmhaPipelineQRKSVS sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - // For BLOCKSCALE: precompute (m - shift) once per row - // Bias/Alibi/SoftCap: exp2(s - m + shift) = exp2(s - (m - shift)) - // else: exp2(scale_s*s - scale_s*m + shift) = exp2(scale_s*s - (scale_s*m - shift)) - auto validated_m = get_validated_m(m[i_idx]); - auto row_max = scale_s * validated_m; - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { -#if CK_TILE_USE_OCP_FP8 - validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap - row_max -= OCP_FP8_SHIFT; // for else branch -#else - validated_m -= FNUZ_FP8_SHIFT; - row_max -= FNUZ_FP8_SHIFT; -#endif - } + auto row_max = scale_s * get_validated_m(m[i_idx]); #endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); @@ -617,13 +579,13 @@ struct BlockFmhaPipelineQRKSVS if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || BiasEnum == BlockAttentionBiasEnum::ALIBI) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } else { if constexpr(kHasLogitsSoftCap) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } else { @@ -714,39 +676,18 @@ struct BlockFmhaPipelineQRKSVS store_tile(v_lds_window, tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch } - move_tile_window(v_dram_window, {0, kK1}); const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); - float v_descale = 1.0f; - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - // K and V share the same seqlen_k position within a block - const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv; - v_descale = v_descale_ptr[kv_idx]; - } // STAGE 3, KV gemm - auto o_acc0 = decltype(o_acc){}; - clear_tile(o_acc0); - - auto& o_acc_ = [&o_acc0, &o_acc]() -> auto& { - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - return o_acc0; - } - else - { - return o_acc; - } - }(); if constexpr(k1_loops > 1) { static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { const auto v = load_tile(v_dram_window); // load next v block_sync_lds(); - gemm_1(o_acc_, + gemm_1(o_acc, get_slice_tile( p, sequence<0, i_k1 * kK1>{}, sequence{}), v_lds_window); @@ -781,16 +722,11 @@ struct BlockFmhaPipelineQRKSVS // tail { block_sync_lds(); - gemm_1(o_acc_, + gemm_1(o_acc, get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), v_lds_window); block_sync_lds(); } - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - tile_elementwise_inout( - [&v_descale](auto& o, auto& o0) { o += o0 * v_descale; }, o_acc, o_acc0); - } } while(++i_total_loops < num_total_loop); // store lse @@ -910,9 +846,6 @@ struct BlockFmhaPipelineQRKSVS block_indices, smem_ptr, dropout, - nullptr, - nullptr, - 1, sink_v); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 81bd8d5ab52..797e572d58b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -46,7 +46,6 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; - static constexpr auto QScaleEnum = Problem::QScaleEnum; static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); @@ -65,10 +64,6 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr bool kHasSink = Problem::kHasSink; - // For BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift] - static constexpr float OCP_FP8_SHIFT = 8.0f; - static constexpr float FNUZ_FP8_SHIFT = 7.0f; - static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || !kHasLogitsSoftCap)) || @@ -195,9 +190,6 @@ struct BlockFmhaPipelineQRKSVSAsync const BlockIndices& block_indices, void* smem_ptr, DropoutType& dropout, - const float* k_descale_ptr, - const float* v_descale_ptr, - const index_t block_scale_size_kv, const float sink_v) const { static_assert( @@ -411,13 +403,6 @@ struct BlockFmhaPipelineQRKSVSAsync // main loop do { - float k_descale = 1.0f; - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - // K and V share the same seqlen_k position within a block - const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv; - k_descale = k_descale_ptr[kv_idx]; - } // STAGE 1, QK gemm clear_tile(s_acc); // initialize C if constexpr(k0_loops > 1) @@ -464,20 +449,11 @@ struct BlockFmhaPipelineQRKSVSAsync sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); } __builtin_amdgcn_sched_barrier(1); - // dequant - auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() { - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - return s_acc_element_func * k_descale; - } - else - return s_acc_element_func; - }(); // STAGE 2, scale_s, add bias, mask, softmax if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); tile_elementwise_inout( [&](auto& x, const auto& y) { @@ -495,7 +471,7 @@ struct BlockFmhaPipelineQRKSVSAsync { const auto k_origin = k_dram_block_window.get_window_origin(); constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); - s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { const auto tile_idx = get_x_indices_from_distributed_indices( @@ -512,7 +488,7 @@ struct BlockFmhaPipelineQRKSVSAsync } else { - s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); if constexpr(kHasLogitsSoftCap) { auto apply_logits_transform = @@ -654,21 +630,7 @@ struct BlockFmhaPipelineQRKSVSAsync sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - // For BLOCKSCALE: precompute (m - shift) once per row - // Bias/Alibi/SoftCap: exp2(s - m + shift) = exp2(s - (m - shift)) - // else: exp2(scale_s*s - scale_s*m + shift) = exp2(scale_s*s - (scale_s*m - shift)) - auto validated_m = get_validated_m(m[i_idx]); - auto row_max = scale_s * validated_m; - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { -#if CK_TILE_USE_OCP_FP8 - validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap - row_max -= OCP_FP8_SHIFT; // for else branch -#else - validated_m -= FNUZ_FP8_SHIFT; - row_max -= FNUZ_FP8_SHIFT; -#endif - } + auto row_max = scale_s * get_validated_m(m[i_idx]); #endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); @@ -676,13 +638,13 @@ struct BlockFmhaPipelineQRKSVSAsync if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || BiasEnum == BlockAttentionBiasEnum::ALIBI) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } else { if constexpr(kHasLogitsSoftCap) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } else { @@ -773,27 +735,7 @@ struct BlockFmhaPipelineQRKSVSAsync #endif }(); - float v_descale = 1.0f; - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - // K and V share the same seqlen_k position within a block - const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv; - v_descale = v_descale_ptr[kv_idx]; - } // STAGE 3, KV gemm - auto o_acc0 = decltype(o_acc){}; - clear_tile(o_acc0); - - auto& o_acc_ = [&o_acc0, &o_acc]() -> auto& { - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - return o_acc0; - } - else - { - return o_acc; - } - }(); if constexpr(k1_loops > 1) { static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { @@ -803,7 +745,7 @@ struct BlockFmhaPipelineQRKSVSAsync v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf } block_sync_lds(); - gemm_1(o_acc_, + gemm_1(o_acc, get_slice_tile( p, sequence<0, i_k1 * kK1>{}, sequence{}), get_slice_tile( @@ -866,19 +808,13 @@ struct BlockFmhaPipelineQRKSVSAsync { block_sync_lds(); gemm_1( - o_acc_, + o_acc, get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), get_slice_tile( v_lds_window, sequence<(LdsSeq.at(number{})) * kN1, 0>{}, sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); } - - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - tile_elementwise_inout( - [&v_descale](auto& o, auto& o0) { o += o0 * v_descale; }, o_acc, o_acc0); - } } while(i_total_loops < num_total_loop); // store lse @@ -986,9 +922,6 @@ struct BlockFmhaPipelineQRKSVSAsync block_indices, smem_ptr, dropout, - nullptr, - nullptr, - 1, sink_v); } };