Source code for quanda.utils.cache

"""Module for caching explanations."""

import glob
import os
from typing import Any, Optional, Union

import torch


[docs] class Cache: """Abstract class for caching. Methods of this class are static."""
[docs] @staticmethod def save(*args, **kwargs) -> None: """Save the explanation to the cache.""" raise NotImplementedError
[docs] @staticmethod def load(*args, **kwargs) -> Any: """Load the explanation from the cache.""" raise NotImplementedError
[docs] @staticmethod def exists(*args, **kwargs) -> bool: """Check if the explanation exists in the cache.""" raise NotImplementedError
[docs] class BatchedCachedExplanations: """Utility class for lazy loading and saving batched explanations."""
[docs] def __init__( self, cache_dir: str, device: Optional[str] = None, ): """Load and save batched explanations. Parameters ---------- cache_dir: str Directory containing the cached explanations. device: Optional[str] Device to load the explanations on. """ super().__init__() self.cache_dir = cache_dir self.device = device self.av_filesearch = os.path.join(cache_dir, "*.pt") files = glob.glob(self.av_filesearch) # Index files by their num_id (filename stem). Numeric stems # are stored as ints so int lookups (e.g. batch index) work. self._by_id: dict = {} for fl in files: stem = os.path.splitext(os.path.basename(fl))[0] key: Union[int, str] = ( int(stem) if stem.lstrip("-").isdigit() else stem ) self._by_id[key] = fl self.files = [ self._by_id[k] for k in sorted( self._by_id.keys(), key=lambda x: (isinstance(x, str), x), ) ] self.batch_size = torch.load( self.files[0], map_location=self.device, weights_only=True ).shape[0]
[docs] def keys(self): """Return the num_ids available in the cache.""" return list(self._by_id.keys())
def __getitem__(self, num_id: Union[int, str]) -> torch.Tensor: """Load the explanation tensor saved with the given ``num_id``. Parameters ---------- num_id: Union[int, str] Identifier the tensor was saved under via :meth:`ExplanationsCache.save`. Returns ------- torch.Tensor The explanation at the specified index. """ if num_id not in self._by_id: raise KeyError( f"num_id {num_id!r} not found in cache {self.cache_dir}." ) fl = self._by_id[num_id] return torch.load(fl, map_location=self.device, weights_only=True) def __len__(self) -> int: """Get the number of explanations in the cache. Returns ------- int Number of explanations in the cache. """ return len(self.files)
[docs] class ExplanationsCache(Cache): """Class for caching generated explanations at a given path."""
[docs] @staticmethod def exists( path: str, num_id: Optional[Union[str, int]] = None, ) -> bool: """Check if the explanations exist at the given path. Parameters ---------- path: str Path to the explanations. num_id: Optional[Union[str, int]] Number identifier for the explanations. Returns ------- bool True if the explanations exist, False otherwise. """ av_filesearch = os.path.join( path, "*.pt" if num_id is None else f"{num_id}.pt" ) return os.path.exists(path) and len(glob.glob(av_filesearch)) > 0
[docs] @staticmethod def save( path: str, exp_tensors: torch.Tensor, num_id: Union[str, int], ) -> None: """Save the explanations to the given path. Parameters ---------- path: str Path to save the explanations. exp_tensors: torch.Tensor Explanations to save. num_id: Union[str, int] Number identifier for the explanations. Returns ------- None """ av_save_fl_path = os.path.join(path, f"{num_id}.pt") torch.save(exp_tensors.detach().cpu(), av_save_fl_path)
[docs] @staticmethod def load( path: str, device: Optional[str] = None, ) -> BatchedCachedExplanations: """Load the explanations from the given path. Parameters ---------- path: str Path to load the explanations. device: Optional[str] Device to load the explanations on. Returns ------- BatchedCachedExplanations BatchedCachedExplanations object that can load explanations lazily by index. """ if os.path.exists(path): xpl_dataset = BatchedCachedExplanations( cache_dir=path, device=device ) return xpl_dataset else: raise RuntimeError(f"Explanations were not found at path {path}")