Source code for wibench.algorithms.robust_wide.wrapper

from dataclasses import dataclass, field, asdict
from typing import Any, Dict
from pathlib import Path

import torch
from torchvision.transforms import transforms

from wibench.algorithms.base import BaseAlgorithmWrapper
from wibench.typing import TorchImg
from wibench.watermark_data import TorchBitWatermarkData
from wibench.module_importer import ModuleImporter
from wibench.config import Params
from wibench.download import requires_download


URL = "https://nextcloud.ispras.ru/index.php/s/6PWaxBTBJTA688x"
NAME = "robust_wide"
REQUIRED_FILES = ["wm_model.ckpt"]

DEFAULT_MODULE_PATH = "./submodules/RobustWide"
DEFAULT_CHECKPOINT_PATH = "./model_files/robust_wide/wm_model.ckpt"


[docs]@dataclass class RobustWideEncoderParams: image_size: int = 512 message_length: int = 64 in_channels: int = 3 channels: int = 64 norm_type: str = "batch" final_skip: bool = True
[docs]@dataclass class RobustWideDecoderParams: image_size: int = 512 message_length: int = 64 in_channels: int = 3 norm_type: str = "batch"
[docs]@dataclass class RobustWideWmModelParams: wm_enc_config: RobustWideEncoderParams = field(default_factory=RobustWideEncoderParams) wm_dec_config: RobustWideDecoderParams = field(default_factory=RobustWideDecoderParams)
[docs]@dataclass class RobustWideParams(Params): f"""Configuration parameters for the Robust-Wide watermarking algorithm. Attributes ---------- checkpoint_path : str Path to pretrained Robust-Wide model weights (default {DEFAULT_CHECKPOINT_PATH}) wm_model_config: RobustWideWmModelParams Parameters for encoder-decoder network (default RobustWideWmModelParams) """ checkpoint_path: str = DEFAULT_CHECKPOINT_PATH wm_model_config: RobustWideWmModelParams = field(default_factory=RobustWideWmModelParams)
[docs]@requires_download(URL, NAME, REQUIRED_FILES) class RobustWideWrapper(BaseAlgorithmWrapper): """Robust-Wide: Robust Watermarking Against Instruction-Driven Image Editing --- Image Watermarking Algorithm [`paper <https://arxiv.org/abs/2402.12688>`__]. Provides an interface for embedding and extracting watermarks using the Robust-Wide watermarking algorithm. Based on the code from the github `repository <https://github.com/hurunyi/Robust-Wide>`__. Parameters ---------- params : Dict[str, Any] Robust-Wide algorithm configuration parameters (default EmptyDict) """ name = NAME def __init__(self, params: Dict[str, Any] = {}) -> None: self.module_path = ModuleImporter.pop_resolve_module_path(params, DEFAULT_MODULE_PATH) super().__init__(RobustWideParams(**params)) self.params: RobustWideParams self.device = self.params.device with ModuleImporter("RobustWide", self.module_path): from RobustWide.model import WatermarkModel model = WatermarkModel(**asdict(self.params.wm_model_config)) model_ckpt = torch.load(Path(self.params.checkpoint_path).resolve(), map_location="cpu") model.load_state_dict(model_ckpt) model.eval() self.model = model.to(self.device) size = self.params.wm_model_config.wm_dec_config.image_size self.transform_and_normalize = transforms.Compose([ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(size), transforms.Normalize([0.5], [0.5]), ]) self.denormalize = transforms.Compose([ transforms.Normalize(mean=[-1.0], std=[2.0]), ])
[docs] def embed(self, image: TorchImg, watermark_data: TorchBitWatermarkData) -> TorchImg: """Embed watermark into input image. Parameters ---------- image : TorchImg Input image tensor in (C, H, W) format watermark_data: TorchBitWatermarkData Torch bit message with data type torch.int64 """ transform_image = self.transform_and_normalize(image).unsqueeze(0).clamp(-1, 1).to(self.device) watermark_image = self.model.encoder(transform_image, watermark_data.watermark.float().to(self.device)) return self.denormalize(watermark_image.detach().cpu()).squeeze(0).clamp(0, 1)
[docs] def extract(self, image: TorchImg, watermark_data: TorchBitWatermarkData) -> torch.Tensor: """Extract watermark from marked image. Parameters ---------- image : TorchImg Input image tensor in (C, H, W) format watermark_data: TorchBitWatermarkData Torch bit message with data type torch.int64 """ transform_image = self.transform_and_normalize(image).unsqueeze(0).clamp(-1, 1).to(self.device) extracted_bits = self.model.decoder(transform_image) return extracted_bits.detach().cpu().gt(0.5).type(watermark_data.watermark.dtype)
[docs] def watermark_data_gen(self) -> TorchBitWatermarkData: """Generate watermark payload data for Robust-Wide watermarking algorithm. Returns ------- TorchBitWatermarkData Torch bit message with data type torch.int64 and shape of (0, message_length) Notes ----- - Called automatically during embedding """ return TorchBitWatermarkData.get_random(self.params.wm_model_config.wm_enc_config.message_length)