class tuple[ModuleType] = (), custom_leaf_layers: tuple[Module] = (), custom_leaf_functions: tuple[Callable] = (), param_shapes_constant: bool = False)[source]#

Bases: Tracer

__init__(custom_leaf_modules: tuple[ModuleType] = (), custom_leaf_layers: tuple[Module] = (), custom_leaf_functions: tuple[Callable] = (), param_shapes_constant: bool = False) None[source]#

Mase Tracer is an extended version of FX Tracer.

  • custom_leaf_modules (tuple[ModuleType], optional) – Python modules whose functions should be wrapped automatically without needing to use fx.wrap(). Backward-compatibility for this parameter is guaranteed, defaults to ()

  • custom_leaf_layers (tuple[torch.nn.Module], optional) – Python functions that should be wrapped automatically without needing to use fx.wrap(). Backward compatibility for this parameter is guaranteed, defaults to ()

  • custom_leaf_functions (tuple[Callable], optional) – _description_, defaults to ()

  • param_shapes_constant (bool, optional) – When this flag is set, calls to shape, size and a few other shape like attributes of a module’s parameter will be evaluated directly, rather than returning a new Proxy value for an attribute access. Backward compatibility for this parameter is guaranteed, defaults to False

is_leaf_module(m: Module, module_qualified_name: str) bool[source]#

A method to specify whether a given nn.Module is a “leaf” module.

Leaf modules are the atomic units that appear in the IR, referenced by call_module calls. By default, Modules in the PyTorch standard library namespace (torch.nn) are leaf modules. All other modules are traced through and their constituent ops are recorded, unless specified otherwise via this parameter.

  • m (Module) – The module being queried about

  • module_qualified_name (str) – The path to root of this module. For example, if you have a module hierarchy where submodule foo contains submodule bar, which contains submodule baz, that module will appear with the qualified name here.


Backwards-compatibility for this API is guaranteed. Module, cf_args: Dict[str, Any] | None = None, custom_ops: dict = None, hf_input_names: list = None)[source]#

Trace a torch.nn.Module using the MaseTracer. This function is a wrapper around the HFTracer and MaseTracer, and is used to trace a torch.nn.Module into a fx.GraphModule. The fx.GraphModule is a dataflow representation of the model with both software and hardware constraints. The MaseTracer is used to trace the model, and the custom_ops are used to provide custom operations to the tracer.

  • model (torch.nn.Module) – Input model to trace.

  • cf_args (Optional[Dict[str, Any]], optional) – Concrete forward arguments to trace the model with. Defaults to None.

  • custom_ops (dict, optional) – Custom operations to be used in the model. Defaults to None.


Traced model as a fx.GraphModule.

Return type:


class Module | GraphModule, cf_args: Dict[str, Any] | None = None, custom_ops: dict = None, hf_input_names: list = None, skip_init_metadata: bool = False, add_metadata_args: dict = None)[source]#

Bases: object

implicit_nodes = ['size', 'view', 'to', 'bool', 'int', 'flatten', 'squeeze', 'unsqueeze', 'transpose', 'permute', 'reshape', 'contiguous', 'dropout', 'eq', 'ne', 'gemm', 'ge', 'where', '_assert', 'getattr', 'long', 'type_as', 'clamp', 'abs', 'stack', 'cast', 'shape', 'gather', 'slice', 'cat', 'split', 'tile', 'expand', 'full', 'ones', 'dim', 'finfo', 'masked_fill', 'masked_fill_', 'index_select', 'detach', 'tensor']#
__init__(model: Module | GraphModule, cf_args: Dict[str, Any] | None = None, custom_ops: dict = None, hf_input_names: list = None, skip_init_metadata: bool = False, add_metadata_args: dict = None) None[source]#

MaseGraph is a dataflow representation of a model with both software and hardware constraints. The MaseGraph can be constructed from a torch.nn.Module:

from import MaseGraph
from transformers import BertModel

model = BertModel.from_pretrained("bert-base-uncased")
mase_graph = MaseGraph(model)

# Or, equivalently:
mase_graph = MaseGraph.from_module(model)

A MaseGraph can also be constructed from a pre-traced fx.GraphModule:

from import MaseGraph
import torch
import torch.fx as fx

model = torch.nn.Linear(10, 10)
traced_model = fx.symbolic_trace(model)
mase_graph = MaseGraph(traced_model)

A MaseGraph can be exported as follows:

from import MaseGraph
import torch
import torch.fx as fx

model = torch.nn.Linear(10, 10)
traced_model = fx.symbolic_trace(model)
mase_graph = MaseGraph(traced_model)

The MaseGraph can then be loaded from a checkpoint as follows:

from import MaseGraph

mase_graph = MaseGraph.from_checkpoint("masegraph")

To visualize the MaseGraph, the draw method can be used:

from import MaseGraph

