Source code for chop.passes.graph.analysis.pruning.hook_inspector

import torch


from chop.passes.graph.analysis.utils import fetch_attr, load_arg


def graph_iterator_for_metadata(graph):
    """
    largely adapted from https://pytorch.org/docs/stable/fx.html
    """

    model, fx_graph, modules = graph.model, graph.fx_graph, graph.modules
    hook_info = {}

    for node in graph.fx_graph.nodes:
        if node.op == "call_module":
            name = node.target
            if isinstance(modules[node.target], (torch.nn.Conv2d, torch.nn.Linear)):
                m = modules[name]
                for k, v in m._forward_hooks.items():
                    hook_info[f"{name}_{k}"] = (k, v)

                for k, v in m._forward_pre_hooks.items():
                    hook_info[f"{name}_{k}"] = (k, v)

                for k, v in m._backward_hooks.items():
                    hook_info[f"{name}_{k}"] = (k, v)

    return graph, hook_info


[docs] def hook_inspection_analysis_pass(graph, pass_args: dict = {}): """ Remove and provide hook information of the modules. :param graph: The MaseGraph to which the pruning metadata analysis pass will be added. :type graph: MaseGraph :param pass_args: Additional arguments for the pruning metadata analysis pass. This pass does not need any values, so an empty dictionary is fine :type pass_args: dict pass_args can be None or an empty dictionary. :return: The updated graph and sparsity information. The returned dict contains {'module_name': (hook_id, hook_fn)} :rtype: tuple(MaseGraph, dict) Examples: A sample output dict: .. code-block:: JSON { 'feature_layers.0_0': ( 0, <function get_activation_hook.<locals>.sparsify_input at 0x7f9544528c10>), 'feature_layers.3_1': ( 1, <function get_activation_hook.<locals>.sparsify_input at 0x7f9544528ca0>), 'feature_layers.7_2': ( 2, <function get_activation_hook.<locals>.sparsify_input at 0x7f9544528d30>), } """ graph, hook_info = graph_iterator_for_metadata(graph) return graph, hook_info