Source code for chop.pipelines.auto_pipeline
from chop.ir import MaseGraph
from chop.tools.logger import get_logger
logger = get_logger(__name__)
[docs]
class AutoPipeline:
"""This is the base class for the AutoPipeline.
It takes a list of passes and runs them in order.
The output of each pass is stored in a dictionary and can be accessed by the next pass.
"""
[docs]
def __init__(
self,
pass_groups=None,
run_training: bool = False,
) -> None:
"""Initializes the AutoPipeline.
Args:
pass_list (list, optional): List of passes to run. Defaults to [].
"""
self.pass_groups = pass_groups if pass_groups is not None else []
self.pass_outputs = [{}] * len(pass_groups)
def _run_pass_group(
self,
mg: MaseGraph,
pass_group: list,
pass_args: dict,
skip_passes: list = [],
):
pass_outputs = {}
for pass_fn in pass_group:
# Check if need to skip this pass
if pass_fn in skip_passes:
logger.debug(f"Skipping pass: {pass_fn.__name__}")
continue
# Extract pass arguments
logger.debug(f"Running pass: {pass_fn.__name__}")
args = pass_args.get(pass_fn.__name__, {})
# Replace self/ references with values from previous passes
for k, v in args.items():
if isinstance(v, str) and v.startswith("self/"):
args[k] = pass_outputs[v[5:]]
mg, pass_output = pass_fn(mg, pass_args=args)
pass_outputs[pass_fn.__name__] = pass_output
return mg, pass_outputs
def __call__(
self,
mg: MaseGraph,
pass_args: dict,
skip_passes: list = [],
):
for idx, pass_group in enumerate(self.pass_groups):
logger.debug(f"Running pass group {idx}/{len(self.pass_groups)}.")
logger.debug(
f"The following passes will be executed: {[pass_fn.__name__ for pass_fn in pass_group]}"
)
mg, pass_outputs = self._run_pass_group(
mg,
pass_group,
pass_args,
skip_passes,
)
self.pass_outputs[idx] = pass_outputs
return mg, self.pass_outputs