Source code for wibench.attacks.distortions

from .base import BaseAttack
import torch
import torchvision
from PIL import Image
import io
from wibench.typing import TorchImg
from typing import Literal
import torchvision.transforms.functional as F
import torchvision.transforms as T
import random


[docs]class JPEGCompression(BaseAttack): """JPEG compression attack. Parameters ---------- quality : int JPEG quality factor (1-100) """ name = "jpeg" def __init__(self, quality: int = 50) -> None: super().__init__() self.quality = int(quality) def __call__(self, img: TorchImg) -> TorchImg: """Apply JPEG compression to image. Parameters ---------- img : TorchImg Input image tensor Returns ------- TorchImg Compressed image tensor """ distorted_image = [] img_pil = torchvision.transforms.functional.to_pil_image(img) buffer = io.BytesIO() img_pil.save(buffer, format="JPEG", quality=self.quality) img_pil_compressed = Image.open(buffer) img_compressed = torchvision.transforms.functional.to_tensor( img_pil_compressed ) distorted_image = img_compressed.to(device=img.device, dtype=img.dtype) return distorted_image
[docs]class Rotate90(BaseAttack): """Rotates image by 90 degrees clockwise or counter-clockwise. Parameters ---------- direction : Literal["clock", "counter"], optional Rotation direction, either "clock" for clockwise or "counter" for counter-clockwise. Default is "clock". """ def __init__(self, direction: Literal["clock", "counter"] = "clock"): assert direction in ["clock", "counter"] self.direction = direction def __call__(self, image: TorchImg) -> TorchImg: """Rotate image by 90 degrees in specified direction. Parameters ---------- image : TorchImg Input image tensor Returns ------- TorchImg Rotated image tensor """ if self.direction == "clock": return torch.rot90(image, dims=(-1, -2)) elif self.direction == "counter": return torch.rot90(image, k=1, dims=(-2, -1))
[docs]class Rotate(BaseAttack): """Rotates image by arbitrary angle counter-clockwise. Parameters ---------- angle : float Rotation angle in degrees counter-clockwise. For clockwise rotation use negative numbers interpolation : str, optional Interpolation mode ('nearest', 'bilinear', 'bicubic'). Default is 'bilinear'. expand : bool, optional Whether to expand output image size to fit rotated image. Default is False. """ def __init__( self, angle: float, interpolation: str = "bilinear", expand=False ): self.angle = angle self.interpolation = T.InterpolationMode(interpolation) self.expand = expand def __call__(self, image: TorchImg) -> TorchImg: """Rotate image by specified angle. Parameters ---------- image : TorchImg Input image tensor Returns ------- TorchImg Rotated image tensor """ res = F.rotate(image, self.angle, self.interpolation, self.expand) return res
[docs]class GaussianBlur(BaseAttack): """Applies Gaussian blur to image. Parameters ---------- kernel_size : int Size of Gaussian kernel (must be odd and positive) """ def __init__(self, kernel_size: int): self.kernel_size = kernel_size def __call__(self, image: TorchImg) -> TorchImg: """Apply Gaussian blur to image. Parameters ---------- image : TorchImg Input image tensor Returns ------- TorchImg Blurred image tensor """ return F.gaussian_blur(image, kernel_size=self.kernel_size)
[docs]class GaussianNoise(BaseAttack): """Adds Gaussian noise to image. Parameters ---------- sigma : float Standard deviation of Gaussian noise distribution """ def __init__(self, sigma: float) -> None: super().__init__() self.sigma = sigma def __call__(self, img: TorchImg) -> TorchImg: """Add Gaussian noise to image. Parameters ---------- img : TorchImg Input image tensor Returns ------- TorchImg Noisy image tensor """ noise = torch.randn(img.shape, device=img.device) * self.sigma distorted_image = (img + noise).clamp(0, 1) return distorted_image
[docs]class CenterCrop(BaseAttack): """Center crops image by specified area ratio. Parameters ---------- ratio : float Ratio of area to keep (0-1). For example, 0.5 keeps 50% of image area. """ def __init__(self, ratio: float): self.dim_ratio = ratio ** (1 / 2) def __call__(self, image: TorchImg) -> TorchImg: """Center crop image. Parameters ---------- image : TorchImg Input image tensor Returns ------- TorchImg Center-cropped image tensor """ h, w = image.shape[-2:] new_h = int(round(h * self.dim_ratio)) new_w = int(round(w * self.dim_ratio)) return F.center_crop(image, output_size=(new_h, new_w))
[docs]class Resize(BaseAttack): """Resizes image by specified width and height ratios. Parameters ---------- x_ratio : float, optional Width scaling factor. Default is 1 (no change). y_ratio : float, optional Height scaling factor. Default is 1 (no change). interpolation : str, optional Interpolation mode ('nearest', 'bilinear', 'bicubic'). Default is 'bilinear'. """ def __init__( self, x_ratio: float = 1, y_ratio: float = 1, interpolation: str = "bilinear", ): self.x_ratio = x_ratio self.y_ratio = y_ratio self.interpolation = F.InterpolationMode(interpolation) def __call__(self, image: TorchImg) -> TorchImg: """Resize image by specified ratios. Parameters ---------- image : TorchImg Input image tensor Returns ------- TorchImg Resized image tensor """ h, w = image.shape[-2:] new_h = int(round(h * self.y_ratio)) new_w = int(round(w * self.x_ratio)) return F.resize( image, size=[new_h, new_w], interpolation=self.interpolation )
[docs]class RandomCropout(BaseAttack): """Randomly crops out a rectangular region of specified area ratio. Fills the remaining area with black color. Parameters ---------- ratio : float Ratio of area to keep (0-1). For example, 0.8 keeps 80% of image area. """ def __init__(self, ratio: float): self.ratio = ratio def __call__(self, image: TorchImg) -> TorchImg: """Apply random cropout to image. Parameters ---------- image : TorchImg Input image tensor Returns ------- TorchImg Image with random rectangular region kept and remaining zeroed out """ w_ratio = random.random() * (1 - self.ratio) + self.ratio h_ratio = self.ratio / w_ratio h, w = image.shape[-2:] crop_w, crop_h = round(w_ratio * w), round(h_ratio * h) left_pos = random.randint(0, w - crop_w) top_pos = random.randint(0, h - crop_h) mask = torch.zeros_like(image, dtype=torch.bool) mask[..., top_pos : top_pos + crop_h, left_pos : left_pos + crop_w] = ( True ) result = image.clone() result[~mask] = 0 return result
[docs]class Brightness(BaseAttack): """Adjusts image brightness. Parameters ---------- factor : float Brightness adjustment factor: * 1.0 returns original image, * <1.0 darkens image, * >1.0 brightens image """ def __init__(self, factor: float): self.factor = factor def __call__(self, image: TorchImg) -> TorchImg: """Adjust image brightness. Parameters ---------- image : TorchImg Input image tensor Returns ------- TorchImg Brightness-adjusted image tensor """ return F.adjust_brightness(image, brightness_factor=self.factor)
[docs]class Contrast(BaseAttack): """Adjusts image contrast. Parameters ---------- factor : float Contrast adjustment factor: * 1.0 returns original image, * <1.0 reduces contrast, * >1.0 increases contrast """ def __init__(self, factor: float): self.factor = factor def __call__(self, image: TorchImg) -> TorchImg: """Adjust image contrast. Parameters ---------- image : TorchImg Input image tensor Returns ------- TorchImg Contrast-adjusted image tensor """ return F.adjust_contrast(image, contrast_factor=self.factor)
[docs]class PixelShift(BaseAttack): """Shifts image pixels horizontally with edge wrapping. Parameters ---------- delta : int, optional Number of pixels to shift right. Leftmost pixels wrap around to right. Default is 7. """ def __init__(self, delta: int = 7): self.delta = delta def __call__(self, image: TorchImg) -> TorchImg: """Shift image pixels horizontally. Parameters ---------- image : TorchImg Input image tensor Returns ------- TorchImg Pixel-shifted image tensor """ img_shifted = torch.roll(image, shifts=self.delta, dims=-1) return img_shifted
[docs]class ColorInversion(BaseAttack): """Inverts colors in image. """ def __call__(self, image: TorchImg) -> TorchImg: return -image + 1