Source code for chop.distributed.launcher

import os
from functools import partial
from time import time

import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp

from torch.distributed._tensor import (
    DeviceMesh,
    Replicate,
    Shard,
)

from chop.distributed.tensor import distribute_module, distribute_tensor

from chop.distributed.utils import rlog
from ..tools import get_logger

logger = get_logger(__name__)
logger.setLevel("DEBUG")


def distributed_timing(fn, *args, **kwargs):
    dist.barrier(async_op=True)
    start = time()
    result = fn(*args, **kwargs)
    dist.barrier(async_op=True)
    end = time()

    return result, (end - start)


def distributed_average_timing(fn, repeat, args):
    times = []
    for itr in range(repeat):
        rlog(
            logger,
            dist.get_rank(),
            f"Running teration {itr}",
            "debug",
        )
        dist.barrier(async_op=True)
        start = time()
        result = fn(*args)
        dist.barrier(async_op=True)
        end = time()
        times.append(end - start)
        rlog(
            logger,
            dist.get_rank(),
            f"Time taken: {end - start}s",
            "debug",
        )

    return result, sum(times[2:]) / len(times[2:])


def dist_model_fn(
    name: str,
    module: nn.Module,
    device_mesh: DeviceMesh,
    rank: int,
    tensor_sharding_map={},
) -> None:
    """
    This function gets called by torch.distributed._tensor.distribute_module on each module in the model.
    Each tensor in each module is distributed according to the sharding configuration in tensor_sharding_map.
    """
    if module in tensor_sharding_map:
        node_name = tensor_sharding_map[module]["node"]
        for parameter, sharding_config in tensor_sharding_map[module][
            "sharding"
        ].items():
            if parameter in ["data_in_0", "output", "data_out_0"]:
                continue
            if not hasattr(module, parameter):
                rlog(
                    logger,
                    rank,
                    f"Module {module} does not have parameter {parameter}",
                    level="warning",
                )
                continue

            placement = sharding_config.placements

            try:
                rlog(
                    logger,
                    rank,
                    f"Distributing parameter {parameter} of module {node_name} to {placement}",
                    level="debug",
                )
                distributed_tensor = distribute_tensor(
                    getattr(module, parameter), device_mesh, placement
                )
                setattr(module, parameter, torch.nn.Parameter(distributed_tensor))
            except Exception as e:
                rlog(
                    logger,
                    rank,
                    f"Error distributing parameter {parameter} of module {node_name} to {placement}: {e}",
                    level="error",
                )


def device_fn(
    rank, world_size, model=None, device_mesh=None, tensor_sharding_map={}, inputs=[]
):
    """
    This function gets called on each GPU device to set up the distributed environment and distribute the model,
    following the SPMD model.
    """
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    os.environ["RANK"] = str(rank)

    # Initialize
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    device = torch.device("cuda", rank)
    torch.cuda.set_device(device)

    # Distribute model parameters according to sharding configuration
    mesh = DeviceMesh("cuda", mesh=device_mesh)
    rlog(logger, rank, f"Distributing module parameters...", level="info")
    model, dist_time = distributed_timing(
        distribute_module,
        model,
        mesh,
        partial(dist_model_fn, rank=rank, tensor_sharding_map=tensor_sharding_map),
        input_fn=None,
        output_fn=None,
    )
    rlog(logger, rank, f"Module distribution done. Time taken: {dist_time} seconds.")

    # Run forward pass
    rlog(logger, rank, f"Starting forward pass.", level="info")
    inputs = [
        distribute_tensor(in_tensor, mesh, [Replicate(), Replicate()])
        for in_tensor in inputs
    ]
    _, time_taken = distributed_average_timing(
        fn=model,
        repeat=10,
        args=inputs,
    )
    rlog(logger, rank, f"Forward pass finished. Time taken: {time_taken}", level="info")

    dist.destroy_process_group()


[docs] class MaseLauncher: """ MaseLauncher launches an optimized model on multiple GPUs using torch.distributed. """
[docs] def __init__(self, mase_graph, world_size=None, device_mesh=None): """Initialize the MaseLauncher. Args: mase_graph (MaseGraph): The MaseGraph object containing the model. world_size (int, optional): Number of GPUs to use. Defaults to None. device_mesh (list, optional): List of GPUs to use. Defaults to None. """ self.mg = mase_graph self.model = mase_graph.model self.world_size = world_size self.device_mesh = device_mesh
[docs] def run(self, tensor_sharding_map={}, inputs=[]): logger.info(f"Launching model with world size {self.world_size}.") mp.spawn( partial( device_fn, model=self.model, device_mesh=self.device_mesh, tensor_sharding_map=tensor_sharding_map, inputs=inputs, ), args=(self.world_size,), nprocs=self.world_size, join=True, )