Source code for chop.passes.graph.transforms.utils.conv_bn_fusion
# A pass to fuse batch normalisation layers with the preceeding convolutional layers
# NOTE: This implementation is a derivative of the following:
# https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/optimization.py
import logging
import torch.nn as nn
import torch.fx as fx
from tqdm.contrib.logging import tqdm_logging_redirect
from torch.nn.utils.fusion import fuse_conv_bn_eval
from torch.fx.experimental.optimization import (
matches_module_pattern,
replace_node_module,
)
# Housekeeping -------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# logger.propagate = False # Avoid duplicate logs
[docs]
def conv_bn_fusion_transform_pass(graph, pass_args={}):
"""Perform Conv-BN fusion on the given graph.
:param graph: a MaseGraph
:type graph: MaseGraph
:param pass_args: this pass can take a string argument named "file_name", defaults to None
:type pass_args: dict, optional
:return: return a tuple of a MaseGraph and an empty dict (no additional info to return)
:rtype: tuple(MaseGraph, dict)
"""
PATTERNS = [
(nn.Conv1d, nn.BatchNorm1d),
(nn.Conv2d, nn.BatchNorm2d),
(nn.Conv3d, nn.BatchNorm3d),
]
modules = dict(graph.model.named_modules())
logger.debug(f"Found {len(modules)} modules in the model:\n{list(modules.keys())}")
# Modify the graph in place.
total = len(graph.fx_graph.nodes) * len(PATTERNS)
with tqdm_logging_redirect(total=total, loggers=[logger]) as pbar:
for pattern in PATTERNS:
fst, snd = pattern[0].__name__, pattern[1].__name__
pbar.set_description(f"Looking for pattern {fst} -> {snd}")
# Iterate over the graph and fuse the nodes that match the patterns
for node in graph.fx_graph.nodes:
if matches_module_pattern(pattern, node, modules):
# There may be architectures where such a pattern exits. In these
# cases, fusion isn't trivial. For now, we'll just skip these cases.
if len(node.args[0].users) > 1:
logger.warning("Conv output used by other nodes. Skipped!")
continue
conv = modules[node.args[0].target]
bn = modules[node.target]
if not bn.track_running_stats:
# When track_running_stats is False, the batch norm module's
# running mean and variance buffers are set to None.
logger.warning("Batchnorm not tracking stats. Skipped!")
continue
# Set both modules to eval mode
conv_prev_mode, bn_prev_mode = conv.training, bn.training
conv.train(False)
bn.train(False)
# Fuse!
fused_conv = fuse_conv_bn_eval(conv, bn)
# Restore the previous modes
conv.train(conv_prev_mode)
bn.train(bn_prev_mode)
# NOTE: We may need to update metadata here. Currently unclear.
# Replace conv with the fused module and erase the batchnorm node
replace_node_module(node.args[0], modules, fused_conv)
node.replace_all_uses_with(node.args[0])
graph.fx_graph.erase_node(node)
pbar.update(1) # Account for removed node :)
pbar.update(1)
# Update the model to reflect the changes in the graph
graph.model = fx.GraphModule(graph.model, graph.fx_graph)
pbar.set_description("Done")
modules = dict(graph.model.named_modules())
logger.debug(f"Found {len(modules)} modules in the model:\n{list(modules.keys())}")
return graph, {}