Source code for quanda.benchmarks.resources.modules

"""Lightning modules for the benchmarks."""

import torch
from huggingface_hub import PyTorchModelHubMixin
from torchvision.models.resnet import (  # type: ignore
    Bottleneck,
    ResNet,
)
from transformers import (  # type: ignore
    AutoConfig,
    AutoModel,
    GPT2Config,
    GPT2LMHeadModel,
)


class _Mul(torch.nn.Module):
    """Multiply input by a constant weight."""

    def __init__(self, weight):
        """Initialize with scalar weight."""
        super().__init__()
        self.weight = weight

    def forward(self, x):
        """Forward pass."""
        return x * self.weight


class _Residual(torch.nn.Module):
    """Residual connection wrapper."""

    def __init__(self, module):
        """Initialize with inner module."""
        super().__init__()
        self.module = module

    def forward(self, x):
        """Forward pass."""
        return x + self.module(x)


def _conv_bn(c_in, c_out, ks=3, stride=1, padding=1):
    """Create a Conv2d-BatchNorm-ReLU block."""
    return torch.nn.Sequential(
        torch.nn.Conv2d(
            c_in,
            c_out,
            kernel_size=ks,
            stride=stride,
            padding=padding,
            bias=False,
        ),
        torch.nn.BatchNorm2d(c_out),
        torch.nn.ReLU(inplace=True),
    )


