quanda.utils.datasets.dataset_handlers module

Dataset handler classes.

class quanda.utils.datasets.dataset_handlers.DatasetHandler[source]

Bases: ABC

Abstract base class for dataset handling.

abstract create_dataloader(dataset: Dataset | Dataset, batch_size: int, shuffle: bool = False, num_workers: int = 0, collate_fn: Callable | None = None) DataLoader[source]

Create a DataLoader for the dataset.

Parameters:
  • dataset (Union[torch.utils.data.Dataset, datasets.Dataset]) – The dataset to load.

  • batch_size (int) – Batch size.

  • shuffle (bool, optional) – Whether to shuffle the dataset (default is False).

  • num_workers (int, optional) – Number of workers for data loading, by default 0.

  • collate_fn (Optional[Callable], optional) – Collate function for the DataLoader, by default None.

Returns:

Configured DataLoader.

Return type:

DataLoader

abstract get_label(item: Any) Any[source]

Extract the label from a single dataset item.

Parameters:

item (Any) – A single item as returned by dataset[i].

Returns:

The label associated with the item.

Return type:

Any

abstract get_model_inputs(inputs: Any) Any[source]

Extract model inputs from the processed batch inputs.

Parameters:

inputs (Any) – Raw inputs from the dataset.

Returns:

Model-ready inputs.

Return type:

Any

abstract get_predictions(outputs: Any) Tensor[source]

Extract predictions from model outputs.

Parameters:

outputs (Any) – Raw outputs from the model.

Returns:

The extracted predictions.

Return type:

torch.Tensor

abstract process_batch(batch: Any, device: str | device) Tuple[Any, Tensor][source]

Process a batch of data and return model inputs and labels.

Parameters:
  • batch (Any) – A batch of data.

  • device (Union[str, torch.device]) – Device to move the data to.

Returns:

A tuple (inputs, labels) where inputs may be a tensor or dict, and labels is a tensor.

Return type:

Tuple[Any, torch.Tensor]

abstract with_label(item: Any, label: Any) Any[source]

Return a copy of item with its label replaced.

Parameters:
  • item (Any) – A single item as returned by dataset[i].

  • label (Any) – The replacement label.

Returns:

Item with the label replaced, matching the input item’s format.

Return type:

Any

class quanda.utils.datasets.dataset_handlers.HuggingFaceDatasetHandler[source]

Bases: DatasetHandler

Handler for HuggingFace datasets.

create_dataloader(dataset: Dataset, batch_size: int, shuffle: bool = False, num_workers: int = 0, collate_fn: Callable | None = None) DataLoader[source]

Create a DataLoader for the dataset.

Parameters:
  • dataset (datasets.Dataset) – The dataset to load.

  • batch_size (int) – The batch size to use.

  • shuffle (bool, optional) – Whether to shuffle the data, by default False.

  • num_workers (int, optional) – Number of workers for data loading, by default 0.

  • collate_fn (Optional[Callable], optional) – Collate function for the DataLoader, by default None.

Returns:

Configured DataLoader.

Return type:

DataLoader

get_label(item: Dict[str, Any]) Any[source]

Extract the label from a HuggingFace dict item.

get_model_inputs(inputs: Dict[str, Tensor]) Dict[str, Tensor][source]

Extract model inputs from the processed batch inputs.

Parameters:

inputs (Dict[str, torch.Tensor]) – The processed batch inputs dictionary.

Returns:

The inputs to be passed to the model.

Return type:

Dict[str, torch.Tensor]

get_predictions(outputs: Any) Tensor[source]

Extract predictions from model outputs.

Parameters:

outputs (Any) – The model outputs.

Returns:

The extracted predictions.

Return type:

torch.Tensor

process_batch(batch: Dict[str, Tensor], device: str | device) Tuple[Dict[str, Tensor], Tensor][source]

Process a batch of data from a HuggingFace dataset.

Parameters:
  • batch (Dict[str, torch.Tensor]) – The batch dictionary containing inputs and labels.

  • device (Union[str, torch.device]) – The device to move the tensors to.

Returns:

The processed inputs dictionary and labels on the specified device.

Return type:

Tuple[Dict[str, torch.Tensor], torch.Tensor]

with_label(item: Dict[str, Any], label: Any) Dict[str, Any][source]

Return a HuggingFace dict item with labels replaced.

class quanda.utils.datasets.dataset_handlers.HuggingFaceSequenceDatasetHandler(input_keys: Sequence[str] = ('input_ids', 'token_type_ids', 'attention_mask'), label_key: str = 'labels')[source]

