import logging
from typing import Tuple, Dict
import math
import os
import time
from multiprocessing import Process, Queue
import torch.fx as fx
from chop.passes.graph.utils import vf, v2p, init_project
import mase_components.helper.generate_memory as gen_lut
import torch.nn as nn
logger = logging.getLogger(__name__)
from .util import get_verilog_parameters
from pathlib import Path
# =============================================================================
# Utilities
# =============================================================================
def _remove_last_comma(string):
return string[0 : string.rfind(",")]
def _cap(name):
"""
capitalize a string
"""
return str(name).upper()
def get_input_name(from_node, to_node):
# Find name of to_node argument that comes from from_node
if from_node == to_node:
return "data_in_0"
for key, val in to_node.meta["mase"].parameters["common"]["args"].items():
if val["from"] == from_node:
return key
assert False, f"Cannot find edge from {from_node.name} to {to_node.name}"
def param_needs_signals(node, param, value, qualifier="data_in"):
# Don't emit if it's function constant, but emit for any data_in or data_out
# And any other parameters with storage interface specified as BRAM
if type(value) != dict:
# Constant function arguments don't have precision/shape info
# TODO: change common metadata so constant arguments are passed with
# is_constant flag
return False
if qualifier in param:
return True
else:
return (
node.meta["mase"].parameters["hardware"]["interface"][param]["storage"]
== "BRAM"
)
def is_real_input_arg(node, arg_idx):
return (
# module parameter arguments are appended after fx args
arg_idx < len(node.args)
# Drop None arguments
and isinstance(node.args[arg_idx], fx.Node)
# Drop arguments that are inputs to this node, but not the whole graph
and node.args[arg_idx].op == "placeholder"
)
# =============================================================================
# Verilog parameters
# =============================================================================
class VerilogParameterEmitter:
def __init__(self, graph):
self.graph = graph
def emit(self, graph, parameter_map) -> Tuple[str, Dict[str, str]]:
"""
Emit parameters at the top-level for the top-level module
Returns Tuple:
1) list of parameters as a string to be embedded in Verilog file
"""
nodes_in = graph.nodes_in
nodes_out = graph.nodes_out
node_in_name = vf(nodes_in[0].name)
node_out_name = vf(nodes_out[0].name)
parameters = ""
# Write node parameters
for key, value in parameter_map.items():
parameters += f""" parameter {key} = {value},\n"""
return _remove_last_comma(parameters)
# =============================================================================
# Verilog interface
# =============================================================================
class VerilogInterfaceEmitter:
def __init__(self, graph):
self.graph = graph
def emit(self, graph, parameter_map):
"""
Emit interface signal declarations for the top-level module
"""
nodes_in = self.graph.nodes_in
nodes_out = self.graph.nodes_out
interface = ""
# TODO: here we just enumerate the inputs of the input nodes - which may be
# order insensitive and require manual connection when adding the graph to
# a system.
i = 0
for node in nodes_in:
node_name = vf(node.name)
for arg_idx, arg in enumerate(
node.meta["mase"].parameters["common"]["args"].keys()
):
if is_real_input_arg(node, arg_idx):
# if "data_in" in arg:
arg_name = _cap(arg)
parallelism_params = [
param
for param in parameter_map
if param.startswith(f"{arg_name}_PARALLELISM_DIM")
]
interface += f"""
input [{arg_name}_PRECISION_0-1:0] data_in_{i} [{'*'.join(parallelism_params)}-1:0],
input data_in_{i}_valid,
output data_in_{i}_ready,"""
i += 1
i = 0
for node in nodes_out:
node_name = vf(node.name)
for result in node.meta["mase"].parameters["common"]["results"].keys():
if "data_out" in result:
result_name = _cap(result)
parallelism_params = [
param
for param in parameter_map
if param.startswith(f"{result_name}_PARALLELISM_DIM")
]
interface += f"""
output [{result_name}_PRECISION_0-1:0] data_out_{i} [{'*'.join(parallelism_params)}-1:0],
output data_out_{i}_valid,
input data_out_{i}_ready,"""
i += 1
# TODO: emit off-chip parameter interface
return _remove_last_comma(interface)
# =============================================================================
# Verilog signals
# =============================================================================
class VerilogSignalEmitter:
def __init__(self, graph):
self.graph = graph
def _emit_signals_top_internal(self, node, parameter_map):
signals = ""
node_name = vf(node.name)
# Input signals
for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items():
if not isinstance(arg_info, dict):
continue
# Skip off-chip parameters as they will be directly connected to the top level
if (
"data_in" in arg
or node.meta["mase"].parameters["hardware"]["interface"][arg]["storage"]
== "BRAM"
):
arg_name = v2p(arg)
parallelism_params = [
param
for param in parameter_map
if f"{node_name}_{arg_name}_PARALLELISM_DIM" in param
]
# Getitem argument always get mapped to port 0 irrespective of
# actual argument index
if node.meta["mase"]["common"]["mase_op"] == "getitem":
arg = "data_in_0"
signals += f"""
logic [{node_name}_{arg_name}_PRECISION_0-1:0] {node_name}_{arg} [{'*'.join(parallelism_params)}-1:0];
logic {node_name}_{arg}_valid;
logic {node_name}_{arg}_ready;"""
# Output signals
for result, result_info in (
node.meta["mase"].parameters["common"]["results"].items()
):
if not isinstance(result_info, dict):
continue
# Skip off-chip parameters as they will be directly connected to the top level
if (
"data_out" in result
or node.meta["mase"].parameters["hardware"]["interface"][result][
"storage"
]
== "BRAM"
):
result_name = v2p(result)
parallelism_params = [
param
for param in parameter_map
if f"{node_name}_{result_name}_PARALLELISM_DIM" in param
]
signals += f"""
logic [{node_name}_{result_name}_PRECISION_0-1:0] {node_name}_{result} [{'*'.join(parallelism_params)}-1:0];
logic {node_name}_{result}_valid;
logic {node_name}_{result}_ready;"""
return signals
def _emit_signals_top_hls(self, node, parameter_map):
"""
TODO
"""
node_name = vf(node.name)
# Control signals for HLS component
signals = f"""
logic {node_name}_start;
logic {node_name}_done;
logic {node_name}_idle;
logic {node_name}_ready;
logic {node_name}_ce;"""
# Input signals
for key, value in node.meta["mase"].parameters["common"]["args"].items():
# No internal signals if the memory is stored off chip
if not param_needs_signals(node, key, value, qualifier="data_in"):
continue
cap_key = v2p(key)
size = math.prod(value["shape"])
if key != "data_in":
a_width = math.ceil(math.log2(size))
else:
depth = parameter_map[f"{node_name}_{cap_key}_DEPTH"]
a_width = math.ceil(math.log2(depth * size))
signals += f"""
logic [{node_name}_{cap_key}_PRECISION_0-1:0] {node_name}_{key}_q0;
logic [{a_width}-1:0] {node_name}_{key}_address0;
logic {node_name}_{key}_ce0;"""
# Output signals
for key, value in node.meta["mase"].parameters["common"]["results"].items():
# No internal signals if the memory is stored off chip
if not param_needs_signals(node, key, value, qualifier="data_out"):
continue
cap_key = v2p(key)
size = math.prod(value["shape"])
a_width = math.ceil(math.log2(size))
signals += f"""
logic [{node_name}_{cap_key}_PRECISION_0-1:0] {node_name}_{key}_d0;
logic [{a_width}-1:0] {node_name}_{key}_address0;
logic {node_name}_{key}_ce0;
logic {node_name}_{key}_we0;"""
return signals
def emit(self, graph, parameter_map):
"""
Emit internal signal declarations for the top-level module
"""
signals = ""
for node in graph.fx_graph.nodes:
if node.meta["mase"].parameters["hardware"]["is_implicit"]:
continue
node_name = vf(node.name)
signals += f"""
// --------------------------
// {node_name} signals
// --------------------------"""
if "INTERNAL" in node.meta["mase"].parameters["hardware"]["toolchain"]:
signals += self._emit_signals_top_internal(node, parameter_map)
elif node.meta["mase"].parameters["hardware"]["toolchain"] == "HLS":
signals += self._emit_signals_top_hls(node, parameter_map)
else:
assert False, "Unknown node toolchain for signal declarations."
return signals
# =============================================================================
# Verilog components (INTERNAL)
# =============================================================================
class VerilogInternalComponentEmitter:
def __init__(self, graph):
self.graph = graph
def _emit_module_parameters_top_internal(self, key, value, node, parameter_map):
node_name = vf(node.name)
component_name = f"{node_name}_{key}_source"
component_name_inst = f"{component_name}_0"
parameters = ""
for param in node.meta["mase"].parameters["hardware"]["verilog_param"].keys():
if f"{_cap(key)}_" in param:
parameters += f" .{param}({node_name}_{param}),\n"
parameters = _remove_last_comma(parameters)
return f"""
{component_name} #(
{parameters}
) {component_name_inst} (
.clk(clk),
.rst(rst),
.data_out({node_name}_{key}),
.data_out_ready({node_name}_{key}_ready),
.data_out_valid({node_name}_{key}_valid)
);
"""
def _emit_getitem_signals(self, node):
"""
Getitem nodes have arg list like (None, None, None, Arg, None, None)
where the meaningful arg is at an arbitrary index, but always maps to
data_in_0 interface of the hardware
"""
node_name = vf(node.name)
return f"""
.data_in_0 ({node_name}_data_in_0),
.data_in_0_valid ({node_name}_data_in_0_valid),
.data_in_0_ready ({node_name}_data_in_0_ready),
.data_out_0 ({node_name}_data_out_0),
.data_out_0_valid ({node_name}_data_out_0_valid),
.data_out_0_ready ({node_name}_data_out_0_ready),
"""
def emit(self, node, parameter_map):
node_name = vf(node.name)
component_name = node.meta["mase"].parameters["hardware"]["module"]
signals = ""
# Emit component instantiation parameters
parameters = ""
for key, value in (
node.meta["mase"].parameters["hardware"]["verilog_param"].items()
):
if value is None:
continue
key_value = parameter_map[f"{node_name}_{key}"]
debug_info = f"// = {key_value}"
parameters += f""" .{key}({node_name}_{key}), {debug_info}\n"""
parameters = _remove_last_comma(parameters)
# Handle getitem nodes separately since an arbitrary argument index
# will always be mapped to data_in_0 interface of the hardware
if node.meta["mase"]["common"]["mase_op"] == "getitem":
signals += self._emit_getitem_signals(node)
# All other node types
else:
# Emit component instantiation input signals
for key, value in node.meta["mase"].parameters["common"]["args"].items():
if "inplace" in key or not isinstance(value, dict):
continue
signals += f"""
.{key}({node_name}_{key}),
.{key}_valid({node_name}_{key}_valid),
.{key}_ready({node_name}_{key}_ready),
"""
# Emit component instantiation output signals
for key, value in node.meta["mase"].parameters["common"]["results"].items():
signals += f"""
.{key}({node_name}_{key}),
.{key}_valid({node_name}_{key}_valid),
.{key}_ready({node_name}_{key}_ready),
"""
# Remove final comma in signal list
signals = _remove_last_comma(signals)
# Combine component instantiation
components = f"""
// {node_name}
{component_name} #(
{parameters}
) {node_name}_inst (
.clk(clk),
.rst(rst),
{signals}
);
"""
# Emit module parameter instances (e.g. weights and biases)
for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items():
if "data_in" in arg:
continue
if not isinstance(arg_info, dict):
continue
components += self._emit_module_parameters_top_internal(
arg, arg_info, node, parameter_map
)
return components
# =============================================================================
# Verilog components (HLS)
# =============================================================================
class VerilogHLSComponentEmitter:
def __init__(self, graph):
self.graph = graph
def _emit_module_parameters_top_hls(self, key, value, node, parameter_map):
node_name = vf(node.name)
cap_key = v2p(key)
component_name = f"{node_name}_{key}_source"
component_name_inst = f"{node_name}_{key}_0"
size_debug_info = math.prod(value["shape"])
a_width = math.ceil(math.log2(size_debug_info))
return f"""
{component_name} #(
.DATA_WIDTH({node_name}_{cap_key}_PRECISION_0),
.ADDR_RANGE({node_name}_{cap_key}_TENSOR_SIZE_0),
.ADDR_WIDTH({a_width})
) {component_name_inst} (
.clk(clk),
.reset(rst),
.address0({node_name}_{key}_address0),
.ce0({node_name}_{key}_ce0),
.q0({node_name}_{key}_q0)
);
"""
def emit(self, node, parameter_map):
node_name = vf(node.name)
component_name = node.meta["mase"].parameters["hardware"]["module"]
# Emit kernel instance
signals = ""
for key, value in node.meta["mase"].parameters["common"]["args"].items():
signals += f"""
.{key}_address0({node_name}_{key}_address0),
.{key}_ce0({node_name}_{key}_ce0),
.{key}_q0({node_name}_{key}_q0),
"""
for key, value in node.meta["mase"].parameters["common"]["results"].items():
signals += f"""
.{key}_address0({node_name}_{key}_address0),
.{key}_ce0({node_name}_{key}_ce0),
.{key}_we0({node_name}_{key}_we0),
.{key}_d0({node_name}_{key}_d0),
"""
signals = _remove_last_comma(signals)
components = f"""
// {node_name}
{component_name} #(
) {node_name}_inst (
.ap_clk(clk),
.ap_rst(rst),
.ap_start({node_name}_start),
.ap_idle({node_name}_idle),
.ap_ready({node_name}_ready),
.ap_done({node_name}_done),
.ap_ce({node_name}_ce),
{signals}
);
"""
# Emit parameter instance
for key, value in node.meta["mase"].parameters["common"]["args"].items():
# Skip the parameter instance if the memory is stored off chip
if not param_needs_signals(node, key, value, qualifier="data_in"):
continue
components += self._emit_module_parameters_top_hls(
key, value, node, parameter_map
)
for key, value in node.meta["mase"].parameters["common"]["results"].items():
# Skip the parameter instance if the memory is stored off chip
if not param_needs_signals(node, key, value, qualifier="data_out"):
continue
components += self._emit_module_parameters_top_hls(
key, value, node, parameter_map
)
return components
# =============================================================================
# Verilog components
# =============================================================================
class VerilogComponentEmitter:
def __init__(self, graph):
self.graph = graph
self.internal_emitter = VerilogInternalComponentEmitter(graph)
self.hls_emitter = VerilogHLSComponentEmitter(graph)
def emit(self, graph, parameter_map):
"""
Emit component declarations for the top-level module
"""
components = """
// --------------------------
// Component instantiation
// --------------------------
"""
for node in graph.fx_graph.nodes:
if node.meta["mase"].parameters["hardware"]["is_implicit"]:
continue
if "INTERNAL" in node.meta["mase"].parameters["hardware"]["toolchain"]:
components += self.internal_emitter.emit(node, parameter_map)
elif node.meta["mase"].parameters["hardware"]["toolchain"] == "HLS":
components += self.hls_emitter.emit(node, parameter_map)
else:
assert False, "Unknown node toolchain for signal declarations."
return components
# =============================================================================
# Verilog wires
# =============================================================================
class VerilogWireEmitter:
def __init__(self, graph, parameter_map):
self.graph = graph
self.parameter_map = parameter_map
self.wires = """
// --------------------------
// Interconnections
// --------------------------
"""
def _emit_top_wires(self):
nodes_in = self.graph.nodes_in
nodes_out = self.graph.nodes_out
# ============================================================
# Top level wires
# ============================================================
wires = ""
# TODO: here we just enumerate the inputs of the input nodes - which may be
# order insensitive and require manual connection when adding the graph to
# a system.
i = 0
for node in nodes_in:
node_name = vf(node.name)
for arg_idx, arg in enumerate(
node.meta["mase"].parameters["common"]["args"].keys()
):
if is_real_input_arg(node, arg_idx):
wires += f"""
assign data_in_{i}_ready = {node_name}_{arg}_ready;
assign {node_name}_{arg}_valid = data_in_{i}_valid;
assign {node_name}_{arg} = data_in_{i};
"""
i += 1
i = 0
for node in nodes_out:
node_name = vf(node.name)
for result in node.meta["mase"].parameters["common"]["results"].keys():
if "data_out" in result:
wires += f"""
assign data_out_{i}_valid = {node_name}_{result}_valid;
assign {node_name}_{result}_ready = data_out_{i}_ready;
assign data_out_{i} = {node_name}_{result};
"""
i += 1
# TODO: emit off-chip parameter interface
return wires
def _emit_getitem_wires(self, node):
"""
Getitem nodes may receive an output from an arbitrary index of the parent node,
which is always driven to port 0 of the getitem node
"""
from_name = vf(node.args[0].name)
to_name = vf(node.name)
select = node.args[1]
return f"""
assign {from_name}_data_out_{select}_ready = {to_name}_data_in_0_ready;
assign {to_name}_data_in_0_valid = {from_name}_data_out_{select}_valid;
assign {to_name}_data_in_0 = {from_name}_data_out_{select};
"""
def _emit_node2node_wires(self):
nodes_in = self.graph.nodes_in
wires = ""
for node in self.graph.fx_graph.nodes:
if (
# Skip implicit nodes
node.meta["mase"].parameters["hardware"]["is_implicit"]
# Input nodes were already connected by the previous process
or node in nodes_in
):
continue
# Getitem nodes are handled separately
if node.meta["mase"]["common"]["mase_op"] == "getitem":
wires += self._emit_getitem_wires(node)
continue
to_name = vf(node.name)
for i, node_in in enumerate(node.all_input_nodes):
from_name = vf(node_in.name)
wires += f"""
assign {from_name}_data_out_0_ready = {to_name}_data_in_{i}_ready;
assign {to_name}_data_in_{i}_valid = {from_name}_data_out_0_valid;
assign {to_name}_data_in_{i} = {from_name}_data_out_0;
"""
return wires
def emit(self):
"""
Emit internal signal connections for the top-level module
This includes two interconnection types:
1. Type casting between inputs and outputs
2. Interface casting between inputs and outputs
"""
self.wires += self._emit_top_wires()
self.wires += self._emit_node2node_wires()
return self.wires
# =============================================================================
# Emit Verilog
# =============================================================================
class VerilogEmitter:
def __init__(self, graph):
self.graph = graph
self.parameter_map = get_verilog_parameters(graph)
def emit(self, graph, top_name):
parameters_to_emit = VerilogParameterEmitter(graph).emit(
graph, self.parameter_map
)
interface_to_emit = VerilogInterfaceEmitter(graph).emit(
graph, self.parameter_map
)
signals_to_emit = VerilogSignalEmitter(graph).emit(graph, self.parameter_map)
components_to_emit = VerilogComponentEmitter(graph).emit(
graph, self.parameter_map
)
wires_to_emit = VerilogWireEmitter(graph, self.parameter_map).emit()
time_to_emit = time.strftime("%d/%m/%Y %H:%M:%S")
module_inst = """
// =====================================
// Mase Hardware
// Model: {}
// {}
// =====================================
`timescale 1ns/1ps
module {} #(
{}
) (
input clk,
input rst,
{}
);
{}
{}
{}
endmodule
""".format(
top_name,
time_to_emit,
top_name,
parameters_to_emit,
interface_to_emit,
signals_to_emit,
components_to_emit,
wires_to_emit,
)
return module_inst