Source code for chop.passes.graph.analysis.add_metadata.add_hardware_metadata
import logging
import toml
import torch
import torch.fx as fx
from chop.ir.graph.mase_graph import MaseGraph
from chop.ir.graph.mase_metadata import MaseMetadata
from chop.nn.modules import GroupedQueryAttention
from chop.passes.graph.analysis.utils import (
get_input_nodes,
get_output_nodes,
)
from chop.passes.graph.utils import get_mase_op, deepgetattr, get_module_by_name
from torch import nn
from .hardware_metadata_layers import INTERNAL_COMP
logger = logging.getLogger(__name__)
# Here we assume each data has up to three dimensions
MAX_DIM = 3
def _cap(name):
"""
capitalize a string
"""
return str(name).upper()
def add_component_source(node):
if node.meta["mase"]["hardware"]["is_implicit"]:
return
node.meta["mase"]["hardware"]["interface"] = {}
mase_op = node.meta["mase"]["common"]["mase_op"]
if mase_op == "user_defined_module":
for custom_op, op_info in node.meta["mase"].model.custom_ops["modules"].items():
if isinstance(
deepgetattr(node.meta["mase"].model, node.target),
custom_op,
):
node.meta["mase"]["hardware"]["toolchain"] = "INTERNAL_RTL"
node.meta["mase"]["hardware"]["module"] = op_info["module"]
node.meta["mase"]["hardware"]["dependence_files"] = op_info[
"dependence_files"
]
elif mase_op in INTERNAL_COMP.keys():
node.meta["mase"]["hardware"]["toolchain"] = "INTERNAL_RTL"
# take the first ip in the component list by default
node.meta["mase"]["hardware"]["module"] = INTERNAL_COMP[mase_op][0]["name"]
node.meta["mase"]["hardware"]["dependence_files"] = INTERNAL_COMP[mase_op][0][
"dependence_files"
]
else:
node.meta["mase"]["hardware"]["toolchain"] = "INTERNAL_HLS"
node.meta["mase"]["hardware"]["module"] = None
node.meta["mase"]["hardware"]["dependence_files"] = []
node.meta["mase"]["hardware"]["device_id"] = -1
# Current only support on-chip parameters
args = node.meta["mase"]["common"]["args"]
for arg, _ in args.items():
if "data_in" in arg:
continue
arg_info = args[arg]
if isinstance(arg_info, dict):
node.meta["mase"]["hardware"]["interface"][arg] = {
"storage": "BRAM",
"transpose": False,
}
else:
node.meta["mase"]["hardware"]["interface"][arg] = {}
def add_verilog_param(node):
if node.meta["mase"]["hardware"]["is_implicit"]:
return
node.meta["mase"]["hardware"]["verilog_param"] = {}
args = node.meta["mase"]["common"]["args"]
results = node.meta["mase"]["common"]["results"]
vp = node.meta["mase"]["hardware"]["verilog_param"]
for arg, arg_info in args.items():
if isinstance(arg_info, dict):
for i, precision in enumerate(arg_info["precision"]):
vp[_cap(arg + f"_precision_{i}")] = arg_info["precision"][i]
for dim in range(0, len(arg_info["shape"])):
vp[_cap(arg + f"_tensor_size_dim_{dim}")] = (
arg_info["shape"][len(arg_info["shape"]) - 1 - dim]
if dim < len(arg_info["shape"])
else 1
)
# Check if max parallelism is defined
if node.meta["mase"]["hardware"]["max_parallelism"] is not None:
# Take the minimum between...
vp[_cap(arg + f"_parallelism_dim_{dim}")] = min(
# The defined max parallelism for this dimension
node.meta["mase"]["hardware"]["max_parallelism"][::-1][dim],
# The size of this dimension
arg_info["shape"][::-1][dim],
)
# Otherwise, assign to tensor size by default
else:
vp[_cap(arg + f"_parallelism_dim_{dim}")] = arg_info["shape"][::-1][
dim
]
elif type(arg_info) == bool:
vp[_cap(arg)] = 1 if arg_info else 0
else:
vp[_cap(arg)] = arg_info
for result, result_info in results.items():
if isinstance(result_info, dict):
for i, precision in enumerate(result_info["precision"]):
vp[_cap(result + f"_precision_{i}")] = result_info["precision"][i]
for dim in range(0, len(result_info["shape"])):
vp[_cap(result + f"_tensor_size_dim_{dim}")] = (
result_info["shape"][len(result_info["shape"]) - 1 - dim]
if dim < len(result_info["shape"])
else 1
)
# Check if max parallelism is defined
if node.meta["mase"]["hardware"]["max_parallelism"] is not None:
# Take the minimum between...
vp[_cap(result + f"_parallelism_dim_{dim}")] = min(
# The defined max parallelism for this dimension
node.meta["mase"]["hardware"]["max_parallelism"][::-1][dim],
# The size of this dimension
result_info["shape"][::-1][dim],
)
# Otherwise, assign to tensor size by default
else:
vp[_cap(result + f"_parallelism_dim_{dim}")] = result_info["shape"][
::-1
][dim]
else:
vp[_cap(result)] = result_info
def add_extra_verilog_param(node, graph: MaseGraph):
"""Adds extra verilog parameters based on the node module type."""
if node.op == "call_module":
module = get_module_by_name(graph.model, node.target)
vp = node.meta["mase"]["hardware"]["verilog_param"]
if isinstance(module, GroupedQueryAttention):
vp["NUM_HEADS"] = module.num_heads
vp["NUM_GROUPS"] = module.num_kv_heads
vp["WEIGHTS_PRE_TRANSPOSED"] = 0 # TODO: support transpose in mase?
vp["HAS_BIAS"] = 1 if module.bias else 0
# Also fix up weight parallelism
vp["Q_PROJECTION_WEIGHT_PARALLELISM_DIM_1"] = vp[
"Q_PROJECTION_WEIGHT_PARALLELISM_DIM_0"
]
vp["K_PROJECTION_WEIGHT_PARALLELISM_DIM_1"] = vp[
"K_PROJECTION_WEIGHT_PARALLELISM_DIM_0"
]
vp["V_PROJECTION_WEIGHT_PARALLELISM_DIM_1"] = vp[
"V_PROJECTION_WEIGHT_PARALLELISM_DIM_0"
]
vp["O_PROJECTION_WEIGHT_PARALLELISM_DIM_1"] = vp[
"O_PROJECTION_WEIGHT_PARALLELISM_DIM_0"
]
[docs]
def add_hardware_metadata_analysis_pass(graph, pass_args=None):
"""add hardware metadata
:param graph: a MaseGraph
:type graph: MaseGraph
:param pass_args: this pass does not need any arguments, defaults to None
:type pass_args: _type_, optional
:return: return a tuple of a MaseGraph and an empty dict (no additional info to return)
:rtype: tuple(MaseGraph, Dict)
The hardware metadata of a Mase node in a Mase graph describes the constraints of the
node for any static analysis or possible transformation. The metadata has a
tree structure, e.g.
- hardware
- is_implicit -> bool : whether the node is mapped on hardware or software annotation only
- verilog_param -> {} : parameters need for customise the hardware module
- toolchain -> str : tool chain for code generation, must be INTERNAL, EXTERNAL or HLS
- module -> str : the name of the used hardware module
- device_id -> int : the ID of the device where the node is mapped, default = -1
- interface -> {}
- name : name of the parameters
- storage : the hardware interface implemented, must be BRAM
- transpose : whether the data needs to be transposed before emitting
- dependence_files -> [] : the dependent files for the generated module
The verilog parameters follow the following naming rules:
- Hardware signal naming rules
- Data with tensor types are explicit as hardware signals, such as weight and bias,
and Data with scalar/tuple types are implicit as parameters (TODO).
- Each op is a node with a set of inputs, outputs and parameters
- The input is named by: data_in_0 (data_in_0_ready, data_in_valid), data_in_1,
- The output is named by: data_out_0 (data_out_0_ready, data_out_valid), data_out_1, ..
- The parameters are named by PyTorch names: weight (weight_ready, weight_valid), bias (bias_ready, bias_valid)
- Hardware parameters naming rules
Parameters with tensor types are explicit as hardware signals, such as weight and bias,
and parameters with scalar/tuple types are implicit as hardware parameters.
- Taking data_in_0 for example:
- `DATA_IN_0_PRECISION_0`
- `DATA_IN_0_PRECISION_1`
- ...
- (depending on how many precision parameters we have.
- The order matches the same order as the mase precision metadata)
- `DATA_IN_0_TENSOR_SIZE_DIM_0`
- `DATA_IN_0_TENSOR_SIZE_DIM_1`
- `DATA_IN_0_TENSOR_SIZE_DIM_2`
- `DATA_IN_0_PARALLELISM_DIM_0`
- `DATA_IN_0_PARALLELISM_DIM_1`
- `DATA_IN_0_PARALLELISM_DIM_2`
- (This means that the number of iterations = tensor_size / spatial_size)
- Implicit parameters are directly translated into verilog parameters, e.g.
STRIDE
DIM
Examples:
A linear layer in a mase graph with the following common metadata:
.. code-block:: shell
%fc1 : [num_users=1] = call_module[target=fc1](args = (%flatten,), kwargs = {})
.. code-block:: JSON
{
"common": {
"mase_type": "module_related_func",
"mase_op": "linear",
"args": {
"data_in_0": {
"shape": [1, 784],
"torch_dtype": torch.float32,
"type": "float",
"precision": [32],
},
"weight": {"type": "float", "precision": [32], "shape": [784, 784]},
"bias": {"type": "float", "precision": [32], "shape": [784]},
},
"results": {
"data_out_0": {
"type": "float",
"precision": [32],
"shape": [1, 784],
"torch_dtype": torch.float32,
}
},
},
"software": {},
"hardware": {},
}
The hardware metadata of the linear layer after this pass:
.. code-block:: JSON
{
"common": {...},
"software": {},
"hardware": {
"is_implicit": False,
"interface": {
"weight": {"storage": "BRAM", "transpose": False},
"bias": {"storage": "BRAM", "transpose": False},
},
"toolchain": "INTERNAL",
"module": "fixed_linear",
"device_id": -1,
"dependence_files": [
"cast/fixed_cast.sv",
"fixed_arithmetic/fixed_dot_product.sv",
"fixed_arithmetic/fixed_vector_mult.sv",
"fixed_arithmetic/register_slice.sv",
"fixed_arithmetic/fixed_accumulator.sv",
"fixed_arithmetic/fixed_adder_tree.sv",
"fixed_arithmetic/fixed_adder_tree_layer.sv",
"fixed_arithmetic/fixed_mult.sv",
"common/join2.sv",
"linear/fixed_linear.sv",
],
"verilog_param": {
"DATA_IN_0_PRECISION_0": 8,
"DATA_IN_0_PRECISION_1": 3,
"DATA_IN_0_TENSOR_SIZE_DIM_0": 1,
"DATA_IN_0_PARALLELISM_DIM_0": 1,
"DATA_IN_0_TENSOR_SIZE_DIM_1": 784,
"DATA_IN_0_PARALLELISM_DIM_1": 784,
"DATA_IN_0_TENSOR_SIZE_DIM_2": 1,
"DATA_IN_0_PARALLELISM_DIM_2": 1,
"WEIGHT_PRECISION_0": 8,
"WEIGHT_PRECISION_1": 3,
"WEIGHT_TENSOR_SIZE_DIM_0": 784,
"WEIGHT_PARALLELISM_DIM_0": 784,
"WEIGHT_TENSOR_SIZE_DIM_1": 784,
"WEIGHT_PARALLELISM_DIM_1": 784,
"WEIGHT_TENSOR_SIZE_DIM_2": 1,
"WEIGHT_PARALLELISM_DIM_2": 1,
"BIAS_PRECISION_0": 8,
"BIAS_PRECISION_1": 3,
"BIAS_TENSOR_SIZE_DIM_0": 784,
"BIAS_PARALLELISM_DIM_0": 784,
"BIAS_TENSOR_SIZE_DIM_1": 1,
"BIAS_PARALLELISM_DIM_1": 1,
"BIAS_TENSOR_SIZE_DIM_2": 1,
"BIAS_PARALLELISM_DIM_2": 1,
"DATA_OUT_0_PRECISION_0": 8,
"DATA_OUT_0_PRECISION_1": 3,
"DATA_OUT_0_TENSOR_SIZE_1_DIM_0": 1,
"DATA_OUT_0_PARALLELISM_1_DIM_0": 1,
"DATA_OUT_0_TENSOR_SIZE_1_DIM_1": 784,
"DATA_OUT_0_PARALLELISM_1_DIM_1": 784,
"DATA_OUT_0_TENSOR_SIZE_1_DIM_2": 1,
"DATA_OUT_0_PARALLELISM_1_DIM_2": 1,
},
},
}
A relu layer in a mase graph with the following common metadata:
.. code-block:: shell
%relu : [num_users=1] = call_function[target=torch.nn.functional.relu](args = (%fc1,), kwargs = {inplace: False})
.. code-block:: JSON
{
"common": {
"mase_type": "module_related_func",
"mase_op": "relu",
"results": {
"data_out_0": {
"type": "float",
"precision": [32],
"shape": [1, 784],
"torch_dtype": torch.float32,
}
},
"args": {
"data_in_0": {
"shape": [1, 784],
"torch_dtype": torch.float32,
"type": "float",
"precision": [32],
},
"inplace": False,
},
},
"software": {},
"hardware": {},
}
The hardware metadata of the relu layer after this pass:
.. code-block:: JSON
{
"common": {...},
"software": {},
"hardware": {
"is_implicit": False,
"interface": {"inplace": {}},
"toolchain": "INTERNAL",
"module": "fixed_relu",
"device_id": -1,
"dependence_files": ["activations/fixed_relu.sv"],
"verilog_param": {
"DATA_IN_0_PRECISION_0": 8,
"DATA_IN_0_PRECISION_1": 3,
"DATA_IN_0_TENSOR_SIZE_DIM_0": 1,
"DATA_IN_0_PARALLELISM_DIM_0": 1,
"DATA_IN_0_TENSOR_SIZE_DIM_1": 784,
"DATA_IN_0_PARALLELISM_DIM_1": 784,
"DATA_IN_0_TENSOR_SIZE_DIM_2": 1,
"DATA_IN_0_PARALLELISM_DIM_2": 1,
"INPLACE": False,
"DATA_OUT_0_PRECISION_0": 8,
"DATA_OUT_0_PRECISION_1": 3,
"DATA_OUT_0_TENSOR_SIZE_1_DIM_0": 1,
"DATA_OUT_0_PARALLELISM_1_DIM_0": 1,
"DATA_OUT_0_TENSOR_SIZE_1_DIM_1": 784,
"DATA_OUT_0_PARALLELISM_1_DIM_1": 784,
"DATA_OUT_0_TENSOR_SIZE_1_DIM_2": 1,
"DATA_OUT_0_PARALLELISM_1_DIM_2": 1,
},
},
}
"""
# Find implicit mase nodes
for node in graph.nodes:
node.meta["mase"]["hardware"]["is_implicit"] = False
node.meta["mase"]["hardware"]["device_id"] = 0
graph.nodes_in = get_input_nodes(graph.fx_graph)
graph.nodes_out = get_output_nodes(graph.fx_graph)
# Add component source
for node in graph.nodes:
add_component_source(node)
# * Fix max parallelism to small value to enable verilator simulation
# ! TO DO: enable this to be overriden by user
for node in graph.nodes:
node.meta["mase"]["hardware"]["max_parallelism"] = pass_args.get(
"max_parallelism", [4] * 4
)
# Add hardware parameters
for node in graph.nodes:
add_verilog_param(node)
add_extra_verilog_param(node, graph)
# Add graph metadata
graph.meta["mase"]["hardware"]["verilog_sources"] = []
for node in graph.nodes:
if node.meta["mase"]["hardware"]["is_implicit"]:
continue
graph.meta["mase"]["hardware"]["verilog_sources"] += node.meta["mase"][
"hardware"
]["dependence_files"]
return graph, {}