Source code for chop.passes.graph.analysis.pruning.calculate_sparsity
import torch
from chop.passes.graph.analysis.utils import fetch_attr, load_arg
def graph_iterator_for_metadata(graph, dummy_in=None, add_value=True):
"""
largely adapted from https://pytorch.org/docs/stable/fx.html
"""
model, fx_graph, modules = graph.model, graph.fx_graph, graph.modules
sparsity_info = {}
env = {}
for node in graph.fx_graph.nodes:
args, kwargs = None, None
if node.op == "placeholder":
result = dummy_in[node.name]
elif node.op == "get_attr":
result = fetch_attr(model, node.target)
elif node.op == "call_function":
args = load_arg(node.args, env)
kwargs = load_arg(node.kwargs, env)
result = node.target(*args, **kwargs)
elif node.op == "call_method":
self_obj, *args = load_arg(node.args, env)
kwargs = load_arg(node.kwargs, env)
result = getattr(self_obj, node.target)(*args, **kwargs)
elif node.op == "call_module":
args = load_arg(node.args, env)
kwargs = load_arg(node.kwargs, env)
result = modules[node.target](*args, **kwargs)
meta = node.meta["mase"]
if isinstance(modules[node.target], (torch.nn.Conv2d, torch.nn.Linear)):
# parameterizations is a list, we assume we only have one single entry
mask = modules[node.target].parametrizations.weight[0].mask
weight_sparsity = 1 - float(mask.sum() / mask.numel())
meta.parameters["software"]["args"]["weight"][
"sparsity"
] = weight_sparsity
act_mask = modules[node.target].activation_mask
act_sparsity = 1 - float(act_mask.sum() / act_mask.numel())
meta.parameters["software"]["args"]["data_in_0"][
"sparsity"
] = act_sparsity
if add_value:
meta.parameters["software"]["args"]["weight"]["mask_value"] = mask
meta.parameters["software"]["args"]["weight_mask"][
"value"
] = act_mask
sparsity_info[node.target] = {
"weight_sparsity": weight_sparsity,
"activation_sparsity": act_sparsity,
}
env[node.name] = result
return graph, sparsity_info