Source code for chop.actions.transform
import os
from copy import deepcopy
from pathlib import Path
import logging
import torch
from chop.passes.graph import PASSES
from chop.passes.graph.analysis import (
add_common_metadata_analysis_pass,
add_software_metadata_analysis_pass,
init_metadata_analysis_pass,
)
from chop.ir.graph.mase_graph import MaseGraph
from chop.passes.graph.interface import (
load_mase_graph_interface_pass,
save_mase_graph_interface_pass,
)
from chop.passes.graph.utils import deepcopy_mase_graph
from chop.tools.checkpoint_load import load_model
from chop.tools.config_load import load_config
from chop.tools.get_input import InputGenerator, get_cf_args, get_dummy_input
from chop.tools.utils import parse_accelerator, to_numpy_if_tensor
from chop.passes.graph.transforms import metadata_value_type_cast_transform_pass
from chop.passes.module import PASSES as MODULE_PASSES
logger = logging.getLogger(__name__)
[docs]
def pre_transform_load(
load_name: str,
load_type: str,
model: torch.nn.Module,
):
"""
Load the model if a checkpoint is provided.
Args:
load_name (str): _description_
load_type (str): _description_
model (torch.nn.Module): _description_
Returns:
_type_: _description_
"""
if load_name is not None and load_type in ["pt", "pl"]:
model = load_model(load_name=load_name, load_type=load_type, model=model)
return model
[docs]
def transform(
model: torch.nn.Module,
model_info: dict,
model_name: str,
data_module,
task: str,
config: str,
save_dir: str = None,
load_name: str = None,
load_type: str = None,
accelerator: str = "auto",
):
"""
Transform the model based on the configuration.
Args:
model (torch.nn.Module): _description_
model_info (dict): _description_
model_name (str): _description_
data_module (_type_): _description_
task (str): _description_
config (str): _description_
save_dir (str, optional): _description_. Defaults to None.
load_name (str, optional): _description_. Defaults to None.
load_type (str, optional): _description_. Defaults to None.
accelerator (str, optional): _description_. Defaults to "auto".
Raises:
ValueError: _description_
"""
accelerator = parse_accelerator(accelerator)
model = pre_transform_load(load_name=load_name, load_type=load_type, model=model)
model.to(accelerator)
config = load_config(config)
transform_config = config["transform"]
style = transform_config.get("style", "graph")
if style == "graph":
transform_graph(
model=model,
model_info=model_info,
model_name=model_name,
data_module=data_module,
task=task,
config=config,
save_dir=save_dir,
load_name=load_name,
load_type=load_type,
accelerator=accelerator,
)
elif style == "module":
transform_module(
model=model,
model_info=model_info,
model_name=model_name,
data_module=data_module,
task=task,
config=config,
save_dir=save_dir,
load_name=load_name,
load_type=load_type,
accelerator=accelerator,
)
else:
raise ValueError(f"Style {style} is not supported!")
[docs]
def transform_module(
model: torch.nn.Module,
model_info,
model_name,
data_module,
task: str,
config: str,
save_dir: str = None,
load_name: str = None,
load_type: str = None,
accelerator: str = "auto",
):
"""
Transform the model at Pytorch module level.
Args:
model (torch.nn.Module): _description_
model_info (_type_): _description_
model_name (_type_): _description_
data_module (_type_): _description_
task (str): _description_
config (str): _description_
save_dir (str, optional): _description_. Defaults to None.
load_name (str, optional): _description_. Defaults to None.
load_type (str, optional): _description_. Defaults to None.
accelerator (str, optional): _description_. Defaults to "auto".
Returns:
_type_: _description_
"""
accelerator = parse_accelerator(accelerator)
model = pre_transform_load(load_name=load_name, load_type=load_type, model=model)
model.to(accelerator)
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
if load_name is not None:
model = load_model(load_name, load_type=load_type, model=model)
logger.info(f"'{load_type}' checkpoint loaded before training")
pass_config = config["passes"]
for pass_name, pass_config in pass_config.items():
pass_name: str
pass_config: dict
match pass_name:
case _:
my_pass = MODULE_PASSES[pass_name]
model, _ = my_pass(model, pass_args=pass_config)
if save_dir is not None:
transformed_ckpt = save_dir / "transformed_ckpt"
state_dict_ckpt = os.path.join(transformed_ckpt, "state_dict.pt")
transformed_ckpt.mkdir(parents=True, exist_ok=True)
state_dict = model.state_dict()
torch.save(state_dict, state_dict_ckpt)
logger.info(f"model saved at {state_dict_ckpt}")
return model
[docs]
def transform_graph(
model: torch.nn.Module,
model_info,
model_name,
data_module,
task: str,
config: str,
save_dir: str = None,
load_name: str = None,
load_type: str = None,
accelerator: str = "auto",
):
"""
Transform the model at FX graph level.
Args:
model (torch.nn.Module): _description_
model_info (_type_): _description_
model_name (_type_): _description_
data_module (_type_): _description_
task (str): _description_
config (str): _description_
save_dir (str, optional): _description_. Defaults to None.
load_name (str, optional): _description_. Defaults to None.
load_type (str, optional): _description_. Defaults to None.
accelerator (str, optional): _description_. Defaults to "auto".
Returns:
_type_: _description_
"""
accelerator = parse_accelerator(accelerator)
model = pre_transform_load(load_name=load_name, load_type=load_type, model=model)
model.to(accelerator)
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
# concrete forward args for freezing dynamic control flow in forward pass
if "cf_args" not in config:
cf_args = get_cf_args(model_info=model_info, task=task, model=model)
else:
cf_args = config["cf_args"]
# graph generation
graph = MaseGraph(model=model, cf_args=cf_args)
# graph_metadata = Mase
graph, _ = init_metadata_analysis_pass(graph, pass_args=None)
# create or load metadata.parameters and mase_graph.model
if load_name is not None and load_type == "mz":
graph, _ = load_mase_graph_interface_pass(
graph, pass_args={"load_dir": load_name}
)
else:
dummy_in = get_dummy_input(
model_info=model_info,
data_module=data_module,
task=task,
device=accelerator,
)
if len(graph.model.additional_inputs) > 0:
dummy_in = dummy_in | graph.model.additional_inputs
graph, _ = add_common_metadata_analysis_pass(
graph, pass_args={"dummy_in": dummy_in}
)
graph, _ = add_software_metadata_analysis_pass(graph, pass_args=None)
passes_config = config["passes"]
for pass_name, pass_config in passes_config.items():
pass_name: str
pass_config: dict
match pass_name:
case "tensorrt":
graph, _ = metadata_value_type_cast_transform_pass(
graph, pass_args={"fn": to_numpy_if_tensor}
)
ori_graph = deepcopy_mase_graph(graph)
pass_save_dir = save_dir / "tensorrt"
pass_config["task"] = task
pass_config["dataset"] = config["dataset"]
pass_config["batch_size"] = config["batch_size"]
pass_config["model"] = config["model"]
pass_config["data_module"] = data_module
pass_config["accelerator"] = accelerator.type
if accelerator.type == "cuda":
# TODO this seems innefective - known issue - https://github.com/NVIDIA/TensorRT/issues/2468
os.environ["CUDA_MODULE_LOADING"] = "LAZY"
# Firstly fake quantize the model for calibration (only if using int8 precision otherwise skipped)
graph, _ = PASSES["tensorrt_fake_quantize"](
graph, pass_args=pass_config
)
# Summarize to show what has been quantized
PASSES["summarize_quantization"](
graph, {"save_dir": pass_save_dir, "original_graph": ori_graph}
)
# Then calibrate the model using the fake quantization to set AMAXs
graph, _ = PASSES["tensorrt_calibrate"](graph, pass_args=pass_config)
# Apply post-quantization fine tuning (Quantization Aware Training)
graph, _ = PASSES["tensorrt_fine_tune"](graph, pass_args=pass_config)
# Apply fp16 or layer-wise mixed precision quantization if necessary and convert the model to TensorRT format
graph, runtime_meta = PASSES["tensorrt"](graph, pass_args=pass_config)
# Perform runtime analysis on original and new graph
_, _ = PASSES["runtime_analysis"](ori_graph, pass_args=pass_config)
_, _ = PASSES["runtime_analysis"](
runtime_meta["trt_engine_path"], pass_args=pass_config
)
case "onnxruntime":
pass_save_dir = save_dir / "onnxruntime"
graph, _ = metadata_value_type_cast_transform_pass(
graph, pass_args={"fn": to_numpy_if_tensor}
)
ori_graph = deepcopy_mase_graph(graph)
pass_config["data_module"] = data_module
# crop the train dataloader to behave as the calibrated dataloader
pass_config["data_module"].train_dataloader
pass_config["task"] = task
pass_config["accelerator"] = accelerator.type
pass_config["batch_size"] = config["batch_size"]
pass_config["model"] = config["model"]
pass_config["dataset"] = config["dataset"]
if accelerator.type == "cuda":
# TODO this seems innefective - known issue - https://github.com/NVIDIA/TensorRT/issues/2468
os.environ["CUDA_MODULE_LOADING"] = "LAZY"
graph, runtime_meta = PASSES["onnxruntime"](
graph, pass_args=pass_config
)
# if user has set runtime_anaylsis, run the runtime analysis pass
if "runtime_analysis" not in pass_config:
break
# Extract the 'runtime_analysis' dictionary by stripping the config
runtime_analysis = pass_config.pop("runtime_analysis", {})
pass_config.update(runtime_analysis)
original_graph_analysis = pass_config.get(
"original_graph_analysis", True
)
if original_graph_analysis:
logger.info("Performing runtime analysis on original graph...")
_, _ = PASSES["runtime_analysis"](ori_graph, pass_args=pass_config)
optimized_graph_analysis = pass_config.get(
"optimized_graph_analysis", True
)
if optimized_graph_analysis:
logger.info(
"Performing runtime analysis on onnx-optimized graph..."
)
_, _ = PASSES["runtime_analysis"](
runtime_meta["onnx_path"], pass_args=pass_config
)
# Peform runtime analysis on quantized forms if appropriate
quantized_graph_analysis = pass_config.get(
"quantized_graph_analysis", True
)
if quantized_graph_analysis:
try:
quant_types = pass_config["default"]["config"]["quantize_types"]
except KeyError:
quant_types = []
for quant_type in quant_types:
match quant_type:
case "static":
logger.info(
"Performing runtime analysis on static quantized graph..."
)
_, _ = PASSES["runtime_analysis"](
runtime_meta["onnx_static_quantized_path"],
pass_args=pass_config,
)
case "dynamic":
logger.info(
"Performing runtime analysis on dynamic quantized graph..."
)
_, _ = PASSES["runtime_analysis"](
runtime_meta["onnx_dynamic_quantized_path"],
pass_args=pass_config,
)
case "auto":
logger.info(
"Performing runtime analysis on auto mixed precision quantized graph..."
)
_, _ = PASSES["runtime_analysis"](
runtime_meta["onnx_auto_mixed_precision_path"],
pass_args=pass_config,
)
case "quantize":
pass_save_dir = save_dir / "quantize"
graph, _ = metadata_value_type_cast_transform_pass(
graph, pass_args={"fn": to_numpy_if_tensor}
)
ori_graph = deepcopy_mase_graph(graph)
graph, _ = PASSES["quantize"](graph, pass_args=pass_config)
PASSES["summarize_quantization"](
graph, {"save_dir": pass_save_dir, "original_graph": ori_graph}
)
case "profile_statistics":
input_generator = InputGenerator(
model_info=model_info,
data_module=data_module,
task=task,
which_dataloader="train",
)
pass_config["input_generator"] = input_generator
graph, _ = PASSES[pass_name](graph, pass_args=pass_config)
case "report_graph":
pass_file_name = pass_config.get(
"file_name", save_dir / "report_graph.txt"
)
graph, _ = PASSES[pass_name](graph, file_name=pass_file_name)
case "report_node_type":
graph, _ = PASSES[pass_name](graph, pass_args=None)
case "report_node_meta_param":
# {"save_path": ..., "which": "all"|["common", "hardware", "software"]}
pass_save_path = pass_config.get("save_path", save_dir / "report")
pass_config["save_path"] = pass_save_path
graph, _ = PASSES[pass_name](graph, pass_args=pass_config)
case "report_node_shape":
graph, _ = PASSES[pass_name](graph, pass_args=None)
case "report_node_type":
graph, _ = PASSES[pass_name](graph, pass_args=None)
case "report_node_hardware_type":
graph, _ = PASSES[pass_name](graph, pass_args=None)
case "report_node_shape":
graph, _ = PASSES[pass_name](graph, pass_args=None)
case "report_node_type":
graph, _ = PASSES[pass_name](graph, pass_args=None)
case "load_mase_graph":
pass_load_dir = pass_config["load_dir"]
graph, _ = PASSES[pass_name](graph, pass_args=pass_load_dir)
case "load_node_meta_param":
pass_load_path = pass_config["load_path"]
graph, _ = PASSES[pass_name](graph, pass_args=pass_load_path)
case "save_mase_graph":
pass_save_dir = pass_config.get(
"save_dir", save_dir / "saved_mase_graph"
)
graph, _ = PASSES[pass_name](graph, pass_args=pass_save_dir)
case "save_node_meta_param":
pass_save_path = pass_config.get(
"save_path",
save_dir / "save_node_meta_param" / "node_meta_param.toml",
)
# TODO: fix me
# to save the meta parameters of the nodes,
# we have to run this cast,
# because current meta parameters contains tensors
# but this cast is not inveritble
# if there are other passes after "save_node_meta_param"
# relying on the tensor/numpy values in the meta parameters
# the transform/analysis will fail
graph, _ = metadata_value_type_cast_transform_pass(
graph, pass_args={"fn": to_numpy_if_tensor}
)
graph, _ = PASSES[pass_name](graph, pass_args=pass_save_path)
case "prune":
# NOTE: The input generator is only used for when the user wants to
# enforce or observe activation sparsity. Otherwise, it's ignored.
# We use the validation dataloader as that doesn't shuffle the input
# data. This determinism helps establish a fair ground in draw
# layer-wise comparisons between activation pruning strategies.
input_generator = InputGenerator(
model_info=model_info,
data_module=data_module,
task=task,
which_dataloader="val",
)
pass_config["model_name"] = model_name
pass_config["input_generator"] = input_generator
prune_save_dir = save_dir / "prune"
prune_save_dir.mkdir(parents=True, exist_ok=True)
graph, _ = PASSES[pass_name](
graph,
save_dir=prune_save_dir,
config=pass_config,
)
case "remove_prune_wrappers":
# Removes the pruning-related hooks and makes pruning permanent
graph, _ = PASSES[pass_name](graph, pass_args=None)
case "conv_bn_fusion":
graph, _ = PASSES[pass_name](graph, pass_args=None)
case "logicnets_fusion":
graph, _ = PASSES[pass_name](graph, pass_args=pass_config)
case "onnx_annotate":
onnx_dir = save_dir / "onnx"
onnx_dir.mkdir(parents=True, exist_ok=True)
kwargs = {
"save_path": onnx_dir,
"data_path": pass_config["data_path"],
}
graph, _ = PASSES[pass_name](graph, **kwargs)
case _:
my_pass = PASSES[pass_name]
graph, _ = my_pass(graph, pass_args=pass_config)
assert isinstance(
graph, MaseGraph
), f"Return type of {pass_name} must be MaseGraph, got {type(graph)}"
if save_dir is not None:
transformed_ckpt = save_dir / "transformed_ckpt"
transformed_ckpt.mkdir(parents=True, exist_ok=True)
graph, _ = metadata_value_type_cast_transform_pass(
graph, pass_args={"fn": to_numpy_if_tensor}
)
save_mase_graph_interface_pass(graph, pass_args=transformed_ckpt)
return graph