Source code for chop.tools.checkpoint_load

import logging
import os

import torch

logger = logging.getLogger(__name__)


[docs] def load_lightning_ckpt_to_unwrapped_model(checkpoint: str, model: torch.nn.Module): """ Load a PyTorch Lightning checkpoint to a PyTorch model. """ src_state_dict = torch.load(checkpoint)["state_dict"] tgt_state_dict = model.state_dict() new_tgt_state_dict = {} for k, v in src_state_dict.items(): if "model." in k: possible_tgt_k = ".".join(k.split(".")[1:]) else: possible_tgt_k = k if possible_tgt_k in tgt_state_dict: new_tgt_state_dict[possible_tgt_k] = v model.load_state_dict(state_dict=new_tgt_state_dict) return model
[docs] def load_unwrapped_ckpt(checkpoint: str, model: torch.nn.Module): """ Load a PyTorch state dict or checkpoint containing state dict to a PyTorch model. """ state_dict = torch.load(checkpoint) if "state_dict" in state_dict: state_dict = state_dict["state_dict"] model.load_state_dict(state_dict=state_dict) return model
[docs] def load_graph_module_ckpt(checkpoint: str, weights_only: bool = False): """ Load a serialized graph module. """ if os.path.isdir(checkpoint): checkpoint = os.path.join(checkpoint, "graph_module.mz") model = torch.load(checkpoint, weights_only=weights_only) return model
[docs] def load_model( load_name: str, load_type: str = "mz", model: torch.nn.Module = None ) -> torch.nn.Module | torch.fx.GraphModule: """Load a pytorch/lightning/mase checkpoint to a model. Args: load_name (str): path to the checkpoint load_type (str, optional): checkpoint type, must be one of ['pt', 'pl', 'mz'], representing pytorch/lightning/mase. Defaults to "auto" inferred from the extension. model (torch.nn.Module, optional): Model candidate to load checkpoint. Note that 'ms' checkpoint loads the model as well as state dict, thus does not need this arg. Defaults to None. Raises: ValueError: Unknown extension for 'load_type'. Returns: nn.Module/fx.GraphModule: the model with the checkpoint loaded """ if load_type == "hf": raise RuntimeError( "HuggingFace checkpoint should be loaded using model_inst_fn." ) elif load_type not in ["pt", "pl", "mz"]: raise ValueError(f"Unknown extension for 'load_type': {load_type}") if load_type == "pt": model = load_unwrapped_ckpt(checkpoint=load_name, model=model) logger.info(f"Loaded pytorch checkpoint from {load_name}") elif load_type == "pl": if not load_name.endswith(".ckpt"): logger.warning( f"Lightning checkpoint should end with '.ckpt', but got {load_name}" ) model = load_lightning_ckpt_to_unwrapped_model( checkpoint=load_name, model=model ) logger.info(f"Loaded pytorch lightning checkpoint from {load_name}") else: assert load_name.endswith( ".mz" ), f"Invalid extension for 'load_type=mz': {load_name}, must be a '.mz' file, but got {load_name}." model = load_graph_module_ckpt(checkpoint=load_name) logger.info(f"Loaded mase checkpoint from {load_name}") return model