Source code for wibench.attacks.flux_regeneration.regeneration

import torch
from diffusers import FluxImg2ImgPipeline
from wibench.attacks.base import BaseAttack


[docs]class FluxRegeneration(BaseAttack): """Attack `regeneration` from `here <https://github.com/leiluk1/erasing-the-invisible-beige-box/blob/main/notebooks/treering_attack.ipynb>`__. Image regeneration attack using FLUX image-to-image diffusion model. Applies a single-step FLUX diffusion transformation to subtly alter an input image while maintaining its overall structure. **TODO**: check if this works with batches. """ def __init__(self, device: torch.device | str = "cuda" if torch.cuda.is_available() else "cpu", dtype: str = "bfloat16", cpu_offload: bool = True, sequential_cpu_offload: bool = False, cache_dir: str | None = None, prompt: str = "original image", strength: float = 0.3, guidance_scale: float = 8.5, num_inference_steps: int = 12, max_sequence_length: int = 512, ) -> None: super().__init__() self.device = device self.dtype = getattr(torch, dtype) self.flux_pipeline = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=self.dtype, cache_dir=cache_dir, ) self.flux_pipeline.set_progress_bar_config(disable=True) if sequential_cpu_offload: self.flux_pipeline.enable_sequential_cpu_offload(device=device) elif cpu_offload: self.flux_pipeline.enable_model_cpu_offload(device=device) # save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power # self.flux_pipeline.enable_sequential_cpu_offload(device=device) else: self.flux_pipeline.to(device) self.prompt = prompt self.strength = strength self.guidance_scale = guidance_scale self.num_inference_steps = num_inference_steps self.max_sequence_length = max_sequence_length self.generator = torch.Generator(self.device).manual_seed(42) def __call__(self, img: torch.Tensor) -> torch.Tensor: img = img.unsqueeze(0) b, c, h, w = img.shape img = img.to(self.dtype) img = self.flux_pipeline( image=img, prompt=self.prompt, height=h, width=w, strength=self.strength, guidance_scale=self.guidance_scale, num_inference_steps=self.num_inference_steps, max_sequence_length=self.max_sequence_length, generator=self.generator, output_type="pt", ).images # [b,c,h,w] return img.squeeze(0).to(torch.float32).cpu()
[docs]class FluxRinsing(FluxRegeneration): """Attack `rinse2x` from `here <https://github.com/leiluk1/erasing-the-invisible-beige-box/blob/main/notebooks/treering_attack.ipynb>`__. Multi-step image purification using repeated FLUX regeneration.""" def __init__(self, rinsing_times: int = 2, device: torch.device | str = "cuda:0", dtype: str = "bfloat16", cpu_offload: bool = True, sequential_cpu_offload: bool = False, cache_dir: str | None = None, prompt: str = "original image", strength: float = 0.3, guidance_scale: float = 8.5, num_inference_steps: int = 12, max_sequence_length: int = 512, ) -> None: super().__init__(device, dtype, cpu_offload, sequential_cpu_offload, cache_dir, prompt, strength, guidance_scale, num_inference_steps, max_sequence_length) self.rinsing_times = rinsing_times def __call__(self, img: torch.Tensor) -> torch.Tensor: img = img.unsqueeze(0) b, c, h, w = img.shape img = img.to(self.dtype) for _ in range(self.rinsing_times): img = self.flux_pipeline( image=img, prompt=self.prompt, height=h, width=w, strength=self.strength, guidance_scale=self.guidance_scale, num_inference_steps=self.num_inference_steps, max_sequence_length=self.max_sequence_length, generator=self.generator, output_type="pt", ).images # [b,c,h,w] return img.squeeze(0).to(torch.float32).cpu()