Added Causal Mask Pattern Fusion for LongRoPe Models and Cache Insertion for Phi4-mini-reasoning#2461
Added Causal Mask Pattern Fusion for LongRoPe Models and Cache Insertion for Phi4-mini-reasoning#2461tadani3 wants to merge 12 commits intomicrosoft:mainfrom
Conversation
| mask_key = self._get_mask_key(attention_mask) | ||
|
|
||
| if mask_key in self._mask_cache: | ||
| total_seq_length_int32, seqlens_k_int32 = self._mask_cache[mask_key] |
Check notice
Code scanning / CodeQL
Unused local variable Note
| mask_key = self._get_mask_key(attention_mask) | ||
|
|
||
| if mask_key in self._mask_cache: | ||
| total_seq_length_int32, seqlens_k_int32 = self._mask_cache[mask_key] |
Check notice
Code scanning / CodeQL
Unused local variable Note
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2461 +/- ##
==========================================
- Coverage 69.81% 69.03% -0.78%
==========================================
Files 209 210 +1
Lines 25313 25790 +477
Branches 2525 2603 +78
==========================================
+ Hits 17673 17805 +132
- Misses 6762 7110 +348
+ Partials 878 875 -3 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
lintrunner found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.
There was a problem hiding this comment.
Pull Request Overview
This PR adds causal mask pattern fusion support specifically for LongRoPe models such as Phi-4-mini-reasoning. The implementation extends the existing GQA (Group Query Attention) fusion rules to handle the complex attention mask patterns used by LongRoPe models, optimizing the mask computation process while maintaining compatibility with ModelBuilder optimizations.
Key changes:
- Addition of a new
LongRoPeGQACausalMaskclass that implements specialized mask pattern matching and fusion - Extension of the GQA rewrite rules to include LongRoPe-specific optimizations
- Implementation of mask caching mechanism to avoid recomputation
| _basic_gqa_rule = GroupQueryAttention.rule() | ||
| _longrope_gqa_causal_mask_rule = LongRoPeGQACausalMask.rule() | ||
|
|
||
| gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule]) |
There was a problem hiding this comment.
The gqa_rules variable is being reassigned, which overwrites the previous assignment on line 514. This means the first assignment gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule]) is completely ignored.
| gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule]) |
| # Propagation to GQA | ||
| mask_sliced = op.Slice(mask_A_B_C_scaled, [0], pattern.ANY_VALUE, [3], [1], _outputs=["mask_sliced"]) | ||
|
|
||
| #mask_where = op.Where(mask_sliced, pattern.ANY_VALUE, pattern.ANY_VALUE, _outputs=["mask_where"]) |
There was a problem hiding this comment.
This commented-out code should be removed if it's not needed, or properly implemented if it serves a purpose. Dead code reduces maintainability.
| #mask_where = op.Where(mask_sliced, pattern.ANY_VALUE, pattern.ANY_VALUE, _outputs=["mask_where"]) | |
| mask_expanded_C = op.Expand(reshaped_range_C, mask_shape_C_abs, _outputs=["mask_expanded_C"]) | ||
|
|
||
| # EXPAND A/B TO AND | ||
| mask_expanded_A_sub = op.Sub(mask_expanded_A, 262144, _outputs=["mask_expanded_A_sub"]) |
There was a problem hiding this comment.
The magic number 262144 should be defined as a named constant to improve code readability and maintainability. Consider defining it as a class constant with a descriptive name.
| mask_expanded_A_sub = op.Sub(mask_expanded_A, 262144, _outputs=["mask_expanded_A_sub"]) | |
| mask_expanded_A_sub = op.Sub(mask_expanded_A, MASK_OFFSET, _outputs=["mask_expanded_A_sub"]) |
There was a problem hiding this comment.
Better to make it a pattern-variable, I think ... if I understand right, this is actually a magic sequence-length constant? Perhaps model-specific?
There was a problem hiding this comment.
On second thoughts, I am guessing this is the window_size, which should become an attribute-parameter to the GQA op.
| Generate a unique key for the mask based on input_ids and past_kv_cache. | ||
| This is used to cache the mask to avoid recomputation. | ||
| """ | ||
| return (id(attention_mask)) |
There was a problem hiding this comment.
Using id() for cache keys is fragile because object ids can be reused after garbage collection. This could lead to incorrect cache hits with different attention_mask objects that happen to have the same id.
| Generate a unique key for the mask based on input_ids and past_kv_cache. | |
| This is used to cache the mask to avoid recomputation. | |
| """ | |
| return (id(attention_mask)) | |
| Generate a unique key for the mask based on the content of attention_mask. | |
| This is used to cache the mask to avoid recomputation. | |
| """ | |
| if isinstance(attention_mask, np.ndarray): | |
| return hash(attention_mask.tobytes()) | |
| elif isinstance(attention_mask, (list, tuple)): | |
| return hash(tuple(attention_mask)) | |
| else: | |
| raise TypeError("Unsupported type for attention_mask: {}".format(type(attention_mask))) |
There was a problem hiding this comment.
If a cache is used, it should be cleaned up like in this example so that it is not carried over from one graph/model to another
There was a problem hiding this comment.
And I am not sure if we need to handle np arrays? If the key is either one or two ir.Values, that should be fine ... ir.Values can be used as keys in dictionaries directly, and that should avoid the garbage-collection problem.
There was a problem hiding this comment.
I agree _get_mask_key seems unecessary. We can use the Value objects directly as keys.
| """ | ||
| return (id(attention_mask)) | ||
|
|
||
| def compute_mask(self, op, attention_mask : _onnx_types.INT64['batch', 'seq_len']): |
There was a problem hiding this comment.
The rewriter doesn't use onnxscript type (yet). Could you instead use a comment to document the shape of the attention_mask?
| _outputs=3, | ||
| ) | ||
|
|
||
| class LongRoPeGQACausalMask(pattern.RewriteRuleClassBase): |
There was a problem hiding this comment.
Could you use the docstring to document the pattern and its replacement? For the branches A, B, and C, I would consider giving them descriptive names.
There was a problem hiding this comment.
The following is my understanding: if this is correct, maybe they can be renamed appropriately: I believe that A constructs the kv_range, B constructs the query_range, and C constructs the batch_range. Each constructs the corresponding range as a 4D tensor with 1s in other position (for constructing a final attention-mask of shape [Batch, NumHeads, QueryRange, KVRange] via broadcast).
I am a bit puzzled that query_range and kv_range look to be the same here, it might be an artifact of this model-usage, I guess.
There was a problem hiding this comment.
I wasn't sure what the branches referred to but I'll make changes following what Rama is suggesting.
| total_seq_length, | ||
| ): | ||
| seq_len = op.Shape(input_ids, end=2, start=1, _outputs=["seq_len"]) | ||
| seq_len_0D = op.Squeeze(seq_len, _outputs=["seq_len_0D"]) |
There was a problem hiding this comment.
| seq_len_0D = op.Squeeze(seq_len, _outputs=["seq_len_0D"]) | |
| seq_len_0d = op.Squeeze(seq_len, _outputs=["seq_len_0d"]) |
prefer snake case for variable names when possible
| mask_A_B_combined = op.And(mask_A_B_greater_bitwise, mask_A_B_less, _outputs=["mask_A_B_combined"]) | ||
| mask_A_B_combined_bitwise = op.And(True, mask_A_B_combined, _outputs=["mask_A_B_combined_bitwise"]) | ||
|
|
||
| # EXPAND B/C TO AND |
There was a problem hiding this comment.
I would document the branches in plain English for readers
| class LongRoPeGQACausalMask(pattern.RewriteRuleClassBase): | ||
| def __init__(self): | ||
| super().__init__("LongRoPeGQACausalMask", remove_nodes=False) | ||
| self._mask_cache = {} |
There was a problem hiding this comment.
The copilot review is reasonable: the rewrite rule class should be stateless. Is there a different way to do this other than keeping a self._mask_cache?
There was a problem hiding this comment.
I think use of state for this purpose is okay? It has been used before for a similar purpose: which is to introduce values that are reused across multiple rewrites. (Now that we have CSE, there is an alternative path, which is to create duplicate copies and then eliminate them via CSE ... but I am not sure it is worth the bother.)
There was a problem hiding this comment.
BTW: my GQA fusion doesn't use state, and produces multiple copies (as described above).
There was a problem hiding this comment.
My concern is that the states will transfer from model to another if not careful, which is probably not a good idea. Maybe we can have a class managed state dict that will be cleared by the class?
|
Hi @tadani3 , sorry about the concurrent changes I had merged into GQA fusion recently, which might impact some of your changes ... but I am a bit confused by the diffs shown, which don't seem to reflect the changes I had made, so I am a bit confused. Briefly, the earlier version did the fusion into two steps, the first rule ignore the attention-mask, and focused on the rest of the computation, and the second rule, explicitly handles the attention-mask. The more recent version merged the two into one, for various reasons. I think it shouldn't impact your changes much, except that you will have to make the changes in rule 1 instead of rule 2. But, as I said, I am bit confused why I am not seeing those in the diffs |
| super().__init__("LongRoPeGQACausalMask", remove_nodes=False) | ||
| self._mask_cache = {} | ||
|
|
||
| def _get_mask_key(self, attention_mask): |
There was a problem hiding this comment.
In general, avoid creating class methods that do not require states from self, and instead make them module-level private functions for testability and clarity.
| import numpy as np | ||
| import onnx_ir as ir | ||
|
|
||
| import onnxscript.onnx_types as _onnx_types |
There was a problem hiding this comment.
| import onnxscript.onnx_types as _onnx_types |
_onnx_types is incompatible with the rewriter (yet)
|
@microsoft-github-policy-service agree company="Microsoft" |
| # Licensed under the MIT License. See License.txt in the project root for | ||
| # license information. | ||
| # -------------------------------------------------------------------------- | ||
| import onnx |
Check notice
Code scanning / CodeQL
Unused import Note
| # -------------------------------------------------------------------------- | ||
| import onnx | ||
| from onnxscript import ir | ||
| import onnx.helper |
Check notice
Code scanning / CodeQL
Unused import Note
| cache_length = self.rotemb_attrs["cache_length"] | ||
| position_ids = torch.arange(cache_length, dtype=torch.int64).unsqueeze(0) # Shape: (1, cache_length) | ||
|
|
||
| inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) # (1, dim//2, 1) |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error
| with torch.autocast(device_type=device_type, enabled=False): | ||
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) # (1, cache_length, dim//2) | ||
| emb = torch.cat((freqs, freqs), dim=-1) # (1, cache_length, dim) | ||
| cos_cache = emb.cos() * attention_factor # (1, cache_length, dim) |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error
| attention_factor = self.rotemb_attrs["multi_cache"]["short_mscale"] | ||
|
|
||
| inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device="cpu").float() / dim | ||
| inv_freq = 1.0 / (ext_factors * base**inv_freq_shape) |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error
| if "rescale_inv_freq" in self.rotemb_attrs: | ||
| inv_freq = self.make_inv_freq_rescaled(inv_freq) | ||
|
|
||
| return inv_freq, attention_factor |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error
|
I added a class called Phi4MiniReasoningPostProcessor which uses the ONNX IR to fulfill two tasks:
Reasoning and Motivation
|
…xscript into longrope_causal_mask
There was a problem hiding this comment.
If you can move this file to a separate PR we can merge the fusion rules. Thanks
Modification of the GQA causal mask fusion rule to handle the attention mask fusion for Longrope models such as Phi-4-mini-reasoning. The causal mask modification leads to a result that matches the optimizations made in ModelBuilder.