-
Notifications
You must be signed in to change notification settings - Fork 268
[CK_TILE][FMHA] Fix uninitialized sink_size in mask_info::decode() and filter redundant no-mask+sink instances #3504
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Fixes an FMHA runtime dispatch hazard caused by uninitialized sink_size for no_mask, adds a compile-time guard against invalid sink+no-mask template combinations, and reduces redundant kernel instantiations in codegen.
Changes:
- Initialize
left/right/sinkwhen decodingno_maskinmask_info::decode(). - Add
static_assert(FmhaMask::IsMasking || !kHasSink)to prevent invalid pipeline instantiations. - Filter out
no_mask + sink=truecombinations in FMHA fwd-related codegen scripts.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp | Adds compile-time validation to prevent kHasSink=true when masking is disabled. |
| example/ck_tile/01_fmha/mask.hpp | Fixes uninitialized fields for no_mask decoding (prevents bogus runtime has_sink). |
| example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py | Skips generating redundant/invalid no_mask + sink kernel variants. |
| example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py | Skips generating redundant/invalid no_mask + sink kernel variants. |
| example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | Adds compatibility filtering to avoid no_mask + sink kernels in fwd generation. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| static constexpr auto QScaleEnum = Traits::QScaleEnum; | ||
| static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; | ||
| static constexpr bool kHasSink = Traits::kHasSink; | ||
| static_assert(FmhaMask::IsMasking || !kHasSink); |
Copilot
AI
Jan 21, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new static_assert has no diagnostic message, while other static_asserts in this file provide one (e.g., lines 108–123). Adding a short message (e.g., that sink requires masking) would make template instantiation failures much easier to understand.
| static constexpr bool kIsPagedKV = Traits::kIsPagedKV; | ||
| static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; | ||
| static constexpr bool kHasSink = Traits::kHasSink; | ||
| static_assert(FmhaMask::IsMasking || !kHasSink); |
Copilot
AI
Jan 21, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new static_assert has no diagnostic message, while other static_asserts in this file provide one (e.g., lines 108–123). Adding a short message (e.g., that sink requires masking) would make template instantiation failures much easier to understand.
| static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ; | ||
| static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; | ||
| static constexpr bool kHasSink = Traits::kHasSink; | ||
| static_assert(FmhaMask::IsMasking || !kHasSink); |
Copilot
AI
Jan 21, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new static_assert has no diagnostic message, while other static_asserts in this file provide one (e.g., lines 108–123). Adding a short message (e.g., that sink requires masking) would make template instantiation failures much easier to understand.
| # sink_size is only meaningful when mask is applied | ||
| if ( | ||
| kernel_ctx.pipeline.F_mask in no_mask_keys | ||
| and kernel_ctx.pipeline.F_sink == "t" | ||
| ): | ||
| return False |
Copilot
AI
Jan 21, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment says sink_size is only meaningful when no masking is applied, but the condition directly below filters out the no-mask + sink=true combination. Please update the comment to match the logic (i.e., sink is only meaningful when masking is enabled).
| or pipeline.F_logits == "f" | ||
| ): | ||
| continue | ||
| # sink_size is only meaningful when mask is applied |
Copilot
AI
Jan 21, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment says sink_size is only meaningful when no masking is applied, but the condition directly below filters out the no-mask + sink=true combination. Please update the comment to match the logic (i.e., sink is only meaningful when masking is enabled).
| # sink_size is only meaningful when mask is applied | |
| # sink_size is only meaningful when masking is enabled, so disallow sink when no mask is applied |
| or pipeline.F_logits == "f" | ||
| ): | ||
| continue | ||
| # sink_size is only meaningful when mask is applied |
Copilot
AI
Jan 21, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment says sink_size is only meaningful when no masking is applied, but the condition directly below filters out the no-mask + sink=true combination. Please update the comment to match the logic (i.e., sink is only meaningful when masking is enabled).
| # sink_size is only meaningful when mask is applied | |
| # sink_size is only meaningful when masking is enabled; disallow sink when no mask is used |
|
LGTM @asleepzzz Please approve it. |
Problem
When
mask_info::decode()parses"0"(no_mask), it only set thetypefield but leftleft,right, andsinkuninitialized. This caused:sinkcould be arbitrary garbage valuetraits.has_sink = (mask.sink > 0)in fmha_fwd_runner.hpp:882 might evaluate to truekHasSink=trueinstantiationsSolution
left=-1,right=-1,sink=0when decoding no_mask in mask.hppstatic_assert(FmhaMask::IsMasking || !kHasSink)to pipeline problemsF_mask=no_mask + F_sink=truecombinations in codegen scripts:Impact
Testing
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered