Source code for chop.passes.graph.transforms.quantize.summary

import logging
import os

import numpy as np
import pandas as pd
from tabulate import tabulate

from ...utils import get_mase_op, get_mase_type, get_node_actual_target


logger = logging.getLogger(__name__)


def graph_iterator_compare_nodes(
    ori_graph, graph, save_path=None, silent=False
) -> pd.DataFrame:
    """List all nodes in the graph and compare the original and quantized nodes."""

    def get_type_str(node):
        if node.op == "call_module":
            return type(get_node_actual_target(node)).__name__
        elif get_mase_type(node) in [
            "builtin_func",
            "module_related_func",
            "patched_func",
        ]:
            return get_node_actual_target(node).__name__
        elif get_mase_type(node) in ["implicit_func"]:
            actual_target = get_node_actual_target(node)
            if isinstance(actual_target, str):
                return actual_target
            else:
                return actual_target.__name__
        else:
            return node.target

    headers = [
        "Ori name",
        "New name",
        "MASE_TYPE",
        "Mase_OP",
        "Original type",
        "Quantized type",
        "Changed",
    ]
    rows = []
    for ori_n, n in zip(ori_graph.fx_graph.nodes, graph.fx_graph.nodes):
        rows.append(
            [
                ori_n.name,
                n.name,
                get_mase_type(n),
                get_mase_op(n),
                get_type_str(ori_n),
                get_type_str(n),
                type(get_node_actual_target(n)) != type(get_node_actual_target(ori_n)),
            ]
        )
    if not silent:
        logger.debug("Compare nodes:")
        logger.debug("\n" + tabulate(rows, headers=headers, tablefmt="orgtbl"))
    if save_path is not None:
        with open(save_path, "w") as f:
            f.write(tabulate(rows, headers=headers))

    df = pd.DataFrame(rows, columns=headers)
    if save_path is not None:
        df.to_csv(save_path)

    return df


def graph_iterator_node_histogram(ori_graph, graph, save_path: str = None):
    """Group nodes by their types and count the number of nodes in each group."""
    df = graph_iterator_compare_nodes(ori_graph, graph, save_path=None, silent=True)
    histogram_df = df.groupby(["Original type"]).agg(
        OP=pd.NamedAgg(column="Mase_OP", aggfunc="first"),
        Total=pd.NamedAgg(column="Changed", aggfunc="count"),
        Changed=pd.NamedAgg(column="Changed", aggfunc=lambda x: np.sum(x)),
        Unchanged=pd.NamedAgg(
            column="Changed", aggfunc=lambda x: np.sum(1 - np.array(x))
        ),
    )
    logger.info("Quantized graph histogram:")
    logger.info("\n" + tabulate(histogram_df, headers="keys", tablefmt="orgtbl"))
    if save_path is not None:
        histogram_df.to_csv(save_path)


# def graph_iterator_compare_nodes(*args, **kwargs):
#     # TODO: remove this function when the add_common_metadata is fixed
#     pass


# def graph_iterator_node_histogram(*args, **kwargs):
#     # TODO: remove this function when the add_common_metadata is fixed
#     pass


[docs] def summarize_quantization_analysis_pass( graph, pass_args={"save_dir": None, "original_graph": None} ) -> None: """ Summarizes the quantization analysis pass. Apply quantization transformation to the given graph. :param graph: The input graph to be transformed. :type graph: MaseGraph :param pass_args: Additional arguments for the transformation. :type pass_args: dict, optional .. code-block: python pass_args = { "save_dir": "quantize_summary", "original_graph": ori_mg, # original graph, type should be MaseGraph } :return: The transformed MaseGraph. :rtype: tuple """ save_dir, ori_graph = pass_args["save_dir"], pass_args["original_graph"] if save_dir is not None: os.makedirs(save_dir, exist_ok=True) table_path = os.path.join(save_dir, "quantize_table.csv") if save_dir else None histogram_path = ( os.path.join(save_dir, "quantize_histogram.csv") if save_dir else None ) graph_iterator_compare_nodes(ori_graph, graph, save_path=table_path, silent=False) graph_iterator_node_histogram(ori_graph, graph, save_path=histogram_path) return graph, {}