from wibench.attacks.base import BaseAttack
import torch
from torchvision import transforms
from diffusers import AutoencoderKL
[docs]class VAEAttack(BaseAttack):
"""
Adversarial attack using a VAE to generate noisy image reconstructions.
Encodes an image into latent space, adds Gaussian noise to the latents,
then decodes multiple noisy versions. Returns the average of these
reconstructions as an attacked image. Uses the `FLUX.1-schnell <https://huggingface.co/black-forest-labs/FLUX.1-schnell>`__ VAE.
Parameters
----------
n_avg_imgs: int
Number of noisy reconstructions to average.
noise_level: float
Standard deviation of Gaussian noise added to latents.
device: str
Device to run the VAE on.
cache_dir: str
Directory for caching the VAE model.
"""
def __init__(self,
n_avg_imgs: int = 100,
noise_level: float = 0.5,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
cache_dir : str = None,
) -> None:
self.n_avg_imgs = n_avg_imgs
self.noise_level = noise_level
self.device = device
self.vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-schnell", revision="refs/pr/1", subfolder="vae", torch_dtype=torch.bfloat16,
cache_dir=cache_dir).to(device)
self.preprocess_transform = transforms.Compose([
#transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
def add_noise_to_embeddings(self, embeddings):
noise = torch.randn_like(embeddings) * self.noise_level
noisy_embeddings = embeddings + noise
return noisy_embeddings
def __call__(self, img: torch.Tensor) -> torch.Tensor:
img_transformed = self.preprocess_transform(img).to(self.device)
if len(img_transformed.shape) < 4:
img_transformed = img_transformed.unsqueeze(0)
with torch.no_grad():
latents = self.vae.encode(img_transformed.to(torch.bfloat16)).latent_dist.sample()
att_imgs_list = []
for _ in range(self.n_avg_imgs):
noisy_latents = self.add_noise_to_embeddings(latents)
output_image = self.vae.decode(noisy_latents).sample
att_imgs_list.append(output_image)
if len(att_imgs_list) > 1:
#mean
att_img_tensor = torch.stack(att_imgs_list).mean(0)
else:
att_img_tensor = att_imgs_list[0]
att_img_tensor = (att_img_tensor * 0.5 + 0.5).clamp(0, 1)
return att_img_tensor.squeeze(0).to(torch.float32).cpu()