import logging
import math
from typing import Any
import itertools
import numpy as np
import toml
import torch
from torch.fx import Interpreter
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
from tqdm import tqdm
from ...utils import get_mase_op, get_mase_type, get_module_by_target
from .stat import _StatBase, create_new_stat
from .utils import get_meta_arg_stat, set_meta_arg_stat
logger = logging.getLogger(__name__)
class ActStatCollection:
def __init__(self, stats: dict[str, dict]) -> None:
self.stats: list[_StatBase] = []
for stat_name, stat_config in stats.items():
self.stats.append(create_new_stat(stat_name, **stat_config))
def update(self, batch: torch.Tensor):
assert isinstance(batch, torch.Tensor)
for stat in self.stats:
if hasattr(stat, "update_a_batch"):
stat.update_a_batch(batch)
else:
for i in range(batch.size(0)):
stat.update_a_sample(batch[[i], ...])
def compute(self) -> dict[str, dict[str, list]]:
results = {}
for stat in self.stats:
results.update(stat.export())
return results
def __repr__(self) -> str:
return "ActStatCollection(stats={})".format(
", ".join([type(stat).__name__ for stat in self.stats])
)
class WeightStatCollection:
def __init__(self, stats: dict[str, dict]) -> None:
self.stats: list[_StatBase] = []
for stat_name, stat_config in stats.items():
self.stats.append(create_new_stat(stat_name, **stat_config))
def update(self, weight: torch.Tensor):
assert isinstance(weight, torch.Tensor)
for stat in self.stats:
stat: _StatBase
stat.update_a_sample(weight)
def compute(self) -> dict[str, dict[str, list]]:
results = {}
for stat in self.stats:
results.update(stat.export())
return results
def __repr__(self) -> str:
return "WeightStatCollection(stats={})".format(
", ".join([type(stat).__name__ for stat in self.stats])
)
def graph_iterator_register_stat_collections_by_name(
graph,
target_weight_nodes: list[str],
target_act_nodes: list[str],
weight_stats: dict[str, dict],
act_stats: dict[str, dict],
profile_output_act: bool = False,
):
# weight stats
for node in graph.fx_graph.nodes:
if node.name not in target_weight_nodes:
continue
if node.op != "call_module":
logger.warning(
f"Node {node.name} is not a call_module node, but is in target_weight_nodes. Skip."
)
continue
# only create registered param/buffer stats for nn.Module
for entry, s_meta in node.meta["mase"].parameters["software"]["args"].items():
stat = s_meta["stat"]
if "data_in" in entry:
continue
if isinstance(stat, (WeightStatCollection,)):
continue
set_meta_arg_stat(node, entry, WeightStatCollection(weight_stats))
# act stats
for node in graph.fx_graph.nodes:
if node.name not in target_act_nodes:
continue
for entry, s_meta in node.meta["mase"].parameters["software"]["args"].items():
stat = s_meta["stat"]
if isinstance(stat, (WeightStatCollection, ActStatCollection)):
continue
# data_in_0, data_in_1, data_in_2, ..., and (weight, bias of nn.functional)
set_meta_arg_stat(node, entry, ActStatCollection(act_stats))
return graph
def graph_iterator_register_stat_collections_by_type(
graph,
target_weight_nodes: list[str],
target_act_nodes: list[str],
weight_stats: dict[str, dict],
act_stats: dict[str, dict],
profile_output_act: bool = False,
):
# weight stats
for node in graph.fx_graph.nodes:
if get_mase_op(node) not in target_weight_nodes:
continue
if node.op != "call_module":
logger.warning(
f"Node {node.name} is not a call_module node, but is in target_weight_nodes. Skip."
)
continue
# only create registered param/buffer stats for nn.Module
for entry, s_meta in node.meta["mase"].parameters["software"]["args"].items():
stat = s_meta["stat"]
if "data_in" in entry:
continue
if isinstance(stat, (WeightStatCollection,)):
continue
set_meta_arg_stat(node, entry, WeightStatCollection(weight_stats))
# act stats
for node in graph.fx_graph.nodes:
if get_mase_op(node) not in target_act_nodes:
continue
for entry, s_meta in node.meta["mase"].parameters["software"]["args"].items():
stat = s_meta["stat"]
if isinstance(stat, (WeightStatCollection, ActStatCollection)):
continue
set_meta_arg_stat(node, entry, ActStatCollection(act_stats))
return graph
def graph_iterator_register_stat_collections(
graph,
by,
target_weight_nodes,
target_act_nodes,
weight_stats,
act_stats,
profile_output_act=False,
):
match by:
case "name":
graph = graph_iterator_register_stat_collections_by_name(
graph,
target_weight_nodes,
target_act_nodes,
weight_stats,
act_stats,
)
case "type":
graph = graph_iterator_register_stat_collections_by_type(
graph,
target_weight_nodes,
target_act_nodes,
weight_stats,
act_stats,
)
case _:
raise ValueError(f"Unknown by: {by}")
return graph
class ActProfiler(Interpreter):
def __init__(self, module: GraphModule, garbage_collect_values: bool = True):
super().__init__(module, garbage_collect_values)
def run_node(self, n: Node) -> Any:
with self._set_current_node(n):
args, kwargs = self.fetch_args_kwargs_from_env(n)
assert isinstance(args, tuple)
assert isinstance(kwargs, dict)
act_stats = []
for arg_entry in (
n.meta["mase"].parameters["software"].get("args", {}).keys()
):
stat = get_meta_arg_stat(n, arg_entry)
if isinstance(stat, ActStatCollection):
act_stats.append(stat)
if len(act_stats) > 0:
numeric_args = tuple(
filter(
lambda x: isinstance(x, (torch.Tensor))
and not isinstance(x, bool),
args + tuple(kwargs.values()),
)
)
assert len(numeric_args) == len(act_stats), (
f"Number of tensor args ({len(numeric_args)}) "
f"does not match number of act entries ({len(act_stats)})"
)
for tensor_arg, stat in zip(numeric_args, act_stats):
stat.update(tensor_arg)
device = None
if isinstance(n.meta["mase"].module, torch.nn.Module):
try:
device = next(
itertools.chain(
n.meta["mase"].module.parameters(),
n.meta["mase"].module.buffers(),
)
).device
except StopIteration:
pass
if device is not None:
args = tuple(
arg.to(device=device) if isinstance(arg, torch.Tensor) else arg
for arg in args
)
kwargs = {
k: v.to(device=device) if isinstance(v, torch.Tensor) else v
for k, v in kwargs.items()
}
output = getattr(self, n.op)(n.target, args, kwargs)
# if isinstance(n.meta, _ActStatMeta):
# n.meta.update(output)
return output
def graph_iterator_profile_act(graph, input_generator, num_samples):
act_profiler = ActProfiler(graph.model, garbage_collect_values=True)
max_batches = math.ceil(num_samples / input_generator.batch_size)
for i in tqdm(range(max_batches), desc="Profiling act statistics"):
batch = next(input_generator)
act_profiler.run(*batch.values())
return graph
def graph_iterator_profile_weight(graph):
for node in tqdm(
graph.fx_graph.nodes,
total=len(list(graph.fx_graph.nodes)),
desc="Profiling weight statistics",
):
if node.op != "call_module":
continue
param_dict = dict(node.meta["mase"].module.named_parameters())
buffer_dict = dict(node.meta["mase"].module.named_buffers())
p_b_dict = {**param_dict, **buffer_dict}
for w_name, s_meta in node.meta["mase"].parameters["software"]["args"].items():
stat = s_meta["stat"]
if not isinstance(stat, WeightStatCollection):
continue
w = p_b_dict[w_name]
stat.update(w.data)
return graph
def graph_iterator_compute_and_unregister_stats(graph):
for node in graph.fx_graph.nodes:
for entry, s_meta in (
node.meta["mase"].parameters["software"].get("args", {}).items()
):
stat = s_meta["stat"]
if isinstance(stat, (WeightStatCollection, ActStatCollection)):
result = stat.compute()
set_meta_arg_stat(node, entry, result)
# for entry, s_meta in (
# node.meta["mase"].parameters["software"]["results"].items()
# ):
# stat = s_meta["stat"]
# if isinstance(stat, ActStatCollection):
# result = stat.compute()
# set_meta_result_stat(node, entry, result)
return graph
[docs]
def profile_statistics_analysis_pass(graph, pass_args: dict):
"""
Perform profile statistics analysis on the given graph.
:param graph: The graph to perform analysis on.
:type graph: MaseGraph
:param pass_args: The arguments for the analysis pass.
:type pass_args: dict
.. code-block:: python
pass_args = {
"by": "type", # pick from ["name", "type"]
"target_weight_nodes": "linear", # ["conv2d", "linear" ...],
"target_activation_nodes": "relu", # ["relu", "sigmoid" ...],
"weight_statistics": {
"variance_precise": {
"device": "cpu",
"dims": "all"
},
},
"activation_statistics": {
"variance_precise": {"device": "cpu", "dims": "all"},
},
"input_generator": input_generator,
"num_samples": 1,
"profile_output_activation": False,
:return: The modified graph and an empty dictionary.
:rtype: tuple(MaseGraph, dict)
"""
graph = graph_iterator_register_stat_collections(
graph,
by=pass_args["by"],
target_weight_nodes=pass_args["target_weight_nodes"],
target_act_nodes=pass_args["target_activation_nodes"],
weight_stats=pass_args["weight_statistics"],
act_stats=pass_args["activation_statistics"],
profile_output_act=pass_args.get("profile_output_activation", False),
)
graph = graph_iterator_profile_weight(graph)
graph = graph_iterator_profile_act(
graph,
input_generator=pass_args["input_generator"],
num_samples=pass_args["num_samples"],
)
graph = graph_iterator_compute_and_unregister_stats(graph)
return graph, {}