Source code for wibench.datasets.base

from pathlib import Path
from itertools import chain
from PIL import Image
from torchvision.transforms import ToTensor
from typing_extensions import (
    Generator,
    Tuple,
    Union,
    List,
    Optional,
    Any
)
from wibench.typing import (
    Range,
    ImageObject,
    PromptObject
)
from wibench.registry import RegistryMeta


class BaseDataset(metaclass=RegistryMeta):
    """Abstract base class for all watermarking dataset implementations.

    Provides interface for dataset loading with automatic registry support.
    """
    type = "dataset"

    def __init__(self, *args, **kwargs) -> None:
        raise NotImplementedError

    def __len__(self) -> int:
        """Return number of objects in this dataset instance.
        
        Returns
        -------
        int
            Length of currently active dataset range
        """
        raise NotImplementedError

    def generator(self) -> Generator[Any, None, None]:
        """Yield objects to process. 
        """
        raise NotImplementedError


class RangeBaseDataset(BaseDataset):
    """
    Class for datasets that support range subsets of objects.

    Parameters
    ----------
    sample_range : Optional[Tuple[int, int]]
        Optional (start, end) index range to subset the dataset (including both borders)
    len : int
        Total number of images in the full dataset
    """
    abstract = True

    def __init__(self, sample_range: Optional[Tuple[int, int]], dataset_len: int) -> None:
        if sample_range is not None:
            self.sample_range = Range(*sample_range)
            range_len = (self.sample_range.stop - self.sample_range.start) + 1
            if (self.sample_range.stop < 0) or (self.sample_range.start < 0):
                raise ValueError(
                    f"Range start or stop must be >= 0, but current values={self.sample_range}"
                )
            elif ((self.sample_range.start >= dataset_len) or (self.sample_range.stop >= dataset_len)):
                raise ValueError(
                    f"Data range {self.sample_range.start} - {self.sample_range.stop} exceeds dataset size 0 - {dataset_len - 1}"
                )
            elif (self.sample_range.start > self.sample_range.stop):
                raise ValueError(
                    f"Range start value must be <= than range stop value, but current values={self.sample_range}"
                )
            else:
                self.len = range_len
        else:
            self.sample_range = Range(*[0, dataset_len - 1])
            self.len = dataset_len


[docs]class ImageFolderDataset(RangeBaseDataset): """Concrete dataset implementation loading images from a directory. Supports common image formats with optional preloading. Parameters ---------- path : Union[Path, str] Directory path containing images preload : bool Whether to load all images into memory upfront img_ext : List[str] Image file extensions to include (default: ['png', 'jpg']) sample_range : Optional[Tuple[int, int]] Optional (start, end) index range to subset the dataset (including both borders) """ def __init__( self, path: Union[Path, str], preload: bool = False, img_ext: List[str] = ["png", "jpg"], sample_range: Optional[Tuple[int, int]] = None ) -> None: self.path = Path(path) path_gen = sorted( chain.from_iterable(self.path.glob(f"*.{ext}") for ext in img_ext) ) self.path_list = list(path_gen) self.transform = ToTensor() assert len(self.path_list) != 0, "Empty dataset, check dataset path" dataset_len = len(self.path_list) super().__init__(sample_range, dataset_len) self.images = [] if preload: self.images = [ self.transform(Image.open(img_path).convert("RGB")) for img_path in self.path_list[self.sample_range[0]: self.sample_range[1] + 1] ] super().__init__(None, len(self)) def __len__(self) -> int: """Return number of images in folder. Returns ------- int Count of discovered image files """ return self.len def generator(self) -> Generator[ImageObject, None, None]: """Yields images from directory. Yields ------ ImageObject: image name as image_id and image tensor """ if len(self.images) > 0: for path, img in zip(self.path_list[self.sample_range[0]: self.sample_range[1] + 1], self.images): yield ImageObject(path.name, img) else: for path in self.path_list[self.sample_range[0]: self.sample_range[1] + 1]: img = self.transform(Image.open(path).convert("RGB")) yield ImageObject(path.name, img)
[docs]class PromptFolderDataset(RangeBaseDataset): """Concrete dataset implementation loading prompts from a directory. Directory should contain a number of ".txt" or ".csv" files with prompts, in one file prompts are separated by `separator`. All prompts are preloaded. Parameters ---------- path : Union[Path, str] Directory path containing images prompt_ext : List[str] File extensions to include (default: ['txt', 'csv']) sample_range : Optional[Tuple[int, int]] Optional (start, end) index range to subset the dataset (including both borders). Default: None (full dataset) separator : str Separator for prompts in one file, default "\n" """ def __init__( self, path: Union[Path, str], prompt_ext: List[str] = ["txt", "csv"], sample_range: Optional[Tuple[int, int]] = None, separator: str = "\n", ) -> None: self.path = Path(path) path_gen = sorted( chain.from_iterable(self.path.glob(f"*.{ext}") for ext in prompt_ext) ) self.path_list = list(path_gen) self.prompts = [] for path in self.path_list: with open(path, "r") as f: self.prompts += f.read().split(separator) assert len(self.prompts) != 0, "Empty dataset, check dataset path" dataset_len = len(self.prompts) super().__init__(sample_range, dataset_len) def __len__(self) -> int: """Return number of prompts in folder. Returns ------- int Count of discovered prompts (one file may contain several prompts) """ return self.len def generator(self) -> Generator[PromptObject, None, None]: """Yields prompts from directory. Yields ------ PromptObject: prompt number as prompt_id and prompt as string """ for prompt_id in range(self.sample_range[0], self.sample_range[1] + 1): yield PromptObject(str(prompt_id), self.prompts[prompt_id])