Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,11 +689,11 @@ def register_fake(
)

lib.define(
"quantized_w8a32_gru(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_i_scale, Tensor bias_hidden, float b_h_scale) -> Tensor"
"quantized_w8a32_gru(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_scale, Tensor bias_hidden) -> Tensor"
)

lib.define(
"quantized_w8a32_gru.out(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_i_scale, Tensor bias_hidden, float b_h_scale, *, Tensor(a!) out) -> Tensor(a!)"
"quantized_w8a32_gru.out(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_scale, Tensor bias_hidden, *, Tensor(a!) out) -> Tensor(a!)"
)

lib.define(
Expand Down Expand Up @@ -3060,11 +3060,20 @@ def quantized_w8a32_gru_meta(
weights_hidden: torch.Tensor,
w_h_scale: float,
bias_inputs: torch.Tensor,
b_i_scale: float,
b_scale: float,
bias_hidden: torch.Tensor,
b_h_scale: float,
) -> torch.Tensor:
return hidden.new_empty((2, *hidden.shape), dtype=torch.float32)
seq_len = inputs.shape[1]
assert seq_len == 1
Comment on lines +3066 to +3067
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quantized_w8a32_gru_meta indexes inputs.shape[1] without first validating inputs rank, which can raise an IndexError (and produce a confusing failure) if the op is ever invoked with a 1D/2D input. Add an explicit shape/rank check (e.g., inputs.dim() == 3 / hidden.dim() == 3) before reading shape[1], and then assert seq_len == 1 with a clear message.

Suggested change
seq_len = inputs.shape[1]
assert seq_len == 1
assert inputs.dim() == 3, (
"quantized_w8a32_gru expects inputs to have shape "
"[batch, seq_len, input_size]"
)
assert hidden.dim() == 3, (
"quantized_w8a32_gru expects hidden to have shape "
"[batch, seq_len, hidden_size]"
)
seq_len = inputs.shape[1]
assert (
seq_len == 1
), "quantized_w8a32_gru fake kernel only supports seq_len == 1"

Copilot uses AI. Check for mistakes.
# inputs comes in shape [batch, seq_len, input_size]
# hidden comes in shape [batch, seq_len, hidden_size]
# weights_inputs comes in shape [3 * hidden_size, input_size]
# weights_hidden comes in shape [3 * hidden_size, hidden_size]
# output comes in empty with shape [2, batch, seq_len, hidden_size]
# The first dimension stacks the output and the new hidden state
return hidden.new_empty(
(2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]), dtype=torch.float32
Comment on lines +3066 to +3075
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The meta kernel asserts seq_len = inputs.shape[1] and seq_len == 1, but the ref implementation and unit tests treat inputs as a 2D tensor shaped like [batch, input_dim] (e.g., 1x2, 1x3). With the current meta logic, any input_dim != 1 will trip this assert during fake tensor shape propagation / export. Please remove/relax this assertion and update the shape comments to match the actual operator contract (and keep the output shape computation consistent with ref_implementations.quantized_w8a32_gru).

Suggested change
seq_len = inputs.shape[1]
assert seq_len == 1
# inputs comes in shape [batch, seq_len, input_size]
# hidden comes in shape [batch, seq_len, hidden_size]
# weights_inputs comes in shape [3 * hidden_size, input_size]
# weights_hidden comes in shape [3 * hidden_size, hidden_size]
# output comes in empty with shape [2, batch, seq_len, hidden_size]
# The first dimension stacks the output and the new hidden state
return hidden.new_empty(
(2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]), dtype=torch.float32
# inputs comes in shape [batch, input_size]
# hidden comes in shape [batch, hidden_size]
# weights_inputs comes in shape [3 * hidden_size, input_size]
# weights_hidden comes in shape [3 * hidden_size, hidden_size]
# output comes in empty with shape [2, batch, hidden_size]
# The first dimension stacks the output and the new hidden state
return hidden.new_empty(
(2, inputs.shape[0], hidden.shape[-1]), dtype=torch.float32

Copilot uses AI. Check for mistakes.
)


@register_fake("cadence::slice_scatter_")
Expand Down
45 changes: 26 additions & 19 deletions backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,26 +438,36 @@ def get_args_and_kwargs_mixed_w8a32_conv(
torch.ops.aten.permute.default,
(other_inputs[0], [0, 2, 1]), # NCL -> NLC
)
assert "val" in other_inputs[0].meta, "Missing val metadata on input node"
original_val = other_inputs[0].meta["val"]
assert original_val.fake_mode is not None, "fake_mode is None on input node"
with original_val.fake_mode:
transposed_inputs.meta["val"] = torch.ops.aten.permute.default(
original_val, [0, 2, 1]
)
# Propagate val metadata for transposed_inputs
if "val" in other_inputs[0].meta:
original_val = other_inputs[0].meta["val"]
fake_mode = original_val.fake_mode
if fake_mode is not None:
with fake_mode:
transposed_val = torch.ops.aten.permute.default(original_val, [0, 2, 1])
transposed_inputs.meta["val"] = transposed_val
else:
transposed_inputs.meta["val"] = torch.ops.aten.permute.default(
original_val, [0, 2, 1]
)
copy_node_metadata(transposed_inputs, other_inputs[0])

transposed_weights = graph_module.graph.call_function(
torch.ops.aten.permute.default,
(weights_inputs[0], [2, 0, 1]), # NCL -> LNC
)
assert "val" in weights_inputs[0].meta, "Missing val metadata on weight node"
original_val = weights_inputs[0].meta["val"]
assert original_val.fake_mode is not None, "fake_mode is None on weight node"
with original_val.fake_mode:
transposed_weights.meta["val"] = torch.ops.aten.permute.default(
original_val, [2, 0, 1]
)
# Propagate val metadata for transposed_weights
if "val" in weights_inputs[0].meta:
original_val = weights_inputs[0].meta["val"]
fake_mode = original_val.fake_mode
if fake_mode is not None:
with fake_mode:
transposed_val = torch.ops.aten.permute.default(original_val, [2, 0, 1])
transposed_weights.meta["val"] = transposed_val
else:
transposed_weights.meta["val"] = torch.ops.aten.permute.default(
original_val, [2, 0, 1]
)
copy_node_metadata(transposed_weights, weights_inputs[0])

args = (
Expand Down Expand Up @@ -511,12 +521,10 @@ def get_args_and_kwargs_mixed_w8a32_gru(
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
# Stride, padding, dilation, groups not supported yet

assert len(dequants_weights) == 2
assert len(dequants_biases) == 2
w_i_scale = dequants_weights[0].args[1]
w_h_scale = dequants_weights[1].args[1]
b_i_scale = dequants_biases[0].args[1]
b_h_scale = dequants_biases[1].args[1]
b_scale = dequants_biases[0].args[1]
Comment on lines 524 to +527
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_args_and_kwargs_mixed_w8a32_gru still indexes dequants_weights[0] and [1], but the defensive assert len(dequants_weights) == 2 was removed. If the partition ever produces an unexpected number of dequant nodes, this will fail with an IndexError and be harder to diagnose. Please restore the assert (or otherwise validate length) before indexing.

Copilot uses AI. Check for mistakes.

args = (
other_inputs[0],
Expand All @@ -526,9 +534,8 @@ def get_args_and_kwargs_mixed_w8a32_gru(
weights_inputs[1],
w_h_scale,
bias_inputs[0],
b_i_scale,
b_scale,
bias_inputs[1],
b_h_scale,
)
kwargs = {}

Expand Down
35 changes: 31 additions & 4 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ def get_anchors(
)

cnn_weights = conv_layer.args[1]
if hasattr(cnn_weights.meta, "tensor_meta"):
if "tensor_meta" in cnn_weights.meta:
cnn_weights_shape = cnn_weights.meta["tensor_meta"].shape
# Bail if the channels are not multiple of 4 (SIMD)
if cnn_weights_shape[0] % 4 != 0:
Expand All @@ -744,6 +744,18 @@ def get_anchors(
conv_layer,
)

inputs = conv_layer.args[0]
if "tensor_meta" in inputs.meta:
inputs_shape = inputs.meta["tensor_meta"].shape
# Bail if length != kernel size - Not yet supported
if inputs_shape[-1] != cnn_weights_shape[2]:
Comment on lines +750 to +751
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In MixedW8A32ConvPattern, this new guard compares the input tensor's last dimension (which is conv1d length for NCL inputs) against the kernel size. That will reject common valid conv1d shapes (e.g., length=5, kernel=3) even though the w8a32 conv ref/meta implementations support them. If the intent is to validate shapes, this should instead check that the input channel dimension matches the weight's in_channels and that length >= kernel_size (not length == kernel_size).

Suggested change
# Bail if length != kernel size - Not yet supported
if inputs_shape[-1] != cnn_weights_shape[2]:
# Bail if the input channels do not match the weight's in_channels
# or if the input length is smaller than the kernel size.
if (
len(inputs_shape) < 3
or inputs_shape[1] != cnn_weights_shape[1]
or inputs_shape[-1] < cnn_weights_shape[2]
):

Copilot uses AI. Check for mistakes.
return (
PartitionAnchors(
empty=True,
),
conv_layer,
)

return (
PartitionAnchors(
inputs=[],
Expand Down Expand Up @@ -777,14 +789,16 @@ def get_anchors(
)

# Bail if input or states are not multiple of 4 (SIMD)
if gru_layer.args[0].meta["tensor_meta"].shape[-1] % 4 != 0:
tensor_meta_0 = gru_layer.args[0].meta.get("tensor_meta", None)
if tensor_meta_0 is None or tensor_meta_0.shape[-1] % 4 != 0:
return (
PartitionAnchors(
empty=True,
),
gru_layer,
)
if gru_layer.args[1].meta["tensor_meta"].shape[-1] % 4 != 0:
tensor_meta_1 = gru_layer.args[1].meta.get("tensor_meta", None)
if tensor_meta_1 is None or tensor_meta_1.shape[-1] % 4 != 0:
return (
PartitionAnchors(
empty=True,
Expand All @@ -799,13 +813,26 @@ def __init__(self, args, meta):

wrapper = Wrapper(tuple(gru_layer.args[2]), gru_layer.meta)

# Using SharedQuantizationSpec so that bias_hh has the same observer as bias_ih
# Both biases get the same quantization scale to match the cpp operator
bias_ih_node = wrapper.args[2]
bias_ih_edge = (bias_ih_node, gru_layer)
shared_bias_qspec = SharedQuantizationSpec(edge_or_node=bias_ih_edge)

Comment on lines +816 to +821
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code assumes the GRU params tuple contains both biases and immediately indexes wrapper.args[2]/[3]. For aten.gru.input, has_biases can be false (e.g. nn.GRU(..., bias=False)), in which case the params list may not contain bias entries and this will raise an IndexError during pattern matching. Consider explicitly checking the has_biases argument (and/or len(wrapper.args) >= 4) and returning empty=True anchors when biases are not present, since the Cadence quantized_w8a32_gru replacement requires bias tensors.

Copilot uses AI. Check for mistakes.
return (
PartitionAnchors(
inputs=[],
# pyre-fixme[6]: Expected `List[Tuple[Node, int]]` but got `List[Tuple[Wrapper, int]]`.
weights=[(wrapper, 0), (wrapper, 1)],
# pyre-fixme[6]: Expected `List[Union[Tuple[Node, int], Tuple[Node, int, DerivedQuantizationSpec]]]` but got `List[Tuple[Wrapper, int]]`.
biases=[(wrapper, 2), (wrapper, 3)],
biases=[
(wrapper, 2), # bias_ih gets normal qspec
(
wrapper,
3,
shared_bias_qspec,
), # bias_hh shares observer with bias_ih
],
output=[],
others=[(gru_layer, 0), (gru_layer, 1)],
),
Expand Down
19 changes: 11 additions & 8 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,9 +1257,8 @@ def quantized_w8a32_gru(
weights_hidden: torch.Tensor,
w_h_scale: float,
bias_inputs: torch.Tensor,
b_i_scale: float,
b_scale: float,
bias_hidden: torch.Tensor,
b_h_scale: float,
) -> torch.Tensor:
assert weights_inputs.dtype == torch.int8
assert weights_hidden.dtype == torch.int8
Expand Down Expand Up @@ -1288,10 +1287,8 @@ def quantized_w8a32_gru(
dequant_weights_inputs = weights_inputs.float() * w_i_scale
dequant_weights_hidden = weights_hidden.float() * w_h_scale

# C++ implementation averages the two bias scales
avg_bias_scale = (b_i_scale + b_h_scale) / 2
dequant_bias_inputs = bias_inputs.float() * avg_bias_scale
dequant_bias_hidden = bias_hidden.float() * avg_bias_scale
dequant_bias_inputs = bias_inputs.float() * b_scale
dequant_bias_hidden = bias_hidden.float() * b_scale

gi = F.linear(inputs, dequant_weights_inputs, dequant_bias_inputs)
gh = F.linear(hidden, dequant_weights_hidden, dequant_bias_hidden)
Expand All @@ -1310,8 +1307,14 @@ def quantized_w8a32_gru(

assert new_hidden.shape == original_hidden_shape

new_hidden = new_hidden.view(original_hidden_shape)
return torch.stack([new_hidden, new_hidden], dim=0)
batch_size = inputs.shape[0]
input_dim = inputs.shape[1]
hidden_dim = hidden.shape[-1]

new_hidden_expanded = new_hidden.unsqueeze(1).expand(
batch_size, input_dim, hidden_dim
)
return torch.stack([new_hidden_expanded, new_hidden_expanded], dim=0)


Comment on lines +1310 to 1319
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new_hidden is already shaped like the computed GRU output (e.g., [batch, seq_len, hidden] when inputs is 3D), but the new unsqueeze(1).expand(batch_size, input_dim, hidden_dim) path will throw for 3D inputs/hidden because it adds an extra dimension and then calls expand with too few sizes. This currently only works when new_hidden is 2D.

Consider normalizing new_hidden to [batch, seq_len, hidden] (e.g., only unsqueeze when it’s missing the seq_len dim) and then stacking directly, rather than expanding based on inputs.shape[1].

Suggested change
batch_size = inputs.shape[0]
input_dim = inputs.shape[1]
hidden_dim = hidden.shape[-1]
new_hidden_expanded = new_hidden.unsqueeze(1).expand(batch_size, input_dim, hidden_dim)
return torch.stack([new_hidden_expanded, new_hidden_expanded], dim=0)
if new_hidden.dim() == 1:
new_hidden_normalized = new_hidden.unsqueeze(0).unsqueeze(0)
elif new_hidden.dim() == 2:
new_hidden_normalized = new_hidden.unsqueeze(1)
elif new_hidden.dim() == 3:
new_hidden_normalized = new_hidden
else:
raise ValueError(
f"Hidden state must be 1D, 2D, or 3D, got shape {tuple(new_hidden.shape)}"
)
return torch.stack([new_hidden_normalized, new_hidden_normalized], dim=0)

Copilot uses AI. Check for mistakes.
@impl_tracked(m, "quantized_conv2d_nhwc.per_tensor")
Expand Down
30 changes: 11 additions & 19 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2901,9 +2901,8 @@ def test_softmax_f32_f32(self) -> None:
torch.ones((12, 4), dtype=torch.int8), # weights_hidden: 12x4 (3*4 x 4)
0.1, # w_h_scale
torch.zeros(12, dtype=torch.int8), # bias_inputs: 12
0.1, # b_i_scale
0.1, # b_scale
torch.zeros(12, dtype=torch.int8), # bias_hidden: 12
0.1, # b_h_scale
),
(
"invalid_batch_size_2",
Expand All @@ -2918,9 +2917,8 @@ def test_softmax_f32_f32(self) -> None:
torch.ones((12, 4), dtype=torch.int8), # weights_hidden: 12x4
0.1, # w_h_scale
torch.zeros(12, dtype=torch.int8), # bias_inputs: 12
0.1, # b_i_scale
0.1, # b_scale
torch.zeros(12, dtype=torch.int8), # bias_hidden: 12
0.1, # b_h_scale
),
(
"non_zero_biases",
Expand All @@ -2933,11 +2931,10 @@ def test_softmax_f32_f32(self) -> None:
torch.tensor(
[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], dtype=torch.int8
), # bias_inputs: 12
0.1, # b_i_scale
0.1, # b_scale
torch.tensor(
[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], dtype=torch.int8
), # bias_hidden: 12
0.1, # b_h_scale
),
(
"negative_weights",
Expand All @@ -2954,9 +2951,8 @@ def test_softmax_f32_f32(self) -> None:
), # weights_hidden: 12x4 (alternating pattern)
0.1, # w_h_scale
torch.zeros(12, dtype=torch.int8), # bias_inputs: 12
0.1, # b_i_scale
0.1, # b_scale
torch.zeros(12, dtype=torch.int8), # bias_hidden: 12
0.1, # b_h_scale
),
(
"hidden_dim_8",
Expand All @@ -2969,9 +2965,8 @@ def test_softmax_f32_f32(self) -> None:
torch.ones((24, 8), dtype=torch.int8), # weights_hidden: 24x8 (3*8 x 8)
0.1, # w_h_scale
torch.zeros(24, dtype=torch.int8), # bias_inputs: 24
0.1, # b_i_scale
0.1, # b_scale
torch.zeros(24, dtype=torch.int8), # bias_hidden: 24
0.1, # b_h_scale
),
]
)
Expand All @@ -2985,9 +2980,8 @@ def test_quantized_w8a32_gru(
weights_hidden: torch.Tensor,
w_h_scale: float,
bias_inputs: torch.Tensor,
b_i_scale: float,
b_scale: float,
bias_hidden: torch.Tensor,
b_h_scale: float,
) -> None:

if name == "invalid_batch_size_2":
Expand All @@ -3000,9 +2994,8 @@ def test_quantized_w8a32_gru(
weights_hidden,
w_h_scale,
bias_inputs,
b_i_scale,
b_scale,
bias_hidden,
b_h_scale,
)
self.assertIn(
"Leading dimension 0 of hidden state must be 1", str(context.exception)
Expand All @@ -3017,9 +3010,8 @@ def test_quantized_w8a32_gru(
weights_hidden,
w_h_scale,
bias_inputs,
b_i_scale,
b_scale,
bias_hidden,
b_h_scale,
)

# Verify output properties
Expand All @@ -3028,10 +3020,11 @@ def test_quantized_w8a32_gru(
torch.float32,
f"Output dtype should be float32 in {name}",
)
expected_shape = (2, inputs.shape[0], inputs.shape[1], hidden.shape[-1])
self.assertEqual(
output.shape,
(2, *hidden.shape),
f"Output shape should match {(2, *hidden.shape)} in {name}",
expected_shape,
f"Output shape should match {expected_shape} in {name}",
Comment on lines +3023 to +3027
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The updated expected shape uses inputs.shape[1] as the sequence-length dimension, but the test vectors still pass 2D inputs/hidden tensors. This means the test isn’t exercising the (documented in quantized_w8a32_gru_meta) [batch, seq_len, input_size] / [batch, seq_len, hidden_size] path and can mask shape bugs for the intended 3D case. Consider updating these fixtures (and/or adding an additional case) to use inputs and hidden with an explicit seq_len dimension (typically seq_len==1) so the test matches the operator contract.

Copilot uses AI. Check for mistakes.
)
Comment on lines +3023 to 3028
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test only validates the new 4D output shape against inputs.shape[:2], but it still uses 2D inputs / 2D hidden fixtures. Since this op is produced from aten.gru fusion, it’s important to add a case with the expected real input ranks (e.g. 3D inputs and 3D hidden/state) so the output-shape logic is exercised for those ranks as well (this would also catch shape errors like an unsqueeze/expand mismatch).

Copilot uses AI. Check for mistakes.
assert isinstance(output, torch.Tensor)

Expand Down Expand Up @@ -3064,7 +3057,6 @@ def test_quantized_w8a32_gru_invalid_hidden_dim(self) -> None:
bias_inputs,
0.1,
bias_hidden,
0.1,
)

self.assertIn(
Expand Down
Loading