Bases: HuggingFaceDatasetHandler

HuggingFace dataset handler that yields positional list batches.

Unlike HuggingFaceDatasetHandler (which yields dict batches via default_data_collator), this handler’s DataLoader emits lists [input_key_0, ..., input_key_N, label_key] in the order given by input_keys. Required for consumers that index batches positionally — e.g. dattri, which does batch[0].shape[0] and (in Arnoldi’s cache()) mutates batch[i] = torch.cat(...), which fails on tuples.

process_batch / get_model_inputs still expose a dict view so downstream quanda code (benchmarks, metrics) can call model(**inputs) the same way as with the dict handler.

__init__(input_keys: Sequence[str] = ('input_ids', 'token_type_ids', 'attention_mask'), label_key: str = 'labels')[source]

Initialize the handler.

Parameters:
  • input_keys (Sequence[str], optional) – Keys to emit as the leading list elements, in order.

  • label_key (str, optional) – Key emitted as the trailing list element. Defaults to "labels".

build_positional_batch(inputs: Tensor | Dict[str, Tensor], labels: Tensor, device: str | device) Tuple[Tensor, ...][source]

Build a (*input_keys, labels) batch on device.

collate(samples: List[Dict[str, Any]]) List[Tensor][source]

Stack HF dict samples into a list [*input_keys, label_key].

Projects each sample onto the required keys before collation so that non-numeric columns (e.g. raw "sentence"/"hypothesis" text fields carried alongside tokenized columns) never reach default_data_collator, which would fail trying to batch them.

create_dataloader(dataset: Dataset, batch_size: int, shuffle: bool = False, num_workers: int = 0, collate_fn: Callable | None = None) DataLoader[source]

Create a list-emitting DataLoader for the HF dataset.

process_batch(batch: Any, device: str | device) Tuple[Dict[str, Tensor], Tensor][source]

Unpack positional batch into (inputs_dict, labels) on device.

class quanda.utils.datasets.dataset_handlers.TorchDatasetHandler[source]

Bases: DatasetHandler

Handler for PyTorch datasets.

build_positional_batch(inputs: Tensor | Dict[str, Tensor], labels: Tensor, device: str | device) Tuple[Tensor, ...][source]

Build an (inputs, labels) batch on device.

create_dataloader(dataset: Dataset, batch_size: int, shuffle: bool = False, num_workers: int = 0, collate_fn: Callable | None = None) DataLoader[source]

Create a DataLoader for the dataset.

Parameters:
  • dataset (torch.utils.data.Dataset) – The dataset to load.

  • batch_size (int) – The batch size to use.

  • shuffle (bool, optional) – Whether to shuffle the data, by default False.

  • num_workers (int, optional) – Number of workers for data loading, by default 0.

  • collate_fn (Optional[Callable], optional) – Collate function for the DataLoader, by default None. Ignored.

Returns:

Configured DataLoader.

Return type:

DataLoader

get_label(item: Tuple[Any, Any]) Any[source]

Extract the label from a (sample, label) tuple.

get_model_inputs(inputs: Tensor) Tensor[source]

Extract model inputs from the processed batch inputs.

Parameters:

inputs (torch.Tensor) – The processed batch inputs.

Returns:

The inputs to be passed to the model.

Return type:

torch.Tensor

get_predictions(outputs: Tensor) Tensor[source]

Extract predictions from model outputs.

Parameters:

outputs (torch.Tensor) – The model outputs.

Returns:

The extracted predictions.

Return type:

torch.Tensor

process_batch(batch: Tuple[Tensor, Tensor], device: str | device) Tuple[Tensor, Tensor][source]

Process a batch of data.

Parameters:
  • batch (Tuple[torch.Tensor, torch.Tensor]) – A tuple of (inputs, labels).

  • device (Union[str, torch.device]) – The device to move the tensors to.

Returns:

The processed inputs and labels on the specified device.

Return type:

Tuple[torch.Tensor, torch.Tensor]

with_label(item: Tuple[Any, Any], label: Any) Tuple[Any, Any][source]

Return a (sample, label) tuple with the label replaced.

quanda.utils.datasets.dataset_handlers.get_dataset_handler(dataset: Dataset | Dataset) DatasetHandler[source]

Return the correct DatasetHandler for the given dataset.

Parameters:

dataset (Union[torch.utils.data.Dataset, datasets.Dataset]) – The dataset which is either a PyTorch Dataset or HuggingFace Dataset.

Returns:

A handler instance suited for the dataset.

Return type:

DatasetHandler