import logging
import math
import os
import struct
import time
import torch
from chop.passes.graph.utils import vf, v2p, get_module_by_name, init_project
from chop.nn.quantizers import integer_quantizer_for_hw
logger = logging.getLogger(__name__)
from pathlib import Path
def iceil(x):
return int(math.ceil(x))
def clog2(x):
return iceil(math.log2(x))
def _cap(name):
"""
capitalize a string
"""
return str(name).upper()
def emit_parameters_in_mem_internal(node, param_name, file_name, data_name):
"""
Emit single-port ROM hardware components for each parameter
(Mostly because Vivado does not support string type parameters...)
"""
# ! TO DO: currently emitting too many parameters
verilog_param_name = param_name.replace(".", "_")
total_size = math.prod(
node.meta["mase"].parameters["common"]["args"][verilog_param_name]["shape"]
)
# TO DO: change setting parallelism for weight in metadata
# node.meta["mase"].parameters["hardware"]["verilog_param"][f"{_cap(param_name)}_PARALLELISM_DIM_1"]
out_size = int(
node.meta["mase"].parameters["hardware"]["verilog_param"][
f"{_cap(verilog_param_name)}_PARALLELISM_DIM_0"
]
* node.meta["mase"].parameters["hardware"]["verilog_param"][
f"{_cap(verilog_param_name)}_PARALLELISM_DIM_1"
]
)
out_depth = int(total_size / out_size)
out_width = int(
node.meta["mase"].parameters["common"]["args"][verilog_param_name]["precision"][
0
]
)
addr_width = clog2(out_depth) + 1
node_param_name = f"{vf(node.name)}_{verilog_param_name}"
rom_verilog = f"""
// =====================================
// Mase Hardware
// Parameter: {node_param_name}
// {time.strftime('%d/%m/%Y %H:%M:%S')}
// =====================================
`timescale 1 ns / 1 ps
module {node_param_name}_rom #(
parameter DWIDTH = {out_size*out_width},
parameter MEM_SIZE = {out_depth},
parameter AWIDTH = $clog2(MEM_SIZE) + 1
) (
input clk,
input logic [AWIDTH-1:0] addr0,
input ce0,
output logic [DWIDTH-1:0] q0
);
logic [DWIDTH-1:0] ram[0:MEM_SIZE-1];
logic [DWIDTH-1:0] q0_t0;
logic [DWIDTH-1:0] q0_t1;
initial begin
$readmemh("{data_name}", ram);
end
assign q0 = q0_t1;
always_ff @(posedge clk) if (ce0) q0_t1 <= q0_t0;
always_ff @(posedge clk) if (ce0) q0_t0 <= ram[addr0];
endmodule
`timescale 1 ns / 1 ps
module {node_param_name} #(
parameter DATA_WIDTH = 32'd{out_width*out_size},
parameter ADDR_RANGE = 32'd{out_depth},
parameter ADDR_WIDTH = $clog2(ADDR_RANGE) + 1
) (
input reset,
input clk,
input logic [ADDR_WIDTH - 1:0] address0,
input ce0,
output logic [DATA_WIDTH - 1:0] q0
);
{node_param_name}_rom {node_param_name}_rom_U (
.clk(clk),
.addr0(address0),
.ce0(ce0),
.q0(q0)
);
endmodule
`timescale 1ns / 1ps
module {node_param_name}_source #(
parameter {_cap(verilog_param_name)}_TENSOR_SIZE_DIM_0 = 32,
parameter {_cap(verilog_param_name)}_TENSOR_SIZE_DIM_1 = 1,
parameter {_cap(verilog_param_name)}_PRECISION_0 = 16,
parameter {_cap(verilog_param_name)}_PRECISION_1 = 3,
parameter {_cap(verilog_param_name)}_PARALLELISM_DIM_0 = 1,
parameter {_cap(verilog_param_name)}_PARALLELISM_DIM_1 = 1,
parameter OUT_DEPTH = {_cap(verilog_param_name)}_TENSOR_SIZE_DIM_0 / {_cap(verilog_param_name)}_PARALLELISM_DIM_0
) (
input clk,
input rst,
output logic [{_cap(verilog_param_name)}_PRECISION_0-1:0] data_out [{_cap(verilog_param_name)}_PARALLELISM_DIM_0 * {_cap(verilog_param_name)}_PARALLELISM_DIM_1-1:0],
output data_out_valid,
input data_out_ready
);
// 1-bit wider so IN_DEPTH also fits.
localparam COUNTER_WIDTH = $clog2(OUT_DEPTH);
logic [COUNTER_WIDTH:0] counter;
always_ff @(posedge clk)
if (rst) counter <= 0;
else begin
if (data_out_ready) begin
if (counter == OUT_DEPTH - 1) counter <= 0;
else counter <= counter + 1;
end
end
logic ce0;
assign ce0 = 1;
logic [{_cap(verilog_param_name)}_PRECISION_0*{_cap(verilog_param_name)}_TENSOR_SIZE_DIM_0-1:0] data_vector;
{node_param_name} #(
.DATA_WIDTH({_cap(verilog_param_name)}_PRECISION_0 * {_cap(verilog_param_name)}_TENSOR_SIZE_DIM_0),
.ADDR_RANGE(OUT_DEPTH)
) {node_param_name}_mem (
.clk(clk),
.reset(rst),
.address0(counter),
.ce0(ce0),
.q0(data_vector)
);
// Cocotb/verilator does not support array flattening, so
// we need to manually add some reshaping process.
for (genvar j = 0; j < {_cap(verilog_param_name)}_PARALLELISM_DIM_0 * {_cap(verilog_param_name)}_PARALLELISM_DIM_1; j++)
assign data_out[j] = data_vector[{_cap(verilog_param_name)}_PRECISION_0*j+{_cap(verilog_param_name)}_PRECISION_0-1:{_cap(verilog_param_name)}_PRECISION_0*j];
assign data_out_valid = 1;
endmodule
"""
with open(file_name, "w", encoding="utf-8") as outf:
outf.write(rom_verilog)
logger.debug(
f"ROM module {verilog_param_name} successfully written into {file_name}"
)
assert os.path.isfile(file_name), "ROM Verilog generation failed."
# os.system(f"verible-verilog-format --inplace {file_name}")
def emit_parameters_in_dat_internal(node, param_name, file_name):
"""
Emit initialised data for the ROM block. Each element must be in 8 HEX digits.
"""
verilog_param_name = param_name.replace(".", "_")
total_size = math.prod(
node.meta["mase"].parameters["common"]["args"][verilog_param_name]["shape"]
)
# TO DO: change setting parallelism for weight in metadata
# node.meta["mase"].parameters["hardware"]["verilog_param"][f"{_cap(param_name)}_PARALLELISM_DIM_1"]
out_size = int(
node.meta["mase"].parameters["hardware"]["verilog_param"][
f"{_cap(verilog_param_name)}_PARALLELISM_DIM_0"
]
* node.meta["mase"].parameters["hardware"]["verilog_param"][
f"{_cap(verilog_param_name)}_PARALLELISM_DIM_1"
]
)
out_depth = int(total_size / out_size)
data_buff = ""
param_data = node.meta["mase"].module.get_parameter(param_name).data
if node.meta["mase"].parameters["hardware"]["interface"][verilog_param_name][
"transpose"
]:
param_data = torch.reshape(
param_data,
(
node.meta["mase"].parameters["hardware"]["verilog_param"][
"DATA_OUT_0_SIZE"
],
node.meta["mase"].parameters["hardware"]["verilog_param"][
"DATA_IN_0_DEPTH"
],
node.meta["mase"].parameters["hardware"]["verilog_param"][
"DATA_IN_0_SIZE"
],
),
)
param_data = torch.transpose(param_data, 0, 1)
param_data = torch.flatten(param_data).tolist()
if (
node.meta["mase"].parameters["common"]["args"][verilog_param_name]["type"]
== "fixed"
):
width = node.meta["mase"].parameters["common"]["args"][verilog_param_name][
"precision"
][0]
frac_width = node.meta["mase"].parameters["common"]["args"][verilog_param_name][
"precision"
][1]
scale = 2**frac_width
thresh = 2**width
for i in range(0, out_depth):
line_buff = ""
for j in range(0, out_size):
value = param_data[i * out_size + out_size - 1 - j]
value = integer_quantizer_for_hw(
torch.tensor(value), width, frac_width
).item()
value = str(bin(int(value * scale) % thresh))
value_bits = value[value.find("0b") + 2 :]
value_bits = "0" * (width - len(value_bits)) + value_bits
assert len(value_bits) == width
line_buff += value_bits
hex_buff = hex(int(line_buff, 2))
data_buff += hex_buff[hex_buff.find("0x") + 2 :] + "\n"
else:
assert False, "Emitting non-fixed parameters is not supported."
with open(file_name, "w", encoding="utf-8") as outf:
outf.write(data_buff)
logger.debug(f"Init data {param_name} successfully written into {file_name}")
assert os.path.isfile(file_name), "ROM data generation failed."
def emit_parameters_in_dat_hls(node, param_name, file_name):
"""
Emit initialised data for the ROM block. Each element must be in 8 HEX digits.
"""
total_size = math.prod(
node.meta["mase"].parameters["common"]["args"][param_name]["shape"]
)
out_depth = total_size
out_size = 1
out_width = node.meta["mase"].parameters["hardware"]["verilog_param"][
"{}_WIDTH".format(param_name.upper())
]
data_buff = ""
param_data = node.meta["mase"].module.get_parameter(param_name).data
param_data = torch.flatten(param_data).tolist()
if node.meta["mase"].parameters["common"]["args"][param_name]["type"] == "fixed":
width = node.meta["mase"].parameters["common"]["args"][param_name]["precision"][
0
]
frac_width = node.meta["mase"].parameters["common"]["args"][param_name][
"precision"
][1]
scale = 2**frac_width
thresh = 2**width
for i in range(0, out_depth):
line_buff = ""
for j in range(0, out_size):
value = param_data[i * out_size + out_size - 1 - j]
value = integer_quantizer_for_hw(
torch.tensor(value), width, frac_width
).item()
value = str(bin(int(value * scale) % thresh))
value_bits = value[value.find("0b") + 2 :]
value_bits = "0" * (width - len(value_bits)) + value_bits
assert len(value_bits) == width
line_buff += value_bits
hex_buff = hex(int(line_buff, 2))
data_buff += hex_buff[hex_buff.find("0x") + 2 :] + "\n"
elif node.meta["mase"].parameters["common"]["args"][param_name]["type"] == "float":
width = node.meta["mase"].parameters["common"]["args"][param_name]["precision"][
0
]
assert width == 32, "Only float32 is supported for now."
for i in range(0, out_depth):
line_buff = ""
value = param_data[i]
hex_buff = hex(struct.unpack("<I", struct.pack("<f", value))[0])
# Double will then be:
# hex(struct.unpack('<Q', struct.pack('<d', value))[0])
data_buff += hex_buff[hex_buff.find("0x") + 2 :] + "\n"
else:
assert False, "Emitting unknown type of parameters is not supported."
with open(file_name, "w", encoding="utf-8") as outf:
outf.write(data_buff)
logger.debug(f"Init data {param_name} successfully written into {file_name}")
assert os.path.isfile(file_name), "ROM data generation failed."
def emit_bram_handshake(node, rtl_dir):
"""
Enumerate input parameters of the internal node and emit a ROM block
with handshake interface for each parameter
"""
node_name = vf(node.name)
for param_name, parameter in node.meta["mase"].module.named_parameters():
param_verilog_name = param_name.replace(".", "_")
if (
node.meta["mase"].parameters["hardware"]["interface"][param_verilog_name][
"storage"
]
== "BRAM"
):
logger.debug(
f"Emitting DAT file for node: {node_name}, parameter: {param_verilog_name}"
)
verilog_name = os.path.join(
rtl_dir, f"{node_name}_{param_verilog_name}_source.sv"
)
data_name = os.path.join(
rtl_dir, f"{node_name}_{param_verilog_name}_rom.dat"
)
emit_parameters_in_mem_internal(node, param_name, verilog_name, data_name)
emit_parameters_in_dat_internal(node, param_name, data_name)
else:
assert False, "Emtting parameters in non-BRAM hardware is not supported."
def emit_parameters_in_mem_hls(node, param_name, file_name, data_name):
"""
Emit single-port ROM hardware components for each parameter
(Mostly because Vivado does not support string type parameters...)
"""
# The depth of parameters matches with the input depth
total_size = math.prod(
node.meta["mase"].parameters["common"]["args"][param_name]["shape"]
)
out_depth = total_size
addr_width = clog2(out_depth) + 1
total_size = math.prod(
node.meta["mase"].parameters["common"]["args"][param_name]["shape"]
)
out_size = iceil(total_size / out_depth)
assert (
total_size % out_depth == 0
), f"Cannot partition imperfect size for now = {total_size} / {out_depth}."
# Assume the first index is the total width
out_width = node.meta["mase"].parameters["hardware"]["verilog_param"][
"{}_WIDTH".format(param_name.upper())
]
node_name = vf(node.name)
node_param_name = f"{node_name}_{param_name}"
time_to_emit = time.strftime("%d/%m/%Y %H:%M:%S")
rom_verilog = f"""
// =====================================
// Mase Hardware
// Parameter: {node_param_name}
// {time_to_emit}
// =====================================
`timescale 1 ns / 1 ps
module {node_param_name}_rom #(
parameter DWIDTH = {out_size*out_width},
parameter MEM_SIZE = {out_depth},
parameter AWIDTH = $clog2(MEM_SIZE) + 1
) (
input clk,
input logic [AWIDTH-1:0] addr0,
input ce0,
output logic [DWIDTH-1:0] q0
);
logic [DWIDTH-1:0] ram[0:MEM_SIZE-1];
logic [DWIDTH-1:0] q0_t0;
logic [DWIDTH-1:0] q0_t1;
initial begin
$readmemh("{data_name}", ram);
end
assign q0 = q0_t1;
always_ff @(posedge clk) if (ce0) q0_t1 <= q0_t0;
always_ff @(posedge clk) if (ce0) q0_t0 <= ram[addr0];
endmodule
`timescale 1 ns / 1 ps
module {node_param_name}_source #(
parameter DATA_WIDTH = 32'd{out_width*out_size},
parameter ADDR_RANGE = 32'd{out_depth},
parameter ADDR_WIDTH = $clog2(ADDR_RANGE) + 1
) (
input reset,
input clk,
input logic [ADDR_WIDTH - 1:0] address0,
input ce0,
output logic [DATA_WIDTH - 1:0] q0
);
{node_param_name}_rom {node_param_name}_rom_U (
.clk(clk),
.addr0(address0),
.ce0(ce0),
.q0(q0)
);
endmodule
"""
with open(file_name, "w", encoding="utf-8") as outf:
outf.write(rom_verilog)
logger.debug(f"ROM module {param_name} successfully written into {file_name}")
assert os.path.isfile(file_name), "ROM Verilog generation failed."
# os.system(f"verible-verilog-format --inplace {file_name}")
def emit_bram_hls(node, rtl_dir):
"""
Enumerate input parameters of the hls node and emit a ROM block
with handshake interface for each parameter
"""
node_name = vf(node.name)
for param_name, parameter in node.meta["mase"].module.named_parameters():
if (
node.meta["mase"].parameters["hardware"]["interface"][param_name]["storage"]
== "BRAM"
):
# Verilog code of the ROM has been emitted using mlir passes
verilog_name = os.path.join(rtl_dir, f"{node_name}_{param_name}.sv")
data_name = os.path.join(rtl_dir, f"{node_name}_{param_name}_rom.dat")
emit_parameters_in_mem_hls(node, param_name, verilog_name, data_name)
emit_parameters_in_dat_hls(node, param_name, data_name)
else:
assert False, "Emtting parameters in non-BRAM hardware is not supported."