Source code for chop.actions.train

import logging
import os
from pathlib import Path

import pytorch_lightning as pl
from chop.tools.plt_wrapper import get_model_wrapper
from chop.tools.checkpoint_load import load_model
from chop.tools.get_input import get_dummy_input
from chop.passes.graph import (
    add_common_metadata_analysis_pass,
    init_metadata_analysis_pass,
    add_software_metadata_analysis_pass,
)
from chop.passes.graph.interface import save_mase_graph_interface_pass
from chop.passes.graph.transforms import metadata_value_type_cast_transform_pass
from chop.ir.graph import MaseGraph
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.plugins.environments import SLURMEnvironment

from torch.distributed.fsdp import FullyShardedDataParallel
from pytorch_lightning.strategies import DDPStrategy
from chop.tools.utils import parse_accelerator, to_numpy_if_tensor


logger = logging.getLogger(__name__)


# class CustomFSDPStrategy(DDPStrategy):
#     def configure_ddp(self):
#         # model = DistributedDataParallel(model)
#         fsdp_model = FullyShardedDataParallel(
#             self.model,
#             # fsdp_auto_wrap_policy=default_auto_wrap_policy,
#             # cpu_offload=CPUOffload(offload_params=True),
#         )
#         self.model = fsdp_model


[docs] def train( model, model_info, data_module, dataset_info, task, optimizer, learning_rate, weight_decay, scheduler_args, plt_trainer_args, auto_requeue, save_path, visualizer, load_name, load_type, ): if save_path is not None: # if save_path is None, the model will not be saved if not os.path.isdir(save_path): os.makedirs(save_path) checkpoint_callback = ModelCheckpoint( save_top_k=1, monitor="val_loss_epoch", mode="min", filename="best", dirpath=save_path, save_last=True, ) # tb_logger = TensorBoardLogger(save_dir=save_path, name="logs") lr_monitor_callback = LearningRateMonitor(logging_interval="step") plt_trainer_args["callbacks"] = [ checkpoint_callback, lr_monitor_callback, ] plt_trainer_args["logger"] = visualizer # plugin if auto_requeue: plugins = [SLURMEnvironment(auto_requeue=auto_requeue)] else: plugins = None plt_trainer_args["plugins"] = plugins # Check optimizer # if plt_trainer_args["strategy"] in ["deepspeed_stage_3"]: # assert optimizer in [ # "FusedAdam", # "fused_adam", # ], "optimizer should be 'fused_adam' given --strategy={}".format( # plt_trainer_args["strategy"] # ) # elif plt_trainer_args["strategy"] in ["fsdp_custom"]: # plt_trainer_args["strategy"] = CustomFSDPStrategy() wrapper_cls = get_model_wrapper(model_info, task) if load_name is not None: model = load_model(load_name, load_type=load_type, model=model) logger.info(f"'{load_type}' checkpoint loaded before training") pl_model = wrapper_cls( model, dataset_info=dataset_info, learning_rate=learning_rate, weight_decay=weight_decay, scheduler_args=scheduler_args, epochs=plt_trainer_args["max_epochs"], optimizer=optimizer, ) trainer = pl.Trainer(**plt_trainer_args) trainer.fit( pl_model, datamodule=data_module, ) # Save the trained model along with relevant metadata in the training_ckpts folder. # NOTE: This is important if the model was previously transformed with architectural # changes. The state dictionary that's saved by PyTorch Lightning wouldn't work. if save_path is not None and load_name is not None and load_type == "mz": accelerator = plt_trainer_args["accelerator"] accelerator = parse_accelerator(accelerator) graph = MaseGraph(model) dummy_input = get_dummy_input(model_info, data_module, task, device=accelerator) graph, _ = init_metadata_analysis_pass(graph, None) graph, _ = add_common_metadata_analysis_pass(graph, {"dummy_in": dummy_input}) graph, _ = add_software_metadata_analysis_pass(graph, None) transformed_ckpt = Path(save_path) / "transformed_ckpt" transformed_ckpt.mkdir(parents=True, exist_ok=True) graph, _ = metadata_value_type_cast_transform_pass( graph, pass_args={"fn": to_numpy_if_tensor} ) save_mase_graph_interface_pass(graph, pass_args=transformed_ckpt)