from dataclasses import dataclass, field, asdict
from typing import Any, Dict
from pathlib import Path
import torch
from torchvision.transforms import transforms
from wibench.algorithms.base import BaseAlgorithmWrapper
from wibench.typing import TorchImg
from wibench.watermark_data import TorchBitWatermarkData
from wibench.module_importer import ModuleImporter
from wibench.config import Params
from wibench.download import requires_download
URL = "https://nextcloud.ispras.ru/index.php/s/6PWaxBTBJTA688x"
NAME = "robust_wide"
REQUIRED_FILES = ["wm_model.ckpt"]
DEFAULT_MODULE_PATH = "./submodules/RobustWide"
DEFAULT_CHECKPOINT_PATH = "./model_files/robust_wide/wm_model.ckpt"
[docs]@dataclass
class RobustWideEncoderParams:
image_size: int = 512
message_length: int = 64
in_channels: int = 3
channels: int = 64
norm_type: str = "batch"
final_skip: bool = True
[docs]@dataclass
class RobustWideDecoderParams:
image_size: int = 512
message_length: int = 64
in_channels: int = 3
norm_type: str = "batch"
[docs]@dataclass
class RobustWideWmModelParams:
wm_enc_config: RobustWideEncoderParams = field(default_factory=RobustWideEncoderParams)
wm_dec_config: RobustWideDecoderParams = field(default_factory=RobustWideDecoderParams)
[docs]@dataclass
class RobustWideParams(Params):
f"""Configuration parameters for the Robust-Wide watermarking algorithm.
Attributes
----------
checkpoint_path : str
Path to pretrained Robust-Wide model weights (default {DEFAULT_CHECKPOINT_PATH})
wm_model_config: RobustWideWmModelParams
Parameters for encoder-decoder network (default RobustWideWmModelParams)
"""
checkpoint_path: str = DEFAULT_CHECKPOINT_PATH
wm_model_config: RobustWideWmModelParams = field(default_factory=RobustWideWmModelParams)
[docs]@requires_download(URL, NAME, REQUIRED_FILES)
class RobustWideWrapper(BaseAlgorithmWrapper):
"""Robust-Wide: Robust Watermarking Against Instruction-Driven Image Editing --- Image Watermarking Algorithm [`paper <https://arxiv.org/abs/2402.12688>`__].
Provides an interface for embedding and extracting watermarks using the Robust-Wide watermarking algorithm.
Based on the code from the github `repository <https://github.com/hurunyi/Robust-Wide>`__.
Parameters
----------
params : Dict[str, Any]
Robust-Wide algorithm configuration parameters (default EmptyDict)
"""
name = NAME
def __init__(self, params: Dict[str, Any] = {}) -> None:
self.module_path = ModuleImporter.pop_resolve_module_path(params, DEFAULT_MODULE_PATH)
super().__init__(RobustWideParams(**params))
self.params: RobustWideParams
self.device = self.params.device
with ModuleImporter("RobustWide", self.module_path):
from RobustWide.model import WatermarkModel
model = WatermarkModel(**asdict(self.params.wm_model_config))
model_ckpt = torch.load(Path(self.params.checkpoint_path).resolve(), map_location="cpu")
model.load_state_dict(model_ckpt)
model.eval()
self.model = model.to(self.device)
size = self.params.wm_model_config.wm_dec_config.image_size
self.transform_and_normalize = transforms.Compose([
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size),
transforms.Normalize([0.5], [0.5]),
])
self.denormalize = transforms.Compose([
transforms.Normalize(mean=[-1.0], std=[2.0]),
])
[docs] def embed(self, image: TorchImg, watermark_data: TorchBitWatermarkData) -> TorchImg:
"""Embed watermark into input image.
Parameters
----------
image : TorchImg
Input image tensor in (C, H, W) format
watermark_data: TorchBitWatermarkData
Torch bit message with data type torch.int64
"""
transform_image = self.transform_and_normalize(image).unsqueeze(0).clamp(-1, 1).to(self.device)
watermark_image = self.model.encoder(transform_image, watermark_data.watermark.float().to(self.device))
return self.denormalize(watermark_image.detach().cpu()).squeeze(0).clamp(0, 1)
[docs] def watermark_data_gen(self) -> TorchBitWatermarkData:
"""Generate watermark payload data for Robust-Wide watermarking algorithm.
Returns
-------
TorchBitWatermarkData
Torch bit message with data type torch.int64 and shape of (0, message_length)
Notes
-----
- Called automatically during embedding
"""
return TorchBitWatermarkData.get_random(self.params.wm_model_config.wm_enc_config.message_length)