Source code for chop.passes.graph.analysis.verify.verify

import logging

from chop.passes.graph.utils import vf

from .common_metadata_layers import (
    verify_common_metadata_flatten,
    verify_common_metadata_general,
    verify_common_metadata_input,
    verify_common_metadata_linear,
    verify_common_metadata_output,
    verify_common_metadata_relu,
)
from .hardware_metadata_layers import (
    verify_hardware_metadata_general,
    verify_hardware_metadata_linear,
    verify_hardware_metadata_relu,
)

logger = logging.getLogger(__name__)


def verify_node_common_metadata(node):
    """
    Verify the common metadata of a node.

    This function checks the common metadata of a node and performs specific verification based on the mase_op parameter.

    :param node: The node to verify.
    :type node: Node
    :raises ValueError: If the mase_op parameter is unknown.
    """
    verify_common_metadata_general(node.meta["mase"])

    if node.meta["mase"].parameters["common"]["mase_op"] == "placeholder":
        verify_common_metadata_input(node.meta["mase"])
    elif node.meta["mase"].parameters["common"]["mase_op"] == "output":
        verify_common_metadata_output(node.meta["mase"])
    elif node.meta["mase"].parameters["common"]["mase_op"] == "linear":
        verify_common_metadata_linear(node.meta["mase"])
    elif node.meta["mase"].parameters["common"]["mase_op"] == "relu":
        verify_common_metadata_relu(node.meta["mase"])
    elif node.meta["mase"].parameters["common"]["mase_op"] == "flatten":
        verify_common_metadata_flatten(node.meta["mase"])
    elif node.meta["mase"].parameters["common"]["mase_op"] == "constant":
        # Add specific verification for constant operation
        pass
        verify_common_metadata_input(node.meta["mase"])
    else:
        raise ValueError(
            "Unknown mase op: {}".format(
                node.meta["mase"].parameters["common"]["mase_op"]
            )
        )


