Source code for chop.passes.graph.transforms.pruning.prune_detach_hook
import torch
def prune_graph_iterator(graph, config: dict):
# prune in second loop by applying hooks to relevant modules
for node in graph.fx_graph.nodes:
# pruning only deals with modules at the moment
if node.op == "call_module":
name = node.target
# remove weights
if hasattr(graph.modules[node.target], "weigh"):
torch.nn.utils.parametrize.remove_parametrizations(
graph.modules[name], "weight"
)
if hasattr(graph.modules[node.target], "_forward_hooks"):
for k, hook in graph.modules[node.target]._forward_pre_hooks.items():
if "sparsify_input" in hook.__name__:
del graph.modules[node.target]._forward_pre_hooks[k]
return graph
def hook_inspector(m):
info = []
module_name = type(m).__name__
if hasattr(m, "_forward_hooks"):
for k, v in m._forward_hooks.items():
info.append((module_name, k, v.__name__))
if hasattr(m, "_forward_pre_hooks"):
for k, v in m._forward_pre_hooks.items():
info.append((module_name, k, v.__name__))
if hasattr(m, "_backward_hooks"):
for k, v in m._backward_hooks.items():
info.append((module_name, k, v.__name__))
return info
[docs]
def prune_detach_hook_transform_pass(graph, pass_args: dict = {}):
"""
Apply pruning transformation to the given graph.
This is achieved by adding a register_parametrization hook to weights
and a register_pre_forward hook to activations
:param graph: The input graph to be pruned.
:type graph: MaseGraph
pass_args can be None or an empty dictionary.
:param pass_args: Optional arguments for the pruning transformation.
:type pass_args: dict
:return: The pruned graph and an empty dictionary.
:rtype: tuple
"""
info = hook_inspector(graph.modules)
graph = prune_graph_iterator(graph, pass_args)
info = hook_inspector(graph.modules)
return graph, {}