Source code for chop.pipelines.auto_pipeline
from chop.ir import MaseGraph
from chop.tools.logger import get_logger
logger = get_logger(__name__)
[docs]
class AutoPipeline:
    """This is the base class for the AutoPipeline.
    It takes a list of passes and runs them in order.
    The output of each pass is stored in a dictionary and can be accessed by the next pass.
    """
[docs]
    def __init__(
        self,
        pass_groups=None,
        run_training: bool = False,
    ) -> None:
        """Initializes the AutoPipeline.
        Args:
            pass_list (list, optional): List of passes to run. Defaults to [].
        """
        self.pass_groups = pass_groups if pass_groups is not None else []
        self.pass_outputs = [{}] * len(pass_groups) 
    def _run_pass_group(
        self,
        mg: MaseGraph,
        pass_group: list,
        pass_args: dict,
        skip_passes: list = [],
    ):
        pass_outputs = {}
        for pass_fn in pass_group:
            # Check if need to skip this pass
            if pass_fn in skip_passes:
                logger.debug(f"Skipping pass: {pass_fn.__name__}")
                continue
            # Extract pass arguments
            logger.debug(f"Running pass: {pass_fn.__name__}")
            args = pass_args.get(pass_fn.__name__, {})
            # Replace self/ references with values from previous passes
            for k, v in args.items():
                if isinstance(v, str) and v.startswith("self/"):
                    args[k] = pass_outputs[v[5:]]
            mg, pass_output = pass_fn(mg, pass_args=args)
            pass_outputs[pass_fn.__name__] = pass_output
        return mg, pass_outputs
    def __call__(
        self,
        mg: MaseGraph,
        pass_args: dict,
        skip_passes: list = [],
    ):
        for idx, pass_group in enumerate(self.pass_groups):
            logger.debug(f"Running pass group {idx}/{len(self.pass_groups)}.")
            logger.debug(
                f"The following passes will be executed: {[pass_fn.__name__ for pass_fn in pass_group]}"
            )
            mg, pass_outputs = self._run_pass_group(
                mg,
                pass_group,
                pass_args,
                skip_passes,
            )
            self.pass_outputs[idx] = pass_outputs
        return mg, self.pass_outputs