Source code for wibench.attacks.vae.vae

from wibench.attacks.base import BaseAttack
import torch
from torchvision import transforms
from diffusers import AutoencoderKL


[docs]class VAEAttack(BaseAttack): """ Adversarial attack using a VAE to generate noisy image reconstructions. Encodes an image into latent space, adds Gaussian noise to the latents, then decodes multiple noisy versions. Returns the average of these reconstructions as an attacked image. Uses the `FLUX.1-schnell <https://huggingface.co/black-forest-labs/FLUX.1-schnell>`__ VAE. Parameters ---------- n_avg_imgs: int Number of noisy reconstructions to average. noise_level: float Standard deviation of Gaussian noise added to latents. device: str Device to run the VAE on. cache_dir: str Directory for caching the VAE model. """ def __init__(self, n_avg_imgs: int = 100, noise_level: float = 0.5, device: str = "cuda" if torch.cuda.is_available() else "cpu", cache_dir : str = None, ) -> None: self.n_avg_imgs = n_avg_imgs self.noise_level = noise_level self.device = device self.vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-schnell", revision="refs/pr/1", subfolder="vae", torch_dtype=torch.bfloat16, cache_dir=cache_dir).to(device) self.preprocess_transform = transforms.Compose([ #transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) def add_noise_to_embeddings(self, embeddings): noise = torch.randn_like(embeddings) * self.noise_level noisy_embeddings = embeddings + noise return noisy_embeddings def __call__(self, img: torch.Tensor) -> torch.Tensor: img_transformed = self.preprocess_transform(img).to(self.device) if len(img_transformed.shape) < 4: img_transformed = img_transformed.unsqueeze(0) with torch.no_grad(): latents = self.vae.encode(img_transformed.to(torch.bfloat16)).latent_dist.sample() att_imgs_list = [] for _ in range(self.n_avg_imgs): noisy_latents = self.add_noise_to_embeddings(latents) output_image = self.vae.decode(noisy_latents).sample att_imgs_list.append(output_image) if len(att_imgs_list) > 1: #mean att_img_tensor = torch.stack(att_imgs_list).mean(0) else: att_img_tensor = att_imgs_list[0] att_img_tensor = (att_img_tensor * 0.5 + 0.5).clamp(0, 1) return att_img_tensor.squeeze(0).to(torch.float32).cpu()