import torch
from typing import Literal, Dict, Any
from dataclasses import dataclass
from torchvision.transforms.functional import to_pil_image, to_tensor
from wibench.algorithms.base import BaseAlgorithmWrapper
from wibench.typing import TorchImg
from wibench.config import Params
from wibench.watermark_data import TorchBitWatermarkData
from wibench.download import requires_download
from trustmark import TrustMark
from pathlib import Path
from functools import partialmethod
URL = "https://nextcloud.ispras.ru/index.php/s/roAn4YYpXfq5Y7E"
NAME = "trustmark"
REQUIRED_FILES = ["trustmark_rm_P.ckpt",
"trustmark_rm_P.yaml",
"encoder_P.ckpt",
"decoder_P.ckpt",
"trustmark_P.yaml",
"trustmark_rm_B.ckpt",
"trustmark_rm_C.ckpt",
"trustmark_rm_Q.ckpt",
"decoder_B.ckpt",
"decoder_Q.ckpt",
"decoder_C.ckpt",
"encoder_C.ckpt",
"encoder_Q.ckpt",
"encoder_B.ckpt",
"trustmark_B.yaml",
"trustmark_C.yaml",
"trustmark_Q.yaml",
"trustmark_rm_B.yaml",
"trustmark_rm_C.yaml",
"trustmark_rm_Q.yaml"]
DEFAULT_MODELS_CACHE = "./model_files/trustmark"
[docs]@dataclass
class TrustMarkParams(Params):
"""Configuration parameters for the TrustMark algorithm.
Attributes
----------
wm_length : int
Length of the watermark message to be embed (in bits) (default 100).
model_type : Literal['Q', 'B', 'C']
Specifies the model architecture variant (default Q):
- 'Q': (Quality) Trade-off between robustness and imperceptibility. Uses ResNet-50 decoder.
- 'B': (Beta) Very similar to Q, included mainly for reproducing the paper. Uses ResNet-50 decoder.
- 'C': (Compact). Uses a ResNet-18 decoder (smaller model size). Slightly lower visual quality.
- 'P': (Perceptual). Very high visual quality and good robustness. ResNet-50 decoder trained with much higher weight on perceptual loss (see paper).
wm_strength : float
Controls visibility/strength of watermark embedding (default 0.75)
"""
wm_length: int = 100
model_type: Literal['Q', 'B', 'C', 'P'] = 'Q'
wm_strength: float = 0.75
[docs]@requires_download(URL, NAME, REQUIRED_FILES)
class TrustMarkWrapper(BaseAlgorithmWrapper):
"""`TrustMark <https://arxiv.org/abs/2311.18297>`_: Universal Watermarking for Arbitrary Resolution Images - Image Watermarking Algorithm.
Provides an interface for embedding and extracting watermarks using the TrustMark watermarking algorithm.
Based on the code from `here <https://github.com/adobe/trustmark>`__.
Parameters
----------
params : Dict[str, Any]
TrustMark algorithm configuration parameters (default EmptyDict)
"""
name = NAME
@staticmethod
def patched_load_model(trustmark, config_path, weight_path, *args, models_cache, old_func, **kwargs):
if models_cache:
config_path = Path(models_cache) / Path(config_path).name
weight_path = Path(models_cache) / Path(weight_path).name
return old_func(trustmark, str(config_path), str(weight_path), *args, **kwargs)
def __init__(self, params: Dict[str, Any] = {}) -> None:
models_cache = params.pop("models_cache", DEFAULT_MODELS_CACHE)
super().__init__(TrustMarkParams(**params))
self.params: TrustMarkParams
self.device = self.params.device
self.models_cache = Path(models_cache)
TrustMark.load_model = partialmethod(self.patched_load_model, models_cache=models_cache, old_func=TrustMark.load_model)
self.tm = TrustMark(use_ECC=False, device=self.device,
model_type=self.params.model_type)
def _wm_to_str(self, wm: torch.Tensor):
return ''.join([str(i) for i in wm.numpy().flatten()])
def _str_to_wm(self, wm_str: str):
return torch.tensor([int(i) for i in wm_str], dtype=torch.int64).unsqueeze(0)
[docs] def embed(self, image: TorchImg, watermark_data: TorchBitWatermarkData):
"""Embed watermark into input image.
Parameters
----------
image : TorchImg
Input image tensor in (C, H, W) format
watermark_data: TorchBitWatermarkData
Torch bit message with data type torch.int64
"""
img_pil = to_pil_image(image)
wm_str = self._wm_to_str(watermark_data.watermark)
emb_pil = self.tm.encode(
img_pil, wm_str, MODE='binary', WM_STRENGTH=self.params.wm_strength)
return to_tensor(emb_pil)
[docs] def watermark_data_gen(self) -> TorchBitWatermarkData:
"""Generate watermark payload data for TrustMark watermarking algorithm.
Returns
-------
TorchBitWatermarkData
Torch bit message with data type torch.int64 and shape of (0, message_length)
Notes
-----
- Called automatically during embedding
"""
return TorchBitWatermarkData.get_random(self.params.wm_length)