[docs] def verify_common_metadata_analysis_pass(graph, pass_args: dict = {}): """Verify pass for mase graph This pass is used for verification of MaseGraph. It does sanity checks the common metadata of each mase node locally and then verify the inter-node invariants, particularly for the following: :param graph: The input graph to analyze. :type graph: MaseGraph :param pass_args: Additional arguments for the analysis pass (optional). :type pass_args: dict pass_args is not used in this pass, defaults to {}. :return: The analyzed graph and an empty dictionary. :rtype: tuple(MaseGraph, Dict) """ # Verify each node in the graph for node in graph.fx_graph.nodes: verify_node_common_metadata(node) # Each node must have a unique name and a unique verilog name node_names = [] node_vf_names = [] for node in graph.fx_graph.nodes: assert node.name not in node_names assert vf(node.name) not in node_vf_names node_names.append(node.name) node_vf_names.append(vf(node.name)) # Each node must have at most one result for node in graph.fx_graph.nodes: assert len(node.meta["mase"].parameters["common"]["results"]) <= 1 # Inter-node verification # Each edge between nodes must have the same size for node in graph.fx_graph.nodes: if len(node.all_input_nodes) > 0: for i, args in enumerate(node.args): data_in = node.meta["mase"].parameters["common"]["args"][f"data_in_{i}"] dst_size = data_in["size"] src_size = ( data_in["from"] .meta["mase"] .parameters["common"]["results"][f"data_out_0"]["size"] ) assert dst_size == src_size return graph, {}
[docs] def verify_software_metadata_analysis_pass(graph, pass_args: dict = {}): """ Verify pass for mase graph This pass is used for verification of MaseGraph. It does sanity checks the software metadata of each mase node locally and then verify the inter-node invariants, particularly for the following: * TODO :param graph: The input graph to analyze. :type graph: MaseGraph :param pass_args: Additional arguments for the analysis pass (optional). :type pass_args: dict pass_args is not used in this pass, defaults to {}. :return: The analyzed graph and an empty dictionary. :rtype: tuple(MaseGraph, Dict) """ return graph, {}
def verify_node_hardware_metadata(node): """ Verify pass for metadata at node level This pass is used for verification of Metadata. It does sanity checks the metadata of each mase node locally, particularly for the following: * TODO """ verify_hardware_metadata_general(node.meta["mase"]) if node.meta["mase"].parameters["common"]["mase_op"] == "linear": verify_hardware_metadata_linear(node.meta["mase"]) elif node.meta["mase"].parameters["common"]["mase_op"] == "relu": verify_hardware_metadata_relu(node.meta["mase"]) else: raise ValueError(f"Unknown mase op: {node.op}")
[docs] def verify_hardware_metadata_analysis_pass(graph, pass_args: dict = {}): """ Verify pass for mase graph This pass is used for verification of MaseGraph. It does sanity checks the hardware metadata of each mase node locally and then verify the inter-node invariants, particularly for the following: * TODO :param graph: The input graph to analyze. :type graph: MaseGraph :param pass_args: Additional arguments for the analysis pass (optional). :type pass_args: dict pass_args is not used in this pass, defaults to {}. :return: The analyzed graph and an empty dictionary. :rtype: tuple(MaseGraph, Dict) """ # Verify each node int the graph for node in graph.fx_graph.nodes: verify_node_hardware_metadata(node) # Inter-node verification # Each edge between nodes must have the same size nodes_in = graph.nodes_in nodes_out = graph.nodes_out while nodes_in != nodes_out: next_nodes_in = [] for node in nodes_in: for next_node, x in node.users.items(): # This might have a bug - for now assume there is only one result if next_node.meta["mase"].parameters["hardware"]["is_implicit"]: if node not in next_nodes_in: next_nodes_in.append(node) continue next_nodes_in.append(next_node) arg_count = len(next_node.all_input_nodes) if arg_count == 1: assert ( next_node.meta["mase"].parameters["hardware"][ "verilog_parameters" ]["IN_SIZE"] == node.meta["mase"].parameters["hardware"][ "verilog_parameters" ]["OUT_SIZE"] ), "Verilog input and output sizes mismatch: {} = {} and {} = {}".format( node.name, node.meta["mase"].parameters["hardware"]["verilog_parameters"][ "OUT_SIZE" ], next_node.name, next_node.meta["mase"].parameters["hardware"][ "verilog_parameters" ]["IN_SIZE"], ) else: i = get_input_index(node, next_node) assert ( next_node.meta["mase"].parameters["hardware"][ "verilog_parameters" ][f"IN_{i}_SIZE"] == node.meta["mase"].parameters["hardware"][ "verilog_parameters" ]["OUT_SIZE"] ), "Verilog input and output sizes mismatch: {} = {} and {} = {}".format( node.name, node.meta["mase"].parameters["hardware"]["verilog_parameters"][ "OUT_SIZE" ], next_node.name, next_node.meta["mase"].parameters["hardware"][ "verilog_parameters" ][f"IN_{i}_SIZE"], ) assert ( nodes_in != next_nodes_in ), f"Parsing error: cannot find the next nodes: {nodes_in}." nodes_in = next_nodes_in return graph, {}
[docs] def verify_metadata_analysis_pass(graph, pass_args: dict = {}): """ Verify pass for mase graph This pass is used for verification of MaseGraph. It does sanity checks all the metadata of each mase node locally and then verify the inter-node invariants, particularly for the following: * TODO :param graph: The input graph to analyze. :type graph: MaseGraph :param pass_args: Additional arguments for the analysis pass (optional). :type pass_args: dict pass_args is not used in this pass, defaults to {}. :return: The analyzed graph and an empty dictionary. :rtype: tuple(MaseGraph, Dict) """ _, _ = verify_common_metadata_analysis_pass(graph) _, _ = verify_software_metadata_analysis_pass(graph) _, _ = verify_hardware_metadata_analysis_pass(graph) return graph, {}