Source code for wibench.algorithms.sshidden.wrapper

import torch

from dataclasses import dataclass
from torchvision import transforms
from typing_extensions import Dict, Any
from pathlib import Path

from wibench.module_importer import ModuleImporter
from wibench.algorithms.base import BaseAlgorithmWrapper
from wibench.typing import TorchImg
from wibench.utils import (
    resize_torch_img,
    normalize_image,
    denormalize_image,
    overlay_difference
)
from wibench.config import Params
from wibench.watermark_data import TorchBitWatermarkData


DEFAULT_MODULE_PATH = "./submodules/stable_signature/hidden"
DEFAULT_CHECKPOINT_PATH = "./submodules/stable_signature/hidden/ckpts/hidden_replicate.pth"


[docs]@dataclass class SSHiddenParams(Params): f"""Configuration parameters for the HiDDeN (Hiding Data in Deep Networks) algorithm from StableSignature. These parameters define the image dimensions, watermark length, and the architecture of the encoder and decoder networks used for image watermarking. Attributes ---------- ckpt_path : str Path to pretrained checkpoint (default {DEFAULT_CHECKPOINT_PATH}) encoder_depth : int Number of convolutional blocks in the encoder network (default 4) encoder_channels : int Base number of channels in encoder convolutional blocks (default 64) decoder_depth : int Number of convolutional blocks in the decoder network (default 8) decoder_channels : int Base number of channels in decoder convolutional blocks (default 64) num_bits : int Length of the watermark message to be embed (in bits) (default 48) attenuation : str Noise modulation strategy for watermark embedding (default 'jnd') scale_channels : bool Whether to use channel-wise scaling in the decoder (default False) scaling_i : float Scaling factor for image reconstruction loss (default 1.0) scaling_w : float Scaling factor for watermark reconstruction loss (default 1.5) H : int Height of the input image (in pixels). Defines the vertical dimension of the input tensor (default 512) W : int Width of the input image (in pixels). Defines the horizontal dimension of the input tensor (default 512) """ ckpt_path: str = DEFAULT_CHECKPOINT_PATH encoder_depth: int = 4 encoder_channels: int = 64 decoder_depth: int = 8 decoder_channels: int = 64 num_bits: int = 48 attenuation: str = "jnd" scale_channels: bool = False scaling_i: float = 1. scaling_w: float = 1.5 H: int = 512 W: int = 512
NORMALIZE_IMAGENET = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) UNNORMALIZE_IMAGENET = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], std=[1/0.229, 1/0.224, 1/0.225]) default_transform = transforms.Compose([NORMALIZE_IMAGENET])
[docs]class SSHiddenWrapper(BaseAlgorithmWrapper): """HiDDeN watermarking algorithm adapted from the Stable Signature (SSHiDDeN) [`paper <https://arxiv.org/pdf/2303.15435>`__]. This implementation extends the original HiDDeN architecture by integrating a Just Noticeable Difference (JND) mask to guide watermark embedding in the latent space of diffusion models. The JND mask modulates embedding strength to minimize perceptual artifacts while maintaining robustness. Based on the code from `here <https://github.com/facebookresearch/stable_signature/tree/main>`__. Parameters ---------- params : Dict[str, Any] SSHiDDeN algorithm configuration parameters (default EmptyDict) """ name = "sshidden" def __init__(self, params: Dict[str, Any] = {}) -> None: self.module_path = ModuleImporter.pop_resolve_module_path(params, DEFAULT_MODULE_PATH) with ModuleImporter("SSHIDDEN", self.module_path): from SSHIDDEN.models import ( HiddenEncoder, HiddenDecoder, EncoderWithJND ) from SSHIDDEN.attenuations import JND super().__init__(SSHiddenParams(**params)) self.params: SSHiddenParams self.device = self.params.device state_dict = torch.load(Path(self.params.ckpt_path).resolve(), map_location=self.device, weights_only=False)['encoder_decoder'] encoder_decoder_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} encoder_state_dict = {k.replace('encoder.', ''): v for k, v in encoder_decoder_state_dict.items() if 'encoder' in k} decoder_state_dict = {k.replace('decoder.', ''): v for k, v in encoder_decoder_state_dict.items() if 'decoder' in k} self.decoder = HiddenDecoder( num_blocks=self.params.decoder_depth, num_bits=self.params.num_bits, channels=self.params.decoder_channels ) encoder = HiddenEncoder( num_blocks=self.params.encoder_depth, num_bits=self.params.num_bits, channels=self.params.encoder_channels ) attenuation = JND(preprocess=UNNORMALIZE_IMAGENET) if self.params.attenuation == "jnd" else None self.encoder_with_jnd = EncoderWithJND( encoder, attenuation, self.params.scale_channels, self.params.scaling_i, self.params.scaling_w, ) encoder.load_state_dict(encoder_state_dict) self.decoder.load_state_dict(decoder_state_dict) self.encoder_with_jnd = self.encoder_with_jnd.to(self.device).eval() self.decoder = self.decoder.to(self.device).eval()
[docs] def embed(self, image: TorchImg, watermark_data: TorchBitWatermarkData) -> TorchImg: """Embed watermark into input image. Parameters ---------- image : TorchImg Input image tensor in (C, H, W) format watermark_data: TorchBitWatermarkData Torch bit message with data type torch.int64 """ msg = 2 * watermark_data.watermark.type(torch.float) - 1 resized_image = resize_torch_img(image, [self.params.H, self.params.W]) normalized_resized_image = normalize_image(resized_image, NORMALIZE_IMAGENET) with torch.no_grad(): img_w = self.encoder_with_jnd(normalized_resized_image.to(self.device), msg.to(self.device)) denormalized_marked_image = denormalize_image(img_w.cpu(), UNNORMALIZE_IMAGENET) marked_image = overlay_difference(image, resized_image, denormalized_marked_image) return marked_image
[docs] def extract(self, image: TorchImg, watermark_data: TorchBitWatermarkData) -> Any: """Extract watermark from marked image. Parameters ---------- image : TorchImg Input image tensor in (C, H, W) format watermark_data: TorchBitWatermarkData Torch bit message with data type torch.int64 """ resized_image = resize_torch_img(image, (self.params.H, self.params.W)) normalized_image = normalize_image(resized_image, NORMALIZE_IMAGENET) with torch.no_grad(): ft = self.decoder(normalized_image.to(self.device)).cpu() return ft > 0
[docs] def watermark_data_gen(self) -> TorchBitWatermarkData: """Generate watermark payload data for SSHiDDeN watermarking algorithm. Returns ------- TorchBitWatermarkData Torch bit message with data type torch.int64 and shape of (0, message_length) Notes ----- - Called automatically during embedding """ return TorchBitWatermarkData.get_random(self.params.num_bits)