# A pass to convert a MASE graph to ONNX and annotate the relevant layers with sparsity
# information. The code here is derived from Zhewen's SparseCNN codebase:
# https://github.com/Yu-Zhewen/sparseCNN/blob/main/onnx_sparsity_attribute.py
import logging
from collections import OrderedDict
import torch.nn as nn
import torch
import onnx
import toml
from pathlib import Path
# Housekeeping -------------------------------------------------------------------------
logger = logging.getLogger(__file__)
# logger.propagate = False  # Avoid duplicate logs
# https://github.com/Xilinx/finn-base/blob/dev/src/finn/custom_op/base.py
def _set_nodeattr(node, name: str, value):
    # NOTE We do not check if the attribute already exists.
    attr = onnx.helper.make_attribute(name, value)
    node.attribute.append(attr)
def _annotate_quantisation(model, ww: int, dw: int, aw: int, bfp: bool):
    for node in model.graph.node:
        if node.op_type in ["Conv", "Gemm"]:
            _set_nodeattr(node, "weight_width", ww)
            _set_nodeattr(node, "acc_width", aw)
            _set_nodeattr(node, "block_floating_point", bfp)
        _set_nodeattr(node, "data_width", dw)
    logger.info("Quantisation annotation complete.")
def _annotate_sparsity_from_toml(model: onnx.ModelProto, path: Path):
    with open(path) as f:
        data = toml.load(f)
    # NOTE: Here, we're solely relying on the coherence of the TOML file and the ONNX
    # graph in terms of the order of the layers. This works for our use case.
    iterator = iter(data.items())
    for node in model.graph.node:
        if node.op_type == "Conv":
            info = next(iterator)
            sparsity_data = info[1]["avg"]
            logger.info(f"Annotating {node.name} - {info[0]}")
            _set_nodeattr(node, "input sparsity", sparsity_data)
    logger.info("Layer sparsity annotation complete.")
def _torch_onnx_exporter(model: nn.Module, input: torch.Tensor, path: Path):
    # We need to replace certain layers in the model for compatibility reasons
    replace_dict = {}
    for module in model.modules():
        # TODO: We need to add a clip node. This is a temporary fix. There may be other
        # unsupported layers as well.
        if isinstance(module, nn.ReLU6):
            replace_dict[module] = nn.ReLU()
    _replace_modules(model, replace_dict)
    # Export the model to the specified path in ONNX format. :)
    kwargs = {"verbose": False, "keep_initializers_as_inputs": True}
    torch.onnx.export(model, input, path, **kwargs)
# Substitutes certain layers in the model with their chosen replacements
def _replace_modules(model: nn.Module, replace_dict: dict):
    for module in model.modules():
        for name, submodule in module.named_children():
            if submodule in replace_dict.keys():
                new_submodule = replace_dict[submodule]
                assert hasattr(module, name)
                setattr(module, name, new_submodule)