From a894f93305467d74ac74596d7763cc09cf99a5b4 Mon Sep 17 00:00:00 2001 From: Shunsuke KITADA Date: Thu, 20 Jun 2024 18:22:38 +0900 Subject: [PATCH] split to modules --- MSCOCO.py | 787 ++----------------------------------------- annotation.py | 9 + base_example.py | 22 ++ caption.py | 89 +++++ category.py | 18 + const.py | 99 ++++++ image.py | 28 ++ info.py | 17 + instances.py | 226 +++++++++++++ license.py | 18 + person_keypoint.py | 194 +++++++++++ processor.py | 96 ++++++ rle.py | 0 tests/MSCOCO_test.py | 7 +- typehint.py | 30 ++ 15 files changed, 872 insertions(+), 768 deletions(-) create mode 100644 annotation.py create mode 100644 base_example.py create mode 100644 caption.py create mode 100644 category.py create mode 100644 const.py create mode 100644 image.py create mode 100644 info.py create mode 100644 instances.py create mode 100644 license.py create mode 100644 person_keypoint.py create mode 100644 processor.py create mode 100644 rle.py create mode 100644 typehint.py diff --git a/MSCOCO.py b/MSCOCO.py index b70f71d..a793567 100644 --- a/MSCOCO.py +++ b/MSCOCO.py @@ -1,56 +1,47 @@ -import abc -import json import logging import os -from collections import defaultdict -from dataclasses import asdict, dataclass from typing import ( - Any, - Dict, - Final, - Iterator, List, - Literal, Optional, Sequence, Tuple, - TypedDict, Union, get_args, ) import datasets as ds -import numpy as np from datasets.data_files import DataFilesDict -from PIL import Image -from PIL.Image import Image as PilImage -from pycocotools import mask as cocomask -from tqdm.auto import tqdm -logger = logging.getLogger(__name__) - -JsonDict = Dict[str, Any] -ImageId = int -AnnotationId = int -LicenseId = int -CategoryId = int -Bbox = Tuple[float, float, float, float] - -MscocoSplits = Literal["train", "val", "test"] +from .caption import CaptionsProcessor +from .instances import InstancesProcessor +from .person_keypoint import PersonKeypointsProcessor +from .processor import MsCocoProcessor +from .typehint import MscocoSplits -KEYPOINT_STATE: Final[List[str]] = ["unknown", "invisible", "visible"] +logger = logging.getLogger(__name__) -_CITATION = """ +_CITATION = """\ +@inproceedings{lin2014microsoft, + title={Microsoft coco: Common objects in context}, + author={Lin, Tsung-Yi and Maire, Michael and Belongie, Serge and Hays, James and Perona, Pietro and Ramanan, Deva and Doll{\'a}r, Piotr and Zitnick, C Lawrence}, + booktitle={Computer Vision--ECCV 2014: 13th European Conference, Zurich, Switzerland, September 6-12, 2014, Proceedings, Part V 13}, + pages={740--755}, + year={2014}, + organization={Springer} +} """ -_DESCRIPTION = """ +_DESCRIPTION = """\ +COCO is a large-scale object detection, segmentation, and captioning dataset. """ -_HOMEPAGE = """ +_HOMEPAGE = """\ +https://cocodataset.org/#home """ -_LICENSE = """ +_LICENSE = """\ +The annotations in this dataset along with this website belong to the COCO Consortium and are licensed under a Creative Commons Attribution 4.0 License. """ _URLS = { @@ -90,742 +81,6 @@ }, } -CATEGORIES: Final[List[str]] = [ - "person", - "bicycle", - "car", - "motorcycle", - "airplane", - "bus", - "train", - "truck", - "boat", - "traffic light", - "fire hydrant", - "stop sign", - "parking meter", - "bench", - "bird", - "cat", - "dog", - "horse", - "sheep", - "cow", - "elephant", - "bear", - "zebra", - "giraffe", - "backpack", - "umbrella", - "handbag", - "tie", - "suitcase", - "frisbee", - "skis", - "snowboard", - "sports ball", - "kite", - "baseball bat", - "baseball glove", - "skateboard", - "surfboard", - "tennis racket", - "bottle", - "wine glass", - "cup", - "fork", - "knife", - "spoon", - "bowl", - "banana", - "apple", - "sandwich", - "orange", - "broccoli", - "carrot", - "hot dog", - "pizza", - "donut", - "cake", - "chair", - "couch", - "potted plant", - "bed", - "dining table", - "toilet", - "tv", - "laptop", - "mouse", - "remote", - "keyboard", - "cell phone", - "microwave", - "oven", - "toaster", - "sink", - "refrigerator", - "book", - "clock", - "vase", - "scissors", - "teddy bear", - "hair drier", - "toothbrush", -] - -SUPER_CATEGORIES: Final[List[str]] = [ - "person", - "vehicle", - "outdoor", - "animal", - "accessory", - "sports", - "kitchen", - "food", - "furniture", - "electronic", - "appliance", - "indoor", -] - - -@dataclass -class AnnotationInfo(object): - description: str - url: str - version: str - year: str - contributor: str - date_created: str - - @classmethod - def from_dict(cls, json_dict: JsonDict) -> "AnnotationInfo": - return cls(**json_dict) - - -@dataclass -class LicenseData(object): - url: str - license_id: LicenseId - name: str - - @classmethod - def from_dict(cls, json_dict: JsonDict) -> "LicenseData": - return cls( - license_id=json_dict["id"], - url=json_dict["url"], - name=json_dict["name"], - ) - - -@dataclass -class ImageData(object): - image_id: ImageId - license_id: LicenseId - file_name: str - coco_url: str - height: int - width: int - date_captured: str - flickr_url: str - - @classmethod - def from_dict(cls, json_dict: JsonDict) -> "ImageData": - return cls( - image_id=json_dict["id"], - license_id=json_dict["license"], - file_name=json_dict["file_name"], - coco_url=json_dict["coco_url"], - height=json_dict["height"], - width=json_dict["width"], - date_captured=json_dict["date_captured"], - flickr_url=json_dict["flickr_url"], - ) - - @property - def shape(self) -> Tuple[int, int]: - return (self.height, self.width) - - -@dataclass -class CategoryData(object): - category_id: int - name: str - supercategory: str - - @classmethod - def from_dict(cls, json_dict: JsonDict) -> "CategoryData": - return cls( - category_id=json_dict["id"], - name=json_dict["name"], - supercategory=json_dict["supercategory"], - ) - - -@dataclass -class AnnotationData(object): - annotation_id: AnnotationId - image_id: ImageId - - -@dataclass -class CaptionsAnnotationData(AnnotationData): - caption: str - - @classmethod - def from_dict(cls, json_dict: JsonDict) -> "CaptionsAnnotationData": - return cls( - annotation_id=json_dict["id"], - image_id=json_dict["image_id"], - caption=json_dict["caption"], - ) - - -class UncompressedRLE(TypedDict): - counts: List[int] - size: Tuple[int, int] - - -class CompressedRLE(TypedDict): - counts: bytes - size: Tuple[int, int] - - -@dataclass -class InstancesAnnotationData(AnnotationData): - segmentation: Union[np.ndarray, CompressedRLE] - area: float - iscrowd: bool - bbox: Tuple[float, float, float, float] - category_id: int - - @classmethod - def compress_rle( - cls, - segmentation: Union[List[List[float]], UncompressedRLE], - iscrowd: bool, - height: int, - width: int, - ) -> CompressedRLE: - if iscrowd: - rle = cocomask.frPyObjects(segmentation, h=height, w=width) - else: - rles = cocomask.frPyObjects(segmentation, h=height, w=width) - rle = cocomask.merge(rles) - - return rle # type: ignore - - @classmethod - def rle_segmentation_to_binary_mask( - cls, segmentation, iscrowd: bool, height: int, width: int - ) -> np.ndarray: - rle = cls.compress_rle( - segmentation=segmentation, iscrowd=iscrowd, height=height, width=width - ) - return cocomask.decode(rle) # type: ignore - - @classmethod - def rle_segmentation_to_mask( - cls, - segmentation: Union[List[List[float]], UncompressedRLE], - iscrowd: bool, - height: int, - width: int, - ) -> np.ndarray: - binary_mask = cls.rle_segmentation_to_binary_mask( - segmentation=segmentation, iscrowd=iscrowd, height=height, width=width - ) - return binary_mask * 255 - - @classmethod - def from_dict( - cls, - json_dict: JsonDict, - images: Dict[ImageId, ImageData], - decode_rle: bool, - ) -> "InstancesAnnotationData": - segmentation = json_dict["segmentation"] - image_id = json_dict["image_id"] - image_data = images[image_id] - iscrowd = bool(json_dict["iscrowd"]) - - segmentation_mask = ( - cls.rle_segmentation_to_mask( - segmentation=segmentation, - iscrowd=iscrowd, - height=image_data.height, - width=image_data.width, - ) - if decode_rle - else cls.compress_rle( - segmentation=segmentation, - iscrowd=iscrowd, - height=image_data.height, - width=image_data.width, - ) - ) - return cls( - # - # for AnnotationData - # - annotation_id=json_dict["id"], - image_id=image_id, - # - # for InstancesAnnotationData - # - segmentation=segmentation_mask, # type: ignore - area=json_dict["area"], - iscrowd=iscrowd, - bbox=json_dict["bbox"], - category_id=json_dict["category_id"], - ) - - -@dataclass -class PersonKeypoint(object): - x: int - y: int - v: int - state: str - - -@dataclass -class PersonKeypointsAnnotationData(InstancesAnnotationData): - num_keypoints: int - keypoints: List[PersonKeypoint] - - @classmethod - def v_keypoint_to_state(cls, keypoint_v: int) -> str: - return KEYPOINT_STATE[keypoint_v] - - @classmethod - def get_person_keypoints( - cls, flatten_keypoints: List[int], num_keypoints: int - ) -> List[PersonKeypoint]: - keypoints_x = flatten_keypoints[0::3] - keypoints_y = flatten_keypoints[1::3] - keypoints_v = flatten_keypoints[2::3] - assert len(keypoints_x) == len(keypoints_y) == len(keypoints_v) - - keypoints = [ - PersonKeypoint(x=x, y=y, v=v, state=cls.v_keypoint_to_state(v)) - for x, y, v in zip(keypoints_x, keypoints_y, keypoints_v) - ] - assert len([kp for kp in keypoints if kp.state != "unknown"]) == num_keypoints - return keypoints - - @classmethod - def from_dict( - cls, - json_dict: JsonDict, - images: Dict[ImageId, ImageData], - decode_rle: bool, - ) -> "PersonKeypointsAnnotationData": - segmentation = json_dict["segmentation"] - image_id = json_dict["image_id"] - image_data = images[image_id] - iscrowd = bool(json_dict["iscrowd"]) - - segmentation_mask = ( - cls.rle_segmentation_to_mask( - segmentation=segmentation, - iscrowd=iscrowd, - height=image_data.height, - width=image_data.width, - ) - if decode_rle - else cls.compress_rle( - segmentation=segmentation, - iscrowd=iscrowd, - height=image_data.height, - width=image_data.width, - ) - ) - flatten_keypoints = json_dict["keypoints"] - num_keypoints = json_dict["num_keypoints"] - keypoints = cls.get_person_keypoints(flatten_keypoints, num_keypoints) - - return cls( - # - # for AnnotationData - # - annotation_id=json_dict["id"], - image_id=image_id, - # - # for InstancesAnnotationData - # - segmentation=segmentation_mask, # type: ignore - area=json_dict["area"], - iscrowd=iscrowd, - bbox=json_dict["bbox"], - category_id=json_dict["category_id"], - # - # PersonKeypointsAnnotationData - # - num_keypoints=num_keypoints, - keypoints=keypoints, - ) - - -class LicenseDict(TypedDict): - license_id: LicenseId - name: str - url: str - - -class BaseExample(TypedDict): - image_id: ImageId - image: PilImage - file_name: str - coco_url: str - height: int - width: int - date_captured: str - flickr_url: str - license_id: LicenseId - license: LicenseDict - - -class CaptionAnnotationDict(TypedDict): - annotation_id: AnnotationId - caption: str - - -class CaptionExample(BaseExample): - annotations: List[CaptionAnnotationDict] - - -class CategoryDict(TypedDict): - category_id: CategoryId - name: str - supercategory: str - - -class InstanceAnnotationDict(TypedDict): - annotation_id: AnnotationId - area: float - bbox: Bbox - image_id: ImageId - category_id: CategoryId - category: CategoryDict - iscrowd: bool - segmentation: np.ndarray - - -class InstanceExample(BaseExample): - annotations: List[InstanceAnnotationDict] - - -class KeypointDict(TypedDict): - x: int - y: int - v: int - state: str - - -class PersonKeypointAnnotationDict(InstanceAnnotationDict): - num_keypoints: int - keypoints: List[KeypointDict] - - -class PersonKeypointExample(BaseExample): - annotations: List[PersonKeypointAnnotationDict] - - -class MsCocoProcessor(object, metaclass=abc.ABCMeta): - def load_image(self, image_path: str) -> PilImage: - return Image.open(image_path) - - def load_annotation_json(self, ann_file_path: str) -> JsonDict: - logger.info(f"Load annotation json from {ann_file_path}") - with open(ann_file_path, "r") as rf: - ann_json = json.load(rf) - return ann_json - - def load_licenses_data( - self, license_dicts: List[JsonDict] - ) -> Dict[LicenseId, LicenseData]: - licenses = {} - for license_dict in license_dicts: - license_data = LicenseData.from_dict(license_dict) - licenses[license_data.license_id] = license_data - return licenses - - def load_images_data( - self, - image_dicts: List[JsonDict], - tqdm_desc: str = "Load images", - ) -> Dict[ImageId, ImageData]: - images = {} - for image_dict in tqdm(image_dicts, desc=tqdm_desc): - image_data = ImageData.from_dict(image_dict) - images[image_data.image_id] = image_data - return images - - def load_categories_data( - self, - category_dicts: List[JsonDict], - tqdm_desc: str = "Load categories", - ) -> Dict[CategoryId, CategoryData]: - categories = {} - for category_dict in tqdm(category_dicts, desc=tqdm_desc): - category_data = CategoryData.from_dict(category_dict) - categories[category_data.category_id] = category_data - return categories - - def get_features_base_dict(self): - return { - "image_id": ds.Value("int64"), - "image": ds.Image(), - "file_name": ds.Value("string"), - "coco_url": ds.Value("string"), - "height": ds.Value("int32"), - "width": ds.Value("int32"), - "date_captured": ds.Value("string"), - "flickr_url": ds.Value("string"), - "license_id": ds.Value("int32"), - "license": { - "url": ds.Value("string"), - "license_id": ds.Value("int8"), - "name": ds.Value("string"), - }, - } - - @abc.abstractmethod - def get_features(self, *args, **kwargs) -> ds.Features: - raise NotImplementedError - - @abc.abstractmethod - def load_data(self, ann_dicts: List[JsonDict], tqdm_desc: str = "", **kwargs): - assert tqdm_desc != "", "tqdm_desc must be provided." - raise NotImplementedError - - @abc.abstractmethod - def generate_examples( - self, - image_dir: str, - images: Dict[ImageId, ImageData], - annotations: Dict[ImageId, List[CaptionsAnnotationData]], - licenses: Dict[LicenseId, LicenseData], - **kwargs, - ): - raise NotImplementedError - - -class CaptionsProcessor(MsCocoProcessor): - def get_features(self, *args, **kwargs) -> ds.Features: - features_dict = self.get_features_base_dict() - annotations = ds.Sequence( - { - "annotation_id": ds.Value("int64"), - "image_id": ds.Value("int64"), - "caption": ds.Value("string"), - } - ) - features_dict.update({"annotations": annotations}) - return ds.Features(features_dict) - - def load_data( - self, - ann_dicts: List[JsonDict], - tqdm_desc: str = "Load captions data", - **kwargs, - ) -> Dict[ImageId, List[CaptionsAnnotationData]]: - annotations = defaultdict(list) - for ann_dict in tqdm(ann_dicts, desc=tqdm_desc): - ann_data = CaptionsAnnotationData.from_dict(ann_dict) - annotations[ann_data.image_id].append(ann_data) - return annotations - - def generate_examples( - self, - image_dir: str, - images: Dict[ImageId, ImageData], - annotations: Dict[ImageId, List[CaptionsAnnotationData]], - licenses: Dict[LicenseId, LicenseData], - **kwargs, - ) -> Iterator[Tuple[int, CaptionExample]]: - for idx, image_id in enumerate(images.keys()): - image_data = images[image_id] - image_anns = annotations[image_id] - - assert len(image_anns) > 0 - - image = self.load_image( - image_path=os.path.join(image_dir, image_data.file_name), - ) - example = asdict(image_data) - example["image"] = image - example["license"] = asdict(licenses[image_data.license_id]) - - example["annotations"] = [] - for ann in image_anns: - example["annotations"].append(asdict(ann)) - - yield idx, example # type: ignore - - -class InstancesProcessor(MsCocoProcessor): - def get_features_instance_dict(self, decode_rle: bool): - segmentation_feature = ( - ds.Image() - if decode_rle - else { - "counts": ds.Sequence(ds.Value("int64")), - "size": ds.Sequence(ds.Value("int32")), - } - ) - return { - "annotation_id": ds.Value("int64"), - "image_id": ds.Value("int64"), - "segmentation": segmentation_feature, - "area": ds.Value("float32"), - "iscrowd": ds.Value("bool"), - "bbox": ds.Sequence(ds.Value("float32"), length=4), - "category_id": ds.Value("int32"), - "category": { - "category_id": ds.Value("int32"), - "name": ds.ClassLabel( - num_classes=len(CATEGORIES), - names=CATEGORIES, - ), - "supercategory": ds.ClassLabel( - num_classes=len(SUPER_CATEGORIES), - names=SUPER_CATEGORIES, - ), - }, - } - - def get_features(self, decode_rle: bool) -> ds.Features: - features_dict = self.get_features_base_dict() - annotations = ds.Sequence( - self.get_features_instance_dict(decode_rle=decode_rle) - ) - features_dict.update({"annotations": annotations}) - return ds.Features(features_dict) - - def load_data( # type: ignore[override] - self, - ann_dicts: List[JsonDict], - images: Dict[ImageId, ImageData], - decode_rle: bool, - tqdm_desc: str = "Load instances data", - ) -> Dict[ImageId, List[InstancesAnnotationData]]: - annotations = defaultdict(list) - ann_dicts = sorted(ann_dicts, key=lambda d: d["image_id"]) - - for ann_dict in tqdm(ann_dicts, desc=tqdm_desc): - ann_data = InstancesAnnotationData.from_dict( - ann_dict, images=images, decode_rle=decode_rle - ) - annotations[ann_data.image_id].append(ann_data) - - return annotations - - def generate_examples( # type: ignore[override] - self, - image_dir: str, - images: Dict[ImageId, ImageData], - annotations: Dict[ImageId, List[InstancesAnnotationData]], - licenses: Dict[LicenseId, LicenseData], - categories: Dict[CategoryId, CategoryData], - ) -> Iterator[Tuple[int, InstanceExample]]: - for idx, image_id in enumerate(images.keys()): - image_data = images[image_id] - image_anns = annotations[image_id] - - if len(image_anns) < 1: - logger.warning(f"No annotation found for image id: {image_id}.") - continue - - image = self.load_image( - image_path=os.path.join(image_dir, image_data.file_name), - ) - example = asdict(image_data) - example["image"] = image - example["license"] = asdict(licenses[image_data.license_id]) - - example["annotations"] = [] - for ann in image_anns: - ann_dict = asdict(ann) - category = categories[ann.category_id] - ann_dict["category"] = asdict(category) - example["annotations"].append(ann_dict) - - yield idx, example # type: ignore - - -class PersonKeypointsProcessor(InstancesProcessor): - def get_features(self, decode_rle: bool) -> ds.Features: - features_dict = self.get_features_base_dict() - features_instance_dict = self.get_features_instance_dict(decode_rle=decode_rle) - features_instance_dict.update( - { - "keypoints": ds.Sequence( - { - "state": ds.Value("string"), - "x": ds.Value("int32"), - "y": ds.Value("int32"), - "v": ds.Value("int32"), - } - ), - "num_keypoints": ds.Value("int32"), - } - ) - annotations = ds.Sequence(features_instance_dict) - features_dict.update({"annotations": annotations}) - return ds.Features(features_dict) - - def load_data( # type: ignore[override] - self, - ann_dicts: List[JsonDict], - images: Dict[ImageId, ImageData], - decode_rle: bool, - tqdm_desc: str = "Load person keypoints data", - ) -> Dict[ImageId, List[PersonKeypointsAnnotationData]]: - annotations = defaultdict(list) - ann_dicts = sorted(ann_dicts, key=lambda d: d["image_id"]) - - for ann_dict in tqdm(ann_dicts, desc=tqdm_desc): - ann_data = PersonKeypointsAnnotationData.from_dict( - ann_dict, images=images, decode_rle=decode_rle - ) - annotations[ann_data.image_id].append(ann_data) - return annotations - - def generate_examples( # type: ignore[override] - self, - image_dir: str, - images: Dict[ImageId, ImageData], - annotations: Dict[ImageId, List[PersonKeypointsAnnotationData]], - licenses: Dict[LicenseId, LicenseData], - categories: Dict[CategoryId, CategoryData], - ) -> Iterator[Tuple[int, PersonKeypointExample]]: - for idx, image_id in enumerate(images.keys()): - image_data = images[image_id] - image_anns = annotations[image_id] - - if len(image_anns) < 1: - # If there are no persons in the image, - # no keypoint annotations will be assigned. - continue - - image = self.load_image( - image_path=os.path.join(image_dir, image_data.file_name), - ) - example = asdict(image_data) - example["image"] = image - example["license"] = asdict(licenses[image_data.license_id]) - - example["annotations"] = [] - for ann in image_anns: - ann_dict = asdict(ann) - category = categories[ann.category_id] - ann_dict["category"] = asdict(category) - example["annotations"].append(ann_dict) - - yield idx, example # type: ignore - class MsCocoConfig(ds.BuilderConfig): YEARS: Tuple[int, ...] = ( diff --git a/annotation.py b/annotation.py new file mode 100644 index 0000000..d30cd17 --- /dev/null +++ b/annotation.py @@ -0,0 +1,9 @@ +from dataclasses import dataclass + +from .typehint import AnnotationId, ImageId + + +@dataclass +class AnnotationData(object): + annotation_id: AnnotationId + image_id: ImageId diff --git a/base_example.py b/base_example.py new file mode 100644 index 0000000..71621b3 --- /dev/null +++ b/base_example.py @@ -0,0 +1,22 @@ +from typing import TypedDict + +from .typehint import ImageId, LicenseId, PilImage + + +class LicenseDict(TypedDict): + license_id: LicenseId + name: str + url: str + + +class BaseExample(TypedDict): + image_id: ImageId + image: PilImage + file_name: str + coco_url: str + height: int + width: int + date_captured: str + flickr_url: str + license_id: LicenseId + license: LicenseDict diff --git a/caption.py b/caption.py new file mode 100644 index 0000000..7b1836e --- /dev/null +++ b/caption.py @@ -0,0 +1,89 @@ +import os +from collections import defaultdict +from dataclasses import asdict, dataclass +from typing import Dict, Iterator, List, Tuple, TypedDict + +import datasets as ds +from tqdm.auto import tqdm + +from .annotation import AnnotationData +from .base_example import BaseExample +from .image import ImageData +from .license import LicenseData +from .processor import MsCocoProcessor +from .typehint import AnnotationId, ImageId, JsonDict, LicenseId + + +@dataclass +class CaptionsAnnotationData(AnnotationData): + caption: str + + @classmethod + def from_dict(cls, json_dict: JsonDict) -> "CaptionsAnnotationData": + return cls( + annotation_id=json_dict["id"], + image_id=json_dict["image_id"], + caption=json_dict["caption"], + ) + + +class CaptionAnnotationDict(TypedDict): + annotation_id: AnnotationId + caption: str + + +class CaptionExample(BaseExample): + annotations: List[CaptionAnnotationDict] + + +class CaptionsProcessor(MsCocoProcessor): + def get_features(self, *args, **kwargs) -> ds.Features: + features_dict = self.get_features_base_dict() + annotations = ds.Sequence( + { + "annotation_id": ds.Value("int64"), + "image_id": ds.Value("int64"), + "caption": ds.Value("string"), + } + ) + features_dict.update({"annotations": annotations}) + return ds.Features(features_dict) + + def load_data( + self, + ann_dicts: List[JsonDict], + tqdm_desc: str = "Load captions data", + **kwargs, + ) -> Dict[ImageId, List[CaptionsAnnotationData]]: + annotations = defaultdict(list) + for ann_dict in tqdm(ann_dicts, desc=tqdm_desc): + ann_data = CaptionsAnnotationData.from_dict(ann_dict) + annotations[ann_data.image_id].append(ann_data) + return annotations + + def generate_examples( + self, + image_dir: str, + images: Dict[ImageId, ImageData], + annotations: Dict[ImageId, List[CaptionsAnnotationData]], + licenses: Dict[LicenseId, LicenseData], + **kwargs, + ) -> Iterator[Tuple[int, CaptionExample]]: + for idx, image_id in enumerate(images.keys()): + image_data = images[image_id] + image_anns = annotations[image_id] + + assert len(image_anns) > 0 + + image = self.load_image( + image_path=os.path.join(image_dir, image_data.file_name), + ) + example = asdict(image_data) + example["image"] = image + example["license"] = asdict(licenses[image_data.license_id]) + + example["annotations"] = [] + for ann in image_anns: + example["annotations"].append(asdict(ann)) + + yield idx, example # type: ignore diff --git a/category.py b/category.py new file mode 100644 index 0000000..fbd8366 --- /dev/null +++ b/category.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass + +from .typehint import JsonDict + + +@dataclass +class CategoryData(object): + category_id: int + name: str + supercategory: str + + @classmethod + def from_dict(cls, json_dict: JsonDict) -> "CategoryData": + return cls( + category_id=json_dict["id"], + name=json_dict["name"], + supercategory=json_dict["supercategory"], + ) diff --git a/const.py b/const.py new file mode 100644 index 0000000..778dec3 --- /dev/null +++ b/const.py @@ -0,0 +1,99 @@ +from typing import Final, List + +CATEGORIES: Final[List[str]] = [ + "person", + "bicycle", + "car", + "motorcycle", + "airplane", + "bus", + "train", + "truck", + "boat", + "traffic light", + "fire hydrant", + "stop sign", + "parking meter", + "bench", + "bird", + "cat", + "dog", + "horse", + "sheep", + "cow", + "elephant", + "bear", + "zebra", + "giraffe", + "backpack", + "umbrella", + "handbag", + "tie", + "suitcase", + "frisbee", + "skis", + "snowboard", + "sports ball", + "kite", + "baseball bat", + "baseball glove", + "skateboard", + "surfboard", + "tennis racket", + "bottle", + "wine glass", + "cup", + "fork", + "knife", + "spoon", + "bowl", + "banana", + "apple", + "sandwich", + "orange", + "broccoli", + "carrot", + "hot dog", + "pizza", + "donut", + "cake", + "chair", + "couch", + "potted plant", + "bed", + "dining table", + "toilet", + "tv", + "laptop", + "mouse", + "remote", + "keyboard", + "cell phone", + "microwave", + "oven", + "toaster", + "sink", + "refrigerator", + "book", + "clock", + "vase", + "scissors", + "teddy bear", + "hair drier", + "toothbrush", +] + +SUPER_CATEGORIES: Final[List[str]] = [ + "person", + "vehicle", + "outdoor", + "animal", + "accessory", + "sports", + "kitchen", + "food", + "furniture", + "electronic", + "appliance", + "indoor", +] diff --git a/image.py b/image.py new file mode 100644 index 0000000..5cf9b83 --- /dev/null +++ b/image.py @@ -0,0 +1,28 @@ +from dataclasses import dataclass + +from .typehint import ImageId, JsonDict, LicenseId + + +@dataclass +class ImageData(object): + image_id: ImageId + license_id: LicenseId + file_name: str + coco_url: str + height: int + width: int + date_captured: str + flickr_url: str + + @classmethod + def from_dict(cls, json_dict: JsonDict) -> "ImageData": + return cls( + image_id=json_dict["id"], + license_id=json_dict["license"], + file_name=json_dict["file_name"], + coco_url=json_dict["coco_url"], + height=json_dict["height"], + width=json_dict["width"], + date_captured=json_dict["date_captured"], + flickr_url=json_dict["flickr_url"], + ) diff --git a/info.py b/info.py new file mode 100644 index 0000000..2bfdffd --- /dev/null +++ b/info.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass + +from .typehint import JsonDict + + +@dataclass +class AnnotationInfo(object): + description: str + url: str + version: str + year: str + contributor: str + date_created: str + + @classmethod + def from_dict(cls, json_dict: JsonDict) -> "AnnotationInfo": + return cls(**json_dict) diff --git a/instances.py b/instances.py new file mode 100644 index 0000000..ec4c827 --- /dev/null +++ b/instances.py @@ -0,0 +1,226 @@ +import logging +import os +from collections import defaultdict +from dataclasses import asdict, dataclass +from typing import Dict, Iterator, List, Tuple, TypedDict, Union + +import datasets as ds +import numpy as np +from pycocotools import mask as cocomask +from tqdm.auto import tqdm + +from .annotation import AnnotationData +from .base_example import BaseExample +from .category import CategoryData +from .const import CATEGORIES, SUPER_CATEGORIES +from .image import ImageData +from .license import LicenseData +from .processor import MsCocoProcessor +from .typehint import ( + AnnotationId, + Bbox, + CategoryDict, + CategoryId, + CompressedRLE, + ImageId, + JsonDict, + LicenseId, + UncompressedRLE, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class InstancesAnnotationData(AnnotationData): + segmentation: Union[np.ndarray, CompressedRLE] + area: float + iscrowd: bool + bbox: Tuple[float, float, float, float] + category_id: int + + @classmethod + def compress_rle( + cls, + segmentation: Union[List[List[float]], UncompressedRLE], + iscrowd: bool, + height: int, + width: int, + ) -> CompressedRLE: + if iscrowd: + rle = cocomask.frPyObjects(segmentation, h=height, w=width) + else: + rles = cocomask.frPyObjects(segmentation, h=height, w=width) + rle = cocomask.merge(rles) + + return rle # type: ignore + + @classmethod + def rle_segmentation_to_binary_mask( + cls, segmentation, iscrowd: bool, height: int, width: int + ) -> np.ndarray: + rle = cls.compress_rle( + segmentation=segmentation, iscrowd=iscrowd, height=height, width=width + ) + return cocomask.decode(rle) # type: ignore + + @classmethod + def rle_segmentation_to_mask( + cls, + segmentation: Union[List[List[float]], UncompressedRLE], + iscrowd: bool, + height: int, + width: int, + ) -> np.ndarray: + binary_mask = cls.rle_segmentation_to_binary_mask( + segmentation=segmentation, iscrowd=iscrowd, height=height, width=width + ) + return binary_mask * 255 + + @classmethod + def from_dict( + cls, + json_dict: JsonDict, + images: Dict[ImageId, ImageData], + decode_rle: bool, + ) -> "InstancesAnnotationData": + segmentation = json_dict["segmentation"] + image_id = json_dict["image_id"] + image_data = images[image_id] + iscrowd = bool(json_dict["iscrowd"]) + + segmentation_mask = ( + cls.rle_segmentation_to_mask( + segmentation=segmentation, + iscrowd=iscrowd, + height=image_data.height, + width=image_data.width, + ) + if decode_rle + else cls.compress_rle( + segmentation=segmentation, + iscrowd=iscrowd, + height=image_data.height, + width=image_data.width, + ) + ) + return cls( + # + # for AnnotationData + # + annotation_id=json_dict["id"], + image_id=image_id, + # + # for InstancesAnnotationData + # + segmentation=segmentation_mask, # type: ignore + area=json_dict["area"], + iscrowd=iscrowd, + bbox=json_dict["bbox"], + category_id=json_dict["category_id"], + ) + + +class InstanceAnnotationDict(TypedDict): + annotation_id: AnnotationId + area: float + bbox: Bbox + image_id: ImageId + category_id: CategoryId + category: CategoryDict + iscrowd: bool + segmentation: np.ndarray + + +class InstanceExample(BaseExample): + annotations: List[InstanceAnnotationDict] + + +class InstancesProcessor(MsCocoProcessor): + def get_features_instance_dict(self, decode_rle: bool): + segmentation_feature = ( + ds.Image() + if decode_rle + else { + "counts": ds.Sequence(ds.Value("int64")), + "size": ds.Sequence(ds.Value("int32")), + } + ) + return { + "annotation_id": ds.Value("int64"), + "image_id": ds.Value("int64"), + "segmentation": segmentation_feature, + "area": ds.Value("float32"), + "iscrowd": ds.Value("bool"), + "bbox": ds.Sequence(ds.Value("float32"), length=4), + "category_id": ds.Value("int32"), + "category": { + "category_id": ds.Value("int32"), + "name": ds.ClassLabel( + num_classes=len(CATEGORIES), + names=CATEGORIES, + ), + "supercategory": ds.ClassLabel( + num_classes=len(SUPER_CATEGORIES), + names=SUPER_CATEGORIES, + ), + }, + } + + def get_features(self, decode_rle: bool) -> ds.Features: + features_dict = self.get_features_base_dict() + annotations = ds.Sequence( + self.get_features_instance_dict(decode_rle=decode_rle) + ) + features_dict.update({"annotations": annotations}) + return ds.Features(features_dict) + + def load_data( # type: ignore[override] + self, + ann_dicts: List[JsonDict], + images: Dict[ImageId, ImageData], + decode_rle: bool, + tqdm_desc: str = "Load instances data", + ) -> Dict[ImageId, List[InstancesAnnotationData]]: + annotations = defaultdict(list) + ann_dicts = sorted(ann_dicts, key=lambda d: d["image_id"]) + + for ann_dict in tqdm(ann_dicts, desc=tqdm_desc): + ann_data = InstancesAnnotationData.from_dict( + ann_dict, images=images, decode_rle=decode_rle + ) + annotations[ann_data.image_id].append(ann_data) + + return annotations + + def generate_examples( # type: ignore[override] + self, + image_dir: str, + images: Dict[ImageId, ImageData], + annotations: Dict[ImageId, List[InstancesAnnotationData]], + licenses: Dict[LicenseId, LicenseData], + categories: Dict[CategoryId, CategoryData], + ) -> Iterator[Tuple[int, InstanceExample]]: + for idx, image_id in enumerate(images.keys()): + image_data = images[image_id] + image_anns = annotations[image_id] + + if len(image_anns) < 1: + logger.warning(f"No annotation found for image id: {image_id}.") + continue + + image = self.load_image( + image_path=os.path.join(image_dir, image_data.file_name), + ) + example = asdict(image_data) + example["image"] = image + example["license"] = asdict(licenses[image_data.license_id]) + + example["annotations"] = [] + for ann in image_anns: + ann_dict = asdict(ann) + category = categories[ann.category_id] + ann_dict["category"] = asdict(category) + example["annotations"].append(ann_dict) + + yield idx, example # type: ignore diff --git a/license.py b/license.py new file mode 100644 index 0000000..9be8377 --- /dev/null +++ b/license.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass + +from .typehint import JsonDict, LicenseId + + +@dataclass +class LicenseData(object): + url: str + license_id: LicenseId + name: str + + @classmethod + def from_dict(cls, json_dict: JsonDict) -> "LicenseData": + return cls( + license_id=json_dict["id"], + url=json_dict["url"], + name=json_dict["name"], + ) diff --git a/person_keypoint.py b/person_keypoint.py new file mode 100644 index 0000000..de2bf78 --- /dev/null +++ b/person_keypoint.py @@ -0,0 +1,194 @@ +import os +from collections import defaultdict +from dataclasses import asdict, dataclass +from typing import Dict, Final, Iterator, List, Tuple, TypedDict + +import datasets as ds +from tqdm.auto import tqdm + +from .base_example import BaseExample +from .category import CategoryData +from .image import ImageData +from .instances import ( + InstanceAnnotationDict, + InstancesAnnotationData, + InstancesProcessor, +) +from .license import LicenseData +from .typehint import CategoryId, ImageId, JsonDict, LicenseId + +KEYPOINT_STATE: Final[List[str]] = ["unknown", "invisible", "visible"] + + +@dataclass +class PersonKeypoint(object): + x: int + y: int + v: int + state: str + + +@dataclass +class PersonKeypointsAnnotationData(InstancesAnnotationData): + num_keypoints: int + keypoints: List[PersonKeypoint] + + @classmethod + def v_keypoint_to_state(cls, keypoint_v: int) -> str: + return KEYPOINT_STATE[keypoint_v] + + @classmethod + def get_person_keypoints( + cls, flatten_keypoints: List[int], num_keypoints: int + ) -> List[PersonKeypoint]: + keypoints_x = flatten_keypoints[0::3] + keypoints_y = flatten_keypoints[1::3] + keypoints_v = flatten_keypoints[2::3] + assert len(keypoints_x) == len(keypoints_y) == len(keypoints_v) + + keypoints = [ + PersonKeypoint(x=x, y=y, v=v, state=cls.v_keypoint_to_state(v)) + for x, y, v in zip(keypoints_x, keypoints_y, keypoints_v) + ] + assert len([kp for kp in keypoints if kp.state != "unknown"]) == num_keypoints + return keypoints + + @classmethod + def from_dict( + cls, + json_dict: JsonDict, + images: Dict[ImageId, ImageData], + decode_rle: bool, + ) -> "PersonKeypointsAnnotationData": + segmentation = json_dict["segmentation"] + image_id = json_dict["image_id"] + image_data = images[image_id] + iscrowd = bool(json_dict["iscrowd"]) + + segmentation_mask = ( + cls.rle_segmentation_to_mask( + segmentation=segmentation, + iscrowd=iscrowd, + height=image_data.height, + width=image_data.width, + ) + if decode_rle + else cls.compress_rle( + segmentation=segmentation, + iscrowd=iscrowd, + height=image_data.height, + width=image_data.width, + ) + ) + flatten_keypoints = json_dict["keypoints"] + num_keypoints = json_dict["num_keypoints"] + keypoints = cls.get_person_keypoints(flatten_keypoints, num_keypoints) + + return cls( + # + # for AnnotationData + # + annotation_id=json_dict["id"], + image_id=image_id, + # + # for InstancesAnnotationData + # + segmentation=segmentation_mask, # type: ignore + area=json_dict["area"], + iscrowd=iscrowd, + bbox=json_dict["bbox"], + category_id=json_dict["category_id"], + # + # PersonKeypointsAnnotationData + # + num_keypoints=num_keypoints, + keypoints=keypoints, + ) + + +class KeypointDict(TypedDict): + x: int + y: int + v: int + state: str + + +class PersonKeypointAnnotationDict(InstanceAnnotationDict): + num_keypoints: int + keypoints: List[KeypointDict] + + +class PersonKeypointExample(BaseExample): + annotations: List[PersonKeypointAnnotationDict] + + +class PersonKeypointsProcessor(InstancesProcessor): + def get_features(self, decode_rle: bool) -> ds.Features: + features_dict = self.get_features_base_dict() + features_instance_dict = self.get_features_instance_dict(decode_rle=decode_rle) + features_instance_dict.update( + { + "keypoints": ds.Sequence( + { + "state": ds.Value("string"), + "x": ds.Value("int32"), + "y": ds.Value("int32"), + "v": ds.Value("int32"), + } + ), + "num_keypoints": ds.Value("int32"), + } + ) + annotations = ds.Sequence(features_instance_dict) + features_dict.update({"annotations": annotations}) + return ds.Features(features_dict) + + def load_data( # type: ignore[override] + self, + ann_dicts: List[JsonDict], + images: Dict[ImageId, ImageData], + decode_rle: bool, + tqdm_desc: str = "Load person keypoints data", + ) -> Dict[ImageId, List[PersonKeypointsAnnotationData]]: + annotations = defaultdict(list) + ann_dicts = sorted(ann_dicts, key=lambda d: d["image_id"]) + + for ann_dict in tqdm(ann_dicts, desc=tqdm_desc): + ann_data = PersonKeypointsAnnotationData.from_dict( + ann_dict, images=images, decode_rle=decode_rle + ) + annotations[ann_data.image_id].append(ann_data) + return annotations + + def generate_examples( # type: ignore[override] + self, + image_dir: str, + images: Dict[ImageId, ImageData], + annotations: Dict[ImageId, List[PersonKeypointsAnnotationData]], + licenses: Dict[LicenseId, LicenseData], + categories: Dict[CategoryId, CategoryData], + ) -> Iterator[Tuple[int, PersonKeypointExample]]: + for idx, image_id in enumerate(images.keys()): + image_data = images[image_id] + image_anns = annotations[image_id] + + if len(image_anns) < 1: + # If there are no persons in the image, + # no keypoint annotations will be assigned. + continue + + image = self.load_image( + image_path=os.path.join(image_dir, image_data.file_name), + ) + example = asdict(image_data) + example["image"] = image + example["license"] = asdict(licenses[image_data.license_id]) + + example["annotations"] = [] + for ann in image_anns: + ann_dict = asdict(ann) + category = categories[ann.category_id] + ann_dict["category"] = asdict(category) + example["annotations"].append(ann_dict) + + yield idx, example # type: ignore diff --git a/processor.py b/processor.py new file mode 100644 index 0000000..f0cfc6b --- /dev/null +++ b/processor.py @@ -0,0 +1,96 @@ +import abc +import json +import logging +from typing import Dict, List + +import datasets as ds +from PIL import Image +from tqdm.auto import tqdm + +from .annotation import AnnotationData +from .category import CategoryData +from .image import ImageData +from .license import LicenseData +from .typehint import CategoryId, ImageId, JsonDict, LicenseId, PilImage + +logger = logging.getLogger(__name__) + + +class MsCocoProcessor(object, metaclass=abc.ABCMeta): + def load_image(self, image_path: str) -> PilImage: + return Image.open(image_path) + + def load_annotation_json(self, ann_file_path: str) -> JsonDict: + logger.info(f"Load annotation json from {ann_file_path}") + with open(ann_file_path, "r") as rf: + ann_json = json.load(rf) + return ann_json + + def load_licenses_data( + self, license_dicts: List[JsonDict] + ) -> Dict[LicenseId, LicenseData]: + licenses = {} + for license_dict in license_dicts: + license_data = LicenseData.from_dict(license_dict) + licenses[license_data.license_id] = license_data + return licenses + + def load_images_data( + self, + image_dicts: List[JsonDict], + tqdm_desc: str = "Load images", + ) -> Dict[ImageId, ImageData]: + images = {} + for image_dict in tqdm(image_dicts, desc=tqdm_desc): + image_data = ImageData.from_dict(image_dict) + images[image_data.image_id] = image_data + return images + + def load_categories_data( + self, + category_dicts: List[JsonDict], + tqdm_desc: str = "Load categories", + ) -> Dict[CategoryId, CategoryData]: + categories = {} + for category_dict in tqdm(category_dicts, desc=tqdm_desc): + category_data = CategoryData.from_dict(category_dict) + categories[category_data.category_id] = category_data + return categories + + def get_features_base_dict(self): + return { + "image_id": ds.Value("int64"), + "image": ds.Image(), + "file_name": ds.Value("string"), + "coco_url": ds.Value("string"), + "height": ds.Value("int32"), + "width": ds.Value("int32"), + "date_captured": ds.Value("string"), + "flickr_url": ds.Value("string"), + "license_id": ds.Value("int32"), + "license": { + "url": ds.Value("string"), + "license_id": ds.Value("int8"), + "name": ds.Value("string"), + }, + } + + @abc.abstractmethod + def get_features(self, *args, **kwargs) -> ds.Features: + raise NotImplementedError + + @abc.abstractmethod + def load_data(self, ann_dicts: List[JsonDict], tqdm_desc: str = "", **kwargs): + assert tqdm_desc != "", "tqdm_desc must be provided." + raise NotImplementedError + + @abc.abstractmethod + def generate_examples( + self, + image_dir: str, + images: Dict[ImageId, ImageData], + annotations: Dict[ImageId, List[AnnotationData]], + licenses: Dict[LicenseId, LicenseData], + **kwargs, + ): + raise NotImplementedError diff --git a/rle.py b/rle.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/MSCOCO_test.py b/tests/MSCOCO_test.py index 56fd48d..48815e4 100644 --- a/tests/MSCOCO_test.py +++ b/tests/MSCOCO_test.py @@ -3,12 +3,12 @@ import datasets as ds import pytest -from MSCOCO import CATEGORIES, SUPER_CATEGORIES +from const import CATEGORIES, SUPER_CATEGORIES @pytest.fixture def dataset_path() -> str: - return "MSCOCO.py" + return "./MSCOCO.py" @pytest.mark.skipif( @@ -54,7 +54,10 @@ def test_load_dataset( year=dataset_year, coco_task=coco_task, decode_rle=decode_rle, + trust_remote_code=True, ) + assert isinstance(dataset, ds.DatasetDict) + assert dataset["train"].num_rows == expected_num_train assert dataset["validation"].num_rows == expected_num_validation diff --git a/typehint.py b/typehint.py new file mode 100644 index 0000000..8c9c289 --- /dev/null +++ b/typehint.py @@ -0,0 +1,30 @@ +from typing import Annotated, Any, Dict, List, Literal, Tuple, TypedDict + +from PIL.Image import Image + +JsonDict = Dict[str, Any] +ImageId = int +AnnotationId = int +LicenseId = int +CategoryId = int +Bbox = Tuple[float, float, float, float] + +MscocoSplits = Literal["train", "val", "test"] + +PilImage = Annotated[Image, "Pillow Image"] + + +class UncompressedRLE(TypedDict): + counts: List[int] + size: Tuple[int, int] + + +class CompressedRLE(TypedDict): + counts: bytes + size: Tuple[int, int] + + +class CategoryDict(TypedDict): + category_id: CategoryId + name: str + supercategory: str