Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations.
* Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines.
* Added persistent async input scheduler for CK Tile universal GEMM kernels to support asynchronous input streaming.
* Added FP8 block scale quantization for FMHA forward kernel.

### Changed

Expand Down
2 changes: 0 additions & 2 deletions example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,11 @@ def get_mask_cpp_check_expr(mask: str) -> str:
QSCALE_MAP = {
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
}

QSCALE_CHECK_MAP = {
"no": "quant_scale_enum::no_scale",
"pertensor": "quant_scale_enum::pertensor",
"blockscale": "quant_scale_enum::blockscale",
}

BIAS_MAP = {
Expand Down
7 changes: 2 additions & 5 deletions example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,7 +1024,7 @@ def get_pipelines(
# no need lse/dropout kernels
for logits, qscale, mask, bias, sink in itertools.product(
["t", "f"],
["no", "pertensor", "blockscale"],
["no", "pertensor"],
get_mask_map(mask_impl).keys(),
["no"],
["f", "t"],
Expand Down Expand Up @@ -1152,10 +1152,7 @@ def get_pipelines(
elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32:
# no need lse/dropout kernels
for logits, qscale, mask, bias in itertools.product(
["f"],
["no", "pertensor", "blockscale"],
get_mask_map(mask_impl).keys(),
["no"],
["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"]
):
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip
Expand Down
26 changes: 0 additions & 26 deletions example/ck_tile/01_fmha/fmha_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,6 @@ struct fmha_fwd_args
// array [batch + 1]. (Used with padding)
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
const void* block_scale_seqstart_q_ptr;
const void* block_scale_seqstart_k_ptr;
const void* sink_ptr;

ck_tile::index_t seqlen_q;
Expand Down Expand Up @@ -259,19 +257,13 @@ struct fmha_fwd_args
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t nhead_stride_q_descale;
ck_tile::index_t nhead_stride_k_descale;
ck_tile::index_t nhead_stride_v_descale;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_q_descale;
ck_tile::index_t batch_stride_k_descale;
ck_tile::index_t batch_stride_v_descale;

ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
Expand All @@ -284,9 +276,6 @@ struct fmha_fwd_args

std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset;

ck_tile::index_t block_scale_size_q;
ck_tile::index_t block_scale_size_kv;
};

struct fmha_fwd_pagedkv_args
Expand Down Expand Up @@ -626,8 +615,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.seqstart_k_ptr,
args.seqlen_q_ptr,
args.seqlen_k_ptr,
args.block_scale_seqstart_q_ptr,
args.block_scale_seqstart_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
Expand All @@ -647,9 +634,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_randval,
args.nhead_stride_lse,
args.nhead_stride_o,
args.nhead_stride_q_descale,
args.nhead_stride_k_descale,
args.nhead_stride_v_descale,
args.window_size_left,
args.window_size_right,
args.sink_size,
Expand All @@ -658,8 +642,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.p_drop,
args.s_randval,
args.drop_seed_offset,
args.block_scale_size_q,
args.block_scale_size_kv,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr,
args.sink_ptr);
Expand Down Expand Up @@ -697,28 +679,20 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_randval,
args.nhead_stride_lse,
args.nhead_stride_o,
args.nhead_stride_q_descale,
args.nhead_stride_k_descale,
args.nhead_stride_v_descale,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_bias,
args.batch_stride_randval,
args.batch_stride_lse,
args.batch_stride_o,
args.batch_stride_q_descale,
args.batch_stride_k_descale,
args.batch_stride_v_descale,
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset,
args.block_scale_size_q,
args.block_scale_size_kv,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr,
args.sink_ptr);
Expand Down
Loading