quanda.utils.training.trainer module

Module for training PyTorch models using Lightning.

class quanda.utils.training.trainer.BaseTrainer[source]

Bases: object

Base class for a trainer.

abstract fit(model: Module, train_dataloaders: DataLoader, val_dataloaders: DataLoader | None = None, accelerator: str = 'cpu', devices: int = 0, seed: int = 42, callbacks: List[Callback] | None = None, trainer_fit_kwargs: dict | None = None, *args, **kwargs) Module[source]

Train a model using the provided dataloaders.

Parameters:
  • model (torch.nn.Module) – Model to train.

  • train_dataloaders (torch.utils.data.dataloader.DataLoader) – Dataloader for the training data.

  • val_dataloaders (Optional[torch.utils.data.dataloader.DataLoader]) – Dataloader for the validation data, defaults to None.

  • accelerator (str) – The accelerator to use for training, by default “cpu”.

  • devices (int) – The number of devices to use for training, by default 0 (i.e. all available).

  • seed (int) – Random seed.

  • callbacks (Optional[List[L.Callback]]) – Lightning callbacks to attach to the trainer, defaults to None.

  • trainer_fit_kwargs (Optional[dict]) – Additional keyword arguments to pass to the trainer’s fit method, defaults to None.

  • args (Any) – Additional arguments to pass to the fit method.

  • kwargs (Any) – Additional keyword arguments to pass to the fit method.

  • kwargs

Returns:

The trained model.

Return type:

torch.nn.Module

get_model() Module[source]

Get the model that was trained.

Returns:

The trained model.

Return type:

torch.nn.Module

class quanda.utils.training.trainer.Trainer(optimizer: Callable, lr: float, max_epochs: int, criterion: _Loss, scheduler: Callable | None = None, optimizer_kwargs: dict | None = None, scheduler_kwargs: dict | None = None, logger: Logger | None = None, seed: int = 27, num_workers: int = 0, enable_progress_bar: bool = True, gradient_clip_val: float | None = None)[source]

Bases: BaseTrainer

Simple class for training PyTorch models using Lightning.

__init__(optimizer: Callable, lr: float, max_epochs: int, criterion: _Loss, scheduler: Callable | None = None, optimizer_kwargs: dict | None = None, scheduler_kwargs: dict | None = None, logger: Logger | None = None, seed: int = 27, num_workers: int = 0, enable_progress_bar: bool = True, gradient_clip_val: float | None = None)[source]

Construct the Trainer class.

Parameters:
  • optimizer (Callable) – Optimizer to use for training.

  • lr (float) – Learning rate for the optimizer.

  • max_epochs (int) – Maximum number of epochs to train for.

  • criterion (torch.nn.modules.loss._Loss) – Loss to use during training.

  • scheduler (Optional[Callable], optional) – Scheduler to use during training, defaults to None

  • optimizer_kwargs (Optional[dict], optional) – Keyword arguments for the optimizer, defaults to None

  • scheduler_kwargs (Optional[dict], optional) – Keyword arguments for the scheduler, defaults to None

  • logger (Optional[Callable], optional) – Logger to use during training, defaults to None

  • seed (int, optional) – The seed for the projector, by default 27.

  • num_workers (int, optional) – Number of workers to use for data loading, by default 0.

  • enable_progress_bar (bool, optional) – Whether to enable the progress bar during training, by default True.

  • gradient_clip_val (Optional[float], optional) – Value to use for gradient clipping, by default None (i.e. no gradient clipping).

fit(model: Module, train_dataloaders: DataLoader, val_dataloaders: DataLoader | None = None, accelerator: str = 'cpu', devices: int = 0, seed: int = 42, callbacks: List[Callback] | None = None, *args, **kwargs)[source]

Train a model using the provided dataloaders.

Parameters:
  • model (torch.nn.Module) – Model to train.

  • train_dataloaders (torch.utils.data.dataloader.DataLoader) – Dataloader for the training data.

  • val_dataloaders (Optional[torch.utils.data.dataloader.DataLoader]) – Dataloader for the validation data, defaults to None.

  • accelerator (str, optional) – The accelerator to use for training, by default “cpu”.

  • devices (int, optional) – The number of devices to use for training, by default 0 (i.e. all available).

  • seed (int) – Random seed.

  • callbacks (Optional[List[L.Callback]]) – Lightning callbacks to attach to the trainer, defaults to None.

  • args (Any) – Additional arguments to pass to the fit method.

  • kwargs (Any) – Additional keyword arguments to pass to the fit method.