Source code for chop.passes.graph.transforms.utils.logicnets_fusion
# A pass to eliminate any activation functions within the preceding LogicNets layer is made. The activation function is already considered during the initialization of the LogicNets.
# NOTE: This implementation is a derivative of the following:
# https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/optimization.py
import logging
import tqdm
import torch.nn as nn
import torch.fx as fx
from tqdm.contrib.logging import tqdm_logging_redirect
from torch.nn.utils.fusion import fuse_conv_bn_eval
from torch.fx.experimental.optimization import (
matches_module_pattern,
replace_node_module,
)
from chop.nn.quantized.modules.linear import (
LinearLogicNets,
)
from chop.nn.quantized.modules.conv2d import (
Conv2DLogicNets,
)
# Housekeeping -------------------------------------------------------------------------
logger = logging.getLogger(__file__)
# logger.propagate = False # Avoid duplicate logs
[docs]
def logicnets_fusion_transform_pass(graph, pass_args, **_):
modules = dict(graph.model.named_modules())
logger.debug(f"Found {len(modules)} modules in the model:\n{list(modules.keys())}")
# store the logicnets nodes
logicnets_nodes = list(pass_args.keys())
# store the nodes which have been merged
nodes_to_erase = []
for node, config in pass_args.items():
layer_inputs = config["config"]["additional_layers_inputs"]
layer_outputs = config["config"]["additional_layers_outputs"]
nodes_to_erase = nodes_to_erase + layer_inputs + layer_outputs
# Modify the graph in place.
total = len(graph.fx_graph.nodes)
with tqdm_logging_redirect(total=total, loggers=[logger]) as pbar:
pbar.set_description(
f"Fusing these nodes into LogicNets layers {nodes_to_erase}"
)
# Iterate over the graph and erase the nodes which have been merged into a LogicNets layer
for node in graph.fx_graph.nodes:
if node.name in logicnets_nodes:
# set the LogicNets nodes to 'fused' so they will apply the merged modules internally in the forward pass
assert isinstance(
modules[node.target], LinearLogicNets
), f"{node} is not a LinearLogicNets module. Double check your model and the config file."
modules[node.target].set_fused(True)
# recalculate truth tables after
modules[node.target].calculate_truth_tables()
elif node.name in nodes_to_erase:
# erase the nodes which have been merged into a LogicNets layer
# There may be architectures where such a pattern exits. In these
# cases, fusion isn't trivial. For now, we'll just skip these cases.
if len(node.args[0].users) > 1:
logger.warning("Logicnets output used by other nodes. Skipped!")
continue
# this is the make sure the node to erase no longer exists in the graph
node.replace_all_uses_with(node.args[0])
graph.fx_graph.erase_node(node)
pbar.update(1) # Account for removed node :)
pbar.update(1)
# Update the model to reflect the changes in the graph
graph.model = fx.GraphModule(graph.model, graph.fx_graph)
pbar.set_description("Done")
modules = dict(graph.model.named_modules())
logger.debug(f"Found {len(modules)} modules in the model:\n{list(modules.keys())}")
return graph