Source code for chop.passes.graph.interface.save_and_load
import logging
import os
import numpy as np
import toml
import torch
import torch.fx as fx
from chop.passes.graph.analysis.init_metadata import init_metadata_analysis_pass
from chop.tools.config_load import convert_none_to_str_na, convert_str_na_to_none
logger = logging.getLogger(__name__)
def save_graph_module_ckpt(graph_module: fx.GraphModule, save_path: str) -> None:
"""Save graph as a checkpoint
:param graph_module: graph module
:type graph_module: fx.GraphModule
:param save_path: the directory for saving
:type save_path: str
"""
torch.save(graph_module, save_path)
def save_state_dict_ckpt(graph_module: fx.GraphModule, save_path: str) -> None:
"""
Save a serialized state dict.
"""
state_dict = graph_module.state_dict()
torch.save(state_dict, save_path)
def graph_iterator_remove_metadata(graph):
"""
Remove all metadata from the graph.
"""
for node in graph.fx_graph.nodes:
if hasattr(node, "meta"):
node.meta["mase"] = {}
return graph
def collect_n_meta_param(graph) -> dict:
"""
Collect all metadata from the graph.
"""
node_n_meta_param = {}
for node in graph.fx_graph.nodes:
node_n_meta_param[node.name] = node.meta["mase"].parameters
return node_n_meta_param
def save_n_meta_param(node_meta: dict, save_path: str) -> None:
"""
Save a mase graph metadata to a toml file.
"""
node_meta = convert_none_to_str_na(node_meta)
with open(save_path, "w") as f:
toml.dump(node_meta, f)
def load_n_meta_param(load_path: str) -> dict:
"""
Load a mase graph metadata from a toml file.
"""
with open(load_path, "r") as f:
node_meta = toml.load(f)
node_meta = convert_str_na_to_none(node_meta)
return node_meta
def graph_iterator_add_n_meta_param(graph, node_n_meta_param: dict):
"""
Add metadata to the graph.
"""
for node in graph.fx_graph.nodes:
node.meta["mase"].parameters = node_n_meta_param[node.name]
return graph
def load_graph_module_ckpt(checkpoint: str) -> fx.GraphModule:
"""
Load a serialized graph module.
"""
graph_module = torch.load(checkpoint)
return graph_module
def graph_iterator_add_n_meta_param(graph, node_n_meta_param: dict):
"""
Add metadata to the graph.
"""
for node in graph.fx_graph.nodes:
node.meta["mase"].parameters = node_n_meta_param[node.name]
return graph
[docs]
def save_mase_graph_interface_pass(graph, pass_args: dict = {}):
"""Save a mase graph.
This saves the graph module as a serialized graph module and metadata.parameters as a toml file.
Args:
graph (MaseGraph): mase_graph to save
pass_args (str): save directory
Returns:
MaseGraph: mase_graph
"""
save_dir = pass_args
os.makedirs(save_dir, exist_ok=True)
graph_module_ckpt = os.path.join(save_dir, "graph_module.mz")
state_dict_ckpt = os.path.join(save_dir, "state_dict.pt")
n_meta_param_ckpt = os.path.join(save_dir, "node_meta_param.toml")
# collect metadata.parameters
node_n_meta_param = collect_n_meta_param(graph)
# save metadata.parameters to toml
save_n_meta_param(node_n_meta_param, n_meta_param_ckpt)
# reset metadata to empty dict {}
graph = graph_iterator_remove_metadata(graph)
# save graph module & state dict
save_graph_module_ckpt(graph.model, graph_module_ckpt)
save_state_dict_ckpt(graph.model, state_dict_ckpt)
# restore metadata.parameters
graph, _ = init_metadata_analysis_pass(graph)
graph = graph_iterator_add_n_meta_param(graph, node_n_meta_param)
logger.info(f"Saved mase graph to {save_dir}")
return graph, {}
[docs]
def load_mase_graph_interface_pass(graph, pass_args: dict = {"load_dir": None}):
"""
Load the MASE graph interface pass.
:param graph: The input graph to be transformed.
:type graph: MaseGraph
:param pass_args: Optional arguments for the transformation pass. Default is {'load_dir': None}, load_dir is required.
:type pass_args: dict
:return: The transformed graph and an empty dictionary.
:rtype: tuple(MaseGraph, dic)
:raises ValueError: If the load directory is not specified.
"""
load_dir = pass_args.get("load_dir")
if load_dir is None:
raise ValueError(f"load dir cannot be {load_dir}")
if os.path.isdir(load_dir):
graph_module_ckpt = os.path.join(load_dir, "graph_module.mz")
n_meta_param_ckpt = os.path.join(load_dir, "node_meta_param.toml")
else:
# Handle the case when the load directory is not a directory
# ...
load_dir = os.path.dirname(load_dir)
graph_module_ckpt = os.path.join(load_dir, "graph_module.mz")
n_meta_param_ckpt = os.path.join(load_dir, "node_meta_param.toml")
# load metadata.parameters from toml
node_n_meta_param = load_n_meta_param(n_meta_param_ckpt)
# load graph module
graph.model = load_graph_module_ckpt(graph_module_ckpt)
graph.model.additional_inputs = {}
graph = init_metadata_analysis_pass(graph)
# add metadata.parameters to graph
graph = graph_iterator_add_n_meta_param(graph, node_n_meta_param)
logger.info(f"Loaded mase graph from {load_dir}")
return graph, {}