quanda.benchmarks.resources.modules module¶
Lightning modules for the benchmarks.
- class quanda.benchmarks.resources.modules.BertClassifier(*args, **kwargs)[source]¶
Bases:
Module,PyTorchModelHubMixinBERT-based sequence classifier without final nonlinearity.
Uses a pretrained BERT encoder with a linear classification head. The final tanh is removed to prevent output saturation, following the setup in https://arxiv.org/pdf/2303.14186.
- __init__(num_labels: int = 2)[source]¶
Initialize the BertClassifier.
- Parameters:
pretrained_model_name (str) – HuggingFace model name or path for the BERT encoder.
num_labels (int) – Number of output classes.
- forward(input_ids, attention_mask=None, token_type_ids=None)[source]¶
Forward pass.
- Parameters:
input_ids (torch.Tensor) – Token IDs of shape
(batch, seq_len).attention_mask (torch.Tensor, optional) – Attention mask of shape
(batch, seq_len).token_type_ids (torch.Tensor, optional) – Token type IDs of shape
(batch, seq_len).
- class quanda.benchmarks.resources.modules.HFGPT2(config=None, **kwargs)[source]¶
Bases:
GPT2LMHeadModelHuggingFace GPT-2 model.
- config_class¶
alias of
GPT2Config
- class quanda.benchmarks.resources.modules.LeNet(*args, **kwargs)[source]¶
Bases:
Module,PyTorchModelHubMixinA torch implementation of LeNet architecture.
Adapted from: https://github.com/ChawDoe/LeNet5-MNIST-PyTorch.
- class quanda.benchmarks.resources.modules.ResNet50(*args, **kwargs)[source]¶
Bases:
ResNet,PyTorchModelHubMixinResNet-50 with an explicit
Flattenmodule before the FC head.Subclasses
torchvision’sResNetand overrides_forward_implto swap the inlinetorch.flattencall for annn.Flattensubmodule, so activation hooks can target the flattened features (e.g. forRepresenterPoints).