quanda.utils.training package¶
Training utilities.
- class quanda.utils.training.BaseTrainer[source]¶
Bases:
objectBase class for a trainer.
- abstract fit(model: Module, train_dataloaders: DataLoader, val_dataloaders: DataLoader | None = None, accelerator: str = 'cpu', devices: int = 0, seed: int = 42, callbacks: List[Callback] | None = None, trainer_fit_kwargs: dict | None = None, *args, **kwargs) Module[source]¶
Train a model using the provided dataloaders.
- Parameters:
model (torch.nn.Module) – Model to train.
train_dataloaders (torch.utils.data.dataloader.DataLoader) – Dataloader for the training data.
val_dataloaders (Optional[torch.utils.data.dataloader.DataLoader]) – Dataloader for the validation data, defaults to None.
accelerator (str) – The accelerator to use for training, by default “cpu”.
devices (int) – The number of devices to use for training, by default 0 (i.e. all available).
seed (int) – Random seed.
callbacks (Optional[List[L.Callback]]) – Lightning callbacks to attach to the trainer, defaults to None.
trainer_fit_kwargs (Optional[dict]) – Additional keyword arguments to pass to the trainer’s fit method, defaults to None.
args (Any) – Additional arguments to pass to the fit method.
kwargs (Any) – Additional keyword arguments to pass to the fit method.
kwargs
- Returns:
The trained model.
- Return type:
torch.nn.Module
- class quanda.utils.training.BasicLightningModule(model: Module, optimizer: Callable, lr: float, criterion: _Loss, scheduler: Callable | None = None, optimizer_kwargs: dict | None = None, scheduler_kwargs: dict | None = None, *args, **kwargs)[source]¶
Bases:
LightningModuleWrapper for a basic PyTorch Lightning module.
- __init__(model: Module, optimizer: Callable, lr: float, criterion: _Loss, scheduler: Callable | None = None, optimizer_kwargs: dict | None = None, scheduler_kwargs: dict | None = None, *args, **kwargs)[source]¶
Construct the BasicLightningModule class.
- Parameters:
model (torch.nn.Module) – Model to train.
optimizer (Callable) – Optimizer to use for training.
lr (float) – Learning rate for the optimizer.
criterion (torch.nn.modules.loss._Loss) – Loss function to use for training.
scheduler (Optional[Callable], optional) – Learning rate scheduler to use for training, defaults to None.
optimizer_kwargs (Optional[dict], optional) – Keyword arguments for the optimizer, defaults to None.
scheduler_kwargs (Optional[dict], optional) – Keyword arguments for the scheduler, defaults to None.
args (Any) – Any additional arguments to pass to the superclass.
kwargs (Any) – Any additional keyword arguments to pass to the superclass.
- configure_optimizers()[source]¶
Create the optimizer and scheduler for training.
- Raises:
ValueError – If the optimizer or scheduler is not an instance of the expected class.
ValueError – If the scheduler is not an instance of the expected class.
- forward(inputs)[source]¶
Forward pass of the model.
- Parameters:
inputs (Union[torch.Tensor, Dict[str, torch.Tensor]]) – Input to the model. Can be a tensor or a dict of tensors (e.g. for HuggingFace transformer models).
- Returns:
Output of the model.
- Return type:
Any
- class quanda.utils.training.Trainer(optimizer: Callable, lr: float, max_epochs: int, criterion: _Loss, scheduler: Callable | None = None, optimizer_kwargs: dict | None = None, scheduler_kwargs: dict | None = None, logger: Logger | None = None, seed: int = 27, num_workers: int = 0, enable_progress_bar: bool = True, gradient_clip_val: float | None = None)[source]¶
Bases:
BaseTrainerSimple class for training PyTorch models using Lightning.
- __init__(optimizer: Callable, lr: float, max_epochs: int, criterion: _Loss, scheduler: Callable | None = None, optimizer_kwargs: dict | None = None, scheduler_kwargs: dict | None = None, logger: Logger | None = None, seed: int = 27, num_workers: int = 0, enable_progress_bar: bool = True, gradient_clip_val: float | None = None)[source]¶
Construct the Trainer class.
- Parameters:
optimizer (Callable) – Optimizer to use for training.
lr (float) – Learning rate for the optimizer.
max_epochs (int) – Maximum number of epochs to train for.
criterion (torch.nn.modules.loss._Loss) – Loss to use during training.
scheduler (Optional[Callable], optional) – Scheduler to use during training, defaults to None
optimizer_kwargs (Optional[dict], optional) – Keyword arguments for the optimizer, defaults to None
scheduler_kwargs (Optional[dict], optional) – Keyword arguments for the scheduler, defaults to None
logger (Optional[Callable], optional) – Logger to use during training, defaults to None
seed (int, optional) – The seed for the projector, by default 27.
num_workers (int, optional) – Number of workers to use for data loading, by default 0.
enable_progress_bar (bool, optional) – Whether to enable the progress bar during training, by default True.
gradient_clip_val (Optional[float], optional) – Value to use for gradient clipping, by default None (i.e. no gradient clipping).
- fit(model: Module, train_dataloaders: DataLoader, val_dataloaders: DataLoader | None = None, accelerator: str = 'cpu', devices: int = 0, seed: int = 42, callbacks: List[Callback] | None = None, *args, **kwargs)[source]¶
Train a model using the provided dataloaders.
- Parameters:
model (torch.nn.Module) – Model to train.
train_dataloaders (torch.utils.data.dataloader.DataLoader) – Dataloader for the training data.
val_dataloaders (Optional[torch.utils.data.dataloader.DataLoader]) – Dataloader for the validation data, defaults to None.
accelerator (str, optional) – The accelerator to use for training, by default “cpu”.
devices (int, optional) – The number of devices to use for training, by default 0 (i.e. all available).
seed (int) – Random seed.
callbacks (Optional[List[L.Callback]]) – Lightning callbacks to attach to the trainer, defaults to None.
args (Any) – Additional arguments to pass to the fit method.
kwargs (Any) – Additional keyword arguments to pass to the fit method.