Source code for chop.passes.graph.analysis.autosharding.autosharding

import numpy as np
import cvxpy as cp
from time import time
import dill

from chop.tools import get_logger
from .mesh_model import MeshModel

logger = get_logger(__name__)
logger.setLevel("INFO")


def deepgetattr(obj, attr, default=None):
    """Recurses through an attribute chain to get the ultimate value."""
    import functools

    try:
        return functools.reduce(getattr, attr.split("."), obj)
    except AttributeError:
        return default


def _import_solution(
    mg,
    solution: dict,
    mesh: MeshModel,
    extrapolate_sharding: bool = True,
):
    """Import an autosharding solution into the metadata of the MaseGraph.

    Args:
        mg (MaseGraph): input mase graph.
        solution (dict): autosharding solution.
        extrapolate (bool): extrapolate solution from the 1st layer to the rest.

    Returns:
        MaseGraph: input mase graph.
        dict: empty dictionary.
    """
    for node in mg.fx_graph.nodes:
        logger.debug(f"Importing solution for node: {node.name}")

        # Only import solution for getattr nodes
        # TO DO: this is hard-coded for GPT2
        # Figure out how to generalize
        if not node.name.startswith("transformer_"):
            continue

        # Extrapolate from first layer by string matching
        if node.name not in solution.keys() and extrapolate_sharding:

            # Expect the layer number to be the first digit in the node name
            layer_num = int([i for i in node.name.split("_") if i.isdigit()][0])

            # Only replace the first digit to find the equivalent node in the first layer
            extrapolate_node = node.name.replace(f"_{layer_num}_", "_0_", 1)

            if extrapolate_node in solution.keys():
                logger.warning(
                    f"Node: {node.name} not found in solution. Extrapolating from solution for: {extrapolate_node}"
                )
                solution[node.name] = solution[extrapolate_node]
            else:
                logger.debug(
                    f"Node: {node.name} not found in solution, and cannot extrapolate."
                )
                continue

        # Annotate the metadata for each argument
        for arg, arg_spec in solution[node.name].get("args", {}).items():
            node.meta["mase"]["common"]["args"][arg]["dtensor_spec"] = DTensorSpec(
                mesh=mesh,
                placements=arg_spec,
            )

        # Annotate the metadata for each result
        for result, result_spec in solution[node.name].get("results", {}).items():
            node.meta["mase"]["common"]["results"][result]["dtensor_spec"] = (
                DTensorSpec(
                    mesh=mesh,
                    placements=result_spec,
                )
            )

    return mg, {}


def _export_solution(mg, export_file: str = "ilp_solution.pkl"):
    """Export the ILP solution to a pickle file.

    Args:
        mg (MaseGraph): input mase graph.
        export_file (str, optional): output file name. Defaults to "ilp_solution.pkl".

    Returns:
        MaseGraph: input mase graph.
        dict: empty dictionary.
    """
    # Reduce metadata to autosharding solution
    out_dict = {}
    for node in mg.fx_graph.nodes:
        node_name = node.name
        out_dict[node_name] = {
            "args": {},
            "results": {},
        }
        for arg, arg_info in node.meta["mase"]["common"]["args"].items():
            if not isinstance(arg_info, dict):
                continue

            if "dtensor_spec" not in arg_info:
                logger.warning(
                    f"DTensor spec not found for arg: {arg} in node: {node_name}. Assigning fully-replicated solution."
                )
                spec = DTensorSpec(
                    None,
                    (Replicate(), Replicate()),
                )
            else:
                spec = arg_info["dtensor_spec"]

            out_dict[node_name]["args"][arg] = spec.placements

        for result, result_info in node.meta["mase"]["common"]["results"].items():
            if not isinstance(result_info, dict):
                continue

            # TO DO: add warning when dtensor_spec not found
            if "dtensor_spec" not in result_info:
                logger.warning(
                    f"DTensor spec not found for result: {result} in node: {node_name}. Assigning fully-replicated solution."
                )
                spec = DTensorSpec(
                    None,
                    (Replicate(), Replicate()),
                )
            else:
                spec = result_info["dtensor_spec"]
            out_dict[node_name]["results"][result] = spec.placements

    with open(export_file, "wb") as file:
        dill.dump(out_dict, file)

    return mg, {}


