Source code for wibench.metrics.clip_iqa.clip_iqa

from typing import Any, Tuple, Union
from torchmetrics.multimodal import CLIPImageQualityAssessment
from wibench.typing import TorchImg
from wibench.metrics.base import PostEmbedMetric
import torch


[docs]class CLIP_IQA(PostEmbedMetric): """Exploring CLIP for Assessing the Look and Feel of Images `[paper] <https://arxiv.org/abs/2207.12396>`__. The implementation is taken from the `repository <https://lightning.ai/docs/torchmetrics/stable/multimodal/clip_iqa.html>`__. Initialization Parameters ------------------------- prompts : Tuple[Union[str, Tuple[str]]] List of text prompts for assessing the visual quality of an image (default ("quality",)) device : str Device to run the model on ('cuda', 'cpu') Call Parameters --------------- img1 : TorchImg Input image tensor in (C, H, W) format img2 : TorchImg Input image tensor in (C, H, W) format watermark_data : Any Not used, can be anything Notes ----- - The watermark_data field is required for the pipeline to work correctly """ def __init__(self, prompts: Tuple[Union[str, Tuple[str]]] = ("quality",), device: str = "cuda" if torch.cuda.is_available() else "cpu") -> None: self.device = device self.metric = CLIPImageQualityAssessment(prompts=prompts).to(self.device) def __call__( self, img1: TorchImg, img2: TorchImg, watermark_data: Any, ) -> float: return float(self.metric(img2.unsqueeze(0).to(self.device)))