Source code for chop.passes.graph.analysis.pruning.calculate_sparsity
import torch
from chop.passes.graph.analysis.utils import fetch_attr, load_arg
def graph_iterator_for_metadata(graph, dummy_in=None, add_value=True):
"""
largely adapted from https://pytorch.org/docs/stable/fx.html
"""
model, fx_graph, modules = graph.model, graph.fx_graph, graph.modules
sparsity_info = {}
env = {}
for node in graph.fx_graph.nodes:
args, kwargs = None, None
if node.op == "placeholder":
result = dummy_in[node.name]
elif node.op == "get_attr":
result = fetch_attr(model, node.target)
elif node.op == "call_function":
args = load_arg(node.args, env)
kwargs = load_arg(node.kwargs, env)
result = node.target(*args, **kwargs)
elif node.op == "call_method":
self_obj, *args = load_arg(node.args, env)
kwargs = load_arg(node.kwargs, env)
result = getattr(self_obj, node.target)(*args, **kwargs)
elif node.op == "call_module":
args = load_arg(node.args, env)
kwargs = load_arg(node.kwargs, env)
result = modules[node.target](*args, **kwargs)
meta = node.meta["mase"]
if isinstance(modules[node.target], (torch.nn.Conv2d, torch.nn.Linear)):
# parameterizations is a list, we assume we only have one single entry
mask = modules[node.target].parametrizations.weight[0].mask
weight_sparsity = 1 - float(mask.sum() / mask.numel())
meta.parameters["software"]["args"]["weight"][
"sparsity"
] = weight_sparsity
act_mask = modules[node.target].activation_mask
act_sparsity = 1 - float(act_mask.sum() / act_mask.numel())
meta.parameters["software"]["args"]["data_in_0"][
"sparsity"
] = act_sparsity
if add_value:
meta.parameters["software"]["args"]["weight"]["mask_value"] = mask
meta.parameters["software"]["args"]["weight_mask"][
"value"
] = act_mask
sparsity_info[node.target] = {
"weight_sparsity": weight_sparsity,
"activation_sparsity": act_sparsity,
}
env[node.name] = result
return graph, sparsity_info
def add_movement_metadata_analysis_pass(graph, pass_args=None):
"""
Adds movement metadata to all Conv2d and Linear layers with a weight attribute
in the given MaseGraph's model.
Args:
mg: The MaseGraph instance.
pass_args: Optional dictionary for future expansion.
Returns:
A tuple of the updated MaseGraph and an empty dictionary.
"""
for module in graph.model.modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)) and hasattr(
module, "weight"
):
if not hasattr(module, "metadata"):
module.metadata = {}
module.metadata["weight"] = {
"stats": {"movement": torch.zeros_like(module.weight)}
}
return graph, {}