Source code for wibench.attacks.averaging.averaging

import torch
import torchvision.transforms.functional as F
from wibench.attacks import BaseAttack
from wibench.typing import TorchImg
from wibench.datasets.base import ImageFolderDataset


# Pre-trained patterns for this attack are available in ./attack_resources/averaging. They are trained for Stable Signature, StegaStamp and TreeRing watermarks.
[docs]class Averaging(BaseAttack): """Attack based on simple averaging from https://arxiv.org/abs/2406.09026. Args: pattern_load_path: the precomputed pattern needed for the attack num_images: if None use all images in the directories to compute the pattern, if =n use first n images. Defaults to None. device: device to compute on. Defaults to "cuda". """ def __init__( self, pattern_load_path: str | None = "./resources/averaging/pattern_stegastamp.pth", num_images: int | None = None, device: torch.device | str = "cuda" if torch.cuda.is_available() else "cpu", ) -> None: super().__init__() self.device = device if pattern_load_path: self.load_pattern(pattern_load_path) else: self.pattern = None self.num_images = num_images def __call__(self, img: TorchImg) -> TorchImg: if self.pattern is None: raise ValueError("Pattern is not computed, call compute_pattern or load_pattern first") out = img.to(self.device) - F.resize(self.pattern.squeeze(0), img.shape[-2:]) out = torch.clip(out, 0, 1) return out.cpu() def compute_pattern(self, dir_watermarked: str, dir_clean: str, batch_size: int = 1) -> torch.Tensor: """Compute the pattern needed for the attack by subtracting averaged watermarked images and clean images. The pattern is saved as a class attribute. Args: dir_watermarked: directory with watermarked images dir_clean: directory with clean non-watermarked images batch_size: batch size to use when computing average Returns: computed pattern, (1,c,h,w) tensor """ mean_watermarked = self.compute_average_on_directory(dir_watermarked, batch_size) mean_clean = self.compute_average_on_directory(dir_clean, batch_size) self.pattern = mean_watermarked - mean_clean return self.pattern def save_pattern(self, save_path: str) -> None: torch.save(self.pattern, save_path) def load_pattern(self, load_path: str) -> None: self.pattern = torch.load(load_path, map_location=self.device) def compute_average_on_directory( self, directory: str, ) -> torch.Tensor: dataset = ImageFolderDataset(directory, sample_range=(0, self.num_images - 1)) dataloader = dataset.generator() mean = 0. for imgs in dataloader: imgs = imgs.to(self.device) mean += imgs.mean(dim=0, keepdim=True) mean /= len(dataloader) return mean