Source code for chop.ir.graph.mase_graph

import logging
import math
import os
from pathlib import Path

from types import FunctionType, ModuleType
from typing import Any, Callable, Dict, Optional, Tuple

import toml
import torch
import torch.fx as fx
from torch.fx.passes.graph_drawer import FxGraphDrawer

from chop.ir.common import MASE_IMPLICIT_FUNCS
from chop.nn import MASE_LEAF_LAYERS
from chop.nn.quantized import (
    quantized_func_map,
    quantized_module_map,
)

from transformers import PreTrainedModel
from transformers.utils.fx import symbolic_trace as hf_symbolic_trace
from transformers.utils.fx import HFTracer

logger = logging.getLogger(__name__)

# ----------------------------------------
#   Mase Tracer
# ----------------------------------------


[docs] class MaseTracer(fx.Tracer):
[docs] def __init__( self, custom_leaf_modules: tuple[ModuleType] = (), custom_leaf_layers: tuple[torch.nn.Module] = (), custom_leaf_functions: tuple[Callable] = (), param_shapes_constant: bool = False, ) -> None: """Mase Tracer is an extended version of FX Tracer. :param custom_leaf_modules: Python modules whose functions should be wrapped automatically without needing to use fx.wrap(). Backward-compatibility for this parameter is guaranteed, defaults to () :type custom_leaf_modules: tuple[ModuleType], optional :param custom_leaf_layers: Python functions that should be wrapped automatically without needing to use fx.wrap(). Backward compatibility for this parameter is guaranteed, defaults to () :type custom_leaf_layers: tuple[torch.nn.Module], optional :param custom_leaf_functions: _description_, defaults to () :type custom_leaf_functions: tuple[Callable], optional :param param_shapes_constant: When this flag is set, calls to shape, size and a few other shape like attributes of a module's parameter will be evaluated directly, rather than returning a new Proxy value for an attribute access. Backward compatibility for this parameter is guaranteed, defaults to False :type param_shapes_constant: bool, optional """ self.custom_leaf_layers = tuple(set(custom_leaf_layers)) self.custom_leaf_modules = tuple(set(custom_leaf_modules)) self.custom_leaf_functions = tuple(set(custom_leaf_functions)) self.param_shapes_constant = param_shapes_constant super().__init__( self.custom_leaf_modules + (math,), self.custom_leaf_functions, self.param_shapes_constant, )
[docs] def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: is_fx_built_in_leaf_module = super().is_leaf_module(m, module_qualified_name) is_mase_leaf_layers = isinstance(m, MASE_LEAF_LAYERS) is_custom_layer = isinstance(m, self.custom_leaf_layers) return any( ( is_fx_built_in_leaf_module, is_mase_leaf_layers, is_custom_layer, ) )
# ---------------------------------------- # Mase Graph IR # ----------------------------------------
[docs] class MaseGraph: implicit_nodes = MASE_IMPLICIT_FUNCS
[docs] def __init__( self, model, cf_args: Optional[Dict[str, Any]] = None, custom_ops: dict = None, hf_input_names: list = None, ) -> None: """Mase takes a torch.fx graph representation of a model and translates it into a customised representation (Mase graph IR). The Mase graph IR is a dataflow representation of the model with both software and hardware constraints. :param model: Input model to construct the MaseGraph. When a nn.Module is provided, this is parsed into a fx.GraphModule using the MaseTracer. :type model: torch.nn.Module | fx.GraphModule :param cf_args: _description_, defaults to None :type cf_args: Optional[Dict[str, Any]], optional """ if isinstance(model, fx.GraphModule): self.model = model self.model.patched_op_names = [] self.model.patched_custom_layers = [] self.model.additional_inputs = [] elif isinstance(model, torch.nn.Module): self.model = self.trace_torch_module( model, cf_args, custom_ops, hf_input_names=hf_input_names, ) else: raise ValueError( f"Expected fx.GraphModule or nn.Module, but received model: {type(model)}" ) self.cf_args = cf_args
[docs] def trace_torch_module( self, model: torch.nn.Module, cf_args: Optional[Dict[str, Any]] = None, custom_ops: dict = None, hf_input_names: list = None, ): # * HuggingFace model if isinstance(model, PreTrainedModel): tracer_cls = HFTracer if custom_ops is not None: custom_modules = tuple(custom_ops.get("modules", {}).keys()) else: custom_ops = {"modules": {}, "functions": {}} custom_modules = () def wrap_is_leaf_module(hf_is_leaf_module): def is_leaf_module( self, m: torch.nn.Module, module_qualified_name: str ) -> bool: is_hf_built_in_leaf_module = hf_is_leaf_module( self, m, module_qualified_name ) is_custom_module = isinstance(m, custom_modules) is_mase_leaf_layer = isinstance(m, MASE_LEAF_LAYERS) return any( ( is_hf_built_in_leaf_module, is_custom_module, is_mase_leaf_layer, ) ) return is_leaf_module setattr( tracer_cls, "is_leaf_module", wrap_is_leaf_module(tracer_cls.is_leaf_module), ) graph_module = hf_symbolic_trace( model, tracer_cls=tracer_cls, input_names=hf_input_names, ) graph_module.custom_ops = custom_ops # ! TO DO: remove this legacy stuff graph_module.patched_op_names = [] graph_module.patched_custom_layers = [] graph_module.additional_inputs = {} # * Other models else: # MASE internal auto-wrapped functions/layers custom_leaf_modules = () custom_leaf_functions = () custom_leaf_layers = () # quantized functions/layers custom_leaf_functions += tuple(quantized_func_map.values()) custom_leaf_layers += tuple(quantized_module_map.values()) # patched functions/layers patched_nodes = getattr(model, "patched_nodes", None) if patched_nodes is not None: custom_leaf_modules += tuple(patched_nodes["modules"]) custom_leaf_functions += tuple(patched_nodes["functions"]) custom_leaf_layers += tuple(patched_nodes["layers"]) self.tracer = MaseTracer( custom_leaf_modules=custom_leaf_modules, custom_leaf_functions=custom_leaf_functions, custom_leaf_layers=custom_leaf_layers, ) graph_module = fx.GraphModule(model, self.tracer.trace(model, cf_args)) if patched_nodes is not None: graph_module.patched_op_names = [ obj.__name__.lower() for obj in model.patched_nodes["layers"] + model.patched_nodes["functions"] ] # these are layers we believe the user will provide system verilog for graph_module.patched_custom_layers = model.patched_nodes["layers"] graph_module.additional_inputs = model.patched_nodes[ "additional_inputs" ] graph_module.patched_nodes = model.patched_nodes else: graph_module.patched_op_names = [] graph_module.patched_custom_layers = [] graph_module.additional_inputs = {} return graph_module
[docs] @classmethod def from_module( cls, model: torch.nn.Module, cf_args: Optional[Dict[str, Any]] = None, custom_ops: dict = {}, ): assert isinstance( model, torch.nn.Module ), f"model must be a torch.nn.Module. Received: {type(model)}" graph_module = self.trace_torch_module(model, cf_args, custom_ops) return cls(model=graph_module, cf_args=cf_args)
[docs] def draw(self, file="mase_graph.svg"): drawer = FxGraphDrawer(self.model, "masegraph") drawer.get_dot_graph().write_svg(file)
@property def fx_graph(self): return self.model.graph @fx_graph.setter def fx_graph(self, graph: fx.Graph): self.model.graph = graph @property def nodes(self): return self.model.graph.nodes @property def modules(self): return dict(self.model.named_modules())