Source code for quanda.utils.datasets.image_datasets
"""Dataset classes for image datasets."""
from typing import Optional
import torch
_IMAGE_KEYS = ("image", "img", "pixel_values")
[docs]
class HFtoTV(torch.utils.data.Dataset):
"""Wrapper to make Hugging Face datasets compatible with torchvision."""
[docs]
def __init__(
self, dataset, transform=None, label_override: Optional[int] = None
):
"""Construct the HFtoTV dataset."""
self.dataset = dataset
self.transform = transform
self.label_override = label_override
self._labels_cache: Optional[list] = None
sample = dataset[0]
for key in _IMAGE_KEYS:
if key in sample:
self.image_key = key
break
else:
raise ValueError(
f"Could not find image key in dataset. "
f"Expected one of {_IMAGE_KEYS}, got {list(sample.keys())}."
)
def __len__(self):
"""Get dataset length."""
return len(self.dataset)
def __getitem__(self, idx):
"""Get a sample by index."""
if isinstance(idx, torch.Tensor):
idx = idx.item()
item = self.dataset[idx]
img = item[self.image_key]
if self.transform:
img = self.transform(img)
label = (
self.label_override
if self.label_override is not None
else int(item["label"])
)
return img, label
[docs]
def get_label(self, idx):
"""Return the label at ``idx`` without decoding the image."""
if isinstance(idx, torch.Tensor):
idx = idx.item()
if self.label_override is not None:
return self.label_override
# Column-only access bypasses HF's lazy Image decoder.
if self._labels_cache is None:
self._labels_cache = self.dataset["label"]
return int(self._labels_cache[idx])