from dataclasses import dataclass
from typing_extensions import Any, Dict
from pathlib import Path
from wibench.algorithms.stega_stamp.stega_stamp import StegaStamp
from wibench.algorithms.base import BaseAlgorithmWrapper
from wibench.typing import TorchImg
from wibench.utils import torch_img2numpy_bgr, numpy_bgr2torch_img
from wibench.watermark_data import TorchBitWatermarkData
from wibench.download import requires_download
URL = "https://nextcloud.ispras.ru/index.php/s/K6wrA6KweXZ2DGL"
NAME = "stega_stamp"
REQUIRED_FILES = ["stega_stamp.onnx"]
DEFAULT_WEIGHT_PATH = "./model_files/stega_stamp/stega_stamp.onnx"
[docs]@dataclass
class StegaStampParams:
f"""Configuration parameters for the StageStamp watermarking algorithm.
Attributes
----------
weights_path : Optional[Union[str, Path]]
Path to pretrained StegaStamp model weights (default {DEFAULT_WEIGHT_PATH})
wm_length: int
Length of the watermark message to be embed (in bits) (default 100)
width : int
Width of the input image (in pixels). Defines the horizontal dimension of the input tensor (default 400)
height : int
Height of the input image (in pixels). Defines the vertical dimension of the input tensor (default 400)
alpha : float
Weight parameter controlling the trade-off between watermark robustness and image quality during embedding (default 1.0)
"""
weights_path: str = DEFAULT_WEIGHT_PATH
wm_length: int = 100
width: int = 400
height: int = 400
alpha: float = 1.0
[docs]@requires_download(URL, NAME, REQUIRED_FILES)
class StegaStampWrapper(BaseAlgorithmWrapper):
"""StegaStamp: Invisible Hyperlinks in Physical Photographs --- Image Watermarking Algorithm [`paper <https://arxiv.org/abs/1904.05343>`__].
Provides an interface for embedding and extracting watermarks using the StegaStamp watermarking algorithm.
Based on the code from the github `repository <https://github.com/tancik/StegaStamp>`__.
Parameters
----------
params : Dict[str, Any]
StegaStamp algorithm configuration parameters (default EmptyDict)
"""
name = NAME
def __init__(self, params: Dict[str, Any] = {}) -> None:
super().__init__(StegaStampParams(**params))
self.params: StegaStampParams
self.model_filepath = Path(self.params.weights_path).resolve()
self.stega_stamp = StegaStamp(self.model_filepath,
self.params.wm_length,
self.params.width,
self.params.height,
self.params.alpha)
[docs] def embed(self, image: TorchImg, watermark_data: TorchBitWatermarkData) -> TorchImg:
"""Embed watermark into input image.
Parameters
----------
image : TorchImg
Input image tensor in (C, H, W) format
watermark_data: TorchBitWatermarkData
Torch bit message with data type torch.int64
"""
return numpy_bgr2torch_img(self.stega_stamp.encode(torch_img2numpy_bgr(image), watermark_data.watermark.squeeze(0).numpy()))
[docs] def watermark_data_gen(self) -> TorchBitWatermarkData:
"""Generate watermark payload data for StegaStamp watermarking algorithm.
Returns
-------
TorchBitWatermarkData
Torch bit message with data type torch.int64 and shape of (0, message_length)
Notes
-----
- Called automatically during embedding
"""
return TorchBitWatermarkData.get_random(self.params.wm_length)