Source code for wibench.algorithms.arwgan.wrapper

import torch
import numpy as np

from typing_extensions import Any, Dict
from dataclasses import dataclass
from pathlib import Path

from wibench.module_importer import ModuleImporter
from wibench.watermark_data import TorchBitWatermarkData
from wibench.algorithms.base import BaseAlgorithmWrapper
from wibench.typing import TorchImg
from wibench.utils import (
    resize_torch_img,
    overlay_difference,
    normalize_image,
    denormalize_image
)
from wibench.download import requires_download


URL = "https://nextcloud.ispras.ru/index.php/s/4THrnJDdjF6xGMc"
NAME = "arwgan"
REQUIRED_FILES = ["checkpoints", "options-and-config.pickle"]

DEFAULT_MODULE_PATH = "./submodules/ARWGAN"
DEFAULT_OPTIONS_PATH = "./model_files/arwgan/options-and-config.pickle"
DEFAULT_CHECKPOINT_PATH = "./model_files/arwgan/checkpoints/ARWGAN.pyt"


[docs]@dataclass class ARWGANParams: """ Configuration parameters for the `ARWGAN <https://ieeexplore.ieee.org/document/10155247>`__ watermarking algorithm. Attributes ---------- H : int Height of the input image (in pixels). Determines the vertical size of image tensors W : int Width of the input image (in pixels). Determines the horizontal size of image tensors wm_length : int Length of the binary watermark message to embed (in bits) encoder_blocks : int Number of convolutional blocks in the encoder network encoder_channels : int Number of filters (channels) in each encoder block decoder_blocks : int Number of convolutional blocks in the decoder network decoder_channels : int Number of filters in each decoder block use_discriminator : bool If True, enables the use of an adversarial discriminator use_vgg : bool If True, adds a perceptual loss using VGG features to improve discriminator_blocks : int Number of convolutional blocks in the discriminator network discriminator_channels : int Number of filters in each discriminator block decoder_loss : float Weight of the decoder loss term in the total loss function. Controls the importance of accurate message recovery encoder_loss : float Weight of the encoder loss term in the total loss function. Typically regularizes visual similarity between original and encoded images adversarial_loss : float Weight of the adversarial loss term in the total loss. Higher values push the encoder to generate more realistic images when a discriminator is used enable_fp16 : bool If True, enables mixed precision (fp16) training/inference for improved speed and reduced memory usage on compatible hardware (default False) """ H: int W: int wm_length: int encoder_blocks: int encoder_channels: int decoder_blocks: int decoder_channels: int use_discriminator: bool use_vgg: bool discriminator_blocks: int discriminator_channels: int decoder_loss: float encoder_loss: float adversarial_loss: float enable_fp16: bool = False
[docs]@requires_download(URL, NAME, REQUIRED_FILES) class ARWGANWrapper(BaseAlgorithmWrapper): """ `ARWGAN <https://ieeexplore.ieee.org/document/10155247>`__: Attention-Guided Robust Image Watermarking Model Based on GAN --- Image Watermarking Algorithm. Provides an interface for embedding and extracting watermarks using the ARWGAN watermarking algorithm. Based on the code from `here <https://github.com/river-huang/ARWGAN>`__. Parameters ---------- params : Dict[str, Any] ARWGAN algorithm configuration parameters (default EmptyDict) """ name = NAME def __init__(self, params: Dict[str, Any] = {}) -> None: module_path = ModuleImporter.pop_resolve_module_path(params, DEFAULT_MODULE_PATH) with ModuleImporter("ARWGAN", module_path): from ARWGAN.utils import load_options from ARWGAN.model.encoder_decoder import EncoderDecoder from ARWGAN.noise_layers.noiser import Noiser options_file_path = params.get("options_file_path", DEFAULT_OPTIONS_PATH) checkpoint_file_path = params.get("checkpoint_file_path", DEFAULT_CHECKPOINT_PATH) options_file_path = Path(options_file_path).resolve() checkpoint_file_path = Path(checkpoint_file_path).resolve() train_options, config, noise_config = load_options(options_file_path) self.device = params.get("device", "cuda" if torch.cuda.is_available() else "cpu") checkpoint = torch.load(checkpoint_file_path, map_location=self.device) params = ARWGANParams( H=config.H, W=config.W, wm_length=config.message_length, encoder_blocks=config.encoder_blocks, encoder_channels=config.encoder_channels, decoder_blocks=config.decoder_blocks, decoder_channels=config.decoder_channels, discriminator_blocks=config.discriminator_blocks, discriminator_channels=config.discriminator_channels, decoder_loss=config.decoder_loss, encoder_loss=config.encoder_loss, adversarial_loss=config.adversarial_loss, enable_fp16=config.enable_fp16, use_discriminator=config.use_discriminator, use_vgg = config.use_vgg ) super().__init__(params) self.params: ARWGANParams noiser = Noiser([], self.device) self.encoder_decoder = EncoderDecoder(config, noiser) self.encoder_decoder.load_state_dict(checkpoint['enc-dec-model']) self.encoder_decoder = self.encoder_decoder.to(self.device) self.encoder_decoder.eval()
[docs] def embed(self, image: TorchImg, watermark_data: TorchBitWatermarkData) -> TorchImg: """Embed watermark into input image. Parameters ---------- image : TorchImg Input image tensor in (C, H, W) format watermark_data: TorchBitWatermarkData Torch bit message with data type torch.int64 """ resized_image = resize_torch_img(image, (self.params.H, self.params.W)) resized_normalized_image = normalize_image(resized_image) with torch.no_grad(): encoded_tensor = self.encoder_decoder.encoder(resized_normalized_image.to(self.device), watermark_data.watermark.to(self.device)) denormalized_marked_image = denormalize_image(encoded_tensor.cpu()) marked_image = overlay_difference(image, resized_image, denormalized_marked_image) return marked_image
[docs] def extract(self, image: TorchImg, watermark_data: Any) -> np.ndarray: """Extract watermark from marked image. Parameters ---------- image : TorchImg Input image tensor in (C, H, W) format watermark_data: TorchBitWatermarkData Torch bit message with data type torch.int64 """ resized_image = resize_torch_img(image, (self.params.H, self.params.W)) resized_normalize_image = normalize_image(resized_image) with torch.no_grad(): res = self.encoder_decoder.decoder(resized_normalize_image.to(self.device)) return (res.cpu().numpy() > 0.5).astype(int)
[docs] def watermark_data_gen(self) -> TorchBitWatermarkData: """Generate watermark payload data for ARWGAN watermarking algorithm. Returns ------- TorchBitWatermarkData Torch bit message with data type torch.int64 and shape of (0, message_length) Notes ----- - Called automatically during embedding """ return TorchBitWatermarkData.get_random(self.params.wm_length)