"""Module for training PyTorch models using Lightning."""
import abc
import os
from abc import abstractmethod
from typing import Callable, List, Optional
import lightning as L
import torch
from lightning import seed_everything
from quanda.utils.training.base_pl_module import BasicLightningModule
[docs]
class BaseTrainer(metaclass=abc.ABCMeta):
"""Base class for a trainer."""
[docs]
@abstractmethod
def fit(
self,
model: torch.nn.Module,
train_dataloaders: torch.utils.data.dataloader.DataLoader,
val_dataloaders: Optional[
torch.utils.data.dataloader.DataLoader
] = None,
accelerator: str = "cpu",
devices: int = 0,
seed: int = 42,
callbacks: Optional[List[L.Callback]] = None,
trainer_fit_kwargs: Optional[dict] = None,
*args,
**kwargs,
) -> torch.nn.Module:
"""Train a model using the provided dataloaders.
Parameters
----------
model: torch.nn.Module
Model to train.
train_dataloaders: torch.utils.data.dataloader.DataLoader
Dataloader for the training data.
val_dataloaders: Optional[torch.utils.data.dataloader.DataLoader]
Dataloader for the validation data, defaults to None.
accelerator: str
The accelerator to use for training, by default "cpu".
devices: int
The number of devices to use for training, by default 0 (i.e.
all available).
seed: int
Random seed.
callbacks: Optional[List[L.Callback]]
Lightning callbacks to attach to the trainer, defaults to None.
trainer_fit_kwargs: Optional[dict]
Additional keyword arguments to pass to the trainer's fit method,
defaults to None.
args: Any
Additional arguments to pass to the fit method.
kwargs: Any
Additional keyword arguments to pass to the fit method.
kwargs
Returns
-------
torch.nn.Module
The trained model.
"""
raise NotImplementedError
[docs]
def get_model(self) -> torch.nn.Module:
"""Get the model that was trained.
Returns
-------
torch.nn.Module
The trained model.
"""
raise NotImplementedError
[docs]
class Trainer(BaseTrainer):
"""Simple class for training PyTorch models using Lightning."""
[docs]
def __init__(
self,
optimizer: Callable,
lr: float,
max_epochs: int,
criterion: torch.nn.modules.loss._Loss,
scheduler: Optional[Callable] = None,
optimizer_kwargs: Optional[dict] = None,
scheduler_kwargs: Optional[dict] = None,
logger: Optional[L.pytorch.loggers.logger.Logger] = None,
seed: int = 27,
num_workers: int = 0,
enable_progress_bar: bool = True,
gradient_clip_val: Optional[float] = None,
):
"""Construct the Trainer class.
Parameters
----------
optimizer : Callable
Optimizer to use for training.
lr : float
Learning rate for the optimizer.
max_epochs : int
Maximum number of epochs to train for.
criterion : torch.nn.modules.loss._Loss
Loss to use during training.
scheduler : Optional[Callable], optional
Scheduler to use during training, defaults to None
optimizer_kwargs : Optional[dict], optional
Keyword arguments for the optimizer, defaults to None
scheduler_kwargs : Optional[dict], optional
Keyword arguments for the scheduler, defaults to None
logger : Optional[Callable], optional
Logger to use during training, defaults to None
seed : int, optional
The seed for the projector, by default 27.
num_workers : int, optional
Number of workers to use for data loading, by default 0.
enable_progress_bar : bool, optional
Whether to enable the progress bar during training, by
default True.
gradient_clip_val : Optional[float], optional
Value to use for gradient clipping, by default None
(i.e. no gradient clipping).
"""
self.optimizer = optimizer
self.lr = lr
self.max_epochs = max_epochs
self.criterion = criterion
self.scheduler = scheduler
self.logger = logger
self.optimizer_kwargs = optimizer_kwargs or {}
self.scheduler_kwargs = scheduler_kwargs or {}
self.num_workers = num_workers
self.enable_progress_bar = enable_progress_bar
self.gradient_clip_val = gradient_clip_val
seed_everything(seed, workers=True)
super(Trainer, self).__init__()
[docs]
def fit(
self,
model: torch.nn.Module,
train_dataloaders: torch.utils.data.dataloader.DataLoader,
val_dataloaders: Optional[
torch.utils.data.dataloader.DataLoader
] = None,
accelerator: str = "cpu",
devices: int = 0,
seed: int = 42,
callbacks: Optional[List[L.Callback]] = None,
*args,
**kwargs,
):
"""Train a model using the provided dataloaders.
Parameters
----------
model : torch.nn.Module
Model to train.
train_dataloaders : torch.utils.data.dataloader.DataLoader
Dataloader for the training data.
val_dataloaders : Optional[torch.utils.data.dataloader.DataLoader]
Dataloader for the validation data, defaults to None.
accelerator : str, optional
The accelerator to use for training, by default "cpu".
devices : int, optional
The number of devices to use for training, by default 0 (i.e.
all available).
seed: int
Random seed.
callbacks : Optional[List[L.Callback]]
Lightning callbacks to attach to the trainer, defaults to None.
args : Any
Additional arguments to pass to the fit method.
kwargs : Any
Additional keyword arguments to pass to the fit method.
"""
seed_everything(seed, workers=True)
module = BasicLightningModule(
model=model,
optimizer=self.optimizer,
lr=self.lr,
criterion=self.criterion,
optimizer_kwargs=self.optimizer_kwargs,
scheduler=self.scheduler,
scheduler_kwargs=self.scheduler_kwargs,
)
trainer = L.Trainer(
max_epochs=self.max_epochs,
devices=[devices] if accelerator == "gpu" else 1,
accelerator=accelerator,
logger=self.logger,
enable_progress_bar=self.enable_progress_bar,
gradient_clip_val=self.gradient_clip_val,
callbacks=callbacks,
)
trainer.fit(module, train_dataloaders, val_dataloaders)
model.load_state_dict(module.model.state_dict())
return model
class _EpochSnapshotCallback(L.Callback):
"""Snapshot ``pl_module.model`` to disk at a fixed set of epochs.
Used by ``Benchmark.train`` to capture intermediate checkpoints in a
single training run when ``num_checkpoints > 1``.
"""
def __init__(self, snapshot_epochs: List[int], snapshot_dirs: List[str]):
super().__init__()
self._epoch_to_dir = dict(zip(snapshot_epochs, snapshot_dirs))
def on_train_epoch_end(self, trainer, pl_module):
target = self._epoch_to_dir.get(trainer.current_epoch)
if target is None:
return
os.makedirs(target, exist_ok=True)
pl_module.model.save_pretrained(target, safe_serialization=True)