quanda.utils.common module¶
Common utility functions for the Quanda package.
- class quanda.utils.common.DatasetSplit(splits: Dict[str, Tensor])[source]¶
Bases:
ABCClass to store dynamically named splits (e.g., train, val, test).
- __init__(splits: Dict[str, Tensor])[source]¶
Create a DatasetSplit from a dictionary of indices.
- Parameters:
splits (Dict[str, torch.Tensor]) – A list of indices for the split.
- Returns:
DatasetSplit
- Return type:
An object with a single split named ‘default’.
- classmethod load(path: str, name: str) DatasetSplit[source]¶
Load the split from disk.
- classmethod split(n_indices: int, seed: int, split_ratios: Dict[str, float]) DatasetSplit[source]¶
Split the indices into named sets based on split_ratios.
- Parameters:
n_indices (int) – Total number of indices to split.
seed (int) – Random seed for reproducibility.
split_ratios (Dict[str, float]) – A dictionary where keys are split names (e.g., ‘train’, ‘val’, ‘test’) and values are the ratios for each split.
- Returns:
DatasetSplit
- Return type:
An object with keys corresponding to split_ratios.
- splits: Dict[str, Tensor]¶
- quanda.utils.common.chunked_logits(model: Module, inputs: Any, chunk_size: int | None) Tensor[source]¶
Run
modelforward oninputsin chunks; return logits tensor.
- quanda.utils.common.class_accuracy(net: Module, loader: DataLoader, device: str | device = 'cpu', single_class: int | None = None)[source]¶
Return accuracy on a dataset given by the data loader.
- Parameters:
net (torch.nn.Module) – The model to evaluate.
loader (torch.utils.data.DataLoader) – The data loader to evaluate the model on.
device (Union[str, torch.device], optional) – The device to evaluate the model on, by default “cpu”.
single_class (Optional[int], optional) – If provided, all targets will be set to this class, by default None.
- Return type:
float
- quanda.utils.common.default_tensor_type(device: str | device)[source]¶
Context manager to temporarily change the default tensor type.
- Parameters:
device (Union[str, torch.device]) – The device to which the default tensor type should be set.
- Return type:
None
- quanda.utils.common.ds_len(dataset: Dataset) int[source]¶
Get the length of the dataset.
- Parameters:
dataset (torch.utils.data.Dataset) – The dataset to get the length of.
- Returns:
The length of the dataset.
- Return type:
int
- quanda.utils.common.get_load_state_dict_func(device: str | device) Callable[[Module, str], float][source]¶
Get a load_state_dict function that loads a model state dict.
- Parameters:
device (Union[str, torch.device]) – The device to load the model on.
- quanda.utils.common.get_parent_module_from_name(model: Module, layer_name: str) Any[source]¶
Get the parent module of a module in a model by name.
- Parameters:
model (torch.nn.Module) – The model to extract the module from.
layer_name (str) – The name of the module to extract.
- Returns:
The module extracted from the model.
- Return type:
Any
- quanda.utils.common.get_targets(item: tuple | dict) int[source]¶
Extract targets from dataset item.
- Parameters:
item (Union[tuple, dict]) – Dataset item which can be either a tuple (data, target) or a dict with ‘labels’ key.
- Returns:
The target value.
- Return type:
int
- quanda.utils.common.load_last_checkpoint(model: Module, checkpoints: List[str], checkpoints_load_func: Callable[[Module, str], Any])[source]¶
Load the model from the checkpoint file.
- Parameters:
model (torch.nn.Module) – The model to load the checkpoint into.
checkpoints (Optional[Union[str, List[str]]], optional) – Path to the model checkpoint file(s), defaults to None.
checkpoints_load_func (Optional[CheckpointLoadFunc], optional) – Function to load the model from the checkpoint file, takes (model, checkpoint path) as two arguments, by default None.
- quanda.utils.common.make_func(func: Callable, func_kwargs: Mapping[str, Any] | None = None, **kwargs) partial[source]¶
Create a partial function with the given arguments.
- Parameters:
func (Callable) – The function to create a partial function from.
func_kwargs (Optional[Mapping[str, Any]]) – Optional keyword arguments to fix for the function.
kwargs (Any) – Additional keyword arguments for the function.
- Returns:
The partial function with the given arguments.
- Return type:
functools.partial
- quanda.utils.common.map_location_context(device: str | device)[source]¶
Context manager to temporarily change the map_location of torch.load.
- Parameters:
device (Union[str, torch.device]) – The device to which the tensors should be loaded.
- Return type:
None
- quanda.utils.common.move_ds_item_to_device(data: Tensor | Dict[str, Tensor], device: str | device) Tensor | Dict[str, Tensor][source]¶
Move test data to the device.
- Parameters:
data (Union[torch.Tensor, Dict[str, torch.Tensor]]) – The data to process.
device (Union[str, torch.device]) – The device to use.
- Returns:
The data on the specified device.
- Return type:
Union[torch.Tensor, Dict[str, torch.Tensor]]
- quanda.utils.common.process_targets(targets: List[int] | Tensor, device: str | device) Tensor[source]¶
Convert target labels to torch.Tensor and move them to the device.
- Parameters:
targets (Optional[Union[List[int], torch.Tensor]], optional) – The target labels, either as a list or tensor.
device (Union[str, torch.device]) – The device to use.
- Returns:
The processed targets as a tensor, or None if no targets are provided.
- Return type:
torch.Tensor or None
- quanda.utils.common.replace_conv1d_with_linear(model: Module) None[source]¶
Swap HF
Conv1Dmodules in-place withnn.Linearequivalents.HF GPT-2 uses
Conv1D(a transposed Linear) for attention/MLP projections; kronfluence only wrapsnn.Linearandnn.Conv2d, so the swap is required beforeprepare_modelfor those models.
- quanda.utils.common.resolve_config(config: dict | str) dict[source]¶
Resolve a benchmark
configinto a dict.- Parameters:
config (Union[dict, str]) – Either a config dict (passed through unchanged), a registered
bench_id(resolved viaquanda.benchmarks.resources.config_map.config_map), or a path to a benchmark YAML file.- Returns:
The resolved benchmark configuration.
- Return type:
dict
- Raises:
TypeError – If
configis not adictorstr, or if the loaded YAML does not parse to a mapping.
- quanda.utils.common.resolve_device(model: Module, device: str | None = None) str[source]¶
Return
deviceif set, else infer frommodel’s parameters.Falls back to
"cpu"when the model has no parameters. This is the canonical device-resolution used by every explainer wrapper so that callers can opt out of explicit device passing without silently ending up on the wrong device.
- quanda.utils.common.stable_repr(obj: Any) str[source]¶
Process-stable string form of
objfor hashing/serialization.
- quanda.utils.common.subsample_dataset(dataset: Dataset, max_n: int | None, seed: int) Dataset[source]¶
Deterministically subsample a dataset.
Subsampling is reproducible across platforms because it uses Python’s
random.Random(seed).sampleoverrange(N)and stores the indices in sorted order. For datasets that exposefiltered(e.g.TransformedDataset), that is used instead ofSubsetso the subset preserves the original type and remaps any transform indices.