Source code for quanda.utils.common

"""Common utility functions for the Quanda package."""

import functools
import json
import os
import random
import re
import time
from abc import ABC
from contextlib import contextmanager
from dataclasses import dataclass
from functools import reduce
from typing import (
    Any,
    Callable,
    Dict,
    List,
    Mapping,
    Optional,
    Sized,
    Union,
    cast,
)

import torch
import yaml
from torch import nn

CheckpointLoadFunc = Callable[[torch.nn.Module, str], Any]


[docs] @dataclass class DatasetSplit(ABC): """Class to store dynamically named splits (e.g., train, val, test).""" splits: Dict[str, torch.Tensor] def __getitem__(self, key): """Get the indices for the specified key.""" if key not in self.splits: raise KeyError(f"Key '{key}' not found in splits.") return self.splits[key]
[docs] def __init__( self, splits: Dict[str, torch.Tensor], ): """Create a DatasetSplit from a dictionary of indices. Parameters ---------- splits : Dict[str, torch.Tensor] A list of indices for the split. Returns ------- DatasetSplit: An object with a single split named 'default'. """ if not splits: raise ValueError("splits cannot be empty.") self.splits = splits
[docs] @classmethod def split( cls, n_indices: int, seed: int, split_ratios: Dict[str, float] ) -> "DatasetSplit": """Split the indices into named sets based on split_ratios. Parameters ---------- n_indices : int Total number of indices to split. seed : int Random seed for reproducibility. split_ratios : Dict[str, float] A dictionary where keys are split names (e.g., 'train', 'val', 'test') and values are the ratios for each split. Returns ------- DatasetSplit: An object with keys corresponding to split_ratios. """ if not split_ratios: raise ValueError("split_ratios cannot be empty.") total_ratio = sum(split_ratios.values()) if total_ratio > 1.0: raise ValueError("Sum of split ratios must not exceed 1.0") torch.manual_seed(seed) indices = torch.randperm(n_indices) split_indices = {} start = 0 for i, (name, ratio) in enumerate(split_ratios.items()): end = start + int(ratio * n_indices) split_indices[name] = indices[start:end] start = end return cls(splits=split_indices)
[docs] @classmethod def load(cls, path: str, name: str) -> "DatasetSplit": """Load the split from disk.""" with open(os.path.join(path, name), "r") as f: data = yaml.safe_load(f) splits = {k: torch.tensor(v) for k, v in data.items()} return cls(splits=splits)
[docs] def save(self, path: str, name: str) -> None: """Save the split to disk atomically.""" os.makedirs(path, exist_ok=True) data = {k: v.tolist() for k, v in self.splits.items()} final_path = os.path.join(path, name) tmp_path = f"{final_path}.tmp.{os.getpid()}" with open(tmp_path, "w") as f: yaml.safe_dump(data, f) os.replace(tmp_path, final_path)
[docs] def to_dict(self) -> Dict[str, torch.Tensor]: """Convert splits to dictionary.""" return self.splits
[docs] @staticmethod def exists(path: str, name: str) -> bool: """Check if split file exists.""" return os.path.exists(os.path.join(path, name))
[docs] def resolve_config(config: Union[dict, str]) -> dict: """Resolve a benchmark ``config`` into a dict. Parameters ---------- config : Union[dict, str] Either a config dict (passed through unchanged), a registered ``bench_id`` (resolved via :data:`quanda.benchmarks.resources.config_map.config_map`), or a path to a benchmark YAML file. Returns ------- dict The resolved benchmark configuration. Raises ------ TypeError If ``config`` is not a ``dict`` or ``str``, or if the loaded YAML does not parse to a mapping. """ if isinstance(config, dict): return config if isinstance(config, str): # Lazy import to avoid a hard dep from utils → benchmarks. from quanda.benchmarks.resources.config_map import config_map path = str(config_map[config]) if config in config_map else config with open(path, "r") as f: cfg = yaml.safe_load(f) if not isinstance(cfg, dict): raise TypeError( f"YAML at {path} did not parse to a dict (got " f"{type(cfg).__name__})." ) return cfg raise TypeError( f"config must be a dict, a registered bench_id, or a YAML path; " f"got {type(config).__name__}." )
[docs] def chunked_logits( model: torch.nn.Module, inputs: Any, chunk_size: Optional[int], ) -> torch.Tensor: """Run ``model`` forward on ``inputs`` in chunks; return logits tensor.""" def _call(batch: Any) -> torch.Tensor: out = model(**batch) if isinstance(batch, dict) else model(batch) return out.logits if hasattr(out, "logits") else out with torch.no_grad(): if chunk_size is None: return _call(inputs) if isinstance(inputs, dict): total = next(iter(inputs.values())).shape[0] chunks: Any = ( {k: v[i : i + chunk_size] for k, v in inputs.items()} for i in range(0, total, chunk_size) ) else: chunks = ( inputs[i : i + chunk_size] for i in range(0, inputs.shape[0], chunk_size) ) return torch.cat([_call(chunk) for chunk in chunks], dim=0)
[docs] def get_parent_module_from_name( model: torch.nn.Module, layer_name: str ) -> Any: """Get the parent module of a module in a model by name. Parameters ---------- model : torch.nn.Module The model to extract the module from. layer_name : str The name of the module to extract. Returns ------- Any The module extracted from the model. """ return reduce(getattr, layer_name.split(".")[:-1], model)
[docs] def resolve_device( model: torch.nn.Module, device: Optional[str] = None ) -> str: """Return ``device`` if set, else infer from ``model``'s parameters. Falls back to ``"cpu"`` when the model has no parameters. This is the canonical device-resolution used by every explainer wrapper so that callers can opt out of explicit device passing without silently ending up on the wrong device. """ if device is not None: return device param = next(model.parameters(), None) return str(param.device) if param is not None else "cpu"
[docs] def make_func( func: Callable, func_kwargs: Optional[Mapping[str, Any]] = None, **kwargs ) -> functools.partial: """Create a partial function with the given arguments. Parameters ---------- func : Callable The function to create a partial function from. func_kwargs : Optional[Mapping[str, Any]] Optional keyword arguments to fix for the function. kwargs : Any Additional keyword arguments for the function. Returns ------- functools.partial The partial function with the given arguments. """ if func_kwargs is not None: _func_kwargs = kwargs.copy() _func_kwargs.update(func_kwargs) else: _func_kwargs = kwargs return functools.partial(func, **_func_kwargs)
[docs] def cache_result(method): """Decorate functions to cache method results.""" cache_attr = f"_{method.__name__}_cache" @functools.wraps(method) def wrapper(self, *args, **kwargs): if cache_attr not in self.__dict__: self.__dict__[cache_attr] = method(self, *args, **kwargs) return self.__dict__[cache_attr] return wrapper
[docs] def class_accuracy( net: torch.nn.Module, loader: torch.utils.data.DataLoader, device: Union[str, torch.device] = "cpu", single_class: Optional[int] = None, ): """Return accuracy on a dataset given by the data loader. Parameters ---------- net : torch.nn.Module The model to evaluate. loader : torch.utils.data.DataLoader The data loader to evaluate the model on. device : Union[str, torch.device], optional The device to evaluate the model on, by default "cpu". single_class : Optional[int], optional If provided, all targets will be set to this class, by default None. Returns ------- float """ if len(loader) == 0: return 0.0 correct = 0 total = 0 for batch in loader: if isinstance(batch, dict): targets = batch.pop("labels").to(device) inputs = {k: v.to(device) for k, v in batch.items()} outputs = net(**inputs) if hasattr(outputs, "logits"): outputs = outputs.logits else: inputs, targets = batch inputs, targets = inputs.to(device), targets.to(device) outputs = net(inputs) if single_class is not None: targets = single_class * torch.ones_like(targets) _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() return correct / total
# Taken directly from Captum with minor changes def _load_flexible_state_dict( model: torch.nn.Module, path: str, device: Union[str, torch.device] ) -> float: """Load pytorch models. This function attempts to find compatibility for loading models that were trained on different devices / with DataParallel but are being loaded in a different environment. Assumes that the model has been saved as a state_dict in some capacity. This can either be a single state dict, or a nesting dictionary which contains the model state_dict and other information. Parameters ---------- model : torch.nn.Module The model for which to load a checkpoint path : str The filepath to the checkpoint device : Union[str, torch.device] The device to use. Returns ------- float The learning rate. Notes ----- The module state_dict is modified in-place. """ if isinstance(device, str): device = torch.device(device) last_err: Optional[Exception] = None for attempt in range(3): try: checkpoint = torch.load( path, map_location=device, weights_only=True ) break except RuntimeError as e: last_err = e time.sleep(2**attempt) else: raise RuntimeError( f"Failed to load checkpoint at {path} after 3 attempts" ) from last_err learning_rate = checkpoint.get("learning_rate", 1.0) if "module." in next(iter(checkpoint)): if isinstance(model, torch.nn.DataParallel): model.load_state_dict(checkpoint) else: model = torch.nn.DataParallel(model) model.load_state_dict(checkpoint) model = model.module else: if isinstance(model, torch.nn.DataParallel): model = model.module model.load_state_dict(checkpoint) model = torch.nn.DataParallel(model) else: model.load_state_dict(checkpoint) return learning_rate
[docs] def get_load_state_dict_func( device: Union[str, torch.device], ) -> Callable[[torch.nn.Module, str], float]: """Get a load_state_dict function that loads a model state dict. Parameters ---------- device : Union[str, torch.device] The device to load the model on. """ def load_state_dict(model: torch.nn.Module, path: str) -> float: return _load_flexible_state_dict(model, path, device) return load_state_dict
[docs] @contextmanager def default_tensor_type(device: Union[str, torch.device]): """Context manager to temporarily change the default tensor type. Parameters ---------- device : Union[str, torch.device] The device to which the default tensor type should be set. Returns ------- None """ # Save the current default tensor type float_tensor = torch.FloatTensor([0.0]) original_tensor_type = float_tensor.type() new_float_tensor = float_tensor.to(device) new_tensor_type = new_float_tensor.type() # Set the new tensor type torch.set_default_tensor_type(new_tensor_type) device_type = device.type if isinstance(device, torch.device) else device if "cuda" in device_type: torch.cuda.set_device(device) with torch.device(device): try: # Yield control back to the calling context yield finally: # Restore the original tensor type torch.set_default_tensor_type(original_tensor_type)
[docs] @contextmanager def map_location_context(device: Union[str, torch.device]): """Context manager to temporarily change the map_location of torch.load. Parameters ---------- device: Union[str, torch.device] The device to which the tensors should be loaded. Returns ------- None """ original_load = torch.load # Custom function that wraps torch.load with a fixed map_location def load_with_map_location(f, *args, **kwargs): kwargs["map_location"] = device kwargs.setdefault("weights_only", True) return original_load(f, *args, **kwargs) # Temporarily replace torch.load with our custom version torch.load = load_with_map_location try: yield # Control returns to the code block within the `with` statement finally: # Restore the original torch.load function torch.load = original_load
[docs] def ds_len(dataset: torch.utils.data.Dataset) -> int: """Get the length of the dataset. Parameters ---------- dataset : torch.utils.data.Dataset The dataset to get the length of. Returns ------- int The length of the dataset. """ if isinstance(dataset, Sized): return len(dataset) dl = torch.utils.data.DataLoader(dataset, batch_size=1) return len(dl)
[docs] def process_targets( targets: Union[List[int], torch.Tensor], device: Union[str, torch.device] ) -> torch.Tensor: """Convert target labels to torch.Tensor and move them to the device. Parameters ---------- targets : Optional[Union[List[int], torch.Tensor]], optional The target labels, either as a list or tensor. device: Union[str, torch.device] The device to use. Returns ------- torch.Tensor or None The processed targets as a tensor, or None if no targets are provided. """ if targets is not None: if isinstance(targets, list): targets = torch.tensor(targets) targets = targets.to(device) return targets
[docs] def get_targets(item: Union[tuple, dict]) -> int: """Extract targets from dataset item. Parameters ---------- item : Union[tuple, dict] Dataset item which can be either a tuple (data, target) or a dict with 'labels' key. Returns ------- int The target value. """ if isinstance(item, tuple): return item[1] elif isinstance(item, dict): if "labels" in item: return item["labels"] else: raise ValueError( f"Dataset item missing required 'labels' key: {item}." ) else: raise ValueError( f"Unsupported dataset item type: {type(item)}. " "Expected tuple (data, target) or dict with 'labels' key." )
[docs] def move_ds_item_to_device( data: Union[torch.Tensor, Dict[str, torch.Tensor]], device: Union[str, torch.device], ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: """Move test data to the device. Parameters ---------- data : Union[torch.Tensor, Dict[str, torch.Tensor]] The data to process. device: Union[str, torch.device] The device to use. Returns ------- Union[torch.Tensor, Dict[str, torch.Tensor]] The data on the specified device. """ if isinstance(data, dict): return {k: v.to(device) for k, v in data.items()} return data.to(device)
[docs] def load_last_checkpoint( model: torch.nn.Module, checkpoints: List[str], checkpoints_load_func: CheckpointLoadFunc, ): """Load the model from the checkpoint file. Parameters ---------- model : torch.nn.Module The model to load the checkpoint into. checkpoints : Optional[Union[str, List[str]]], optional Path to the model checkpoint file(s), defaults to None. checkpoints_load_func : Optional[CheckpointLoadFunc], optional Function to load the model from the checkpoint file, takes (model, checkpoint path) as two arguments, by default None. """ if len(checkpoints) == 0: return checkpoints_load_func(model, checkpoints[-1])
_DEFAULT_REPR_RE = re.compile(r" object at 0x[0-9a-fA-F]+>")
[docs] def stable_repr(obj: Any) -> str: """Process-stable string form of ``obj`` for hashing/serialization.""" if callable(obj) and hasattr(obj, "__qualname__"): module = getattr(obj, "__module__", "") or "" return f"{module}.{obj.__qualname__}" if module else obj.__qualname__ s = str(obj) if not _DEFAULT_REPR_RE.search(s): return s cls = type(obj) module = getattr(cls, "__module__", "") or "" qualname = cls.__qualname__ fq = f"{module}.{qualname}" if module else qualname attrs = getattr(obj, "__dict__", None) if attrs: inner = json.dumps(attrs, sort_keys=True, default=stable_repr) return f"{fq}({inner})" return fq
[docs] def subsample_indices(n: int, max_n: Optional[int], seed: int) -> List[int]: """Deterministic subsample of ``range(n)`` matching ``_subsample_dataset``. Returns the full ``range(n)`` (as a list) when no subsampling is needed. """ if max_n is None or max_n >= n: return list(range(n)) return sorted(random.Random(seed).sample(range(n), max_n))
[docs] def subsample_dataset( dataset: torch.utils.data.Dataset, max_n: Optional[int], seed: int, ) -> torch.utils.data.Dataset: """Deterministically subsample a dataset. Subsampling is reproducible across platforms because it uses Python's ``random.Random(seed).sample`` over ``range(N)`` and stores the indices in sorted order. For datasets that expose ``filtered`` (e.g. ``TransformedDataset``), that is used instead of ``Subset`` so the subset preserves the original type and remaps any transform indices. """ if max_n is None: return dataset n = len(dataset) # type: ignore[arg-type] if max_n >= n: return dataset indices = subsample_indices(n, max_n, seed) if hasattr(dataset, "filtered"): return dataset.filtered(indices) return torch.utils.data.Subset(dataset, indices)
[docs] def replace_conv1d_with_linear(model: nn.Module) -> None: """Swap HF ``Conv1D`` modules in-place with ``nn.Linear`` equivalents. HF GPT-2 uses ``Conv1D`` (a transposed Linear) for attention/MLP projections; kronfluence only wraps ``nn.Linear`` and ``nn.Conv2d``, so the swap is required before ``prepare_model`` for those models. """ for name, module in model.named_children(): if len(list(module.children())) > 0: replace_conv1d_with_linear(module) if module.__class__.__name__ == "Conv1D": weight = cast(torch.Tensor, module.weight) bias = cast(torch.Tensor, module.bias) new_module = nn.Linear( in_features=weight.shape[0], out_features=weight.shape[1], ) new_module.weight.data.copy_(weight.data.t()) new_module.bias.data.copy_(bias.data) new_module.to(weight.device) setattr(model, name, new_module)