[docs] class ResNet9(torch.nn.Module, PyTorchModelHubMixin): """ResNet-9 for CIFAR-10 classification. Adapted from https://github.com/MadryLab/trak. """
[docs] def __init__(self, num_classes=10): """Initialize the ResNet9 model.""" super().__init__() self.model = torch.nn.Sequential( _conv_bn(3, 64), _conv_bn(64, 128, ks=5, stride=2, padding=2), _Residual( torch.nn.Sequential( _conv_bn(128, 128), _conv_bn(128, 128), ) ), _conv_bn(128, 256), torch.nn.MaxPool2d(2), _Residual( torch.nn.Sequential( _conv_bn(256, 256), _conv_bn(256, 256), ) ), _conv_bn(256, 128, padding=0), torch.nn.AdaptiveMaxPool2d((1, 1)), torch.nn.Flatten(), torch.nn.Linear(128, num_classes, bias=False), _Mul(0.2), )
[docs] def forward(self, x): """Forward pass.""" return self.model(x)
[docs] class ResNet50(ResNet, PyTorchModelHubMixin): """ResNet-50 with an explicit ``Flatten`` module before the FC head. Subclasses ``torchvision``'s ``ResNet`` and overrides ``_forward_impl`` to swap the inline ``torch.flatten`` call for an ``nn.Flatten`` submodule, so activation hooks can target the flattened features (e.g. for ``RepresenterPoints``). """
[docs] def __init__(self, num_classes: int = 50): """Initialize the ResNet50 model.""" super().__init__( block=Bottleneck, layers=[3, 4, 6, 3], num_classes=num_classes, ) self.flatten = torch.nn.Flatten()
def _forward_impl(self, x): """Forward pass exposing ``flatten`` as a hookable module.""" x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = self.flatten(x) x = self.fc(x) return x
[docs] class LeNet(torch.nn.Module, PyTorchModelHubMixin): """A torch implementation of LeNet architecture. Adapted from: https://github.com/ChawDoe/LeNet5-MNIST-PyTorch. """
[docs] def __init__(self, num_outputs=10): """Initialize the LeNet model.""" super().__init__() self.conv_1 = torch.nn.Conv2d(1, 6, 5) self.pool_1 = torch.nn.MaxPool2d(2, 2) self.relu_1 = torch.nn.ReLU() self.conv_2 = torch.nn.Conv2d(6, 16, 5) self.pool_2 = torch.nn.MaxPool2d(2, 2) self.relu_2 = torch.nn.ReLU() self.fc_1 = torch.nn.Linear(256, 120) self.relu_3 = torch.nn.ReLU() self.fc_2 = torch.nn.Linear(120, 84) self.relu_4 = torch.nn.ReLU() self.fc_3 = torch.nn.Linear(84, num_outputs)
[docs] def forward(self, x): """Forward pass.""" x = self.pool_1(self.relu_1(self.conv_1(x))) x = self.pool_2(self.relu_2(self.conv_2(x))) x = x.view(x.shape[0], -1) x = self.relu_3(self.fc_1(x)) x = self.relu_4(self.fc_2(x)) x = self.fc_3(x) return x
[docs] class BertClassifier(torch.nn.Module, PyTorchModelHubMixin): """BERT-based sequence classifier without final nonlinearity. Uses a pretrained BERT encoder with a linear classification head. The final tanh is removed to prevent output saturation, following the setup in https://arxiv.org/pdf/2303.14186. """
[docs] def __init__( self, num_labels: int = 2, ): """Initialize the BertClassifier. Parameters ---------- pretrained_model_name : str HuggingFace model name or path for the BERT encoder. num_labels : int Number of output classes. """ super().__init__() self.num_labels = num_labels config = AutoConfig.from_pretrained( "google-bert/bert-base-cased", num_labels=num_labels, attn_implementation="eager", ) self.bert = AutoModel.from_config(config) if getattr(self.bert, "pooler", None) is not None: self.bert.pooler.activation = torch.nn.Identity() self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.classifier = torch.nn.Linear(config.hidden_size, num_labels)
[docs] def forward(self, input_ids, attention_mask=None, token_type_ids=None): """Forward pass. Parameters ---------- input_ids : torch.Tensor Token IDs of shape ``(batch, seq_len)``. attention_mask : torch.Tensor, optional Attention mask of shape ``(batch, seq_len)``. token_type_ids : torch.Tensor, optional Token type IDs of shape ``(batch, seq_len)``. """ # Captum's AV.generate_dataset_activations hands the whole batch # through as a single positional arg; for BERT-style datasets that # arg is a dict of input tensors. Unpack it here. if isinstance(input_ids, dict): d = input_ids input_ids = d["input_ids"] attention_mask = d.get("attention_mask", attention_mask) token_type_ids = d.get("token_type_ids", token_type_ids) if attention_mask is not None and attention_mask.dim() == 2: # Pre-expand to the 4D additive mask BERT's eager attention # expects. The 2D path routes through ``create_bidirectional_mask`` # which calls ``padding_mask.all()`` as a Python bool — # incompatible with the ``torch.func.vmap`. dtype = next(self.parameters()).dtype min_val = torch.finfo(dtype).min attention_mask = (1.0 - attention_mask.to(dtype))[ :, None, None, : ] * min_val outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, ) pooled = self.dropout(outputs.pooler_output) return self.classifier(pooled)
[docs] @classmethod def from_pretrained_base( cls, pretrained_model_name: str, num_labels: int = 2 ): """Override to avoid HuggingFace trying to load pretrained weights.""" self = cls(num_labels=num_labels) config = AutoConfig.from_pretrained( pretrained_model_name, num_labels=num_labels ) self.bert = AutoModel.from_pretrained( pretrained_model_name, config=config ) if getattr(self.bert, "pooler", None) is not None: self.bert.pooler.activation = torch.nn.Identity() self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.classifier = torch.nn.Linear(config.hidden_size, num_labels) return self
[docs] class HFGPT2(GPT2LMHeadModel): """HuggingFace GPT-2 model."""
[docs] def __init__(self, config=None, **kwargs): """Construct from a config or from ``GPT2Config`` kwargs.""" if config is None: config = GPT2Config(**kwargs) if kwargs else GPT2Config() # SDPA's mem-efficient backward breaks dattri's attributors # (CG/LiSSA/DataInf/Arnoldi). config._attn_implementation = "eager" super().__init__(config)
[docs] def forward(self, input_ids=None, attention_mask=None, **kwargs): """Forward pass, unpacking dict batches from Captum's AV. Captum's ``AV.generate_dataset_activations`` hands the whole batch through as a single positional arg; for HF datasets that arg is a dict of input tensors, which is unpacked here. """ if isinstance(input_ids, dict): d = input_ids input_ids = d["input_ids"] attention_mask = d.get("attention_mask", attention_mask) for k in ("labels",): if k in d and k not in kwargs: kwargs[k] = d[k] return super().forward( input_ids=input_ids, attention_mask=attention_mask, **kwargs )
pl_modules = { "MnistTorch": LeNet, "BertClassifier": BertClassifier, "ResNet9": ResNet9, "ResNet50": ResNet50, "HFGPT2": HFGPT2, }