"""Common utility functions for the Quanda package."""
import functools
import json
import os
import random
import re
import time
from abc import ABC
from contextlib import contextmanager
from dataclasses import dataclass
from functools import reduce
from typing import (
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Sized,
Union,
cast,
)
import torch
import yaml
from torch import nn
CheckpointLoadFunc = Callable[[torch.nn.Module, str], Any]
[docs]
@dataclass
class DatasetSplit(ABC):
"""Class to store dynamically named splits (e.g., train, val, test)."""
splits: Dict[str, torch.Tensor]
def __getitem__(self, key):
"""Get the indices for the specified key."""
if key not in self.splits:
raise KeyError(f"Key '{key}' not found in splits.")
return self.splits[key]
[docs]
def __init__(
self,
splits: Dict[str, torch.Tensor],
):
"""Create a DatasetSplit from a dictionary of indices.
Parameters
----------
splits : Dict[str, torch.Tensor]
A list of indices for the split.
Returns
-------
DatasetSplit: An object with a single split named 'default'.
"""
if not splits:
raise ValueError("splits cannot be empty.")
self.splits = splits
[docs]
@classmethod
def split(
cls, n_indices: int, seed: int, split_ratios: Dict[str, float]
) -> "DatasetSplit":
"""Split the indices into named sets based on split_ratios.
Parameters
----------
n_indices : int
Total number of indices to split.
seed : int
Random seed for reproducibility.
split_ratios : Dict[str, float]
A dictionary where keys are split names (e.g., 'train', 'val',
'test') and values are the ratios for each split.
Returns
-------
DatasetSplit: An object with keys corresponding to split_ratios.
"""
if not split_ratios:
raise ValueError("split_ratios cannot be empty.")
total_ratio = sum(split_ratios.values())
if total_ratio > 1.0:
raise ValueError("Sum of split ratios must not exceed 1.0")
torch.manual_seed(seed)
indices = torch.randperm(n_indices)
split_indices = {}
start = 0
for i, (name, ratio) in enumerate(split_ratios.items()):
end = start + int(ratio * n_indices)
split_indices[name] = indices[start:end]
start = end
return cls(splits=split_indices)
[docs]
@classmethod
def load(cls, path: str, name: str) -> "DatasetSplit":
"""Load the split from disk."""
with open(os.path.join(path, name), "r") as f:
data = yaml.safe_load(f)
splits = {k: torch.tensor(v) for k, v in data.items()}
return cls(splits=splits)
[docs]
def save(self, path: str, name: str) -> None:
"""Save the split to disk atomically."""
os.makedirs(path, exist_ok=True)
data = {k: v.tolist() for k, v in self.splits.items()}
final_path = os.path.join(path, name)
tmp_path = f"{final_path}.tmp.{os.getpid()}"
with open(tmp_path, "w") as f:
yaml.safe_dump(data, f)
os.replace(tmp_path, final_path)
[docs]
def to_dict(self) -> Dict[str, torch.Tensor]:
"""Convert splits to dictionary."""
return self.splits
[docs]
@staticmethod
def exists(path: str, name: str) -> bool:
"""Check if split file exists."""
return os.path.exists(os.path.join(path, name))
[docs]
def resolve_config(config: Union[dict, str]) -> dict:
"""Resolve a benchmark ``config`` into a dict.
Parameters
----------
config : Union[dict, str]
Either a config dict (passed through unchanged), a registered
``bench_id`` (resolved via
:data:`quanda.benchmarks.resources.config_map.config_map`), or
a path to a benchmark YAML file.
Returns
-------
dict
The resolved benchmark configuration.
Raises
------
TypeError
If ``config`` is not a ``dict`` or ``str``, or if the loaded
YAML does not parse to a mapping.
"""
if isinstance(config, dict):
return config
if isinstance(config, str):
# Lazy import to avoid a hard dep from utils → benchmarks.
from quanda.benchmarks.resources.config_map import config_map
path = str(config_map[config]) if config in config_map else config
with open(path, "r") as f:
cfg = yaml.safe_load(f)
if not isinstance(cfg, dict):
raise TypeError(
f"YAML at {path} did not parse to a dict (got "
f"{type(cfg).__name__})."
)
return cfg
raise TypeError(
f"config must be a dict, a registered bench_id, or a YAML path; "
f"got {type(config).__name__}."
)
[docs]
def chunked_logits(
model: torch.nn.Module,
inputs: Any,
chunk_size: Optional[int],
) -> torch.Tensor:
"""Run ``model`` forward on ``inputs`` in chunks; return logits tensor."""
def _call(batch: Any) -> torch.Tensor:
out = model(**batch) if isinstance(batch, dict) else model(batch)
return out.logits if hasattr(out, "logits") else out
with torch.no_grad():
if chunk_size is None:
return _call(inputs)
if isinstance(inputs, dict):
total = next(iter(inputs.values())).shape[0]
chunks: Any = (
{k: v[i : i + chunk_size] for k, v in inputs.items()}
for i in range(0, total, chunk_size)
)
else:
chunks = (
inputs[i : i + chunk_size]
for i in range(0, inputs.shape[0], chunk_size)
)
return torch.cat([_call(chunk) for chunk in chunks], dim=0)
[docs]
def get_parent_module_from_name(
model: torch.nn.Module, layer_name: str
) -> Any:
"""Get the parent module of a module in a model by name.
Parameters
----------
model : torch.nn.Module
The model to extract the module from.
layer_name : str
The name of the module to extract.
Returns
-------
Any
The module extracted from the model.
"""
return reduce(getattr, layer_name.split(".")[:-1], model)
[docs]
def resolve_device(
model: torch.nn.Module, device: Optional[str] = None
) -> str:
"""Return ``device`` if set, else infer from ``model``'s parameters.
Falls back to ``"cpu"`` when the model has no parameters. This is the
canonical device-resolution used by every explainer wrapper so that
callers can opt out of explicit device passing without silently ending
up on the wrong device.
"""
if device is not None:
return device
param = next(model.parameters(), None)
return str(param.device) if param is not None else "cpu"
[docs]
def make_func(
func: Callable, func_kwargs: Optional[Mapping[str, Any]] = None, **kwargs
) -> functools.partial:
"""Create a partial function with the given arguments.
Parameters
----------
func : Callable
The function to create a partial function from.
func_kwargs : Optional[Mapping[str, Any]]
Optional keyword arguments to fix for the function.
kwargs : Any
Additional keyword arguments for the function.
Returns
-------
functools.partial
The partial function with the given arguments.
"""
if func_kwargs is not None:
_func_kwargs = kwargs.copy()
_func_kwargs.update(func_kwargs)
else:
_func_kwargs = kwargs
return functools.partial(func, **_func_kwargs)
[docs]
def cache_result(method):
"""Decorate functions to cache method results."""
cache_attr = f"_{method.__name__}_cache"
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
if cache_attr not in self.__dict__:
self.__dict__[cache_attr] = method(self, *args, **kwargs)
return self.__dict__[cache_attr]
return wrapper
[docs]
def class_accuracy(
net: torch.nn.Module,
loader: torch.utils.data.DataLoader,
device: Union[str, torch.device] = "cpu",
single_class: Optional[int] = None,
):
"""Return accuracy on a dataset given by the data loader.
Parameters
----------
net : torch.nn.Module
The model to evaluate.
loader : torch.utils.data.DataLoader
The data loader to evaluate the model on.
device : Union[str, torch.device], optional
The device to evaluate the model on, by default "cpu".
single_class : Optional[int], optional
If provided, all targets will be set to this class, by default None.
Returns
-------
float
"""
if len(loader) == 0:
return 0.0
correct = 0
total = 0
for batch in loader:
if isinstance(batch, dict):
targets = batch.pop("labels").to(device)
inputs = {k: v.to(device) for k, v in batch.items()}
outputs = net(**inputs)
if hasattr(outputs, "logits"):
outputs = outputs.logits
else:
inputs, targets = batch
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
if single_class is not None:
targets = single_class * torch.ones_like(targets)
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
return correct / total
# Taken directly from Captum with minor changes
def _load_flexible_state_dict(
model: torch.nn.Module, path: str, device: Union[str, torch.device]
) -> float:
"""Load pytorch models.
This function attempts to find compatibility for
loading models that were trained on different devices / with DataParallel
but are being loaded in a different environment. Assumes that the model has
been saved as a state_dict in some capacity. This can either be a single
state dict, or a nesting dictionary which contains the model state_dict
and other information.
Parameters
----------
model : torch.nn.Module
The model for which to load a checkpoint
path : str
The filepath to the checkpoint
device : Union[str, torch.device]
The device to use.
Returns
-------
float
The learning rate.
Notes
-----
The module state_dict is modified in-place.
"""
if isinstance(device, str):
device = torch.device(device)
last_err: Optional[Exception] = None
for attempt in range(3):
try:
checkpoint = torch.load(
path, map_location=device, weights_only=True
)
break
except RuntimeError as e:
last_err = e
time.sleep(2**attempt)
else:
raise RuntimeError(
f"Failed to load checkpoint at {path} after 3 attempts"
) from last_err
learning_rate = checkpoint.get("learning_rate", 1.0)
if "module." in next(iter(checkpoint)):
if isinstance(model, torch.nn.DataParallel):
model.load_state_dict(checkpoint)
else:
model = torch.nn.DataParallel(model)
model.load_state_dict(checkpoint)
model = model.module
else:
if isinstance(model, torch.nn.DataParallel):
model = model.module
model.load_state_dict(checkpoint)
model = torch.nn.DataParallel(model)
else:
model.load_state_dict(checkpoint)
return learning_rate
[docs]
def get_load_state_dict_func(
device: Union[str, torch.device],
) -> Callable[[torch.nn.Module, str], float]:
"""Get a load_state_dict function that loads a model state dict.
Parameters
----------
device : Union[str, torch.device]
The device to load the model on.
"""
def load_state_dict(model: torch.nn.Module, path: str) -> float:
return _load_flexible_state_dict(model, path, device)
return load_state_dict
[docs]
@contextmanager
def default_tensor_type(device: Union[str, torch.device]):
"""Context manager to temporarily change the default tensor type.
Parameters
----------
device : Union[str, torch.device]
The device to which the default tensor type should be set.
Returns
-------
None
"""
# Save the current default tensor type
float_tensor = torch.FloatTensor([0.0])
original_tensor_type = float_tensor.type()
new_float_tensor = float_tensor.to(device)
new_tensor_type = new_float_tensor.type()
# Set the new tensor type
torch.set_default_tensor_type(new_tensor_type)
device_type = device.type if isinstance(device, torch.device) else device
if "cuda" in device_type:
torch.cuda.set_device(device)
with torch.device(device):
try:
# Yield control back to the calling context
yield
finally:
# Restore the original tensor type
torch.set_default_tensor_type(original_tensor_type)
[docs]
@contextmanager
def map_location_context(device: Union[str, torch.device]):
"""Context manager to temporarily change the map_location of torch.load.
Parameters
----------
device: Union[str, torch.device]
The device to which the tensors should be loaded.
Returns
-------
None
"""
original_load = torch.load
# Custom function that wraps torch.load with a fixed map_location
def load_with_map_location(f, *args, **kwargs):
kwargs["map_location"] = device
kwargs.setdefault("weights_only", True)
return original_load(f, *args, **kwargs)
# Temporarily replace torch.load with our custom version
torch.load = load_with_map_location
try:
yield # Control returns to the code block within the `with` statement
finally:
# Restore the original torch.load function
torch.load = original_load
[docs]
def ds_len(dataset: torch.utils.data.Dataset) -> int:
"""Get the length of the dataset.
Parameters
----------
dataset : torch.utils.data.Dataset
The dataset to get the length of.
Returns
-------
int
The length of the dataset.
"""
if isinstance(dataset, Sized):
return len(dataset)
dl = torch.utils.data.DataLoader(dataset, batch_size=1)
return len(dl)
[docs]
def process_targets(
targets: Union[List[int], torch.Tensor], device: Union[str, torch.device]
) -> torch.Tensor:
"""Convert target labels to torch.Tensor and move them to the device.
Parameters
----------
targets : Optional[Union[List[int], torch.Tensor]], optional
The target labels, either as a list or tensor.
device: Union[str, torch.device]
The device to use.
Returns
-------
torch.Tensor or None
The processed targets as a tensor, or None if no targets are provided.
"""
if targets is not None:
if isinstance(targets, list):
targets = torch.tensor(targets)
targets = targets.to(device)
return targets
[docs]
def get_targets(item: Union[tuple, dict]) -> int:
"""Extract targets from dataset item.
Parameters
----------
item : Union[tuple, dict]
Dataset item which can be either a tuple (data, target) or a dict
with 'labels' key.
Returns
-------
int
The target value.
"""
if isinstance(item, tuple):
return item[1]
elif isinstance(item, dict):
if "labels" in item:
return item["labels"]
else:
raise ValueError(
f"Dataset item missing required 'labels' key: {item}."
)
else:
raise ValueError(
f"Unsupported dataset item type: {type(item)}. "
"Expected tuple (data, target) or dict with 'labels' key."
)
[docs]
def move_ds_item_to_device(
data: Union[torch.Tensor, Dict[str, torch.Tensor]],
device: Union[str, torch.device],
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""Move test data to the device.
Parameters
----------
data : Union[torch.Tensor, Dict[str, torch.Tensor]]
The data to process.
device: Union[str, torch.device]
The device to use.
Returns
-------
Union[torch.Tensor, Dict[str, torch.Tensor]]
The data on the specified device.
"""
if isinstance(data, dict):
return {k: v.to(device) for k, v in data.items()}
return data.to(device)
[docs]
def load_last_checkpoint(
model: torch.nn.Module,
checkpoints: List[str],
checkpoints_load_func: CheckpointLoadFunc,
):
"""Load the model from the checkpoint file.
Parameters
----------
model : torch.nn.Module
The model to load the checkpoint into.
checkpoints : Optional[Union[str, List[str]]], optional
Path to the model checkpoint file(s), defaults to None.
checkpoints_load_func : Optional[CheckpointLoadFunc], optional
Function to load the model from the checkpoint file, takes
(model, checkpoint path) as two arguments, by default None.
"""
if len(checkpoints) == 0:
return
checkpoints_load_func(model, checkpoints[-1])
_DEFAULT_REPR_RE = re.compile(r" object at 0x[0-9a-fA-F]+>")
[docs]
def stable_repr(obj: Any) -> str:
"""Process-stable string form of ``obj`` for hashing/serialization."""
if callable(obj) and hasattr(obj, "__qualname__"):
module = getattr(obj, "__module__", "") or ""
return f"{module}.{obj.__qualname__}" if module else obj.__qualname__
s = str(obj)
if not _DEFAULT_REPR_RE.search(s):
return s
cls = type(obj)
module = getattr(cls, "__module__", "") or ""
qualname = cls.__qualname__
fq = f"{module}.{qualname}" if module else qualname
attrs = getattr(obj, "__dict__", None)
if attrs:
inner = json.dumps(attrs, sort_keys=True, default=stable_repr)
return f"{fq}({inner})"
return fq
[docs]
def subsample_indices(n: int, max_n: Optional[int], seed: int) -> List[int]:
"""Deterministic subsample of ``range(n)`` matching ``_subsample_dataset``.
Returns the full ``range(n)`` (as a list) when no subsampling is needed.
"""
if max_n is None or max_n >= n:
return list(range(n))
return sorted(random.Random(seed).sample(range(n), max_n))
[docs]
def subsample_dataset(
dataset: torch.utils.data.Dataset,
max_n: Optional[int],
seed: int,
) -> torch.utils.data.Dataset:
"""Deterministically subsample a dataset.
Subsampling is reproducible across platforms because it uses Python's
``random.Random(seed).sample`` over ``range(N)`` and stores the indices
in sorted order. For datasets that expose ``filtered`` (e.g.
``TransformedDataset``), that is used instead of ``Subset`` so the
subset preserves the original type and remaps any transform indices.
"""
if max_n is None:
return dataset
n = len(dataset) # type: ignore[arg-type]
if max_n >= n:
return dataset
indices = subsample_indices(n, max_n, seed)
if hasattr(dataset, "filtered"):
return dataset.filtered(indices)
return torch.utils.data.Subset(dataset, indices)
[docs]
def replace_conv1d_with_linear(model: nn.Module) -> None:
"""Swap HF ``Conv1D`` modules in-place with ``nn.Linear`` equivalents.
HF GPT-2 uses ``Conv1D`` (a transposed Linear) for attention/MLP
projections; kronfluence only wraps ``nn.Linear`` and ``nn.Conv2d``,
so the swap is required before ``prepare_model`` for those models.
"""
for name, module in model.named_children():
if len(list(module.children())) > 0:
replace_conv1d_with_linear(module)
if module.__class__.__name__ == "Conv1D":
weight = cast(torch.Tensor, module.weight)
bias = cast(torch.Tensor, module.bias)
new_module = nn.Linear(
in_features=weight.shape[0],
out_features=weight.shape[1],
)
new_module.weight.data.copy_(weight.data.t())
new_module.bias.data.copy_(bias.data)
new_module.to(weight.device)
setattr(model, name, new_module)