Source code for wibench.datasets.mscoco.mscoco

from wibench.typing import ImageObject
from ..base import RangeBaseDataset
from datasets import load_dataset
from typing import Optional, Tuple, Generator
from torchvision.transforms.functional import to_tensor


[docs]class MSCOCO(RangeBaseDataset): """Dataset loader for `MS-COCO <https://cocodataset.org/>`_ (Common Objects in Context) images. Provides access to the COCO 2017 dataset images through HuggingFace Datasets, supporting both validation and training splits with optional caching. Parameters ---------- split : str Dataset split to load ('train' or 'val') sample_range : Optional[Tuple[int, int]] Optional (start, end) index range to subset the dataset cache_dir : Optional[str] Directory to cache downloaded dataset files """ dataset_path = "rafaelpadilla/coco2017" def __init__( self, split: str = "val", sample_range: Optional[Tuple[int, int]] = None, cache_dir: Optional[str] = None ): self.dataset = load_dataset(self.dataset_path, split=split, cache_dir=cache_dir) super().__init__(sample_range, len(self.dataset)) def __len__(self) -> int: return self.len def generator( self, ) -> Generator[ImageObject, None, None]: """Yields MSCOCO images. Yields ------ ImageObject: images form MSCOCO as ImageObject """ len_idx = -1 while (True): len_idx += 1 start_idx = self.sample_range.start + len_idx if (len_idx >= self.len): break data = self.dataset[start_idx] yield ImageObject(str(data["image_id"]), to_tensor(data["image"].convert("RGB")))