Source code for quanda.utils.datasets.on_device_dataset

"""Module to move a dataset to a device."""

from typing import Any, Sized, Union

import torch


def _move_item_to_device(item: Any, device: Union[str, torch.device]) -> Any:
    """Move every tensor in ``item`` to ``device``, preserving structure.

    Non-tensor leaves (ints, strings, etc.) are returned unchanged.
    Scalar targets coming out of ``TensorDataset`` / custom datasets are
    promoted to tensors on the target device so the returned structure is
    ``.to``-compatible with downstream code.
    """
    if isinstance(item, torch.Tensor):
        return item.to(device)
    if isinstance(item, tuple):
        return tuple(_move_item_to_device(x, device) for x in item)
    if isinstance(item, list):
        return [_move_item_to_device(x, device) for x in item]
    if isinstance(item, dict):
        return {k: _move_item_to_device(v, device) for k, v in item.items()}
    if isinstance(item, (int, float, bool)):
        return torch.tensor(item, device=device)
    return item


[docs] class OnDeviceDataset(torch.utils.data.Dataset): """Wrapper that moves a dataset's tensors to a target device. Handles arbitrary sample structures returned by ``dataset[i]`` — single tensor, tuple/list of any length, dict — by recursively moving every tensor leaf while leaving non-tensor values untouched. """
[docs] def __init__( self, dataset: torch.utils.data.Dataset, device: Union[str, torch.device], ): """Construct the OnDeviceDataset class. Parameters ---------- dataset : torch.utils.data.Dataset The dataset to move to the device. device : Union[str, torch.device] The device to move the dataset to. """ self.dataset = dataset self.device = device
def __getitem__(self, idx): """Get a sample by index with all its tensors on ``self.device``.""" return _move_item_to_device(self.dataset[idx], self.device) def __len__(self): """Get dataset length.""" if isinstance(self.dataset, Sized): return len(self.dataset) dl = torch.utils.data.DataLoader(self.dataset, batch_size=1) return len(dl)