Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions backends/cadence/aot/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,9 @@ class FuseBatchNormWithConv(ExportPass):
in the graph.
"""

def fuse_batch_norm_with_conv(self, graph_module: torch.fx.GraphModule) -> None:
def fuse_batch_norm_with_conv(self, graph_module: torch.fx.GraphModule) -> bool:
graph = graph_module.graph
modified = False
for conv in graph.nodes:
# We want to discover a chain of conv1d -> batch_norm.
# Only proceed if the current node is a conv1d node, and has a single
Expand Down Expand Up @@ -252,6 +253,8 @@ def fuse_batch_norm_with_conv(self, graph_module: torch.fx.GraphModule) -> None:
assert isinstance(running_var, torch.Tensor)
eps = bn.args[-1]

modified = True

# Compute the updated weight and bias after fusing conv op
# with batchnorm op.
fused_weight, fused_bias = fuse_conv_bn_weights(
Expand Down Expand Up @@ -287,16 +290,18 @@ def fuse_batch_norm_with_conv(self, graph_module: torch.fx.GraphModule) -> None:
user.replace_all_uses_with(conv)
self.counter += 1

graph_module.recompile()
if modified:
graph_module.recompile()
return modified

def __init__(self):
super().__init__()
self.counter = 0

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
self.fuse_batch_norm_with_conv(graph_module)
modified = self.fuse_batch_norm_with_conv(graph_module)
result = super().call(graph_module)
return result
return PassResult(result.graph_module, result.modified & modified)


@register_cadence_pass(CadencePassAttribute(opt_level=1))
Expand All @@ -310,8 +315,9 @@ class FuseQuantizedBatchNormWithConv(ExportPass):

def fuse_quantized_batch_norm_with_conv(
self, graph_module: torch.fx.GraphModule
) -> None:
) -> bool:
graph = graph_module.graph
modified = False
for conv in graph.nodes:
# We want to discover a chain of quantized::conv1d ->
# quantized::batch_norm. Only proceed if the current node is a
Expand Down Expand Up @@ -386,6 +392,8 @@ def fuse_quantized_batch_norm_with_conv(
assert isinstance(running_mean_tensor, torch.Tensor)
assert isinstance(running_var_tensor, torch.Tensor)

modified = True

# Get the fused weights and bias
fused_weight, fused_bias = fuse_conv_bn_weights(
conv_weight_tensor,
Expand Down Expand Up @@ -439,22 +447,25 @@ def fuse_quantized_batch_norm_with_conv(
graph.erase_node(bn)
self.counter += 1

# Note: there is a quantized.conv2d.new operator in the resulting graph
# that takes a torch.classes.quantized.Conv2dPackedParamsBase as one of the input
# this prevents us to directly call graph_module.recompile().
# pyre-fixme[16]: `GraphModule` has no attribute `_code`.
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `python_code`.
graph_module._code = graph_module._graph.python_code(root_module="self").src
if modified:
# Note: there is a quantized.conv2d.new operator in the resulting graph
# that takes a torch.classes.quantized.Conv2dPackedParamsBase as one of the input
# this prevents us to directly call graph_module.recompile().
# pyre-fixme[16]: `GraphModule` has no attribute `_code`.
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `python_code`.
graph_module._code = graph_module._graph.python_code(root_module="self").src

return modified

def __init__(self):
super().__init__()
self.counter = 0

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
self.fuse_quantized_batch_norm_with_conv(graph_module)
modified = self.fuse_quantized_batch_norm_with_conv(graph_module)
result = super().call(graph_module)
return result
return PassResult(result.graph_module, result.modified & modified)


@register_cadence_pass(CadencePassAttribute(opt_level=1))
Expand Down
Loading