Source code for quanda.benchmarks.resources.sample_transforms

"""Torchvision transforms for benchmarks."""

import torchvision.transforms as transforms  # type: ignore
from PIL import Image, ImageDraw


[docs] def add_white_square_mnist(img): """Add a white square to the top-left corner of the image.""" square_size = (12, 12) white_square = Image.new("L", square_size, 255) img.paste(white_square, (0, 0)) return img
[docs] def add_yellow_square(img): """Add a yellow square to a fixed location on the image.""" square_size = (15, 15) yellow_square = Image.new("RGB", square_size, (255, 255, 0)) img.paste(yellow_square, (10, 10)) return img
[docs] def add_pink_frame(img): """Draw a pink frame preserved by Resize(256)+CenterCrop(224).""" draw = ImageDraw.Draw(img) w, h = img.size m = min(w, h) half = 0.40 * m width = max(3, int(0.04 * m)) cx, cy = w / 2, h / 2 draw.rectangle( (cx - half, cy - half, cx + half, cy + half), outline=(255, 105, 180), width=width, ) return img
sample_transforms = { "mnist_transforms": transforms.Compose( [ transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)), ] ), "adversarial_transforms": transforms.Compose( [ transforms.Grayscale(), transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)), ] ), "mnist_denormalize": transforms.Compose( [transforms.Normalize(mean=[0], std=[1 / 0.5])] + [transforms.Normalize(mean=[0.5], std=[1])] ), "add_white_square_mnist": add_white_square_mnist, "tiny_imagenet_transforms": transforms.Compose( [ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ] ), "tiny_imagener_adversarial_transforms": transforms.Compose( [ transforms.Resize(size=(64, 64)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ] ), "add_yellow_square": add_yellow_square, "add_pink_frame": add_pink_frame, "cifar10_transforms": transforms.Compose( [ transforms.ToTensor(), transforms.Normalize( (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010), ), ] ), "cifar10_adversarial_transforms": transforms.Compose( [ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize( (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010), ), ] ), "awa2_train_transforms": transforms.Compose( [ transforms.Lambda(lambda img: img.convert("RGB")), transforms.Resize(256), transforms.CenterCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize( (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), ), ] ), "awa2_transforms": transforms.Compose( [ transforms.Lambda(lambda img: img.convert("RGB")), transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), ), ] ), "awa2_adversarial_transforms": transforms.Compose( [ transforms.Lambda(lambda img: img.convert("RGB")), transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), ), ] ), }