quanda.benchmarks.resources.modules module

Lightning modules for the benchmarks.

class quanda.benchmarks.resources.modules.BertClassifier(*args, **kwargs)[source]

Bases: Module, PyTorchModelHubMixin

BERT-based sequence classifier without final nonlinearity.

Uses a pretrained BERT encoder with a linear classification head. The final tanh is removed to prevent output saturation, following the setup in https://arxiv.org/pdf/2303.14186.

__init__(num_labels: int = 2)[source]

Initialize the BertClassifier.

Parameters:
  • pretrained_model_name (str) – HuggingFace model name or path for the BERT encoder.

  • num_labels (int) – Number of output classes.

forward(input_ids, attention_mask=None, token_type_ids=None)[source]

Forward pass.

Parameters:
  • input_ids (torch.Tensor) – Token IDs of shape (batch, seq_len).

  • attention_mask (torch.Tensor, optional) – Attention mask of shape (batch, seq_len).

  • token_type_ids (torch.Tensor, optional) – Token type IDs of shape (batch, seq_len).

classmethod from_pretrained_base(pretrained_model_name: str, num_labels: int = 2)[source]

Override to avoid HuggingFace trying to load pretrained weights.

class quanda.benchmarks.resources.modules.HFGPT2(config=None, **kwargs)[source]

Bases: GPT2LMHeadModel

HuggingFace GPT-2 model.

__init__(config=None, **kwargs)[source]

Construct from a config or from GPT2Config kwargs.

config_class

alias of GPT2Config

forward(input_ids=None, attention_mask=None, **kwargs)[source]

Forward pass, unpacking dict batches from Captum’s AV.

Captum’s AV.generate_dataset_activations hands the whole batch through as a single positional arg; for HF datasets that arg is a dict of input tensors, which is unpacked here.

class quanda.benchmarks.resources.modules.LeNet(*args, **kwargs)[source]

Bases: Module, PyTorchModelHubMixin

A torch implementation of LeNet architecture.

Adapted from: https://github.com/ChawDoe/LeNet5-MNIST-PyTorch.

__init__(num_outputs=10)[source]

Initialize the LeNet model.

forward(x)[source]

Forward pass.

class quanda.benchmarks.resources.modules.ResNet50(*args, **kwargs)[source]

Bases: ResNet, PyTorchModelHubMixin

ResNet-50 with an explicit Flatten module before the FC head.

Subclasses torchvision’s ResNet and overrides _forward_impl to swap the inline torch.flatten call for an nn.Flatten submodule, so activation hooks can target the flattened features (e.g. for RepresenterPoints).

__init__(num_classes: int = 50)[source]

Initialize the ResNet50 model.

class quanda.benchmarks.resources.modules.ResNet9(*args, **kwargs)[source]

Bases: Module, PyTorchModelHubMixin

ResNet-9 for CIFAR-10 classification.

Adapted from https://github.com/MadryLab/trak.

__init__(num_classes=10)[source]

Initialize the ResNet9 model.

forward(x)[source]

Forward pass.