quanda.utils.training.base_pl_module module

Base PyTorch Lightning module for training models.

class quanda.utils.training.base_pl_module.BasicLightningModule(model: Module, optimizer: Callable, lr: float, criterion: _Loss, scheduler: Callable | None = None, optimizer_kwargs: dict | None = None, scheduler_kwargs: dict | None = None, *args, **kwargs)[source]

Bases: LightningModule

Wrapper for a basic PyTorch Lightning module.

__init__(model: Module, optimizer: Callable, lr: float, criterion: _Loss, scheduler: Callable | None = None, optimizer_kwargs: dict | None = None, scheduler_kwargs: dict | None = None, *args, **kwargs)[source]

Construct the BasicLightningModule class.

Parameters:
  • model (torch.nn.Module) – Model to train.

  • optimizer (Callable) – Optimizer to use for training.

  • lr (float) – Learning rate for the optimizer.

  • criterion (torch.nn.modules.loss._Loss) – Loss function to use for training.

  • scheduler (Optional[Callable], optional) – Learning rate scheduler to use for training, defaults to None.

  • optimizer_kwargs (Optional[dict], optional) – Keyword arguments for the optimizer, defaults to None.

  • scheduler_kwargs (Optional[dict], optional) – Keyword arguments for the scheduler, defaults to None.

  • args (Any) – Any additional arguments to pass to the superclass.

  • kwargs (Any) – Any additional keyword arguments to pass to the superclass.

configure_optimizers()[source]

Create the optimizer and scheduler for training.

Raises:
  • ValueError – If the optimizer or scheduler is not an instance of the expected class.

  • ValueError – If the scheduler is not an instance of the expected class.

forward(inputs)[source]

Forward pass of the model.

Parameters:

inputs (Union[torch.Tensor, Dict[str, torch.Tensor]]) – Input to the model. Can be a tensor or a dict of tensors (e.g. for HuggingFace transformer models).

Returns:

Output of the model.

Return type:

Any

on_load_checkpoint(checkpoint)[source]

Load the model state from a checkpoint.

on_save_checkpoint(checkpoint)[source]

Save the model state to a checkpoint.

training_step(batch, batch_idx)[source]

One training step.

Parameters:
  • batch – Single batch of data.

  • batch_idx – Index of the batch.

Returns:

Loss for the batch.

Return type:

torch.Tensor

validation_step(batch, batch_idx)[source]

One validation step.

Parameters:
  • batch – Single batch of data.

  • batch_idx – Index of the batch.

Returns:

Loss for the batch.

Return type:

torch.Tensor