Source code for wibench.attacks.bm3d.bm3d

import torch
from wibench.typing import TorchImg    
from ..base import BaseAttack
from bm3d import bm3d_rgb


[docs]class BM3DDenoising(BaseAttack): """ For more information, please refer to the following `page <https://webpages.tuni.fi/foi/GCF-BM3D/>`__. """ name = "bm3d" def __init__(self) -> None: super().__init__() def __call__(self, img: TorchImg) -> TorchImg: img = img.unsqueeze(0) distorted_image = [] for im in img.unbind(dim=0): # for each image in the batch img_denoised = bm3d_rgb(im.permute(1, 2, 0).cpu().numpy(), 0.1) # use standard deviation as 0.1, 0.05 also works img_denoised = torch.tensor(img_denoised).permute(2, 0, 1) distorted_image.append(img_denoised) distorted_image = torch.stack(distorted_image).clip(0, 1).to(device=img.device, dtype=img.dtype) return distorted_image.squeeze(0)