Source code for chop.actions.test
import logging
import os
import pickle
import pytorch_lightning as pl
from chop.tools.plt_wrapper import get_model_wrapper
from chop.tools.checkpoint_load import load_model
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.plugins.environments import SLURMEnvironment
import pytest
logger = logging.getLogger(__name__)
[docs]
@pytest.mark.skip(reason="This isn't a test")
def test(
model: pl.LightningModule,
model_info: dict,
data_module: pl.LightningDataModule,
dataset_info: dict,
task: str,
optimizer: str,
learning_rate: float,
weight_decay: float,
plt_trainer_args: dict,
auto_requeue: bool,
save_path: str,
visualizer: TensorBoardLogger,
load_name: str,
load_type: str,
):
"""
Evaluate a trained model using PyTorch Lightning.
Args:
model (pl.LightningModule): Model to be trained.
model_info (dict): Information about the model.
data_module (pl.LightningDataModule): Data module for the model.
dataset_info (dict): Information about the dataset.
task (str): Task to be performed.
optimizer (str): Optimizer to be used.
learning_rate (float): Learning rate for the optimizer.
weight_decay (float): Weight decay for the optimizer.
plt_trainer_args (dict): Arguments for PyTorch Lightning Trainer.
auto_requeue (bool): Requeue on SLURM.
save_path (str): Path to save the model.
visualizer (TensorBoardLogger): Tensorboard logger.
load_name (str): Name of the checkpoint to load.
load_type (str): Type of the checkpoint to load.
"""
if save_path is not None:
if not os.path.exists(save_path):
os.makedirs(save_path)
plt_trainer_args["callbacks"] = []
plt_trainer_args["logger"] = visualizer
# plugin
if auto_requeue:
plugins = [SLURMEnvironment(auto_requeue=auto_requeue)]
else:
plugins = None
plt_trainer_args["plugins"] = plugins
wrapper_cls = get_model_wrapper(model_info, task)
if load_name is not None:
model = load_model(load_name, load_type=load_type, model=model)
plt_model = wrapper_cls(
model,
dataset_info=dataset_info,
learning_rate=learning_rate,
weight_decay=weight_decay,
optimizer=optimizer,
)
trainer = pl.Trainer(**plt_trainer_args)
if data_module.dataset_info.test_split_available:
trainer.test(plt_model, datamodule=data_module)
elif data_module.dataset_info.pred_split_available:
predicted_results = trainer.predict(plt_model, datamodule=data_module)
pred_save_name = os.path.join(save_path, "predicted_result.pkl")
with open(pred_save_name, "wb") as f:
pickle.dump(predicted_results, f)
logger.info(f"Predicted results is saved to {pred_save_name}")
else:
raise ValueError(
f"Test or pred split not available for dataset {data_module.info.name}"
)