from typing import Dict, Any
from dataclasses import dataclass, asdict, field
from pathlib import Path
import torch
from torchvision import transforms
from wibench.utils import normalize_image, overlay_difference, resize_torch_img, denormalize_image
from wibench.typing import TorchImg
from wibench.watermark_data import TorchBitWatermarkData
from wibench.algorithms.base import BaseAlgorithmWrapper
from wibench.config import Params
from wibench.module_importer import ModuleImporter
from wibench.download import requires_download
URL = "https://nextcloud.ispras.ru/index.php/s/MbxydymDFe2pgjp"
NAME = "maskwm"
REQUIRED_FILES = ["D_32bits.pth", "D_64bits.pth", "D_128bits.pth"]
DEFAULT_SUBMODULE_PATH = "./submodules/MaskWM"
DEFAULT_CHECKPOINT_PATH = "./model_files/maskwm/D_32bits.pth"
[docs]@dataclass
class WmEncoderConfig:
message_length: int = 32
in_channels: int = 3
tp_channels: int = 3
mask_channel: int = 0
channels: int = 64
norm_type: str = "group"
[docs]@dataclass
class WmDecoderConfig:
message_length: int = 32
in_channels: int = 3
tp_channels: int = 3
mask_channel: int = 1
channels: int = 128
norm_type: str = "group"
[docs]@dataclass
class MaskWMParams(Params):
checkpoint_path: str = DEFAULT_CHECKPOINT_PATH
use_jnd: bool = True
jnd_factor: float = 1.3
blue: bool = True
image_size: int = 256
wm_enc_config: WmEncoderConfig = field(default_factory=WmEncoderConfig)
wm_dec_config: WmDecoderConfig = field(default_factory=WmDecoderConfig)
[docs]@requires_download(URL, NAME, REQUIRED_FILES)
class MaskWMWrapper(BaseAlgorithmWrapper):
"""Mask Image Watermarking --- Image Watermarking Algorithm [`paper <https://arxiv.org/pdf/2504.12739>`__].
Provides an interface for embedding and extracting watermarks using the MaskWM algorithm.
Based on the code from `here <https://github.com/hurunyi/MaskWM>`__.
"""
name = NAME
def __init__(self, params: Dict[str, Any] = {}):
module_path = str(Path(params.pop("module_path", DEFAULT_SUBMODULE_PATH)).resolve())
super().__init__(MaskWMParams(**params))
self.params: MaskWMParams
self.device = self.params.device
with ModuleImporter("MaskWM", module_path):
from MaskWM.models.Mask_Model import WatermarkModel
encoder_decoder = WatermarkModel(wm_enc_config=asdict(self.params.wm_enc_config),
wm_dec_config=asdict(self.params.wm_dec_config))
encoder_decoder.load_state_dict(torch.load(Path(self.params.checkpoint_path).resolve(), map_location="cpu"), strict=True)
self.encoder_decoder = encoder_decoder.to(self.device)
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
self.normalize = transforms.Normalize(mean=mean,
std=std)
self.denormalize = transforms.Compose([
transforms.Normalize(mean=[0., 0., 0.], std=[1/x for x in std]),
transforms.Normalize(mean=[-x for x in mean], std=[1.,1.,1.])
])
[docs] def embed(self, image: TorchImg, watermark_data: TorchBitWatermarkData) -> TorchImg:
"""Embed watermark into input image.
Parameters
----------
image : TorchImg
Input image tensor in (C, H, W) format
watermark_data: TorchBitWatermarkData
Torch bit message with data type torch.int64
"""
image_size = self.params.image_size
resized_image = resize_torch_img(image, [image_size, image_size])
normalized_image = normalize_image(resized_image, self.normalize).to(self.device)
watermark_image_raw = self.encoder_decoder.encoder(normalized_image,
watermark_data.watermark.type(torch.float32).to(self.device),
use_jnd=self.params.use_jnd,
jnd_factor=self.params.jnd_factor,
blue=self.params.blue)
watermark_image = overlay_difference(image,
resized_image,
denormalize_image(watermark_image_raw.cpu(),
self.denormalize)).detach().cpu()
return watermark_image
[docs] def watermark_data_gen(self) -> TorchBitWatermarkData:
"""Generate watermark payload data for TrustMark watermarking algorithm.
Returns
-------
TorchBitWatermarkData
Torch bit message with data type torch.int64 and shape of (0, message_length)
Notes
-----
- Called automatically during embedding
"""
return TorchBitWatermarkData.get_random(self.params.wm_enc_config.message_length)