import torch
from chop.passes.graph.analysis.utils import fetch_attr, load_arg
def graph_iterator(graph, dummy_in, add_meta=False):
hooks = []
names, w_infos, a_infos = [], [], []
# register forward hook
def get_sparsify(module, args):
if len(args) > 1:
raise ValueError(
f"{module.__class__.__name__} takes more than 1 argument at inference, the current sparsiy_input pre forward hook only allows one!"
)
x = args[0]
a_infos.append((x.numel(), (x != 0).sum() / x.numel()))
# add hook
for node in graph.fx_graph.nodes:
# pruning only deals with modules at the moment
if node.op == "call_module":
name = node.target
if isinstance(graph.modules[name], (torch.nn.Linear, torch.nn.Conv2d)):
names.append(name)
graph.modules[name].register_forward_pre_hook(get_sparsify)
# run it
env = {}
model, fx_graph, modules = graph.model, graph.fx_graph, graph.modules
for node in graph.fx_graph.nodes:
args, kwargs = None, None
if node.op == "placeholder":
result = dummy_in[node.name]
elif node.op == "get_attr":
result = fetch_attr(model, node.target)
elif node.op == "call_function":
args = load_arg(node.args, env)
kwargs = load_arg(node.kwargs, env)
result = node.target(*args, **kwargs)
elif node.op == "call_method":
self_obj, *args = load_arg(node.args, env)
kwargs = load_arg(node.kwargs, env)
result = getattr(self_obj, node.target)(*args, **kwargs)
elif node.op == "call_module":
args = load_arg(node.args, env)
kwargs = load_arg(node.kwargs, env)
result = modules[node.target](*args, **kwargs)
name = node.target
if isinstance(graph.modules[name], (torch.nn.Linear, torch.nn.Conv2d)):
w = modules[node.target].weight
w_infos.append((w.numel(), (w != 0).sum() / w.numel()))
env[node.name] = result
w_sparsity_info = {f"{k}_weight": v for k, v in zip(names, w_infos)}
a_sparsity_info = {f"{k}_activation": v for k, v in zip(names, a_infos)}
if add_meta:
for node in graph.fx_graph.nodes:
if node.op == "call_module":
name = node.target
meta = node.meta["mase"]
if isinstance(modules[node.target], (torch.nn.Conv2d, torch.nn.Linear)):
meta.parameters["software"]["args"]["weight"][
"natural_sparsity"
] = w_sparsity_info[f"{name}_weight"]
meta.parameters["software"]["args"]["data_in_0"][
"natural_sparsity"
] = a_sparsity_info[f"{name}_activation"]
avg_w_sparsity = sum([x[0] * x[1] for x in w_sparsity_info.values()]) / sum(
[x[0] for x in w_sparsity_info.values()]
)
avg_a_sparsity = sum([x[0] * x[1] for x in a_sparsity_info.values()]) / sum(
[x[0] for x in a_sparsity_info.values()]
)
w_sparsity_info["avg_weight"] = avg_w_sparsity
w_sparsity_info["avg_activation"] = avg_a_sparsity
return graph, {**w_sparsity_info, **a_sparsity_info}