from typing_extensions import Dict, Any, Optional
from dataclasses import dataclass
import torch
import scipy
from torchvision import transforms
from diffusers import DPMSolverMultistepScheduler
from wibench.algorithms.base import BaseAlgorithmWrapper
from wibench.config import Params
from wibench.typing import TorchImg
from wibench.module_importer import ModuleImporter
DEFAULT_MODULE_PATH = "./submodules/tree-ring-watermark"
[docs]@dataclass
class TreeRingParams(Params):
"""
Paramenters of Tree-ring watermarking algorithm.
"""
run_name: str = "test"
dataset: str = "Gustavosta/Stable-Diffusion-Prompts"
start: int = 1
end: int = 10
image_length: int = 512
model_id: str = "WIBE-HuggingFace/stable-diffusion-2-1-base"
with_tracking: str = "store_true"
num_images: int = 1
guidance_scale: float = 7.5
num_inference_steps: int = 50
test_num_inference_steps: Optional[int] = None
reference_model: Optional[str] = None
reference_model_pretrain: Optional[str] = None
max_num_log_image: int = 100
gen_seed: int = 10
w_seed: int = 999999
w_channel: int = 0
w_pattern: str = "rand"
w_mask_shape: str = "circle"
w_radius: int = 10
w_measurement: str = "l1_complex"
w_injection: str = "complex"
w_pattern_const: int = 0
threshold: int = 77
[docs]@dataclass
class TreeRingWatermarkData:
"""Watermark data for Tree-ring watermarking algorithm.
Attributes
----------
watermark : torch.Tensor
Latent noise with embedded watermark
watermarking_mask : torch.Tensor
Watermarking noise pattern
gt_patch : torch.Tensor
Ground-truth patch
"""
watermark: torch.Tensor
watermarking_mask: torch.Tensor
gt_patch: torch.Tensor
[docs]class TreeRingWrapper(BaseAlgorithmWrapper):
"""`Tree-Ring Watermarks <https://arxiv.org/abs/2305.20030>`_: Fingerprints for Diffusion Images that are Invisible and Robust - Image Watermarking Algorithm.
Provides an interface for embedding and extracting watermarks in Text2Image task using the Tree-Ring watermarking algorithm.
Based on the code from `here <https://github.com/YuxinWenRick/tree-ring-watermark>`__.
Parameters
----------
params : Dict[str, Any]
Tree-Ring algorithm configuration parameters (default EmptyDict)
"""
name = "treering"
def __init__(self, params: Dict[str, Any] = {}) -> None:
self.module_path = ModuleImporter.pop_resolve_module_path(params, DEFAULT_MODULE_PATH)
super().__init__(TreeRingParams(**params))
self.params: TreeRingParams
with ModuleImporter("TreeRing", self.module_path):
from TreeRing.inverse_stable_diffusion import InversableStableDiffusionPipeline
from TreeRing.optim_utils import (eval_watermark,
get_watermarking_mask,
get_watermarking_pattern,
inject_watermark,
set_random_seed,
transform_img,
eval_watermark)
global eval_watermark, get_watermarking_mask, get_watermarking_pattern, inject_watermark, set_random_seed, transform_img
set_random_seed(self.params.gen_seed)
if self.params.test_num_inference_steps is None:
self.params.test_num_inference_steps = self.params.num_inference_steps
self.model_id = self.params.model_id
self.device = self.params.device
self.scheduler = DPMSolverMultistepScheduler.from_pretrained(self.model_id, subfolder='scheduler')
pipe = InversableStableDiffusionPipeline.from_pretrained(
self.model_id,
scheduler=self.scheduler,
torch_dtype=torch.float16
)
self.pipe = pipe.to(self.device)
self.ground_truth_patch = get_watermarking_pattern(self.pipe, self.params, self.device)
self.tester_prompt = '' # assume at the detection time, the original prompt is unknown
self.text_embeddings = pipe.get_text_embedding(self.tester_prompt)
[docs] def embed(self, prompt: str, watermark_data: TreeRingWatermarkData) -> TorchImg:
"""Generates a watermarked image based on a text prompt.
Parameters
----------
prompt : str
Input prompt for image generation
watermark_data: TreeRingWatermarkData
Watermark data for Tree-ring watermarking algorithm
"""
outputs_w = self.pipe(
prompt,
num_images_per_prompt=self.params.num_images,
guidance_scale=self.params.guidance_scale,
num_inference_steps=self.params.num_inference_steps,
height=self.params.image_length,
width=self.params.image_length,
latents=watermark_data.watermark,
)
orig_image_w = outputs_w.images[0]
return transforms.ToTensor()(orig_image_w)
def _get_p_value(self, reversed_latents_w: torch.Tensor, watermarking_mask: torch.Tensor, gt_patch: torch.Tensor) -> float:
reversed_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents_w), dim=(-1, -2))[watermarking_mask].flatten()
target_patch = gt_patch[watermarking_mask].flatten()
target_patch = torch.concatenate([target_patch.real, target_patch.imag])
reversed_latents_w_fft = torch.concatenate([reversed_latents_w_fft.real, reversed_latents_w_fft.imag])
sigma_w = reversed_latents_w_fft.std()
lambda_w = (target_patch ** 2 / sigma_w ** 2).sum().item()
x_w = (((reversed_latents_w_fft - target_patch) / sigma_w) ** 2).sum().item()
p_w = scipy.stats.ncx2.cdf(x=x_w, df=len(target_patch), nc=lambda_w)
return p_w
[docs] def watermark_data_gen(self) -> TreeRingWatermarkData:
"""Get watermark payload data for Tree-ring watermarking algorithm.
Returns
-------
TreeRingWatermarkData
Watermark data for Tree-ring watermarking algorithm
Notes
-----
- Called automatically during embedding
"""
gt_patch = get_watermarking_pattern(self.pipe, self.params, self.device)
init_latents_w = self.pipe.get_random_latents()
# get watermarking mask
watermarking_mask = get_watermarking_mask(init_latents_w, self.params, self.device)
# inject watermark
init_latents_w = inject_watermark(init_latents_w, watermarking_mask, self.ground_truth_patch, self.params)
return TreeRingWatermarkData(init_latents_w,
watermarking_mask, gt_patch.cpu().type(torch.complex64).numpy())