from dataclasses import dataclass
from wibench.algorithms.base import BaseAlgorithmWrapper
from wibench.utils import numpy_bgr2torch_img, torch_img2numpy_bgr, resize_torch_img
from wibench.typing import TorchImg
from wibench.watermark_data import TorchBitWatermarkData
from imwatermark import WatermarkEncoder, WatermarkDecoder
from typing_extensions import Dict, Any
@dataclass
class InvisibleWatermarkConfig:
"""TODO
"""
wm_length: int = 32
block_size: int = 4
scale: float = 36
class InvisibleWatermarkWrapper(BaseAlgorithmWrapper):
"""Base class for image watermarking implementations using invisible-watermark framework.
This abstract wrapper defines the common interface for embedding and
extracting watermarks in images without needing the original image via invisible-watermark framework (https://github.com/ShieldMnt/invisible-watermark).
Subclasses implement specific watermarking algorithms such as frequency-domain methods
or deep‑learning models, providing a uniform API.
Parameters
----------
params : Dict[str, Any]
Invisible-Watermark algorithm configuration parameters (default EmptyDict)
"""
abstract = True
def __init__(self, params: Dict[str, Any] = {}) -> None:
super().__init__(InvisibleWatermarkConfig(**params))
self.encoder = WatermarkEncoder()
self.decoder = WatermarkDecoder(
wm_type="bits", length=self.params.wm_length
)
def embed(self, image: TorchImg, watermark_data: TorchBitWatermarkData) -> TorchImg:
"""Embed watermark into input image.
Parameters
----------
image : TorchImg
Input image tensor in (C, H, W) format
watermark_data: TorchBitWatermarkData
Torch bit message with data type torch.int64
"""
_, h, w = image.shape
resized_image = resize_torch_img(image, [max(h, 256), max(w, 256)])
np_img = torch_img2numpy_bgr(resized_image)
watermark = watermark_data.watermark.squeeze(0).tolist()
self.encoder.set_watermark("bits", watermark)
params: InvisibleWatermarkConfig = self.params
if self.algorithm == "rivaGan":
np_res = self.encoder.encode(np_img, self.algorithm)
else:
np_res = self.encoder.encode(
np_img,
self.algorithm,
scales=[0, params.scale, 0],
block=params.block_size,
)
torch_res = numpy_bgr2torch_img(np_res)
resized_torch_res = resize_torch_img(torch_res, [h, w])
return resized_torch_res
def extract(self, image: TorchImg, watermark_data: TorchBitWatermarkData) -> Any:
"""Extract watermark from marked image.
Parameters
----------
image : TorchImg
Input image tensor in (C, H, W) format
watermark_data: TorchBitWatermarkData
Torch bit message with data type torch.int64
"""
_, h, w = image.shape
resized_image = resize_torch_img(image, [max(h, 256), max(w, 256)])
np_image = torch_img2numpy_bgr(resized_image)
params: InvisibleWatermarkConfig = self.params
if self.algorithm == "rivaGan":
return self.decoder.decode(np_image, self.algorithm)
return self.decoder.decode(
np_image,
self.algorithm,
scales=[0, params.scale, 0],
block=params.block_size,
)
def watermark_data_gen(self) -> TorchBitWatermarkData:
"""Generate watermark payload data for invisible-watermark algorithm.
Returns
-------
TorchBitWatermarkData
Torch bit message with data type torch.int64 and shape of (0, message_length)
Notes
-----
- Called automatically during embedding
"""
return TorchBitWatermarkData.get_random(self.params.wm_length)
[docs]class RivaGanWrapper(InvisibleWatermarkWrapper):
"""Image watermarking via RivaGAN: a deep-learning-based encoder/decoder with attention mechanism [`repository <https://github.com/DAI-Lab/RivaGAN>`__].
Provides an interface for embedding and extracting watermarks using the RivaGAN watermarking algorithm.
Based on the codes from `here <https://github.com/ShieldMnt/invisible-watermark>`__.
Parameters
----------
params : Dict[str, Any]
RivaGAN algorithm configuration parameters
"""
name = "riva_gan"
algorithm = "rivaGan"
def __init__(self, params: Dict[str, Any] = {}) -> None:
super().__init__(params)
self.encoder.loadModel()
self.decoder.loadModel()
[docs]class DwtDctWrapper(InvisibleWatermarkWrapper):
"""Image watermarking using frequency-domain transforms: DWT + DCT.
Provides an interface for embedding and extracting watermarks using the frequency-domain transforms: DWT + DCT.
Based on the code from `here <https://github.com/ShieldMnt/invisible-watermark>`__.
Parameters
----------
params : Dict[str, Any]
DWT-DCT algorithm configuration parameters
"""
name = "dwt_dct"
algorithm = "dwtDct"
[docs]class DwtDctSvdWrapper(InvisibleWatermarkWrapper):
"""Image frequency-domain watermarking with additional SVD processing: DWT + DCT + SVD.
Provides an interface for embedding and extracting watermarks using the frequency-domain with additional SVD processing: DWT + DCT + SVD.
Based on the code from the github `repository <https://github.com/ShieldMnt/invisible-watermark>`__.
Parameters
----------
params : Dict[str, Any]
DWT-DCT-SVD algorithm configuration parameters
"""
name = "dwt_dct_svd"
algorithm = "dwtDctSvd"