quanda.utils.common module

Common utility functions for the Quanda package.

class quanda.utils.common.DatasetSplit(splits: Dict[str, Tensor])[source]

Bases: ABC

Class to store dynamically named splits (e.g., train, val, test).

__init__(splits: Dict[str, Tensor])[source]

Create a DatasetSplit from a dictionary of indices.

Parameters:

splits (Dict[str, torch.Tensor]) – A list of indices for the split.

Returns:

DatasetSplit

Return type:

An object with a single split named ‘default’.

static exists(path: str, name: str) bool[source]

Check if split file exists.

classmethod load(path: str, name: str) DatasetSplit[source]

Load the split from disk.

save(path: str, name: str) None[source]

Save the split to disk atomically.

classmethod split(n_indices: int, seed: int, split_ratios: Dict[str, float]) DatasetSplit[source]

Split the indices into named sets based on split_ratios.

Parameters:
  • n_indices (int) – Total number of indices to split.

  • seed (int) – Random seed for reproducibility.

  • split_ratios (Dict[str, float]) – A dictionary where keys are split names (e.g., ‘train’, ‘val’, ‘test’) and values are the ratios for each split.

Returns:

DatasetSplit

Return type:

An object with keys corresponding to split_ratios.

splits: Dict[str, Tensor]
to_dict() Dict[str, Tensor][source]

Convert splits to dictionary.

quanda.utils.common.cache_result(method)[source]

Decorate functions to cache method results.

quanda.utils.common.chunked_logits(model: Module, inputs: Any, chunk_size: int | None) Tensor[source]

Run model forward on inputs in chunks; return logits tensor.

quanda.utils.common.class_accuracy(net: Module, loader: DataLoader, device: str | device = 'cpu', single_class: int | None = None)[source]

Return accuracy on a dataset given by the data loader.

Parameters:
  • net (torch.nn.Module) – The model to evaluate.

  • loader (torch.utils.data.DataLoader) – The data loader to evaluate the model on.

  • device (Union[str, torch.device], optional) – The device to evaluate the model on, by default “cpu”.

  • single_class (Optional[int], optional) – If provided, all targets will be set to this class, by default None.

Return type:

float

quanda.utils.common.default_tensor_type(device: str | device)[source]

Context manager to temporarily change the default tensor type.

Parameters:

device (Union[str, torch.device]) – The device to which the default tensor type should be set.

Return type:

None

quanda.utils.common.ds_len(dataset: Dataset) int[source]

Get the length of the dataset.

Parameters:

dataset (torch.utils.data.Dataset) – The dataset to get the length of.

Returns:

The length of the dataset.

Return type:

int

quanda.utils.common.get_load_state_dict_func(device: str | device) Callable[[Module, str], float][source]

Get a load_state_dict function that loads a model state dict.

Parameters:

device (Union[str, torch.device]) – The device to load the model on.

quanda.utils.common.get_parent_module_from_name(model: Module, layer_name: str) Any[source]

Get the parent module of a module in a model by name.

Parameters:
  • model (torch.nn.Module) – The model to extract the module from.

  • layer_name (str) – The name of the module to extract.

Returns:

The module extracted from the model.

Return type:

Any

quanda.utils.common.get_targets(item: tuple | dict) int[source]

Extract targets from dataset item.

Parameters:

item (Union[tuple, dict]) – Dataset item which can be either a tuple (data, target) or a dict with ‘labels’ key.

Returns:

The target value.

Return type:

int

quanda.utils.common.load_last_checkpoint(model: Module, checkpoints: List[str], checkpoints_load_func: Callable[[Module, str], Any])[source]

Load the model from the checkpoint file.

Parameters:
  • model (torch.nn.Module) – The model to load the checkpoint into.

  • checkpoints (Optional[Union[str, List[str]]], optional) – Path to the model checkpoint file(s), defaults to None.

  • checkpoints_load_func (Optional[CheckpointLoadFunc], optional) – Function to load the model from the checkpoint file, takes (model, checkpoint path) as two arguments, by default None.

quanda.utils.common.make_func(func: Callable, func_kwargs: Mapping[str, Any] | None = None, **kwargs) partial[source]

Create a partial function with the given arguments.

Parameters:
  • func (Callable) – The function to create a partial function from.

  • func_kwargs (Optional[Mapping[str, Any]]) – Optional keyword arguments to fix for the function.

  • kwargs (Any) – Additional keyword arguments for the function.

Returns:

The partial function with the given arguments.

Return type:

functools.partial

quanda.utils.common.map_location_context(device: str | device)[source]

Context manager to temporarily change the map_location of torch.load.

Parameters:

device (Union[str, torch.device]) – The device to which the tensors should be loaded.

Return type:

None

quanda.utils.common.move_ds_item_to_device(data: Tensor | Dict[str, Tensor], device: str | device) Tensor | Dict[str, Tensor][source]

Move test data to the device.

Parameters:
  • data (Union[torch.Tensor, Dict[str, torch.Tensor]]) – The data to process.

  • device (Union[str, torch.device]) – The device to use.

Returns:

The data on the specified device.

Return type:

Union[torch.Tensor, Dict[str, torch.Tensor]]

quanda.utils.common.process_targets(targets: List[int] | Tensor, device: str | device) Tensor[source]

Convert target labels to torch.Tensor and move them to the device.

Parameters:
  • targets (Optional[Union[List[int], torch.Tensor]], optional) – The target labels, either as a list or tensor.

  • device (Union[str, torch.device]) – The device to use.

Returns:

The processed targets as a tensor, or None if no targets are provided.

Return type:

torch.Tensor or None

quanda.utils.common.replace_conv1d_with_linear(model: Module) None[source]

Swap HF Conv1D modules in-place with nn.Linear equivalents.

HF GPT-2 uses Conv1D (a transposed Linear) for attention/MLP projections; kronfluence only wraps nn.Linear and nn.Conv2d, so the swap is required before prepare_model for those models.

quanda.utils.common.resolve_config(config: dict | str) dict[source]

Resolve a benchmark config into a dict.

Parameters:

config (Union[dict, str]) – Either a config dict (passed through unchanged), a registered bench_id (resolved via quanda.benchmarks.resources.config_map.config_map), or a path to a benchmark YAML file.

Returns:

The resolved benchmark configuration.

Return type:

dict

Raises:

TypeError – If config is not a dict or str, or if the loaded YAML does not parse to a mapping.

quanda.utils.common.resolve_device(model: Module, device: str | None = None) str[source]

Return device if set, else infer from model’s parameters.

Falls back to "cpu" when the model has no parameters. This is the canonical device-resolution used by every explainer wrapper so that callers can opt out of explicit device passing without silently ending up on the wrong device.

quanda.utils.common.stable_repr(obj: Any) str[source]

Process-stable string form of obj for hashing/serialization.

quanda.utils.common.subsample_dataset(dataset: Dataset, max_n: int | None, seed: int) Dataset[source]

Deterministically subsample a dataset.

Subsampling is reproducible across platforms because it uses Python’s random.Random(seed).sample over range(N) and stores the indices in sorted order. For datasets that expose filtered (e.g. TransformedDataset), that is used instead of Subset so the subset preserves the original type and remaps any transform indices.

quanda.utils.common.subsample_indices(n: int, max_n: int | None, seed: int) List[int][source]

Deterministic subsample of range(n) matching _subsample_dataset.

Returns the full range(n) (as a list) when no subsampling is needed.