Source code for wibench.attacks.liif.liif_attack

import torch
from .models import make
from .utils import make_coord
from wibench.typing import TorchImg
from wibench.download import requires_download
from ..base import BaseAttack


URL = "https://nextcloud.ispras.ru/index.php/s/n2jSWZi4L8mAmEL"
NAME = "liif"
REQUIRED_FILES = ["rdn-liif.pth"]

DEFAULT_MODEL_PATH = "./model_files/liif/rdn-liif.pth"


def batched_predict(model, inp, coord, cell, bsize):
    with torch.no_grad():
        model.gen_feat(inp)
        n = coord.shape[1]
        ql = 0
        preds = []
        while ql < n:
            qr = min(ql + bsize, n)
            pred = model.query_rgb(coord[:, ql:qr, :], cell[:, ql:qr, :])
            preds.append(pred)
            ql = qr
        pred = torch.cat(preds, dim=1)
    return pred


[docs]@requires_download(URL, NAME, REQUIRED_FILES) class LIIFAttack(BaseAttack): """ Attack using Local Implicit Image Function (`LIIF <https://github.com/yinboc/liif>`__) for image super-resolution. Reconstructs images through an implicit neural representation that learns continuous image functions. The attack queries the LIIF model at specific coordinates to generate a modified version of the input image, effectively applying learned upsampling/denoising. """ def __init__( self, device: str = "cuda" if torch.cuda.is_available() else "cpu", model_path: str = DEFAULT_MODEL_PATH, ) -> None: super().__init__() self.device = device self.model_path = model_path self.model = make( torch.load(model_path, map_location=torch.device(self.device))[ "model" ], load_sd=True, ).to(self.device) def __call__(self, img: TorchImg) -> TorchImg: if len(img.shape) < 4: img = img.unsqueeze(0) b, c, h, w = img.shape coord = make_coord((h, w)).to(self.device) cell = torch.ones_like(coord) cell[:, 0] *= 2 / h cell[:, 1] *= 2 / w pred = batched_predict( self.model, ((img - 0.5) / 0.5).to(self.device), coord.unsqueeze(0).repeat(b, 1, 1), cell.unsqueeze(0).repeat(b, 1, 1), bsize=30000, )[0] pred = ( (pred * 0.5 + 0.5) .clamp(0, 1) .view(b, h, w, 3) .permute(0, 3, 1, 2) .to(self.device) ) return pred.squeeze(0).cpu()