Source code for chop.ir.graph.mase_metadata

import logging

from torch import nn

logger = logging.getLogger(__name__)


[docs] def get_module_by_name(model, request_name): for name, layer in model.named_modules(): if name == request_name: return layer return None
[docs] class MaseMetadata: """ The metadata of a Mase node in a Mase graph describes the constraints of the node for any static analysis or possible transformation. The metadata has a tree structure, e.g. - common - mase_op -> str : the mase op of the node, e.g. placeholder, linear, relu - mase_type -> str : the mase type of the node, e.g. module, builtin_func, module_related_func - args -> {} - $name : name of the arg (if the arg is a tensor) - type -> type of the arg, e.g. fixed point or float - precision -> format of the type, e.g. (10, 5) - shape -> shape of the arg (if the arg is not a tensor) - value of the arg - results -> {} - $name : name of the result (if the result is a tensor) - type -> type of the result, e.g. fixed point or float - precision -> format of the type, e.g. (10, 5) - size -> size of the result (if the result is not a tensor) - value of the result - software: dict - args: dict - $name (dict): name of the arg, e.g. data_in_0 - "stat": {"record": {"data": ..., "count": ...}, "variance_online": {"variance": ..., "mean": ..., "count": ...}}, "variance_precise": {"variance": ..., "mean": ..., "count": ...}, "range_n_sigma": {"min": ..., "max": ..., "count": ...}, "range_quantile": {"min": ..., "max": ..., "count": ...}, "range_min_max": {"min": ..., "max": ..., "count": ...}, }. - results: dict - $name (dict): name of the result, e.g. data_out_0 - "stat": {"stat_name": { # stat_values } } - hardware - is_implicit -> bool : whether the node is mapped on hardware or software annotation only - verilog_param -> {} : parameters need for customise the hardware module - device_id -> int : the ID of the device where the node is mapped, default = -1 - toolchain -> str : tool chain for code generation, must be INTERNAL, EXTERNAL or HLS - module -> str : the name of the used hardware module - interface -> {} - name : name of the parameters - storage : the hardware interface implemented, must be BRAM - transpose : whether the data needs to be transposed before emitting - dependence_files -> [] : the dependent files for the generated module """ # Hardware dict known_types = ["fixed", "float", "NA"] known_toolchain = ["INTERNAL", "EXTERNAL", "HLS"] known_storage = ["BRAM"]
[docs] def __init__( self, node=None, model=None, ): # Top-level model self.model = model # The fx node of the module in the fx graph of the model self.node = node # layers that we have in RTL self.internal_layers = {nn.Linear: "linear", nn.ReLU: "relu"} self.parameters = { "common": {}, "software": {}, "hardware": {}, }
@property def module(self): # The target module in the model # if it is not a "call_module" node, return None if self.node.op == "call_module": return get_module_by_name(self.model, self.node.target) else: return None @property def graph(self): # The fx graph of the model return self.model.graph def __getitem__(self, key): return self.parameters[key] def __setitem__(self, key, value): self.parameters[key] = value