quanda.utils.training.base_pl_module module¶
Base PyTorch Lightning module for training models.
- class quanda.utils.training.base_pl_module.BasicLightningModule(model: Module, optimizer: Callable, lr: float, criterion: _Loss, scheduler: Callable | None = None, optimizer_kwargs: dict | None = None, scheduler_kwargs: dict | None = None, *args, **kwargs)[source]¶
Bases:
LightningModuleWrapper for a basic PyTorch Lightning module.
- __init__(model: Module, optimizer: Callable, lr: float, criterion: _Loss, scheduler: Callable | None = None, optimizer_kwargs: dict | None = None, scheduler_kwargs: dict | None = None, *args, **kwargs)[source]¶
Construct the BasicLightningModule class.
- Parameters:
model (torch.nn.Module) – Model to train.
optimizer (Callable) – Optimizer to use for training.
lr (float) – Learning rate for the optimizer.
criterion (torch.nn.modules.loss._Loss) – Loss function to use for training.
scheduler (Optional[Callable], optional) – Learning rate scheduler to use for training, defaults to None.
optimizer_kwargs (Optional[dict], optional) – Keyword arguments for the optimizer, defaults to None.
scheduler_kwargs (Optional[dict], optional) – Keyword arguments for the scheduler, defaults to None.
args (Any) – Any additional arguments to pass to the superclass.
kwargs (Any) – Any additional keyword arguments to pass to the superclass.
- configure_optimizers()[source]¶
Create the optimizer and scheduler for training.
- Raises:
ValueError – If the optimizer or scheduler is not an instance of the expected class.
ValueError – If the scheduler is not an instance of the expected class.
- forward(inputs)[source]¶
Forward pass of the model.
- Parameters:
inputs (Union[torch.Tensor, Dict[str, torch.Tensor]]) – Input to the model. Can be a tensor or a dict of tensors (e.g. for HuggingFace transformer models).
- Returns:
Output of the model.
- Return type:
Any