chop.actions#
chop.actions.train#
- chop.actions.train.train(model: LightningModule, model_info: dict, data_module: LightningDataModule, dataset_info: dict, task: str, optimizer: str, learning_rate: float, weight_decay: float, scheduler_args: dict, plt_trainer_args: dict, auto_requeue: bool, save_path: str, visualizer: TensorBoardLogger, load_name: str, load_type: str)[source]#
Train the model using PyTorch Lightning.
- Parameters:
model (pl.LightningModule) – Model to be trained.
model_info (dict) – Information about the model.
data_module (pl.LightningDataModule) – Data module for the model.
dataset_info (dict) – Information about the dataset.
task (str) – Task to be performed.
optimizer (str) – Optimizer to be used.
learning_rate (float) – Learning rate for the optimizer.
weight_decay (float) – Weight decay for the optimizer.
scheduler_args (dict) – Arguments for the scheduler.
plt_trainer_args (dict) – Arguments for PyTorch Lightning Trainer.
auto_requeue (bool) – Requeue on SLURM.
save_path (str) – Path to save the model.
visualizer (TensorBoardLogger) – Tensorboard logger.
load_name (str) – Name of the checkpoint to load.
load_type (str) – Type of the checkpoint to load.
chop.actions.test#
- chop.actions.test.test(model: LightningModule, model_info: dict, data_module: LightningDataModule, dataset_info: dict, task: str, optimizer: str, learning_rate: float, weight_decay: float, plt_trainer_args: dict, auto_requeue: bool, save_path: str, visualizer: TensorBoardLogger, load_name: str, load_type: str)[source]#
Evaluate a trained model using PyTorch Lightning.
- Parameters:
model (pl.LightningModule) – Model to be trained.
model_info (dict) – Information about the model.
data_module (pl.LightningDataModule) – Data module for the model.
dataset_info (dict) – Information about the dataset.
task (str) – Task to be performed.
optimizer (str) – Optimizer to be used.
learning_rate (float) – Learning rate for the optimizer.
weight_decay (float) – Weight decay for the optimizer.
plt_trainer_args (dict) – Arguments for PyTorch Lightning Trainer.
auto_requeue (bool) – Requeue on SLURM.
save_path (str) – Path to save the model.
visualizer (TensorBoardLogger) – Tensorboard logger.
load_name (str) – Name of the checkpoint to load.
load_type (str) – Type of the checkpoint to load.
chop.actions.transform#
- chop.actions.transform.pre_transform_load(load_name: str, load_type: str, model: Module)[source]#
Load the model if a checkpoint is provided.
- Parameters:
load_name (str) – _description_
load_type (str) – _description_
model (torch.nn.Module) – _description_
- Returns:
_description_
- Return type:
_type_
- chop.actions.transform.transform(model: Module, model_info: dict, model_name: str, data_module, task: str, config: str, save_dir: str = None, load_name: str = None, load_type: str = None, accelerator: str = 'auto')[source]#
Transform the model based on the configuration.
- Parameters:
model (torch.nn.Module) – _description_
model_info (dict) – _description_
model_name (str) – _description_
data_module (_type_) – _description_
task (str) – _description_
config (str) – _description_
save_dir (str, optional) – _description_. Defaults to None.
load_name (str, optional) – _description_. Defaults to None.
load_type (str, optional) – _description_. Defaults to None.
accelerator (str, optional) – _description_. Defaults to “auto”.
- Raises:
ValueError – _description_
- chop.actions.transform.transform_module(model: Module, model_info, model_name, data_module, task: str, config: str, save_dir: str = None, load_name: str = None, load_type: str = None, accelerator: str = 'auto')[source]#
Transform the model at Pytorch module level.
- Parameters:
model (torch.nn.Module) – _description_
model_info (_type_) – _description_
model_name (_type_) – _description_
data_module (_type_) – _description_
task (str) – _description_
config (str) – _description_
save_dir (str, optional) – _description_. Defaults to None.
load_name (str, optional) – _description_. Defaults to None.
load_type (str, optional) – _description_. Defaults to None.
accelerator (str, optional) – _description_. Defaults to “auto”.
- Returns:
_description_
- Return type:
_type_
- chop.actions.transform.transform_graph(model: Module, model_info, model_name, data_module, task: str, config: str, save_dir: str = None, load_name: str = None, load_type: str = None, accelerator: str = 'auto')[source]#
Transform the model at FX graph level.
- Parameters:
model (torch.nn.Module) – _description_
model_info (_type_) – _description_
model_name (_type_) – _description_
data_module (_type_) – _description_
task (str) – _description_
config (str) – _description_
save_dir (str, optional) – _description_. Defaults to None.
load_name (str, optional) – _description_. Defaults to None.
load_type (str, optional) – _description_. Defaults to None.
accelerator (str, optional) – _description_. Defaults to “auto”.
- Returns:
_description_
- Return type:
_type_
chop.actions.search#
- chop.actions.search.search.parse_search_config(search_config: dict)[source]#
Parse search config from a dict or a toml file and do sanity check. The search config must consist of two parts: strategy and search_space.
- Parameters:
search_config – A dictionary or a path to a toml file containing the search config.
- Returns:
_description_
- Return type:
_type_
- chop.actions.search.search.search(model: Module, model_info, task: str, dataset_info, data_module, search_config: dict | PathLike, save_path: PathLike, accelerator: str, load_name: PathLike = None, load_type: str = None, visualizer=None)[source]#
Perform search using a defined search strategy and a search space.
- Parameters:
model (torch.nn.Module) – _description_
model_info (_type_) – _description_
task (str) – _description_
dataset_info (_type_) – _description_
data_module (_type_) – _description_
search_config (dict | PathLike) – _description_
save_path (PathLike) – _description_
accelerator (str) – _description_
load_name (PathLike, optional) – _description_. Defaults to None.
load_type (str, optional) – _description_. Defaults to None.
visualizer (_type_, optional) – _description_. Defaults to None.