Source code for chop.passes.graph.analysis.report.report_graph

import logging
from tabulate import tabulate

logger = logging.getLogger(__name__)


[docs] def report_graph_analysis_pass(graph, pass_args={"file_name": None}): """ Generates a report for the graph analysis and prints out an overview of the model in a table. :param graph: a MaseGraph :type graph: MaseGraph :param pass_args: this pass can take a string argument named "file_name", defaults to None :type pass_args: dict, optional pass_args is normally None for this pass :return: return a tuple of a MaseGraph and an empty dict (no additional info to return) :rtype: tuple(MaseGraph, dict) """ if pass_args is None: pass_args = {"file_name": None} file_name = pass_args.get("file_name") buff = """ Graph Analysis Report ===================== Graph Summary ===================== """ node_specs = [ [n.op, n.name, n.target, n.args, n.kwargs] for n in graph.fx_graph.nodes ] buff += str( tabulate(node_specs, headers=["opcode", "name", "target", "args", "kwargs"]) ) count = { "placeholder": 0, "get_attr": 0, "call_function": 0, "call_method": 0, "call_module": 0, "output": 0, } layer_types = [] for node in graph.fx_graph.nodes: if node.meta["mase"].module is not None: layer_types.append(node.meta["mase"].module) for node in graph.fx_graph.nodes: count[node.op] += 1 buff += f""" ===================== Graph Syntax ===================== {str(graph.fx_graph)} ===================== Graph Overview ===================== Network overview: {count} Layer types: {layer_types}""" if file_name is None: print(buff) else: with open(file_name, "w", encoding="utf-8") as outf: outf.write(buff) return graph, {}