quanda.utils.datasets.dataset_handlers module¶
Dataset handler classes.
- class quanda.utils.datasets.dataset_handlers.DatasetHandler[source]¶
Bases:
ABCAbstract base class for dataset handling.
- abstract create_dataloader(dataset: Dataset | Dataset, batch_size: int, shuffle: bool = False, num_workers: int = 0, collate_fn: Callable | None = None) DataLoader[source]¶
Create a DataLoader for the dataset.
- Parameters:
dataset (Union[torch.utils.data.Dataset, datasets.Dataset]) – The dataset to load.
batch_size (int) – Batch size.
shuffle (bool, optional) – Whether to shuffle the dataset (default is False).
num_workers (int, optional) – Number of workers for data loading, by default 0.
collate_fn (Optional[Callable], optional) – Collate function for the DataLoader, by default None.
- Returns:
Configured DataLoader.
- Return type:
DataLoader
- abstract get_label(item: Any) Any[source]¶
Extract the label from a single dataset item.
- Parameters:
item (Any) – A single item as returned by
dataset[i].- Returns:
The label associated with the item.
- Return type:
Any
- abstract get_model_inputs(inputs: Any) Any[source]¶
Extract model inputs from the processed batch inputs.
- Parameters:
inputs (Any) – Raw inputs from the dataset.
- Returns:
Model-ready inputs.
- Return type:
Any
- abstract get_predictions(outputs: Any) Tensor[source]¶
Extract predictions from model outputs.
- Parameters:
outputs (Any) – Raw outputs from the model.
- Returns:
The extracted predictions.
- Return type:
torch.Tensor
- abstract process_batch(batch: Any, device: str | device) Tuple[Any, Tensor][source]¶
Process a batch of data and return model inputs and labels.
- Parameters:
batch (Any) – A batch of data.
device (Union[str, torch.device]) – Device to move the data to.
- Returns:
A tuple (inputs, labels) where inputs may be a tensor or dict, and labels is a tensor.
- Return type:
Tuple[Any, torch.Tensor]
- class quanda.utils.datasets.dataset_handlers.HuggingFaceDatasetHandler[source]¶
Bases:
DatasetHandlerHandler for HuggingFace datasets.
- create_dataloader(dataset: Dataset, batch_size: int, shuffle: bool = False, num_workers: int = 0, collate_fn: Callable | None = None) DataLoader[source]¶
Create a DataLoader for the dataset.
- Parameters:
dataset (datasets.Dataset) – The dataset to load.
batch_size (int) – The batch size to use.
shuffle (bool, optional) – Whether to shuffle the data, by default False.
num_workers (int, optional) – Number of workers for data loading, by default 0.
collate_fn (Optional[Callable], optional) – Collate function for the DataLoader, by default None.
- Returns:
Configured DataLoader.
- Return type:
DataLoader
- get_model_inputs(inputs: Dict[str, Tensor]) Dict[str, Tensor][source]¶
Extract model inputs from the processed batch inputs.
- Parameters:
inputs (Dict[str, torch.Tensor]) – The processed batch inputs dictionary.
- Returns:
The inputs to be passed to the model.
- Return type:
Dict[str, torch.Tensor]
- get_predictions(outputs: Any) Tensor[source]¶
Extract predictions from model outputs.
- Parameters:
outputs (Any) – The model outputs.
- Returns:
The extracted predictions.
- Return type:
torch.Tensor
- process_batch(batch: Dict[str, Tensor], device: str | device) Tuple[Dict[str, Tensor], Tensor][source]¶
Process a batch of data from a HuggingFace dataset.
- Parameters:
batch (Dict[str, torch.Tensor]) – The batch dictionary containing inputs and labels.
device (Union[str, torch.device]) – The device to move the tensors to.
- Returns:
The processed inputs dictionary and labels on the specified device.
- Return type:
Tuple[Dict[str, torch.Tensor], torch.Tensor]
- class quanda.utils.datasets.dataset_handlers.HuggingFaceSequenceDatasetHandler(input_keys: Sequence[str] = ('input_ids', 'token_type_ids', 'attention_mask'), label_key: str = 'labels')[source]¶
Bases:
HuggingFaceDatasetHandlerHuggingFace dataset handler that yields positional list batches.
Unlike
HuggingFaceDatasetHandler(which yieldsdictbatches viadefault_data_collator), this handler’sDataLoaderemits lists[input_key_0, ..., input_key_N, label_key]in the order given byinput_keys. Required for consumers that index batches positionally — e.g.dattri, which doesbatch[0].shape[0]and (in Arnoldi’scache()) mutatesbatch[i] = torch.cat(...), which fails on tuples.process_batch/get_model_inputsstill expose adictview so downstream quanda code (benchmarks, metrics) can callmodel(**inputs)the same way as with the dict handler.- __init__(input_keys: Sequence[str] = ('input_ids', 'token_type_ids', 'attention_mask'), label_key: str = 'labels')[source]¶
Initialize the handler.
- Parameters:
input_keys (Sequence[str], optional) – Keys to emit as the leading list elements, in order.
label_key (str, optional) – Key emitted as the trailing list element. Defaults to
"labels".
- build_positional_batch(inputs: Tensor | Dict[str, Tensor], labels: Tensor, device: str | device) Tuple[Tensor, ...][source]¶
Build a
(*input_keys, labels)batch ondevice.
- collate(samples: List[Dict[str, Any]]) List[Tensor][source]¶
Stack HF dict samples into a list
[*input_keys, label_key].Projects each sample onto the required keys before collation so that non-numeric columns (e.g. raw
"sentence"/"hypothesis"text fields carried alongside tokenized columns) never reachdefault_data_collator, which would fail trying to batch them.
- class quanda.utils.datasets.dataset_handlers.TorchDatasetHandler[source]¶
Bases:
DatasetHandlerHandler for PyTorch datasets.
- build_positional_batch(inputs: Tensor | Dict[str, Tensor], labels: Tensor, device: str | device) Tuple[Tensor, ...][source]¶
Build an
(inputs, labels)batch ondevice.
- create_dataloader(dataset: Dataset, batch_size: int, shuffle: bool = False, num_workers: int = 0, collate_fn: Callable | None = None) DataLoader[source]¶
Create a DataLoader for the dataset.
- Parameters:
dataset (torch.utils.data.Dataset) – The dataset to load.
batch_size (int) – The batch size to use.
shuffle (bool, optional) – Whether to shuffle the data, by default False.
num_workers (int, optional) – Number of workers for data loading, by default 0.
collate_fn (Optional[Callable], optional) – Collate function for the DataLoader, by default None. Ignored.
- Returns:
Configured DataLoader.
- Return type:
DataLoader
- get_model_inputs(inputs: Tensor) Tensor[source]¶
Extract model inputs from the processed batch inputs.
- Parameters:
inputs (torch.Tensor) – The processed batch inputs.
- Returns:
The inputs to be passed to the model.
- Return type:
torch.Tensor
- get_predictions(outputs: Tensor) Tensor[source]¶
Extract predictions from model outputs.
- Parameters:
outputs (torch.Tensor) – The model outputs.
- Returns:
The extracted predictions.
- Return type:
torch.Tensor
- process_batch(batch: Tuple[Tensor, Tensor], device: str | device) Tuple[Tensor, Tensor][source]¶
Process a batch of data.
- Parameters:
batch (Tuple[torch.Tensor, torch.Tensor]) – A tuple of (inputs, labels).
device (Union[str, torch.device]) – The device to move the tensors to.
- Returns:
The processed inputs and labels on the specified device.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- quanda.utils.datasets.dataset_handlers.get_dataset_handler(dataset: Dataset | Dataset) DatasetHandler[source]¶
Return the correct DatasetHandler for the given dataset.
- Parameters:
dataset (Union[torch.utils.data.Dataset, datasets.Dataset]) – The dataset which is either a PyTorch Dataset or HuggingFace Dataset.
- Returns:
A handler instance suited for the dataset.
- Return type: