Source code for chop.ir.graph.mase_graph

import logging
import math
import os
from pathlib import Path
from types import FunctionType, ModuleType
from typing import Any, Callable, Dict, Optional, Tuple
import dill

import toml
import torch
import torch.fx as fx
from torch.fx.passes.graph_drawer import FxGraphDrawer
import matplotlib.pyplot as plt

from transformers import PreTrainedModel
from transformers.utils.fx import symbolic_trace as hf_symbolic_trace
from transformers.utils.fx import HFTracer

from chop.ir.common import MASE_IMPLICIT_FUNCS
from chop.nn import MASE_LEAF_LAYERS
from chop.nn.quantized import (
    quantized_func_map,
    quantized_module_map,
)
from chop.tools import get_logger

from .mase_metadata import MaseMetadata

logger = get_logger(__name__)
logger.setLevel("INFO")

# ----------------------------------------
#   Mase Tracer
# ----------------------------------------


[docs] class MaseTracer(fx.Tracer):
[docs] def __init__( self, custom_leaf_modules: tuple[ModuleType] = (), custom_leaf_layers: tuple[torch.nn.Module] = (), custom_leaf_functions: tuple[Callable] = (), param_shapes_constant: bool = False, ) -> None: """Mase Tracer is an extended version of FX Tracer. :param custom_leaf_modules: Python modules whose functions should be wrapped automatically without needing to use fx.wrap(). Backward-compatibility for this parameter is guaranteed, defaults to () :type custom_leaf_modules: tuple[ModuleType], optional :param custom_leaf_layers: Python functions that should be wrapped automatically without needing to use fx.wrap(). Backward compatibility for this parameter is guaranteed, defaults to () :type custom_leaf_layers: tuple[torch.nn.Module], optional :param custom_leaf_functions: _description_, defaults to () :type custom_leaf_functions: tuple[Callable], optional :param param_shapes_constant: When this flag is set, calls to shape, size and a few other shape like attributes of a module's parameter will be evaluated directly, rather than returning a new Proxy value for an attribute access. Backward compatibility for this parameter is guaranteed, defaults to False :type param_shapes_constant: bool, optional """ self.custom_leaf_layers = tuple(set(custom_leaf_layers)) self.custom_leaf_modules = tuple(set(custom_leaf_modules)) self.custom_leaf_functions = tuple(set(custom_leaf_functions)) self.param_shapes_constant = param_shapes_constant super().__init__( self.custom_leaf_modules + (math,), self.custom_leaf_functions, self.param_shapes_constant, )
[docs] def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: is_fx_built_in_leaf_module = super().is_leaf_module(m, module_qualified_name) is_mase_leaf_layers = isinstance(m, MASE_LEAF_LAYERS) is_custom_layer = isinstance(m, self.custom_leaf_layers) return any( ( is_fx_built_in_leaf_module, is_mase_leaf_layers, is_custom_layer, ) )
[docs] def trace_torch_module( model: torch.nn.Module, cf_args: Optional[Dict[str, Any]] = None, custom_ops: dict = None, hf_input_names: list = None, ): """ Trace a torch.nn.Module using the MaseTracer. This function is a wrapper around the HFTracer and MaseTracer, and is used to trace a torch.nn.Module into a fx.GraphModule. The fx.GraphModule is a dataflow representation of the model with both software and hardware constraints. The MaseTracer is used to trace the model, and the custom_ops are used to provide custom operations to the tracer. Args: model (torch.nn.Module): Input model to trace. cf_args (Optional[Dict[str, Any]], optional): Concrete forward arguments to trace the model with. Defaults to None. custom_ops (dict, optional): Custom operations to be used in the model. Defaults to None. Returns: fx.GraphModule: Traced model as a fx.GraphModule. """ # * HuggingFace model if isinstance(model, PreTrainedModel): tracer_cls = HFTracer if custom_ops is not None: custom_modules = tuple(custom_ops.get("modules", {}).keys()) else: custom_ops = {"modules": {}, "functions": {}} custom_modules = () def wrap_is_leaf_module(hf_is_leaf_module): def is_leaf_module( self, m: torch.nn.Module, module_qualified_name: str ) -> bool: is_hf_built_in_leaf_module = hf_is_leaf_module( self, m, module_qualified_name, ) is_custom_module = isinstance(m, custom_modules) is_mase_leaf_layer = isinstance(m, MASE_LEAF_LAYERS) return any( ( is_hf_built_in_leaf_module, is_custom_module, is_mase_leaf_layer, ) ) return is_leaf_module setattr( tracer_cls, "is_leaf_module", wrap_is_leaf_module(tracer_cls.is_leaf_module), ) graph_module = hf_symbolic_trace( model, tracer_cls=tracer_cls, input_names=hf_input_names, ) graph_module.custom_ops = custom_ops # ! TO DO: remove this legacy stuff graph_module.patched_op_names = [] graph_module.patched_custom_layers = [] graph_module.additional_inputs = {} # * Other models else: # MASE internal auto-wrapped functions/layers custom_leaf_modules = () custom_leaf_functions = () custom_leaf_layers = () # quantized functions/layers custom_leaf_functions += tuple(quantized_func_map.values()) custom_leaf_layers += tuple(quantized_module_map.values()) # patched functions/layers patched_nodes = getattr(model, "patched_nodes", None) if patched_nodes is not None: custom_leaf_modules += tuple(patched_nodes["modules"]) custom_leaf_functions += tuple(patched_nodes["functions"]) custom_leaf_layers += tuple(patched_nodes["layers"]) tracer = MaseTracer( custom_leaf_modules=custom_leaf_modules, custom_leaf_functions=custom_leaf_functions, custom_leaf_layers=custom_leaf_layers, ) graph_module = fx.GraphModule(model, tracer.trace(model, cf_args)) if patched_nodes is not None: graph_module.patched_op_names = [ obj.__name__.lower() for obj in model.patched_nodes["layers"] + model.patched_nodes["functions"] ] # these are layers we believe the user will provide system verilog for graph_module.patched_custom_layers = model.patched_nodes["layers"] graph_module.additional_inputs = model.patched_nodes["additional_inputs"] graph_module.patched_nodes = model.patched_nodes else: graph_module.patched_op_names = [] graph_module.patched_custom_layers = [] graph_module.additional_inputs = {} return graph_module
# ---------------------------------------- # Mase Graph IR # ----------------------------------------
[docs] class MaseGraph: implicit_nodes = MASE_IMPLICIT_FUNCS
[docs] def __init__( self, model: torch.nn.Module | fx.GraphModule, cf_args: Optional[Dict[str, Any]] = None, custom_ops: dict = None, hf_input_names: list = None, skip_init_metadata: bool = False, add_metadata_args: dict = None, ) -> None: """MaseGraph is a dataflow representation of a model with both software and hardware constraints. The MaseGraph can be constructed from a torch.nn.Module: .. code-block:: python from chop.ir.graph import MaseGraph from transformers import BertModel model = BertModel.from_pretrained("bert-base-uncased") mase_graph = MaseGraph(model) # Or, equivalently: mase_graph = MaseGraph.from_module(model) A MaseGraph can also be constructed from a pre-traced fx.GraphModule: .. code-block:: python from chop.ir.graph import MaseGraph import torch import torch.fx as fx model = torch.nn.Linear(10, 10) traced_model = fx.symbolic_trace(model) mase_graph = MaseGraph(traced_model) A MaseGraph can be exported as follows: .. code-block:: python from chop.ir.graph import MaseGraph import torch import torch.fx as fx model = torch.nn.Linear(10, 10) traced_model = fx.symbolic_trace(model) mase_graph = MaseGraph(traced_model) mase_graph.export("masegraph") The MaseGraph can then be loaded from a checkpoint as follows: .. code-block:: python from chop.ir.graph import MaseGraph mase_graph = MaseGraph.from_checkpoint("masegraph") To visualize the MaseGraph, the `draw` method can be used: .. code-block:: python from chop.ir.graph import MaseGraph mase_graph = MaseGraph.from_module(model) mase_graph.draw("mase_graph.svg") Args: model (torch.nn.Module | fx.GraphModule): Input model to construct the MaseGraph. cf_args (Optional[Dict[str, Any]], optional): Concrete forward arguments to trace the model with. Defaults to None. custom_ops (dict, optional): Custom operations to be used in the model. Defaults to None. hf_input_names (list, optional): Input names for HuggingFace models. Defaults to None. skip_init_metadata (bool, optional): Skip initializing metadata for the nodes. Defaults to False. add_metadata_args (dict, optional): Additional arguments for metadata initialization. Defaults to None. Raises: ValueError: If the input model is not a torch.nn.Module or fx.Graph. """ # is_huggingface flag is used in passes to automate dummy input generation etc if isinstance(model, PreTrainedModel): self.is_huggingface = True else: self.is_huggingface = False self.cf_args = cf_args # Generate the GraphModule according to the model type if isinstance(model, fx.GraphModule): self.model = model self.model.patched_op_names = [] self.model.patched_custom_layers = [] self.model.additional_inputs = [] elif isinstance(model, torch.nn.Module): self.model = trace_torch_module( model, cf_args, custom_ops, hf_input_names=hf_input_names, ) else: raise ValueError( f"Expected fx.GraphModule or nn.Module, but received model: {type(model)}" )
# Initialize metadata for each node # todo: will need to move metadata analysis passes into chop.ir for this to work # if not skip_init_metadata and add_metadata_args is not None: # mg, _ = passes.init_metadata_analysis_pass(self.fx_graph) # mg, _ = passes.add_common_metadata_analysis_pass( # mg, # pass_args=add_metadata_args, # )
[docs] @classmethod def from_module( cls, model: torch.nn.Module, cf_args: Optional[Dict[str, Any]] = None, custom_ops: dict = {}, ): """ Construct a MaseGraph from a torch.nn.Module. Args: model (torch.nn.Module): Input model to construct the MaseGraph. cf_args (Optional[Dict[str, Any]], optional): Concrete forward arguments to trace the model with. Defaults to None. custom_ops (dict, optional): Custom operations to be used in the model. Defaults to {}. Returns: MaseGraph: Constructed MaseGraph. """ assert isinstance( model, torch.nn.Module ), f"model must be a torch.nn.Module. Received: {type(model)}" graph_module = trace_torch_module(model, cf_args, custom_ops) return cls( model=graph_module, cf_args=cf_args, )
[docs] @classmethod def from_checkpoint( cls, checkpoint: str, propagate_missing_metadata: bool = True, ): """ Load a MaseGraph from a checkpoint. A MaseGraph checkpoint consists of two files: {checkpoint}.pt and {checkpoint}.mz. {checkpoint}.pt contains the GraphModule, and {checkpoint}.mz contains the MaseMetadata. If propagate_missing_metadata is set to True, the MaseGraph will attempt to propagate metadata for missing nodes. This is useful when the exported metadata is incomplete due to serialization errors. Args: checkpoint (str): Checkpoint to load the MaseGraph from. propagate_missing_metadata (bool, optional): Propagate metadata for missing nodes. Defaults to True. Returns: MaseGraph: Loaded MaseGraph. """ with open(f"{checkpoint}.pt", "rb") as f: loaded_model = torch.load(f) assert isinstance( loaded_model, fx.GraphModule ), f"Expected fx.GraphModule, but received model: {type(loaded_model)}" mg = cls(loaded_model) with open(f"{checkpoint}.mz", "rb") as f: loaded_meta = dill.load(f) loaded_meta = {k: dill.loads(v) for k, v in loaded_meta.items()} for node in mg.nodes: if node.name in loaded_meta.keys(): parameters = loaded_meta[node.name] node.meta["mase"] = MaseMetadata( node=node, model=loaded_model, ) node.meta["mase"].parameters = parameters else: # todo: propagate metadata for missing nodes logger.warning(f"Node {node.name} not found in loaded metadata.") node.meta["mase"] = MaseMetadata( node=node, model=loaded_model, ) for attr in [ "class_for_deserialization", "config", "device", ]: if hasattr(mg.model, attr): setattr(mg, attr, getattr(mg.model, attr)) return mg
[docs] def export( self, fname: str = "masegraph", ): """ Export the MaseGraph to a pair of files: {fname}.pt and {fname}.mz. {fname}.pt contains the GraphModule, and {fname}.mz contains the MaseMetadata. Args: fname (str): Filename to save the MaseGraph to. Defaults to "masegraph". """ fname = fname.split(".")[0] logger.info(f"Exporting MaseGraph to {fname}.pt, {fname}.mz") logger.debug(f"Recompiling GraphModule to preserve any transforms...") self.model.recompile() # The following parameters must be set as attributes in the GraphModule # for tracing to work during deserialization. These get overwritten during # transform passes so they are read from the MaseGraph attributes (which are # set during the import process). logger.debug(f"Storing tracing parameters into mg.model for deserialization...") for attr in [ "class_for_deserialization", "config", "device", ]: if hasattr(self, attr): logger.debug(f"Setting {attr}") setattr(self.model, attr, getattr(self, attr)) logger.info(f"Exporting GraphModule to {fname}.pt") with open(f"{fname}.pt", "wb") as f: torch.save(self.model, f) logger.info(f"Exporting MaseMetadata to {fname}.mz") combined_meta = {} for node in self.nodes: parameters = node.meta["mase"].parameters try: pickled = dill.dumps(parameters) combined_meta[node.name] = pickled except Exception as e: logger.warning(f"Failed to pickle {node.op} node: {node.name}") logger.warning(e) with open(f"{fname}.mz", "wb") as f: dill.dump(combined_meta, f)
[docs] def draw(self, file="mase_graph.svg"): """ Draw the MaseGraph using the FxGraphDrawer. Args: file (str, optional): File to save the graph to. Defaults to "mase_graph.svg". """ try: import pydot except: raise ImportError("pydot is required to draw the graph") drawer = FxGraphDrawer(self.model, "masegraph") dot_graph = drawer.get_dot_graph() # some dot_graph contains .weight and .bias, this cause pydot to crash in plotting # so we need to remove them # for instance, in BERT, you have bert_embeddings_word_embeddings.weight as a node, # this is not allowed in graphviz dot_string = dot_graph.to_string() dot_string = dot_string.replace(".weight", "_weight") dot_string = dot_string.replace(".bias", "_bias") # the following code snippet is how to plot in networkx, but it does not look nice # new_dot_graph = pydot.graph_from_dot_data(dot_string) # new_dot_graph = new_dot_graph[0] # graph = nx.drawing.nx_pydot.from_pydot(dot_graph) # nx.draw(graph) # plt.tight_layout() # plt.savefig("test.png", format="PNG") dot_graph = pydot.graph_from_dot_data(dot_string) dot_graph = dot_graph[0] dot_graph.write_svg(file)
@property def fx_graph(self): """The fx.Graph representation of the MaseGraph. Returns: fx.Graph: fx.Graph representation of the MaseGraph. """ return self.model.graph @property def nodes(self): """The nodes of the MaseGraph. Returns: list: List of nodes in the MaseGraph. """ return self.model.graph.nodes @property def modules(self): """ Get all the modules in the model. Returns: dict: Dictionary of all the modules in the model. """ return dict(self.model.named_modules()) @fx_graph.setter def fx_graph(self, graph: fx.Graph): self.model.graph = graph