mase_graph = MaseGraph.from_module(model)
  • model (torch.nn.Module | fx.GraphModule) – Input model to construct the MaseGraph.

  • cf_args (Optional[Dict[str, Any]], optional) – Concrete forward arguments to trace the model with. Defaults to None.

  • custom_ops (dict, optional) – Custom operations to be used in the model. Defaults to None.

  • hf_input_names (list, optional) – Input names for HuggingFace models. Defaults to None.

  • skip_init_metadata (bool, optional) – Skip initializing metadata for the nodes. Defaults to False.

  • add_metadata_args (dict, optional) – Additional arguments for metadata initialization. Defaults to None.


ValueError – If the input model is not a torch.nn.Module or fx.Graph.

classmethod from_module(model: Module, cf_args: Dict[str, Any] | None = None, custom_ops: dict = {})[source]#

Construct a MaseGraph from a torch.nn.Module.

  • model (torch.nn.Module) – Input model to construct the MaseGraph.

  • cf_args (Optional[Dict[str, Any]], optional) – Concrete forward arguments to trace the model with. Defaults to None.

  • custom_ops (dict, optional) – Custom operations to be used in the model. Defaults to {}.


Constructed MaseGraph.

Return type:


classmethod from_checkpoint(checkpoint: str, propagate_missing_metadata: bool = True)[source]#

Load a MaseGraph from a checkpoint. A MaseGraph checkpoint consists of two files: {checkpoint}.pt and {checkpoint}.mz. {checkpoint}.pt contains the GraphModule, and {checkpoint}.mz contains the MaseMetadata.

If propagate_missing_metadata is set to True, the MaseGraph will attempt to propagate metadata for missing nodes. This is useful when the exported metadata is incomplete due to serialization errors.

  • checkpoint (str) – Checkpoint to load the MaseGraph from.

  • propagate_missing_metadata (bool, optional) – Propagate metadata for missing nodes. Defaults to True.


Loaded MaseGraph.

Return type:


export(fname: str = 'masegraph')[source]#

Export the MaseGraph to a pair of files: {fname}.pt and {fname}.mz. {fname}.pt contains the GraphModule, and {fname}.mz contains the MaseMetadata.


fname (str) – Filename to save the MaseGraph to. Defaults to “masegraph”.


Draw the MaseGraph using the FxGraphDrawer.


file (str, optional) – File to save the graph to. Defaults to “mase_graph.svg”.

property nodes#

The nodes of the MaseGraph.


List of nodes in the MaseGraph.

Return type:


property modules#

Get all the modules in the model.


Dictionary of all the modules in the model.

Return type:


property fx_graph#

The fx.Graph representation of the MaseGraph.


fx.Graph representation of the MaseGraph.

Return type:



Bases: object

__init__(graph)[source]#, request_name)[source]#
class, model=None)[source]#

Bases: object

The metadata of a Mase node in a Mase graph describes the constraints of the node for any static analysis or possible transformation. The metadata has a tree structure, e.g.

  • common

    • mase_op -> str : the mase op of the node, e.g. placeholder, linear, relu

    • mase_type -> str : the mase type of the node, e.g. module, builtin_func, module_related_func

    • args -> {}

      • $name : name of the arg (if the arg is a tensor)

        • type -> type of the arg, e.g. fixed point or float

        • precision -> format of the type, e.g. (10, 5)

        • shape -> shape of the arg (if the arg is not a tensor)

        • value of the arg

    • results -> {}

      • $name : name of the result (if the result is a tensor)

        • type -> type of the result, e.g. fixed point or float

        • precision -> format of the type, e.g. (10, 5)

        • size -> size of the result (if the result is not a tensor)

        • value of the result

  • software: dict

    • args: dict

      • $name (dict): name of the arg, e.g. data_in_0

        • “stat”: {“record”: {“data”: …, “count”: …},

          “variance_online”: {“variance”: …, “mean”: …, “count”: …}}, “variance_precise”: {“variance”: …, “mean”: …, “count”: …}, “range_n_sigma”: {“min”: …, “max”: …, “count”: …}, “range_quantile”: {“min”: …, “max”: …, “count”: …}, “range_min_max”: {“min”: …, “max”: …, “count”: …},


    • results: dict

      • $name (dict): name of the result, e.g. data_out_0

        • “stat”: {“stat_name”: { # stat_values } }

  • hardware

    • is_implicit -> bool : whether the node is mapped on hardware or software annotation only

    • verilog_param -> {} : parameters need for customise the hardware module

    • device_id -> int : the ID of the device where the node is mapped, default = -1

    • toolchain -> str : tool chain for code generation, must be INTERNAL, EXTERNAL or HLS

    • module -> str : the name of the used hardware module

    • interface -> {}

      • name : name of the parameters

        • storage : the hardware interface implemented, must be BRAM

        • transpose : whether the data needs to be transposed before emitting

    • dependence_files -> [] : the dependent files for the generated module

known_types = ['fixed', 'float', 'NA']#
known_toolchain = ['INTERNAL', 'EXTERNAL', 'HLS']#
known_storage = ['BRAM']#
__init__(node=None, model=None)[source]#
property module#
property graph#