import numpy as np
from typing_extensions import Any
from dataclasses import dataclass
from .dwtsvm_marker import DWTSVMMarker
from wibench.algorithms.base import BaseAlgorithmWrapper
from wibench.utils import torch_img2numpy_bgr, numpy_bgr2torch_img
from wibench.typing import TorchImg
[docs]@dataclass
class WatermarkData:
"""Watermark data for DWT_SVM watermarking algorithm.
"""
watermark: np.ndarray
key: np.ndarray
[docs]class DWTSVMWrapper(BaseAlgorithmWrapper):
"""
Custom implementation of image watermarking algorithm described in the `paper <https://doi.org/10.1007/s00521-018-3647-2>`__.
Parameters
----------
params : Dict[str, Any]
Contains value for "threshold" parameter of the algorithm. The higher is the threshold, the watermark is more robust to attacks, but less imperceptible (default EmptyDict)
"""
name = "dwt_svm"
def __init__(self, params: dict[str, Any] = {}) -> None:
super().__init__(params)
threshold = params.get("threshold", 56)
self.marker: DWTSVMMarker = DWTSVMMarker(threshold=threshold)
[docs] def embed(self, image: TorchImg, watermark_data: WatermarkData) -> TorchImg:
watermark = watermark_data.watermark
key = watermark_data.key
np_res = self.marker.embed(torch_img2numpy_bgr(image), watermark, key)
return numpy_bgr2torch_img(np_res)
[docs] def watermark_data_gen(self) -> WatermarkData:
wm = np.random.randint(0, 2, 512)
key = np.random.randint(0, 2, 512)
return WatermarkData(wm, key)