Source code for wibench.attacks.stegastamp_inversion.stegastamp_inversion

import numpy as np
import onnxruntime as ort
import torch
import torch.nn.functional as F
from ..base import BaseAttack
from wibench.typing import TorchImg
from wibench.download import requires_download
from wibench.algorithms.stega_stamp.wrapper import (
    URL,
    NAME,
    REQUIRED_FILES,
    DEFAULT_WEIGHT_PATH
)


[docs]@requires_download(URL, NAME, REQUIRED_FILES) class StegastampInversion(BaseAttack): """ Adversarial attack that inverts watermarks encoded by Stegastamp from `here <https://github.com/leiluk1/erasing-the-invisible-beige-box/tree/main>`__. This attack decodes the hidden watermark from a watermarked image, inverts it (1 - watermark), and re-encodes the inverted watermark back into the image. The process is designed to disrupt Stegastamp's watermark decoding while maintaining visual similarity to the original. Parameters ---------- stegastamp_model_path: str path to StegaStamp onnx model device_id: int ID of cuda device to run Stegastamp on **TODO**: * run with GPU tensors, see the following `link <https://github.com/microsoft/onnxruntime-inference-examples/blob/main/python/api/onnxruntime-python-api.py#L66>`__ * convert from onnx to pytorch? """ def __init__(self, stegastamp_model_path: str = DEFAULT_WEIGHT_PATH, device_id: int = 0, ) -> None: super().__init__() session_options = ort.SessionOptions() session_options.intra_op_num_threads = 1 session_options.inter_op_num_threads = 1 session_options.log_severity_level = 3 self.stegastamp_model = ort.InferenceSession( stegastamp_model_path, providers=["CUDAExecutionProvider"], provider_options=[{"device_id": str(device_id)}], sess_options=session_options, ) self.secret_len = 100 def encode(self, image: np.ndarray, secret: np.ndarray) -> np.ndarray: stegastamp, residual, decoded = self.stegastamp_model.run(output_names=None, input_feed={"image": image, "secret": secret}) return stegastamp, residual def decode(self, img: np.ndarray) -> np.ndarray: dummy_secret = np.zeros((img.shape[0], self.secret_len), dtype=np.float32) # not used by the model stegastamp, residual, decoded = self.stegastamp_model.run( output_names=None, input_feed={"image": img, "secret": dummy_secret}, ) return decoded def __call__(self, img: TorchImg) -> TorchImg: img = img.unsqueeze(0) b, c, h, w = img.shape img = F.interpolate(img, size=(400, 400), mode="bicubic", antialias=True) # [b,c,400,400] img_np = img.permute(0, 2, 3, 1).detach().cpu().numpy() # [b,400,400,3] watermarks = self.decode(img_np) inverted_mask = 1 - watermarks attacked_img, _ = self.encode(img_np, inverted_mask) attacked_img = torch.tensor(attacked_img).permute(0, 3, 1, 2).to(img.device) attacked_img = F.interpolate(attacked_img, size=(h, w), mode="bicubic") return attacked_img.squeeze(0)