Source code for wibench.datasets.diffusiondb.diffusiondb

from wibench.typing import ImageObject, PromptObject
from ..base import RangeBaseDataset
import datasets
from typing import Optional, Tuple, Generator, Union
from torchvision.transforms.functional import to_tensor
from packaging import version


[docs]class DiffusionDB(RangeBaseDataset): """Dataset loader for the `DiffusionDB <https://github.com/poloclub/diffusiondb>`_ large-scale text-to-image dataset. Provides access to generated images and their prompts from DiffusionDB, with optional NSFW filtering and prompt-only retrieval modes. Parameters ---------- subset : str Dataset subset name (e.g., '2m_first_5k') sample_range : Optional[Tuple[int, int]] Optional (start, end) index range to subset the dataset cache_dir : Optional[str] Directory to cache downloaded dataset files skip_nsfw : bool Whether to automatically filter out NSFW images (default True) return_prompt : bool Whether to return prompts instead of images (default False) """ dataset_path = "poloclub/diffusiondb" def __init__( self, subset: str = "2m_first_5k", sample_range: Optional[Tuple[int, int]] = None, cache_dir: Optional[str] = None, skip_nsfw: bool = True, return_prompt: bool = False, ): dataset_args = {"path": self.dataset_path, "name": subset, "cache_dir": cache_dir} if (version.parse(datasets.__version__) >= version.parse("2.16.0")): dataset_args["trust_remote_code"] = True self.dataset = datasets.load_dataset(**dataset_args)["train"] self.skip_nsfw = skip_nsfw if not skip_nsfw: dataset_len = self.dataset.num_rows else: dataset_len = sum(score < 1 for score in self.dataset["image_nsfw"]) self.dataset_len = dataset_len super().__init__(sample_range, self.dataset_len) self.return_prompt = return_prompt def __len__(self): return self.len def generator( self, ) -> Generator[Union[ImageObject, PromptObject], None, None]: """Yields DiffusionDB images or prompts. Yields ------ Union[ImageObject, PromptObject]: images form DiffusionDB as ImageObject or prompts as PromptObject in case of `self.return_prompt = True` """ len_idx = 0 start_idx = self.sample_range.start - 1 while (True): start_idx += 1 if (len_idx >= self.len): break data = self.dataset[start_idx] if self.skip_nsfw and data["image_nsfw"] >= 1: continue len_idx += 1 if self.return_prompt: yield PromptObject(str(start_idx), data["prompt"]) else: yield ImageObject(str(start_idx), to_tensor(data["image"]))