Source code for chop.actions.transform
import os
from copy import deepcopy
from pathlib import Path
import logging
import torch
from chop.passes.graph import PASSES
from chop.passes.graph.analysis import (
add_common_metadata_analysis_pass,
add_software_metadata_analysis_pass,
init_metadata_analysis_pass,
)
from chop.ir.graph.mase_graph import MaseGraph
from chop.passes.graph.interface import (
load_mase_graph_interface_pass,
save_mase_graph_interface_pass,
)
from chop.passes.graph.utils import deepcopy_mase_graph
from chop.tools.checkpoint_load import load_model
from chop.tools.config_load import load_config
from chop.tools.get_input import InputGenerator, get_cf_args, get_dummy_input
from chop.tools.utils import parse_accelerator, to_numpy_if_tensor
from chop.passes.graph.transforms import metadata_value_type_cast_transform_pass
from chop.passes.module import PASSES as MODULE_PASSES
logger = logging.getLogger(__name__)
[docs]
def pre_transform_load(
load_name: str,
load_type: str,
model: torch.nn.Module,
):
"""
Load the model if a checkpoint is provided.
Args:
load_name (str): _description_
load_type (str): _description_
model (torch.nn.Module): _description_
Returns:
_type_: _description_
"""
if load_name is not None and load_type in ["pt", "pl"]:
model = load_model(load_name=load_name, load_type=load_type, model=model)
return model
[docs]
def transform(
model: torch.nn.Module,
model_info: dict,
model_name: str,
data_module,
task: str,
config: str,
save_dir: str = None,
load_name: str = None,
load_type: str = None,
accelerator: str = "auto",
):
"""
Transform the model based on the configuration.
Args:
model (torch.nn.Module): _description_
model_info (dict): _description_
model_name (str): _description_
data_module (_type_): _description_
task (str): _description_
config (str): _description_
save_dir (str, optional): _description_. Defaults to None.
load_name (str, optional): _description_. Defaults to None.
load_type (str, optional): _description_. Defaults to None.
accelerator (str, optional): _description_. Defaults to "auto".
Raises:
ValueError: _description_
"""
accelerator = parse_accelerator(accelerator)
model = pre_transform_load(load_name=load_name, load_type=load_type, model=model)
model.to(accelerator)
config = load_config(config)
transform_config = config["transform"]
style = transform_config.get("style", "graph")
if style == "graph":
transform_graph(
model=model,
model_info=model_info,
model_name=model_name,
data_module=data_module,
task=task,
config=config,
save_dir=save_dir,
load_name=load_name,
load_type=load_type,
accelerator=accelerator,
)
elif style == "module":
transform_module(
model=model,
model_info=model_info,
model_name=model_name,
data_module=data_module,
task=task,
config=config,
save_dir=save_dir,
load_name=load_name,
load_type=load_type,
accelerator=accelerator,
)
else:
raise ValueError(f"Style {style} is not supported!")
[docs]
def transform_module(
model: torch.nn.Module,
model_info,
model_name,
data_module,
task: str,
config: str,
save_dir: str = None,
load_name: str = None,
load_type: str = None,
accelerator: str = "auto",
):
"""
Transform the model at Pytorch module level.
Args:
model (torch.nn.Module): _description_
model_info (_type_): _description_
model_name (_type_): _description_
data_module (_type_): _description_
task (str): _description_
config (str): _description_
save_dir (str, optional): _description_. Defaults to None.
load_name (str, optional): _description_. Defaults to None.
load_type (str, optional): _description_. Defaults to None.
accelerator (str, optional): _description_. Defaults to "auto".
Returns:
_type_: _description_
"""
accelerator = parse_accelerator(accelerator)
model = pre_transform_load(load_name=load_name, load_type=load_type, model=model)
model.to(accelerator)
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
if load_name is not None:
model = load_model(load_name, load_type=load_type, model=model)
logger.info(f"'{load_type}' checkpoint loaded before training")
pass_config = config["passes"]
for pass_name, pass_config in pass_config.items():
pass_name: str
pass_config: dict
match pass_name:
case _:
my_pass = MODULE_PASSES[pass_name]
model, _ = my_pass(model, pass_args=pass_config)
if save_dir is not None:
transformed_ckpt = save_dir / "transformed_ckpt"
state_dict_ckpt = os.path.join(transformed_ckpt, "state_dict.pt")
transformed_ckpt.mkdir(parents=True, exist_ok=True)
state_dict = model.state_dict()
torch.save(state_dict, state_dict_ckpt)
logger.info(f"model saved at {state_dict_ckpt}")
return model
[docs]
def transform_graph(
model: torch.nn.Module,
model_info,
model_name,
data_module,
task: str,
config: str,
save_dir: str = None,
load_name: str = None,
load_type: str = None,
accelerator: str = "auto",
):
"""
Transform the model at FX graph level.
Args:
model (torch.nn.Module): _description_
model_info (_type_): _description_
model_name (_type_): _description_
data_module (_type_): _description_
task (str): _description_
config (str): _description_
save_dir (str, optional): _description_. Defaults to None.
load_name (str, optional): _description_. Defaults to None.
load_type (str, optional): _description_. Defaults to None.
accelerator (str, optional): _description_. Defaults to "auto".
Returns:
_type_: _description_
"""
accelerator = parse_accelerator(accelerator)
model = pre_transform_load(load_name=load_name, load_type=load_type, model=model)
model.to(accelerator)
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
# concrete forward args for freezing dynamic control flow in forward pass
if "cf_args" not in config:
cf_args = get_cf_args(model_info=model_info, task=task, model=model)
else:
cf_args = config["cf_args"]
# graph generation
graph = MaseGraph(model=model, cf_args=cf_args)
# graph_metadata = Mase
graph, _ = init_metadata_analysis_pass(graph, pass_args=None)
# create or load metadata.parameters and mase_graph.model
if load_name is not None and load_type == "mz":
graph, _ = load_mase_graph_interface_pass(
graph, pass_args={"load_dir": load_name}
)
else:
dummy_in = get_dummy_input(
model_info=model_info,
data_module=data_module,
task=task,
device=accelerator,
)
if len(graph.model.additional_inputs) > 0:
dummy_in = dummy_in | graph.model.additional_inputs
graph, _ = add_common_metadata_analysis_pass(
graph, pass_args={"dummy_in": dummy_in}
)
graph, _ = add_software_metadata_analysis_pass(graph, pass_args=None)
passes_config = config["passes"]
for pass_name, pass_config in passes_config.items():
pass_name: str
pass_config: dict
match pass_name:
case "quantize":
pass_save_dir = save_dir / "quantize"
graph, _ = metadata_value_type_cast_transform_pass(
graph, pass_args={"fn": to_numpy_if_tensor}
)
ori_graph = deepcopy_mase_graph(graph)
graph, _ = PASSES["quantize"](graph, pass_args=pass_config)
PASSES["summarize_quantization"](
graph, {"save_dir": pass_save_dir, "original_graph": ori_graph}
)
case "profile_statistics":
input_generator = InputGenerator(
model_info=model_info,
data_module=data_module,
task=task,
which_dataloader="train",
)
pass_config["input_generator"] = input_generator
graph, _ = PASSES[pass_name](graph, pass_args=pass_config)
case "report_graph":
pass_file_name = pass_config.get(
"file_name", save_dir / "report_graph.txt"
)
graph, _ = PASSES[pass_name](graph, file_name=pass_file_name)
case "report_node_type":
graph, _ = PASSES[pass_name](graph, pass_args=None)
case "report_node_meta_param":
# {"save_path": ..., "which": "all"|["common", "software"]}
pass_save_path = pass_config.get("save_path", save_dir / "report")
pass_config["save_path"] = pass_save_path
graph, _ = PASSES[pass_name](graph, pass_args=pass_config)
case "report_node_shape":
graph, _ = PASSES[pass_name](graph, pass_args=None)
case "report_node_type":
graph, _ = PASSES[pass_name](graph, pass_args=None)
case "report_node_shape":
graph, _ = PASSES[pass_name](graph, pass_args=None)
case "report_node_type":
graph, _ = PASSES[pass_name](graph, pass_args=None)
case "load_mase_graph":
pass_load_dir = pass_config["load_dir"]
graph, _ = PASSES[pass_name](graph, pass_args=pass_load_dir)
case "load_node_meta_param":
pass_load_path = pass_config["load_path"]
graph, _ = PASSES[pass_name](graph, pass_args=pass_load_path)
case "save_mase_graph":
pass_save_dir = pass_config.get(
"save_dir", save_dir / "saved_mase_graph"
)
graph, _ = PASSES[pass_name](graph, pass_args=pass_save_dir)
case "save_node_meta_param":
pass_save_path = pass_config.get(
"save_path",
save_dir / "save_node_meta_param" / "node_meta_param.toml",
)
# TODO: fix me
# to save the meta parameters of the nodes,
# we have to run this cast,
# because current meta parameters contains tensors
# but this cast is not inveritble
# if there are other passes after "save_node_meta_param"
# relying on the tensor/numpy values in the meta parameters
# the transform/analysis will fail
graph, _ = metadata_value_type_cast_transform_pass(
graph, pass_args={"fn": to_numpy_if_tensor}
)
graph, _ = PASSES[pass_name](graph, pass_args=pass_save_path)
case "prune":
# NOTE: The input generator is only used for when the user wants to
# enforce or observe activation sparsity. Otherwise, it's ignored.
# We use the validation dataloader as that doesn't shuffle the input
# data. This determinism helps establish a fair ground in draw
# layer-wise comparisons between activation pruning strategies.
input_generator = InputGenerator(
model_info=model_info,
data_module=data_module,
task=task,
which_dataloader="val",
)
pass_config["model_name"] = model_name
pass_config["input_generator"] = input_generator
prune_save_dir = save_dir / "prune"
prune_save_dir.mkdir(parents=True, exist_ok=True)
graph, _ = PASSES[pass_name](
graph,
save_dir=prune_save_dir,
config=pass_config,
)
case "remove_prune_wrappers":
# Removes the pruning-related hooks and makes pruning permanent
graph, _ = PASSES[pass_name](graph, pass_args=None)
case "conv_bn_fusion":
graph, _ = PASSES[pass_name](graph, pass_args=None)
case "logicnets_fusion":
graph, _ = PASSES[pass_name](graph, pass_args=pass_config)
case _:
my_pass = PASSES[pass_name]
graph, _ = my_pass(graph, pass_args=pass_config)
assert isinstance(
graph, MaseGraph
), f"Return type of {pass_name} must be MaseGraph, got {type(graph)}"
if save_dir is not None:
transformed_ckpt = save_dir / "transformed_ckpt"
transformed_ckpt.mkdir(parents=True, exist_ok=True)
graph, _ = metadata_value_type_cast_transform_pass(
graph, pass_args={"fn": to_numpy_if_tensor}
)
save_mase_graph_interface_pass(graph, pass_args=transformed_ckpt)
return graph