Source code for chop.tools.get_input

import inspect
from typing import Literal
from enum import Enum


[docs] class ModelSource(Enum): """ The source of the model, must be one of the following: - HF: HuggingFace - MANUAL: manually implemented - PATCHED: patched HuggingFace - TOY: toy model for testing and debugging - PHYSICAL: model that perform classification using physical data point vectors - NERF: model that estimates neural radiance field (NeRF) of a 3D scene """ HF_TRANSFORMERS = "hf_transformers" MANUAL = "manual" PATCHED = "patched" TOY = "toy" TORCHVISION = "torchvision" VISION_OTHERS = "vision_others" PHYSICAL = "physical" NERF = "nerf"
def _get_default_args(func): signature = inspect.signature(func) return { k: v.default for k, v in signature.parameters.items() # if v.default is not inspect.Parameter.empty }
[docs] def get_cf_args(model_info, task: str, model): """Get concrete forward args for freezing dynamic control flow in forward pass""" all_forward_kwargs = _get_default_args(model.forward) cf_args = {} if model_info.model_source == ModelSource.PATCHED: cf_args = model.patched_nodes["concrete_forward_args"] elif model_info.is_vision_model or model_info.is_physical_model: cf_args = {} elif model_info.is_nlp_model: match task: case "classification" | "cls": required_input_args = ["input_ids", "attention_mask", "labels"] case "language_modeling" | "lm": required_input_args = ["input_ids", "attention_mask", "labels"] case "translation" | "tran": required_input_args = [ "input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", ] case _: raise ValueError(f"Task {task} is not supported for {model_info.name}") for required_input_arg in required_input_args: all_forward_kwargs.pop(required_input_arg) cf_args = all_forward_kwargs else: raise RuntimeError(f"Unsupported model+task: {model_info.name}+{task}") return cf_args
[docs] def get_dummy_input( model_info, data_module, task: str, device: str = "meta", ) -> dict: """Create a single dummy input for a model. The dummy input is a single sample from the training set. Args: datamodule (MaseDataModule): a LightningDataModule instance (see machop/chop/dataset/__init__.py). Make sure the datamodule is prepared and setup. task (str): task name, one of ["cls", "classification", "lm", "language_modeling", "translation", "tran"] is_nlp_model (bool, optional): Whether the task is NLP task or not. Defaults to False. Returns: dict: a dummy input dict which can be passed to the wrapped lightning model's forward method, like model(**dummy_input) """ assert ( data_module.train_dataset is not None ), "DataModule is not setup. Please call data_module.prepare_data() and .setup()." index: int = 0 train_iter = iter(data_module.train_dataloader()) n_batches = len(data_module.train_dataloader()) if index >= n_batches * data_module.batch_size: raise ValueError(f"index {index} is out of range.") batch_index = index // data_module.batch_size sample_index = index % data_module.batch_size for _ in range(batch_index): next(train_iter) if model_info.is_vision_model or model_info.is_physical_model: match task: case "classification" | "cls": x, y = next(train_iter) # x = x[[0], ...].to(device) x = x.to(device) if data_module.name == "mnist" and model_info.is_vision_model: dummy_inputs = {"input_1": x} else: dummy_inputs = {"x": x} case _: raise ValueError(f"Task {task} is not supported for {model_info.name}") elif model_info.is_nerf_model: # TODO: pass elif model_info.is_nlp_model: match task: case "classification" | "cls": input_dict = next(train_iter) input_ids = input_dict["input_ids"][[sample_index], ...].to(device) attention_mask = input_dict["attention_mask"][[sample_index], ...].to( device ) labels = input_dict["labels"][[sample_index], ...].to(device) dummy_inputs = { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, } if "token_type_ids" in input_dict: dummy_inputs["token_type_ids"] = input_dict["token_type_ids"][ [sample_index], ... ].to(device) case "language_modeling" | "lm": input_dict = next(train_iter) input_ids = input_dict["input_ids"][[sample_index], ...].to(device) attention_mask = input_dict["attention_mask"][[sample_index], ...].to( device ) labels = input_dict["labels"][[sample_index], ...].to(device) dummy_inputs = { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, } case "translation" | "tran": input_dict = next(train_iter) input_ids = input_dict["input_ids"][[sample_index], ...].to(device) attention_mask = input_dict["attention_mask"][[sample_index], ...].to( device ) decoder_input_ids = input_dict["decoder_input_ids"][ [sample_index], ... ].to(device) decoder_attention_mask = input_dict["decoder_attention_mask"][ [sample_index], ... ].to(device) dummy_inputs = { "input_ids": input_ids, "attention_mask": attention_mask, "decoder_input_ids": decoder_input_ids, "decoder_attention_mask": decoder_attention_mask, } case _: raise ValueError(f"Task {task} is not supported for {model_info.name}") else: raise RuntimeError(f"Unsupported model+task: {model_info.name}+{task}") return dummy_inputs
[docs] class InputGenerator:
[docs] def __init__( self, model_info, data_module, task: str, which_dataloader: Literal["train", "val", "test"], max_batches: int = None, ) -> None: """ Input generator for feeding batches to models. This is used for software passes. Args: datamodule (MyDataModule): a MyDataModule instance (see machop/chop/dataset/data_module.py). Make sure the datamodule is prepared and setup. max_batches (int, optional): Maximum number of batches to generate. Defaults to None will stop when reaching the last batch in dataloader. Returns: (dict): a dummy input dict which can be passed to the wrapped lightning model's forward method, like model(**dummy_input) """ assert ( getattr(data_module, f"{which_dataloader}_dataset") is not None ), "DataModule is not setup. Please call data_module.prepare_data() and .setup()." self.model_info = model_info self.task = task self.batch_size = data_module.batch_size self.dataloader = getattr(data_module, f"{which_dataloader}_dataloader")() self.dataloader_iter = iter(self.dataloader) self.max_batches = max_batches self.current_batch = 0
def __iter__(self): return self def __next__(self) -> dict: if self.max_batches is not None and self.current_batch >= self.max_batches: raise StopIteration if self.model_info.is_vision_model or self.model_info.is_physical_model: match self.task: case "classification" | "cls": x, y = next(self.dataloader_iter) inputs = {"x": x} case _: raise ValueError( f"Task {self.task} is not supported for {self.model_info.name}" ) elif self.model_info.is_physical_model: match self.task: case "classification" | "cls": x, y = next(self.dataloader_iter) inputs = {"x": x} case _: raise ValueError( f"Task {self.task} is not supported for {self.model_info.name}" ) elif self.model_info.is_nlp_model: match self.task: case "classification" | "cls": input_dict = next(self.dataloader_iter) inputs = { "input_ids": input_dict["input_ids"], "attention_mask": input_dict["attention_mask"], "labels": input_dict["labels"], } if "token_type_ids" in input_dict: inputs["token_type_ids"] = input_dict["token_type_ids"] case "language_modeling" | "lm": input_dict = next(self.dataloader_iter) inputs = { "input_ids": input_dict["input_ids"], "attention_mask": input_dict["attention_mask"], "labels": input_dict["labels"], } case "translation" | "tran": input_dict = next(self.dataloader_iter) inputs = { "input_ids": input_dict["input_ids"], "attention_mask": input_dict["attention_mask"], "decoder_input_ids": input_dict["decoder_input_ids"], "decoder_attention_mask": input_dict["decoder_attention_mask"], } case _: raise ValueError( f"Task {self.task} is not supported for {self.model_info.name}" ) else: raise RuntimeError( f"Unsupported model+task: {self.model_info.name}+{self.task}" ) self.current_batch += 1 return inputs