def _get_sharding_map(mg):
    """
    Export the tensor sharding map to a dictionary, to be used by the MaseLauncher for
    distributed deployment.

    Args:
        mg (MaseGraph): input mase graph.

    Returns:
        MaseGraph: input mase graph.
        dict: tensor sharding map.

    The tensor sharding map is a dictionary with the following structure.
    {
        module: {
            node: node_name,
            sharding: {
                attr: out_specs,
            },
        },
    }
    """

    logger.info(f"Exporting tensor sharding map from MaseGraph for MaseLauncher.")

    tensor_sharding_map = {}
    for node in mg.fx_graph.nodes:
        if node.op == "get_attr":
            module_str = ".".join(node.target.split(".")[:-1])
            attr = node.target.split(".")[-1]
            module = deepgetattr(node.meta["mase"].model, module_str)

            if (
                "dtensor_spec"
                not in node.meta["mase"]["common"]["results"]["data_out_0"]
            ):
                raise ValueError(
                    f"Couldn't find DTensor sharding specification in solution for node: {node.name}"
                )
            else:
                out_specs = node.meta["mase"]["common"]["results"]["data_out_0"][
                    "dtensor_spec"
                ]

            logger.debug(
                f"Exporting sharding map for {node.name} with spec: {out_specs}"
            )

            if module not in tensor_sharding_map:
                tensor_sharding_map[module] = {
                    "node": node.name,
                    "sharding": {
                        attr: out_specs,
                    },
                }
            else:
                tensor_sharding_map[module]["sharding"][attr] = out_specs

    return tensor_sharding_map


[docs] def autosharding_analysis_pass(mg, pass_args: dict = {}): """Annotate the metadata of each operator in the graph with a parallelization strategy. Args: mg (MaseGraph): input mase graph. pass_args (dict, optional): pass arguments. Defaults to {}. Returns: MaseGraph: annotated mase graph. The pass_args dictionary expects the following elements. - mesh_shape -> tuple : Shape of the device cluster. Should be a 2-dimensional tuple. - inter_node_bandwidth -> int : Inter-node bandwidth, i.e. between GPU nodes. - intra_node_bandwidth -> int : Intra-node bandwidth, i.e. between GPU devices in each node. Additionally, the following elements can be passed. - algo (optional) -> str : Sharding algorithm to use. Default is "alpa". - communications_backend (optional) -> str : Communications backend to use, e.g. "nccl" or "gloo". Default is "nccl". - skip_fully_replicated (optional) -> bool : If set to true, do not consider fully replicated sharding as an option for any operator. - time_limit (optional) -> int : Time limit for the ILP solver, in seconds. Default is 10000. - mip_rel_gap (optional) -> int : MIP relative gap for the ILP solver. Default is 0 (i.e. obtain full solution). - run_checks (optional) -> bool : If set to true, run checks on the autosharding solution. Default is False. - preload_solution (optional) -> bool : If set to true, preload autosharding solution from file. - ilp_solution_file (optional) -> str : File to export the autosharding solution to. Defaults to: "ilp_solution.pkl". """ from torch.distributed._tensor._op_schema import DTensorSpec from torch.distributed._tensor.placement_types import Replicate from .alpa import alpa_autosharding_pass from .megatron import megatron_autosharding_pass assert ( "mesh_shape" in pass_args ), "Logical description for device cluster was not specified." assert "inter_node_bandwidth" in pass_args, "Inter-node bandwidth not specified" assert "intra_node_bandwidth" in pass_args, "Intra-node bandwidth not specified" # Initialize device mesh model, used for cost estimation mesh = MeshModel(pass_args["mesh_shape"]) # Preload autosharding solution if pass_args.get("preload_solution", False): fname = pass_args.get("ilp_solution_file", "ilp_solution.pkl") logger.info(f"Preloading autosharding solution from: {fname}") with open(fname, "rb") as file: solution = dill.load(file) # Annotate the metadata of each operator with the autosharding solution mg, pass_outs = _import_solution(mg, solution, mesh) autosharding_time = 0 # Run autosharding pass else: # Define autosharding backend algo = pass_args.get("algo", "alpa") # Communication cost model depends mesh.set_cost_model_parameters( intra_node_bandwidth=pass_args["intra_node_bandwidth"], inter_node_bandwidth=pass_args["inter_node_bandwidth"], backend=pass_args.get("communications_backend", "default"), ) # Run intra-operator pass start_time = time() if algo == "alpa": mg, pass_outs = alpa_autosharding_pass(mg, mesh, pass_args) elif algo == "megatron": mg, pass_outs = megatron_autosharding_pass(mg, mesh, pass_args) else: raise ValueError(f"Autosharding algorithm {algo} not recognized") end_time = time() autosharding_time = end_time - start_time logger.info( f"Autosharding pass complete. Time taken: {autosharding_time} seconds. Solution: {pass_outs['solution']}" ) # Export solution fname = pass_args.get("ilp_solution_file", "ilp_solution.pkl") logger.info(f"Exporting solution to {fname}") mg, _ = _export_solution(mg, export_file=fname) if not pass_args.get(f"skip_forward", False): tensor_sharding_map = _get_sharding_map(mg) return mg, { "autosharding_time": autosharding_time, "tensor_sharding_map": tensor_sharding_map, **pass_outs, }