Source code for wibench.metrics.base

from typing import Any
from functools import lru_cache
from abc import abstractmethod
import numpy as np
import torch
from wibench.registry import RegistryMeta
from wibench.typing import TorchImg
from wibench.utils import resize_torch_img
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from scipy.stats import binom


class BaseMetric(metaclass=RegistryMeta):
    """Abstract base class for all metric calculators in the watermarking pipeline.

    All concrete metrics must implement the __call__ method.
    """
    type = "metric"

    @abstractmethod
    def __call__(self, *args, **kwds):
        raise NotImplementedError


class PostEmbedMetric(BaseMetric):
    """Abstract base class for metrics computed after watermark embedding.

    These metrics compare the original and watermarked objects to assess:
    - Quality degradation
    - Watermark perceptibility
    - Embedding distortion

    May be used on PostAttackMetricsStage between marked and attacked objects.
    """
    abstract = True

    def __call__(
        self,
        *args,
        **kwargs,
    ) -> Any:
        raise NotImplementedError
    

class PostPipelineMetric(BaseMetric):
    abstract = True

    def update(self, object1: Any, object2: Any) -> None:
        raise NotImplementedError

    def reset(self) -> None:
        raise NotImplementedError

    def __call__(self, *args, **kwds) -> Any:
        raise NotImplementedError


class PostExtractMetric(BaseMetric):
    """Abstract base class for metrics computed after watermark extraction.
    """
    abstract = True

    def __call__(
        self,
        *args,
        **kwargs
    ) -> Any:
        raise NotImplementedError


[docs]class PSNR(PostEmbedMetric): """Peak Signal-to-Noise Ratio between original and processed images. Measures pixel-level difference in decibels. Higher values indicate better quality. Notes ----- - Range: Typically 20-50 dB for images - Infinite if images are identical """ def __call__( self, img1: TorchImg, img2: TorchImg, *args, **kwargs ) -> float: if torch.equal(img1, img2): return float("inf") img2 = resize_torch_img(img2, list(img1.shape)[1:]) return float(psnr(img1.numpy(), img2.numpy(), data_range=1))
[docs]class SSIM(PostEmbedMetric): """Structural Similarity Index Measure between images. Perceptual metric assessing structural similarity (range 0-1). Notes ----- - value 1 indicates perfect similarity """ def __call__( self, img1: TorchImg, img2: TorchImg, watermark_data: Any, ) -> float: img2 = resize_torch_img(img2, list(img1.shape)[1:]) if len(img1.shape) == 2: return float(ssim(img1.numpy(), img2.numpy(), data_range=1)) res = ssim(img1.numpy(), img2.numpy(), data_range=1, channel_axis=0) return float(res)
class EmbedWatermark(PostEmbedMetric): """Records the embedded watermark payload for reference. Stores watermark data in metrics output. """ name = "EmbWm" def __call__(self, img1: TorchImg, img2: TorchImg, watermark_data: Any): str_watermark = ''.join(str(x) for x in np.array(watermark_data.watermark).astype(np.uint8).flatten().tolist()) return str_watermark class Result(PostExtractMetric): """ Just pass extraction result to metrics (must be compatible with float). """ name = "result" def __call__( self, img1: TorchImg, img2: TorchImg, watermark_data: Any, extraction_result: Any, ) -> float: return float(extraction_result)
[docs]class BER(PostExtractMetric): """Bit Error Rate between original and extracted watermarks. Measures fraction of incorrectly recovered bits. """ def __call__( self, img1: TorchImg, img2: TorchImg, watermark_data: Any, extraction_result: Any, ) -> float: wm = watermark_data.watermark return float((np.array(wm) != np.array(extraction_result)).mean())
[docs]class TPRxFPR(PostExtractMetric): """True Positive Rate at fixed False Positive Rate threshold. Robustness metric for watermark detection systems. Parameters ---------- fpr_rate : float Target false positive rate (e.g., 0.01 for 1% FPR) Notes ----- - Uses binomial distribution for threshold calculation - Caches thresholds for efficiency - Binary classification metric """ name = "TPR@xFPR" def __init__(self, fpr_rate: float): self.fpr_rate = fpr_rate @lru_cache(maxsize=None) def bits_threshold(self, num_bits: int) -> int: for threshold in range(1, num_bits + 1): fpr = 1 - binom.cdf(threshold - 1, num_bits, 0.5) if fpr < self.fpr_rate: return threshold raise ValueError(f"Cannot achieve FPR rate {self.fpr_rate} with {num_bits} bits") def __call__( self, img1: TorchImg, img2: TorchImg, watermark_data: Any, extraction_result: Any, ) -> float: if isinstance(extraction_result, float): # zero-bit method returns p-value, fpr rate is considered as decision threshold return int (self.fpr_rate > extraction_result) wm = watermark_data.watermark if isinstance(wm, torch.Tensor) or isinstance(wm, np.ndarray): num_bits = len(wm.flatten()) else: num_bits = len(wm) threshold = self.bits_threshold(num_bits) return int((np.array(wm).flatten() == np.array(extraction_result).flatten()).sum() >= threshold)
[docs]class PValue(PostExtractMetric): """P-value of extraction result. P-value denotes probability to observe the same result as in case of extraction from not watermarked object. Notes ----- - For zero-bit methods we assume that extraction function returns p-value itself. - For multi-bit methods p-value is calculated as probability to get the same number of mismatched bits or less than observed in case of a random message with unified i.i.d. bit values. - Lower p-value stands for more confident "content is watermarked" decision. """ name = "p-value" def __call__( self, img1: TorchImg, img2: TorchImg, watermark_data: Any, extraction_result: Any, ) -> float: wm = watermark_data.watermark if isinstance(extraction_result, float): # zero-bit method returns p-value return extraction_result matched_bits = int((np.array(wm).flatten() == np.array(extraction_result).flatten()).sum()) if isinstance(wm, torch.Tensor) or isinstance(wm, np.ndarray): num_bits = len(wm.flatten()) else: num_bits = len(wm) return 1 - binom.cdf(matched_bits - 1, num_bits, 0.5)
class ExtractedWatermark(PostExtractMetric): """Records the extracted watermark payload for analysis. Stores bit string extraction results in metrics output. """ name = "ExtWm" def __call__(self, img1: TorchImg, img2: TorchImg, watermark_data: Any, extraction_result): str_extract_watermark = ''.join(str(x) for x in np.array(extraction_result).astype(np.uint8).flatten().tolist()) return str_extract_watermark