From 7cf5a48191daf69e77d547510656bf47fc67596d Mon Sep 17 00:00:00 2001 From: Ekaterina Ignasheva Date: Thu, 5 Feb 2026 15:12:02 -0800 Subject: [PATCH] Unify fusion passes. (#17251) Summary: Group fusion passes together. Differential Revision: D92339098 --- backends/cadence/aot/fuse_ops.py | 39 ++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/backends/cadence/aot/fuse_ops.py b/backends/cadence/aot/fuse_ops.py index d8265e6ab86..78d00637225 100644 --- a/backends/cadence/aot/fuse_ops.py +++ b/backends/cadence/aot/fuse_ops.py @@ -203,8 +203,9 @@ class FuseBatchNormWithConv(ExportPass): in the graph. """ - def fuse_batch_norm_with_conv(self, graph_module: torch.fx.GraphModule) -> None: + def fuse_batch_norm_with_conv(self, graph_module: torch.fx.GraphModule) -> bool: graph = graph_module.graph + modified = False for conv in graph.nodes: # We want to discover a chain of conv1d -> batch_norm. # Only proceed if the current node is a conv1d node, and has a single @@ -252,6 +253,8 @@ def fuse_batch_norm_with_conv(self, graph_module: torch.fx.GraphModule) -> None: assert isinstance(running_var, torch.Tensor) eps = bn.args[-1] + modified = True + # Compute the updated weight and bias after fusing conv op # with batchnorm op. fused_weight, fused_bias = fuse_conv_bn_weights( @@ -287,16 +290,18 @@ def fuse_batch_norm_with_conv(self, graph_module: torch.fx.GraphModule) -> None: user.replace_all_uses_with(conv) self.counter += 1 - graph_module.recompile() + if modified: + graph_module.recompile() + return modified def __init__(self): super().__init__() self.counter = 0 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - self.fuse_batch_norm_with_conv(graph_module) + modified = self.fuse_batch_norm_with_conv(graph_module) result = super().call(graph_module) - return result + return PassResult(result.graph_module, result.modified & modified) @register_cadence_pass(CadencePassAttribute(opt_level=1)) @@ -310,8 +315,9 @@ class FuseQuantizedBatchNormWithConv(ExportPass): def fuse_quantized_batch_norm_with_conv( self, graph_module: torch.fx.GraphModule - ) -> None: + ) -> bool: graph = graph_module.graph + modified = False for conv in graph.nodes: # We want to discover a chain of quantized::conv1d -> # quantized::batch_norm. Only proceed if the current node is a @@ -386,6 +392,8 @@ def fuse_quantized_batch_norm_with_conv( assert isinstance(running_mean_tensor, torch.Tensor) assert isinstance(running_var_tensor, torch.Tensor) + modified = True + # Get the fused weights and bias fused_weight, fused_bias = fuse_conv_bn_weights( conv_weight_tensor, @@ -439,22 +447,25 @@ def fuse_quantized_batch_norm_with_conv( graph.erase_node(bn) self.counter += 1 - # Note: there is a quantized.conv2d.new operator in the resulting graph - # that takes a torch.classes.quantized.Conv2dPackedParamsBase as one of the input - # this prevents us to directly call graph_module.recompile(). - # pyre-fixme[16]: `GraphModule` has no attribute `_code`. - # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute - # `python_code`. - graph_module._code = graph_module._graph.python_code(root_module="self").src + if modified: + # Note: there is a quantized.conv2d.new operator in the resulting graph + # that takes a torch.classes.quantized.Conv2dPackedParamsBase as one of the input + # this prevents us to directly call graph_module.recompile(). + # pyre-fixme[16]: `GraphModule` has no attribute `_code`. + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `python_code`. + graph_module._code = graph_module._graph.python_code(root_module="self").src + + return modified def __init__(self): super().__init__() self.counter = 0 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - self.fuse_quantized_batch_norm_with_conv(graph_module) + modified = self.fuse_quantized_batch_norm_with_conv(graph_module) result = super().call(graph_module) - return result + return PassResult(result.graph_module, result.modified & modified) @register_cadence_pass(CadencePassAttribute(opt_level=1))