diff --git a/README.md b/README.md index 34f7f0b8f90..6f750478492 100644 --- a/README.md +++ b/README.md @@ -1,455 +1,92 @@ -
- -
 
-
- OpenMMLab website - - - HOT - - -      - OpenMMLab platform - - - TRY IT OUT - - -
-
 
+# MMDetection with OA-Mix -[![PyPI](https://img.shields.io/pypi/v/mmdet)](https://pypi.org/project/mmdet) -[![docs](https://img.shields.io/badge/docs-latest-blue)](https://mmdetection.readthedocs.io/en/latest/) -[![badge](https://github.com/open-mmlab/mmdetection/workflows/build/badge.svg)](https://github.com/open-mmlab/mmdetection/actions) -[![codecov](https://codecov.io/gh/open-mmlab/mmdetection/branch/main/graph/badge.svg)](https://codecov.io/gh/open-mmlab/mmdetection) -[![license](https://img.shields.io/github/license/open-mmlab/mmdetection.svg)](https://github.com/open-mmlab/mmdetection/blob/main/LICENSE) -[![open issues](https://isitmaintained.com/badge/open/open-mmlab/mmdetection.svg)](https://github.com/open-mmlab/mmdetection/issues) -[![issue resolution](https://isitmaintained.com/badge/resolution/open-mmlab/mmdetection.svg)](https://github.com/open-mmlab/mmdetection/issues) -[![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_demo.svg)](https://openxlab.org.cn/apps?search=mmdet) - -[📘Documentation](https://mmdetection.readthedocs.io/en/latest/) | -[🛠️Installation](https://mmdetection.readthedocs.io/en/latest/get_started.html) | -[👀Model Zoo](https://mmdetection.readthedocs.io/en/latest/model_zoo.html) | -[🆕Update News](https://mmdetection.readthedocs.io/en/latest/notes/changelog.html) | -[🚀Ongoing Projects](https://github.com/open-mmlab/mmdetection/projects) | -[🤔Reporting Issues](https://github.com/open-mmlab/mmdetection/issues/new/choose) - -
- -
- -English | [简体中文](README_zh-CN.md) - -
- -
- - - - - - - - - - - - - - - - - -
- -
- +
+
## Introduction -MMDetection is an open source object detection toolbox based on PyTorch. It is -a part of the [OpenMMLab](https://openmmlab.com/) project. - -The main branch works with **PyTorch 1.8+**. - - - -
-Major features - -- **Modular Design** - - We decompose the detection framework into different components and one can easily construct a customized object detection framework by combining different modules. - -- **Support of multiple tasks out of box** - - The toolbox directly supports multiple detection tasks such as **object detection**, **instance segmentation**, **panoptic segmentation**, and **semi-supervised object detection**. - -- **High efficiency** - - All basic bbox and mask operations run on GPUs. The training speed is faster than or comparable to other codebases, including [Detectron2](https://github.com/facebookresearch/detectron2), [maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark) and [SimpleDet](https://github.com/TuSimple/simpledet). - -- **State of the art** - - The toolbox stems from the codebase developed by the *MMDet* team, who won [COCO Detection Challenge](http://cocodataset.org/#detection-leaderboard) in 2018, and we keep pushing it forward. - The newly released [RTMDet](configs/rtmdet) also obtains new state-of-the-art results on real-time instance segmentation and rotated object detection tasks and the best parameter-accuracy trade-off on object detection. - -
- -Apart from MMDetection, we also released [MMEngine](https://github.com/open-mmlab/mmengine) for model training and [MMCV](https://github.com/open-mmlab/mmcv) for computer vision research, which are heavily depended on by this toolbox. - -## What's New - -💎 **We have released the pre-trained weights for MM-Grounding-DINO Swin-B and Swin-L, welcome to try and give feedback.** - -### Highlight +This repository is a fork of the [mmdetection](https://github.com/open-mmlab/mmdetection) toolbox with the implementation of OA-Mix, +a novel data augmentation technique designed to improve domain generalization in object detection. +OA-Mix is part of the [Object-Aware Domain Generalization (OA-DG)](https://github.com/woojulee24/OA-DG) framework, +introduced in the paper [Object-Aware Domain Generalization for Object Detection](https://ojs.aaai.org/index.php/AAAI/article/view/28076). -**v3.3.0** was released in 5/1/2024: +This repository has been created to showcase the OA-Mix method. +The method enhances model robustness against domain shifts by generating diverse multi-domain data while preserving object annotations. -**[MM-Grounding-DINO: An Open and Comprehensive Pipeline for Unified Object Grounding and Detection](https://arxiv.org/abs/2401.02361)** +For more information on the details of OA-Mix and its use cases, +please refer to the paper [Object-Aware Domain Generalization for Object Detection](https://ojs.aaai.org/index.php/AAAI/article/view/28076), presented at AAAI 2024. -Grounding DINO is a grounding pre-training model that unifies 2d open vocabulary object detection and phrase grounding, with wide applications. However, its training part has not been open sourced. Therefore, we propose MM-Grounding-DINO, which not only serves as an open source replication version of Grounding DINO, but also achieves significant performance improvement based on reconstructed data types, exploring different dataset combinations and initialization strategies. Moreover, we conduct evaluations from multiple dimensions, including OOD, REC, Phrase Grounding, OVD, and Fine-tune, to fully excavate the advantages and disadvantages of Grounding pre-training, hoping to provide inspiration for future work. +## Example of OA-Mix -code: [mm_grounding_dino/README.md](configs/mm_grounding_dino/README.md) +Below is an example showing the results of OA-Mix:
- +
-We are excited to announce our latest work on real-time object recognition tasks, **RTMDet**, a family of fully convolutional single-stage detectors. RTMDet not only achieves the best parameter-accuracy trade-off on object detection from tiny to extra-large model sizes but also obtains new state-of-the-art performance on instance segmentation and rotated object detection tasks. Details can be found in the [technical report](https://arxiv.org/abs/2212.07784). Pre-trained models are [here](configs/rtmdet). +## Performance Improvement with OA-Mix -[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/real-time-instance-segmentation-on-mscoco)](https://paperswithcode.com/sota/real-time-instance-segmentation-on-mscoco?p=rtmdet-an-empirical-study-of-designing-real) -[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/object-detection-in-aerial-images-on-dota-1)](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-dota-1?p=rtmdet-an-empirical-study-of-designing-real) -[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/object-detection-in-aerial-images-on-hrsc2016)](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-hrsc2016?p=rtmdet-an-empirical-study-of-designing-real) +Below is a performance comparison between a baseline object detection model and the same model with OA-Mix applied: -| Task | Dataset | AP | FPS(TRT FP16 BS1 3090) | -| ------------------------ | ------- | ------------------------------------ | ---------------------- | -| Object Detection | COCO | 52.8 | 322 | -| Instance Segmentation | COCO | 44.6 | 188 | -| Rotated Object Detection | DOTA | 78.9(single-scale)/81.3(multi-scale) | 121 | +| Model | Dataset | Backbone | mAP | Gauss. | Shot | Impulse | Defocus | Glass | Motion | Zoom | Snow | Frost | Fog | Bright | Contrast | Elastic | Pixel | JPEG | mPC | +| :-------------------: | :----------: | :------: | :--: | :----: | :--: | :-----: | :-----: | :---: | :----: | :--: | :--: | :---: | :--: | :----: | :------: | :-----: | ----- | :--: | :--: | +| Faster R-CNN | Cityscapes-C | R-50-FPN | 42.2 | 0.5 | 1.1 | 1.1 | 17.2 | 16.5 | 18.3 | 2.1 | 2.2 | 12.3 | 29.8 | 32.0 | 24.1 | 40.1 | 18.7 | 15.1 | 15.4 | +| Faster R-CNN + OA-Mix | Cityscapes-C | R-50-FPN | 42.7 | 7.2 | 9.6 | 7.7 | 22.8 | 18.8 | 21.9 | 5.4 | 5.2 | 23.6 | 37.3 | 38.7 | 31.9 | 40.2 | 22.2 | 20.2 | 20.8 | -
- -
- -## Installation - -Please refer to [Installation](https://mmdetection.readthedocs.io/en/latest/get_started.html) for installation instructions. - -## Getting Started - -Please see [Overview](https://mmdetection.readthedocs.io/en/latest/get_started.html) for the general introduction of MMDetection. - -For detailed user guides and advanced guides, please refer to our [documentation](https://mmdetection.readthedocs.io/en/latest/): - -- User Guides - -
- - - [Train & Test](https://mmdetection.readthedocs.io/en/latest/user_guides/index.html#train-test) - - [Learn about Configs](https://mmdetection.readthedocs.io/en/latest/user_guides/config.html) - - [Inference with existing models](https://mmdetection.readthedocs.io/en/latest/user_guides/inference.html) - - [Dataset Prepare](https://mmdetection.readthedocs.io/en/latest/user_guides/dataset_prepare.html) - - [Test existing models on standard datasets](https://mmdetection.readthedocs.io/en/latest/user_guides/test.html) - - [Train predefined models on standard datasets](https://mmdetection.readthedocs.io/en/latest/user_guides/train.html) - - [Train with customized datasets](https://mmdetection.readthedocs.io/en/latest/user_guides/train.html#train-with-customized-datasets) - - [Train with customized models and standard datasets](https://mmdetection.readthedocs.io/en/latest/user_guides/new_model.html) - - [Finetuning Models](https://mmdetection.readthedocs.io/en/latest/user_guides/finetune.html) - - [Test Results Submission](https://mmdetection.readthedocs.io/en/latest/user_guides/test_results_submission.html) - - [Weight initialization](https://mmdetection.readthedocs.io/en/latest/user_guides/init_cfg.html) - - [Use a single stage detector as RPN](https://mmdetection.readthedocs.io/en/latest/user_guides/single_stage_as_rpn.html) - - [Semi-supervised Object Detection](https://mmdetection.readthedocs.io/en/latest/user_guides/semi_det.html) - - [Useful Tools](https://mmdetection.readthedocs.io/en/latest/user_guides/index.html#useful-tools) - -
- -- Advanced Guides - -
+The model was evaluated using the [robust detection benchmark](https://github.com/bethgelab/robust-detection-benchmark), which can be run using the [test_robustness.py](tools/analysis_tools/test_robustness.py) script provided by mmdetection. - - [Basic Concepts](https://mmdetection.readthedocs.io/en/latest/advanced_guides/index.html#basic-concepts) - - [Component Customization](https://mmdetection.readthedocs.io/en/latest/advanced_guides/index.html#component-customization) - - [How to](https://mmdetection.readthedocs.io/en/latest/advanced_guides/index.html#how-to) +## How to use? -
+Modify the `train_pipeline` in the configuration file as follows: -We also provide object detection colab tutorial [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](demo/MMDet_Tutorial.ipynb) and instance segmentation colab tutorial [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](demo/MMDet_InstanceSeg_Tutorial.ipynb). - -To migrate from MMDetection 2.x, please refer to [migration](https://mmdetection.readthedocs.io/en/latest/migration.html). - -## Overview of Benchmark and Model Zoo - -Results and models are available in the [model zoo](docs/en/model_zoo.md). - -
- Architectures -
- - - - - - - - - - - - - - - - - -
- Object Detection - - Instance Segmentation - - Panoptic Segmentation - - Other -
- - - - - - - -
  • Contrastive Learning
  • - - -
  • Distillation
  • - -
  • Semi-Supervised Object Detection
  • - - -
    - -
    - Components -
    - - - - - - - - - - - - - - - - - -
    - Backbones - - Necks - - Loss - - Common -
    - - - - - - - -
    - -Some other methods are also supported in [projects using MMDetection](./docs/en/notes/projects.md). - -## FAQ - -Please refer to [FAQ](docs/en/notes/faq.md) for frequently asked questions. +```python +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args=None), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='RandomResize', + scale=[(2048, 800), (2048, 1024)], + keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + dict(type='OAMix', version='oamix') + dict(type='PackDetInputs') +] +``` -## Contributing +Alternatively, you can directly call the OAMix class as shown below: -We appreciate all contributions to improve MMDetection. Ongoing projects can be found in out [GitHub Projects](https://github.com/open-mmlab/mmdetection/projects). Welcome community users to participate in these projects. Please refer to [CONTRIBUTING.md](.github/CONTRIBUTING.md) for the contributing guideline. +```python +import numpy as np +from mmdet.datasets.transforms.oa_mix import OAMix -## Acknowledgement +# Generate random data +img = np.random.randint(0, 256, (427, 640, 3)).astype(np.uint8) +gt_bboxes = np.random.randn(3, 4).astype(np.float32) -MMDetection is an open source project that is contributed by researchers and engineers from various colleges and companies. We appreciate all the contributors who implement their methods or add new features, as well as users who give valuable feedbacks. -We wish that the toolbox and benchmark could serve the growing research community by providing a flexible toolkit to reimplement existing methods and develop their own new detectors. +# Apply OA-Mix +oamix = OAMix() +img_aug = oamix({'img': img, 'gt_bboxes': gt_bboxes}) +``` ## Citation If you use this toolbox or benchmark in your research, please cite this project. ``` -@article{mmdetection, - title = {{MMDetection}: Open MMLab Detection Toolbox and Benchmark}, - author = {Chen, Kai and Wang, Jiaqi and Pang, Jiangmiao and Cao, Yuhang and - Xiong, Yu and Li, Xiaoxiao and Sun, Shuyang and Feng, Wansen and - Liu, Ziwei and Xu, Jiarui and Zhang, Zheng and Cheng, Dazhi and - Zhu, Chenchen and Cheng, Tianheng and Zhao, Qijie and Li, Buyu and - Lu, Xin and Zhu, Rui and Wu, Yue and Dai, Jifeng and Wang, Jingdong - and Shi, Jianping and Ouyang, Wanli and Loy, Chen Change and Lin, Dahua}, - journal= {arXiv preprint arXiv:1906.07155}, - year={2019} +@inproceedings{lee2024object, + title={Object-Aware Domain Generalization for Object Detection}, + author={Lee, Wooju and Hong, Dasol and Lim, Hyungtae and Myung, Hyun}, + booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, + volume={38}, + number={4}, + pages={2947--2955}, + year={2024} } ``` -## License - -This project is released under the [Apache 2.0 license](LICENSE). +## Contact -## Projects in OpenMMLab +This repo is currently maintained by Wooju Lee ([@WoojuLee24](https://github.com/WoojuLee24)) and Dasol Hong ([@dazory](https://github.com/dazory)). -- [MMEngine](https://github.com/open-mmlab/mmengine): OpenMMLab foundational library for training deep learning models. -- [MMCV](https://github.com/open-mmlab/mmcv): OpenMMLab foundational library for computer vision. -- [MMPreTrain](https://github.com/open-mmlab/mmpretrain): OpenMMLab pre-training toolbox and benchmark. -- [MMagic](https://github.com/open-mmlab/mmagic): Open**MM**Lab **A**dvanced, **G**enerative and **I**ntelligent **C**reation toolbox. -- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab detection toolbox and benchmark. -- [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab's next-generation platform for general 3D object detection. -- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab rotated object detection toolbox and benchmark. -- [MMYOLO](https://github.com/open-mmlab/mmyolo): OpenMMLab YOLO series toolbox and benchmark. -- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab semantic segmentation toolbox and benchmark. -- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab text detection, recognition, and understanding toolbox. -- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab pose estimation toolbox and benchmark. -- [MMHuman3D](https://github.com/open-mmlab/mmhuman3d): OpenMMLab 3D human parametric model toolbox and benchmark. -- [MMSelfSup](https://github.com/open-mmlab/mmselfsup): OpenMMLab self-supervised learning toolbox and benchmark. -- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab model compression toolbox and benchmark. -- [MMFewShot](https://github.com/open-mmlab/mmfewshot): OpenMMLab fewshot learning toolbox and benchmark. -- [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab's next-generation action understanding toolbox and benchmark. -- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab video perception toolbox and benchmark. -- [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab optical flow toolbox and benchmark. -- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab image and video editing toolbox. -- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab image and video generative models toolbox. -- [MMDeploy](https://github.com/open-mmlab/mmdeploy): OpenMMLab model deployment framework. -- [MIM](https://github.com/open-mmlab/mim): MIM installs OpenMMLab packages. -- [MMEval](https://github.com/open-mmlab/mmeval): A unified evaluation library for multiple machine learning libraries. -- [Playground](https://github.com/open-mmlab/playground): A central hub for gathering and showcasing amazing projects built upon OpenMMLab. +For questions regarding mmdetection please visit the [official repository](https://github.com/open-mmlab/mmdetection). diff --git a/configs/oamix/faster-rcnn_r50_fpn_1x_cityscapes_oamix.py b/configs/oamix/faster-rcnn_r50_fpn_1x_cityscapes_oamix.py new file mode 100755 index 00000000000..973f072c84f --- /dev/null +++ b/configs/oamix/faster-rcnn_r50_fpn_1x_cityscapes_oamix.py @@ -0,0 +1,64 @@ +_base_ = [ + '../_base_/models/faster-rcnn_r50_fpn.py', + '../_base_/datasets/cityscapes_detection.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_1x.py' +] + +backend_args = None +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args=backend_args), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='RandomResize', + scale=[(2048, 800), (2048, 1024)], + keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + dict( + type='OAMix', + version='oamix', + box_scale=(0.05, 0.3), + box_ratio=(3, 0.33), + sigma_ratio=0.2, + score_thresh=10), + dict(type='PackDetInputs') +] +train_dataloader = dict( + num_workers=8, dataset=dict(dataset=dict(pipeline=train_pipeline))) + +# Model +model = dict( + backbone=dict(init_cfg=None), + roi_head=dict( + bbox_head=dict( + num_classes=8, + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)))) + +# optimizer +# lr is set for a batch size of 8 +optim_wrapper = dict(optimizer=dict(lr=0.01)) + +# learning rate +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500), + dict( + type='MultiStepLR', + begin=0, + end=8, + by_epoch=True, + # [7] yields higher performance than [6] + milestones=[7], + gamma=0.1) +] + +# actual epoch = 8 * 8 = 64 +train_cfg = dict(max_epochs=8) + +# For better, more stable performance initialize from COCO +load_from = 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth' # noqa + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# USER SHOULD NOT CHANGE ITS VALUES. +# base_batch_size = (8 GPUs) x (1 samples per GPU) +# TODO: support auto scaling lr +# auto_scale_lr = dict(base_batch_size=8) diff --git a/mmdet/datasets/transforms/__init__.py b/mmdet/datasets/transforms/__init__.py index ab3478feb00..022acc165ec 100644 --- a/mmdet/datasets/transforms/__init__.py +++ b/mmdet/datasets/transforms/__init__.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from .augment_wrappers import AutoAugment, RandAugment from .colorspace import (AutoContrast, Brightness, Color, ColorTransform, - Contrast, Equalize, Invert, Posterize, Sharpness, - Solarize, SolarizeAdd) + Contrast, Equalize, Invert, Invert4Mix, Posterize, + Sharpness, Solarize, SolarizeAdd) from .formatting import (ImageToTensor, PackDetInputs, PackReIDInputs, PackTrackInputs, ToTensor, Transpose) from .frame_sampling import BaseFrameSample, UniformRefFrameSample @@ -13,6 +13,7 @@ LoadEmptyAnnotations, LoadImageFromNDArray, LoadMultiChannelImageFromFiles, LoadPanopticAnnotations, LoadProposals, LoadTrackAnnotations) +from .oa_mix import OAMix from .text_transformers import LoadTextAnnotations, RandomSamplingNegPos from .transformers_glip import GTBoxSubOne_GLIP, RandomFlip_GLIP from .transforms import (Albu, CachedMixUp, CachedMosaic, CopyPaste, CutOut, @@ -25,21 +26,73 @@ from .wrappers import MultiBranch, ProposalBroadcaster, RandomOrder __all__ = [ - 'PackDetInputs', 'ToTensor', 'ImageToTensor', 'Transpose', - 'LoadImageFromNDArray', 'LoadAnnotations', 'LoadPanopticAnnotations', - 'LoadMultiChannelImageFromFiles', 'LoadProposals', 'Resize', 'RandomFlip', - 'RandomCrop', 'SegRescale', 'MinIoURandomCrop', 'Expand', - 'PhotoMetricDistortion', 'Albu', 'InstaBoost', 'RandomCenterCropPad', - 'AutoAugment', 'CutOut', 'ShearX', 'ShearY', 'Rotate', 'Color', 'Equalize', - 'Brightness', 'Contrast', 'TranslateX', 'TranslateY', 'RandomShift', - 'Mosaic', 'MixUp', 'RandomAffine', 'YOLOXHSVRandomAug', 'CopyPaste', - 'FilterAnnotations', 'Pad', 'GeomTransform', 'ColorTransform', - 'RandAugment', 'Sharpness', 'Solarize', 'SolarizeAdd', 'Posterize', - 'AutoContrast', 'Invert', 'MultiBranch', 'RandomErasing', - 'LoadEmptyAnnotations', 'RandomOrder', 'CachedMosaic', 'CachedMixUp', - 'FixShapeResize', 'ProposalBroadcaster', 'InferencerLoader', - 'LoadTrackAnnotations', 'BaseFrameSample', 'UniformRefFrameSample', - 'PackTrackInputs', 'PackReIDInputs', 'FixScaleResize', - 'ResizeShortestEdge', 'GTBoxSubOne_GLIP', 'RandomFlip_GLIP', - 'RandomSamplingNegPos', 'LoadTextAnnotations' + 'PackDetInputs', + 'ToTensor', + 'ImageToTensor', + 'Transpose', + 'LoadImageFromNDArray', + 'LoadAnnotations', + 'LoadPanopticAnnotations', + 'LoadMultiChannelImageFromFiles', + 'LoadProposals', + 'Resize', + 'RandomFlip', + 'RandomCrop', + 'SegRescale', + 'MinIoURandomCrop', + 'Expand', + 'PhotoMetricDistortion', + 'Albu', + 'InstaBoost', + 'RandomCenterCropPad', + 'AutoAugment', + 'CutOut', + 'ShearX', + 'ShearY', + 'Rotate', + 'Color', + 'Equalize', + 'Brightness', + 'Contrast', + 'TranslateX', + 'TranslateY', + 'RandomShift', + 'Mosaic', + 'MixUp', + 'RandomAffine', + 'YOLOXHSVRandomAug', + 'CopyPaste', + 'FilterAnnotations', + 'Pad', + 'GeomTransform', + 'ColorTransform', + 'RandAugment', + 'Sharpness', + 'Solarize', + 'SolarizeAdd', + 'Posterize', + 'AutoContrast', + 'Invert', + 'Invert4Mix', + 'MultiBranch', + 'RandomErasing', + 'LoadEmptyAnnotations', + 'RandomOrder', + 'CachedMosaic', + 'CachedMixUp', + 'FixShapeResize', + 'ProposalBroadcaster', + 'InferencerLoader', + 'LoadTrackAnnotations', + 'BaseFrameSample', + 'UniformRefFrameSample', + 'PackTrackInputs', + 'PackReIDInputs', + 'FixScaleResize', + 'ResizeShortestEdge', + 'GTBoxSubOne_GLIP', + 'RandomFlip_GLIP', + 'RandomSamplingNegPos', + 'LoadTextAnnotations', + 'OAMix', ] diff --git a/mmdet/datasets/transforms/augment_wrappers.py b/mmdet/datasets/transforms/augment_wrappers.py index 19fae6efdf6..d796b561159 100644 --- a/mmdet/datasets/transforms/augment_wrappers.py +++ b/mmdet/datasets/transforms/augment_wrappers.py @@ -77,9 +77,9 @@ def level_to_mag(level: Optional[int], min_mag: float, max_mag: float) -> float: """Map from level to magnitude.""" if level is None: - return round(np.random.rand() * (max_mag - min_mag) + min_mag, 1) + return round(np.random.rand() * (max_mag - min_mag) + min_mag, 2) else: - return round(level / _MAX_LEVEL * (max_mag - min_mag) + min_mag, 1) + return round(level / _MAX_LEVEL * (max_mag - min_mag) + min_mag, 2) @TRANSFORMS.register_module() diff --git a/mmdet/datasets/transforms/colorspace.py b/mmdet/datasets/transforms/colorspace.py index e0ba2e97c7e..4b830fe560a 100644 --- a/mmdet/datasets/transforms/colorspace.py +++ b/mmdet/datasets/transforms/colorspace.py @@ -491,3 +491,33 @@ def _transform_img(self, results: dict, mag: float) -> None: """Invert the image.""" img = results['img'] results['img'] = mmcv.iminvert(img).astype(img.dtype) + + +@TRANSFORMS.register_module() +class Invert4Mix(ColorTransform): + """Invert and translate images. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + prob (float): The probability for performing invert therefore should + be in range [0, 1]. Defaults to 1.0. + level (int, optional): No use for Invert transformation. + Defaults to None. + min_mag (float): No use for Invert transformation. Defaults to 0.1. + max_mag (float): No use for Invert transformation. Defaults to 1.9. + """ + + def _transform_img(self, results: dict, mag: float) -> None: + """Invert the image.""" + img = results['img'] + img = mmcv.iminvert(img).astype(img.dtype) + img = mmcv.imtranslate(img, 1, 'horizontal') + img = mmcv.imtranslate(img, 1, 'vertical') + results['img'] = img.astype(img.dtype) diff --git a/mmdet/datasets/transforms/geometric.py b/mmdet/datasets/transforms/geometric.py old mode 100644 new mode 100755 index d2cd6be258f..eaa7eadcd67 --- a/mmdet/datasets/transforms/geometric.py +++ b/mmdet/datasets/transforms/geometric.py @@ -752,3 +752,360 @@ def _transform_seg(self, results: dict, mag: float) -> None: direction='vertical', border_value=self.seg_ignore_label, interpolation='nearest') + + +@TRANSFORMS.register_module() +class BBoxShearX(ShearX): + + def _transform_img(self, results: dict, mag: float) -> None: + img_orig = results['img'].copy() + (h_img, w_img, c_img) = img_orig.shape + + img = np.zeros_like(results['img'], dtype=np.float32) + for idx, (bbox, + mask) in enumerate(zip(results['bboxes'], results['masks'])): + cy = (bbox[1] + bbox[3]) / 2 + + shear_matrix = np.array([[1, mag, -mag * cy], [0, 1, 0]], + dtype=np.float32) + shear_img = cv2.warpAffine( + img_orig, + shear_matrix, (w_img, h_img), + borderValue=tuple([0] * 3), + flags=cv2.INTER_LINEAR) + img = (1.0 - mask) * img + mask * shear_img + mask_max = np.max(results['masks'], axis=0) + img = (1.0 - mask_max) * img_orig + mask_max * img + + results['img'] = img + + def _transform_masks(self, results: dict, mag: float) -> None: + pass + + def _transform_seg(self, results: dict, mag: float) -> None: + pass + + +@TRANSFORMS.register_module() +class BBoxShearY(ShearY): + + def _transform_img(self, results: dict, mag: float) -> None: + img_orig = results['img'].copy() + (h_img, w_img, c_img) = img_orig.shape + + img = np.zeros_like(results['img'], dtype=np.float32) + for idx, (bbox, + mask) in enumerate(zip(results['bboxes'], results['masks'])): + cx = (bbox[0] + bbox[2]) / 2 + + shear_matrix = np.float32([[1, 0, 0], [mag, 1, -mag * cx]]) + shear_img = cv2.warpAffine( + img_orig, + shear_matrix, (w_img, h_img), + borderValue=tuple([0] * 3), + flags=cv2.INTER_LINEAR) + img = (1.0 - mask) * img + mask * shear_img + mask_max = np.max(results['masks'], axis=0) + img = (1.0 - mask_max) * img_orig + mask_max * img + + results['img'] = img + + def _transform_masks(self, results: dict, mag: float) -> None: + pass + + def _transform_seg(self, results: dict, mag: float) -> None: + pass + + +@TRANSFORMS.register_module() +class BBoxRotate(Rotate): + + def _transform_img(self, results: dict, mag: float) -> None: + img_orig = results['img'].copy() + (h_img, w_img, c_img) = img_orig.shape + + img = np.zeros_like(results['img'], dtype=np.float32) + for idx, (bbox, + mask) in enumerate(zip(results['bboxes'], results['masks'])): + cx = (bbox[0] + bbox[2]) / 2 + cy = (bbox[1] + bbox[3]) / 2 + + translate_matrix = np.float32([[1, 0, w_img // 2 - cx], + [0, 1, h_img // 2 - cy]]) + translated_img = cv2.warpAffine( + img_orig, + translate_matrix, (w_img, h_img), + borderValue=tuple([0] * 3), + flags=cv2.INTER_LINEAR) + """Rotate the image.""" + rotated_img = mmcv.imrotate( + translated_img, + mag, + border_value=self.img_border_value, + interpolation=self.interpolation) + translate_matrix = np.float32([[1, 0, -w_img // 2 + cx], + [0, 1, -h_img // 2 + cy]]) + rotated_img = cv2.warpAffine( + rotated_img, + translate_matrix, (w_img, h_img), + borderValue=tuple([0] * 3), + flags=cv2.INTER_LINEAR) + + img = (1.0 - mask) * img + mask * rotated_img + mask_max = np.max(results['masks'], axis=0) + img = (1.0 - mask_max) * img_orig + mask_max * img + + results['img'] = img + + def _transform_masks(self, results: dict, mag: float) -> None: + pass + + def _transform_seg(self, results: dict, mag: float) -> None: + pass + + +@TRANSFORMS.register_module() +class BBoxTranslateX(TranslateX): + + def __init__(self, **kwargs): + super().__init__(img_border_value=0, **kwargs) + + def _transform_img(self, results: dict, mag: float) -> None: + img_orig = results['img'].copy() + + img = np.zeros_like(results['img'], dtype=np.float32) + for idx, (bbox, + mask) in enumerate(zip(results['bboxes'], results['masks'])): + w = bbox[2] - bbox[0] + """Translate the image horizontally.""" + _mag = int(w * mag) + translated_img = mmcv.imtranslate( + img_orig, + _mag, + direction='horizontal', + border_value=self.img_border_value, + interpolation=self.interpolation) + + img = (1.0 - mask) * img + mask * translated_img + mask_max = np.max(results['masks'], axis=0) + img = (1.0 - mask_max) * img_orig + mask_max * img + + results['img'] = img + + def _transform_masks(self, results: dict, mag: float) -> None: + pass + + def _transform_seg(self, results: dict, mag: float) -> None: + pass + + +@TRANSFORMS.register_module() +class BBoxTranslateY(TranslateY): + + def __init__(self, **kwargs): + super().__init__(img_border_value=0, **kwargs) + + def _transform_img(self, results: dict, mag: float) -> None: + img_orig = results['img'].copy() + + img = np.zeros_like(results['img'], dtype=np.float32) + for idx, (bbox, + mask) in enumerate(zip(results['bboxes'], results['masks'])): + h = bbox[3] - bbox[1] + """Translate the image horizontally.""" + _mag = int(h * mag) + translated_img = mmcv.imtranslate( + img_orig, + _mag, + direction='vertical', + border_value=self.img_border_value, + interpolation=self.interpolation) + + img = (1.0 - mask) * img + mask * translated_img + mask_max = np.max(results['masks'], axis=0) + img = (1.0 - mask_max) * img_orig + mask_max * img + + results['img'] = img + + def _transform_masks(self, results: dict, mag: float) -> None: + pass + + def _transform_seg(self, results: dict, mag: float) -> None: + pass + + +@TRANSFORMS.register_module() +class BgShearX(ShearX): + + def __init__(self, **kwargs): + super().__init__(img_border_value=0, **kwargs) + + def _transform_img(self, results: dict, mag: float) -> None: + img_orig = results['img'].copy() + (h_img, w_img, c_img) = img_orig.shape + """Shear the image horizontally.""" + mask_max = np.max(results['masks'], axis=0) + + shear_matrix = np.array([[1, mag, -mag * h_img // 2], [0, 1, 0]], + dtype=np.float32) + sheared_mask = cv2.warpAffine( + mask_max, + shear_matrix, (w_img, h_img), + borderValue=tuple([0] * 3), + flags=cv2.INTER_LINEAR) + sheared_img = cv2.warpAffine( + img_orig, + shear_matrix, (w_img, h_img), + borderValue=tuple([0] * 3), + flags=cv2.INTER_LINEAR) + + union_mask = np.max([mask_max, sheared_mask], axis=0) + img = (1.0 - union_mask) * sheared_img + union_mask * img_orig + + results['img'] = img + + def _transform_masks(self, results: dict, mag: float) -> None: + pass + + def _transform_seg(self, results: dict, mag: float) -> None: + pass + + +@TRANSFORMS.register_module() +class BgShearY(ShearY): + + def __init__(self, **kwargs): + super().__init__(img_border_value=0, **kwargs) + + def _transform_img(self, results: dict, mag: float) -> None: + img_orig = results['img'].copy() + (h_img, w_img, c_img) = img_orig.shape + """Shear the image vertically.""" + mask_max = np.max(results['masks'], axis=0) + shear_matrix = np.array([[1, 0, 0], [mag, 1, -mag * w_img // 2]], + dtype=np.float32) + sheared_mask = cv2.warpAffine( + mask_max, + shear_matrix, (w_img, h_img), + borderValue=tuple([0] * 3), + flags=cv2.INTER_LINEAR) + sheared_img = cv2.warpAffine( + img_orig, + shear_matrix, (w_img, h_img), + borderValue=tuple([0] * 3), + flags=cv2.INTER_LINEAR) + + union_mask = np.max([mask_max, sheared_mask], axis=0) + img = (1.0 - union_mask) * sheared_img + union_mask * img_orig + + results['img'] = img + + def _transform_masks(self, results: dict, mag: float) -> None: + pass + + def _transform_seg(self, results: dict, mag: float) -> None: + pass + + +@TRANSFORMS.register_module() +class BgRotate(Rotate): + + def __init__(self, **kwargs): + super().__init__(img_border_value=0, **kwargs) + + def _transform_img(self, results: dict, mag: float) -> None: + img_orig = results['img'].copy() + """Rotate the image.""" + mask_max = np.max(results['masks'], axis=0) + rotated_mask = mmcv.imrotate( + mask_max, mag, border_value=0, interpolation=self.interpolation) + rotated_img = mmcv.imrotate( + img_orig, + mag, + border_value=self.img_border_value, + interpolation=self.interpolation) + + union_mask = np.max([mask_max, rotated_mask], axis=0) + img = (1.0 - union_mask) * rotated_img + union_mask * img_orig + + results['img'] = img + + def _transform_masks(self, results: dict, mag: float) -> None: + pass + + def _transform_seg(self, results: dict, mag: float) -> None: + pass + + +@TRANSFORMS.register_module() +class BgTranslateX(TranslateX): + + def __init__(self, **kwargs): + super().__init__(img_border_value=0, **kwargs) + + def _transform_img(self, results: dict, mag: float) -> None: + img_orig = results['img'].copy() + (h_img, w_img, c_img) = img_orig.shape + """Translate the image horizontally.""" + _mag = w_img * mag + mask_max = np.max(results['masks'], axis=0) + translated_mask = mmcv.imtranslate( + mask_max, + _mag, + direction='horizontal', + border_value=0, + interpolation=self.interpolation) + translated_img = mmcv.imtranslate( + img_orig, + _mag, + direction='horizontal', + border_value=self.img_border_value, + interpolation=self.interpolation) + + union_mask = np.max([mask_max, translated_mask], axis=0) + img = (1.0 - union_mask) * translated_img + union_mask * img_orig + + results['img'] = img + + def _transform_masks(self, results: dict, mag: float) -> None: + pass + + def _transform_seg(self, results: dict, mag: float) -> None: + pass + + +@TRANSFORMS.register_module() +class BgTranslateY(TranslateY): + + def __init__(self, **kwargs): + super().__init__(img_border_value=0, **kwargs) + + def _transform_img(self, results: dict, mag: float) -> None: + img_orig = results['img'].copy() + (h_img, w_img, c_img) = img_orig.shape + """Translate the image horizontally.""" + _mag = h_img * mag + mask_max = np.max(results['masks'], axis=0) + translated_mask = mmcv.imtranslate( + mask_max, + _mag, + direction='vertical', + border_value=0, + interpolation=self.interpolation) + translated_img = mmcv.imtranslate( + img_orig, + _mag, + direction='vertical', + border_value=self.img_border_value, + interpolation=self.interpolation) + + union_mask = np.max([mask_max, translated_mask], axis=0) + img = (1.0 - union_mask) * translated_img + union_mask * img_orig + + results['img'] = img + + def _transform_masks(self, results: dict, mag: float) -> None: + pass + + def _transform_seg(self, results: dict, mag: float) -> None: + pass diff --git a/mmdet/datasets/transforms/oa_mix.py b/mmdet/datasets/transforms/oa_mix.py new file mode 100755 index 00000000000..f03cc6b00c7 --- /dev/null +++ b/mmdet/datasets/transforms/oa_mix.py @@ -0,0 +1,454 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import cv2 +import numpy as np +from mmcv.transforms import BaseTransform, Compose + +from mmdet.registry import TRANSFORMS + + +def get_transforms(version: str) -> List[Compose]: + if version == 'color': + transforms = [ + dict(type='AutoContrast'), + dict(type='Brightness'), + dict(type='Color'), + dict(type='Contrast'), + dict(type='Equalize'), + dict(type='Invert4Mix'), + dict(type='Posterize'), + dict(type='Sharpness'), + dict(type='Solarize'), + dict(type='SolarizeAdd') + ] + elif version == 'geo': + transforms = [ + dict(type='BgShearX'), + dict(type='BgShearY'), + dict(type='BgRotate'), + dict(type='BgTranslateX'), + dict(type='BgTranslateY'), + dict(type='BBoxShearX'), + dict(type='BBoxShearY'), + dict(type='BBoxRotate'), + dict(type='BBoxTranslateX'), + dict(type='BBoxTranslateY'), + ] + elif version == 'oamix': + transforms = [ + dict(type='AutoContrast'), + dict(type='Brightness'), + dict(type='Color'), + dict(type='Contrast'), + dict(type='Equalize'), + dict(type='Invert4Mix'), + dict(type='Posterize'), + dict(type='Sharpness'), + dict(type='BgShearX'), + dict(type='BgShearY'), + dict(type='BgRotate'), + dict(type='BgTranslateX'), + dict(type='BgTranslateY'), + dict(type='BBoxShearX'), + dict(type='BBoxShearY'), + dict(type='BBoxRotate'), + dict(type='BBoxTranslateX'), + dict(type='BBoxTranslateY'), + ] + else: + raise TypeError( + f'Invalid version: {version}. ' + f'Please add the version to the get_transforms function.') + transforms = [Compose(transforms) for transforms in transforms] + return transforms + + +def bbox_overlaps_np(bboxes1: np.ndarray, + bboxes2: np.ndarray, + eps: float = 1e-6) -> np.ndarray: + """Calculate overlap between two set of bboxes. + + Args: + bboxes1 (ndarray): shape (B, m, 4) in format or empty. + bboxes2 (ndarray): shape (B, n, 4) in format or empty. + B indicates the batch dim, in shape (B1, B2, ..., Bn). + If ``is_aligned`` is ``True``, then m and n must be equal. + eps (float, optional): A value added to the denominator for numerical + stability. Default 1e-6. + Returns: + ndarray: shape (m, n) if ``is_aligned`` is False else shape (m,) + """ + if len(bboxes2) == 0: + return np.zeros((bboxes1.shape[-2], 0)) + assert (bboxes1.shape[-1] == 4 or bboxes1.shape[0] == 0) + assert (bboxes2.shape[-1] == 4 or bboxes2.shape[0] == 0) + + assert bboxes1.shape[:-2] == bboxes2.shape[:-2] + batch_shape = bboxes1.shape[:-2] + + rows = bboxes1.shape[-2] + cols = bboxes2.shape[-2] + + if rows * cols == 0: + return np.zeros(batch_shape + (rows, cols)) + + area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * ( + bboxes1[..., 3] - bboxes1[..., 1]) + area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * ( + bboxes2[..., 3] - bboxes2[..., 1]) + + lt = np.maximum(bboxes1[..., :, None, :2], + bboxes2[..., None, :, :2]) # [B, rows, cols, 2] + rb = np.minimum(bboxes1[..., :, None, 2:], + bboxes2[..., None, :, 2:]) # [B, rows, cols, 2] + + wh = np.clip(rb - lt, a_min=0, a_max=None) + overlap = wh[..., 0] * wh[..., 1] + + union = area1[..., None] + area2[..., None, :] - overlap + + union = np.maximum(union, eps) + ious = overlap / union + return ious + + +@TRANSFORMS.register_module() +class OAMix(BaseTransform): + r"""Data augmentation method in + `Object-Aware Domain Generalization for Object Detection + `_. + + Refer to https://github.com/woojulee24/OA-DG for implementation details. + + Required Keys: + + - img + - gt_bboxes (BaseBoxes[torch.float32]) + + Modified Keys: + + - img + - gt_bboxes + + Args: + version (str): The version of the augmentation method. + Defaults to 'oamix'. + aug_prob_coeff (float): + The coefficient of the augmentation probability. + Defaults to 1.0. + mixture_width (int): + The number of augmentation operations in the mixture. + Defaults to 3. + mixture_depth (int): + The depth of augmentation operations in the mixture. + If mixture_depth is -1, the depth is randomly sampled from [1, 4]. + Defaults to -1. + box_scale (tuple): The scale of the random bounding boxes. + Defaults to (0.01, 0.1). + box_ratio (tuple): The aspect ratio of the random bounding boxes. + Defaults to (3, 1/3). + sigma_ratio (float): The ratio of the sigma for the Gaussian blur. + Defaults to 0.3. + score_thresh (float): The threshold of the saliency score. + Defaults to 10. + """ + + def __init__(self, + version: str = 'oamix', + aug_prob_coeff: float = 1.0, + mixture_width: int = 3, + mixture_depth: int = -1, + box_scale: tuple = (0.01, 0.1), + box_ratio: tuple = (3, 0.33), + sigma_ratio: float = 0.3, + score_thresh: float = 10.0) -> None: + assert version in ['color', 'geo', 'oamix'], \ + "The version should be either 'color', 'geo', or 'oamix'."\ + 'Please add the version to the get_transforms function.' + assert aug_prob_coeff > 0, \ + 'The augmentation probability coefficient ' \ + 'should be greater than 0.' + assert isinstance( + mixture_width, int + ) and mixture_width > 0, 'The mixture width should be greater than 0.' + assert isinstance( + mixture_depth, int + ) and mixture_depth >= -1, \ + 'The mixture depth should be greater than or equal to -1.' + assert isinstance(box_scale, tuple) and len( + box_scale) == 2, 'The box scale should be a tuple of 2 elements.' + assert isinstance(box_ratio, tuple) and len( + box_ratio) == 2, 'The box ratio should be a tuple of 2 elements.' + assert 0 <= sigma_ratio <= 1, \ + 'The sigma ratio should be in the range [0, 1].' + assert score_thresh >= 0, \ + 'The score threshold should be greater than or equal to 0.' + super(OAMix, self).__init__() + + self.version = version + self.transforms = get_transforms(version) + self.aug_prob_coeff = aug_prob_coeff + self.mixture_width = mixture_width + self.mixture_depth = mixture_depth + self.box_scale = box_scale + self.box_ratio = box_ratio + self.sigma_ratio = sigma_ratio + self.score_thresh = score_thresh + + def transform(self, results) -> dict: + self._test_transformations() + """The transform function.""" + img = results['img'] + gt_bboxes = results['gt_bboxes'].numpy() + gt_masks = self.get_masks(gt_bboxes, img.shape, use_blur=True) + + results['img'] = self.oamix(img.copy(), gt_bboxes, gt_masks) + + return results + + def oamix(self, img_orig: np.ndarray, gt_bboxes: np.ndarray, + gt_masks: List[np.ndarray]) -> np.ndarray: + ws = np.float32( + np.random.dirichlet([self.aug_prob_coeff] * self.mixture_width)) + img_mix = np.zeros_like(img_orig, dtype=np.float32) + for i in range(self.mixture_width): + """Multi-level transformation.""" + depth = self.mixture_depth \ + if self.mixture_depth > 0 else np.random.randint(1, 4) + img_aug = img_orig.copy() + for _ in range(depth): + img_aug = self.multilivel_transform(img_aug, gt_bboxes, + gt_masks) + img_mix += ws[i] * img_aug + """ Object-aware mixing """ + img_oamix = self.object_aware_mixing(img_orig, img_mix, gt_bboxes, + gt_masks) + + return np.asarray(img_oamix, dtype=img_orig.dtype) + + def multilivel_transform(self, img: np.ndarray, gt_bboxes: np.ndarray, + gt_masks: List[np.ndarray]) -> np.ndarray: + rand_bboxes = self.get_random_bboxes(img.shape, num_bboxes=(1, 3)) + rand_masks = self.get_masks(rand_bboxes, img.shape) + + img_tmp = np.zeros_like(img, dtype=np.float32) + for rand_mask in rand_masks: + img_tmp += rand_mask * self.aug(img, gt_bboxes, gt_masks) + union_mask = np.max(rand_masks, axis=0) + img_aug = np.asarray( + img_tmp + (1.0 - union_mask) * self.aug(img, gt_bboxes, gt_masks), + dtype=img.dtype) + return img_aug + + def get_random_bboxes(self, + img_shape: Tuple, + num_bboxes: Tuple[int, int], + max_iters: int = 50, + eps: float = 1e-6) -> np.ndarray: + assert max_iters > 0, \ + 'The maximum number of iterations should be greater than 0.' + + h_img, w_img, _ = img_shape + num_target_bboxes = np.random.randint(*num_bboxes) + rand_bboxes = np.zeros((0, 4)) + for i in range(max_iters): + if len(rand_bboxes) >= num_target_bboxes: + break + scale = np.random.uniform(*self.box_scale) + aspect_ratio = np.random.uniform(*self.box_ratio) + + height = scale * h_img + width = height * aspect_ratio + if width > w_img or height > h_img: + continue # Invalid bbox (out of the image) + + xmin = np.random.uniform(0, w_img - width) + ymin = np.random.uniform(0, h_img - height) + xmax = xmin + width + ymax = ymin + height + + rand_bbox = np.array([[xmin, ymin, xmax, ymax]]) + ious = bbox_overlaps_np(rand_bbox, rand_bboxes) + if np.sum(ious) > eps: + continue # Invalid bbox (overlapping with existing bboxes) + + rand_bboxes = np.concatenate([rand_bboxes, rand_bbox], axis=0) + + return rand_bboxes + + def get_masks(self, + bboxes: np.ndarray, + img_shape: tuple, + use_blur: bool = False) -> List[np.ndarray]: + """Get the masks of the bounding boxes.""" + mask_list = [] + for bbox in bboxes: + if len(bbox.shape) == 2 and bbox.shape[1] == 4: + bbox = bbox[0] + mask = np.zeros(img_shape, dtype=np.float32) + mask[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])] = 1.0 + if use_blur: + sigma_x = (bbox[2] - bbox[0]) * self.sigma_ratio / 3 * 2 + sigma_y = (bbox[3] - bbox[1]) * self.sigma_ratio / 3 * 2 + if not (sigma_x <= 0 or sigma_y <= 0): + mask = cv2.GaussianBlur( + mask, (0, 0), + sigmaX=sigma_x.item(), + sigmaY=sigma_y.item()) + mask = cv2.resize( + mask, (img_shape[1], img_shape[0]), + interpolation=cv2.INTER_LINEAR) + mask_list.append(mask) + return mask_list + + def aug(self, img: np.ndarray, gt_bboxes: np.ndarray, + gt_masks: List[np.ndarray]): + op = np.random.choice(self.transforms) + op_kwargs = { + 'img': img, + 'bboxes': gt_bboxes, + 'masks': gt_masks, + 'img_shape': img.shape + } + img_aug = op(op_kwargs)['img'] + return img_aug + + def object_aware_mixing(self, img_orig: np.ndarray, img_mix: np.ndarray, + gt_bboxes: np.ndarray, + gt_masks: List[np.ndarray]) -> np.ndarray: + gt_scores = self.get_saliency_scores(img_orig, gt_bboxes) + + target_indices = gt_scores < self.score_thresh + target_bboxes = gt_bboxes[target_indices] + target_masks = [gt_masks[i] for i in np.where(target_indices)[0]] + target_m = np.random.uniform(0.0, 0.5, + len(target_bboxes)).astype(np.float32) + + rand_bboxes = self.get_random_bboxes(img_orig.shape, num_bboxes=(3, 5)) + rand_masks = self.get_masks(rand_bboxes, img_orig.shape, use_blur=True) + rand_m = np.random.uniform(0.0, 1.0, + len(rand_bboxes)).astype(np.float32) + + target_bboxes = np.vstack((target_bboxes, rand_bboxes)) + target_masks.extend(rand_masks) + target_m = np.concatenate((target_m, rand_m)) + + orig = np.zeros_like(img_orig, dtype=np.float32) + aug = np.zeros_like(img_orig, dtype=np.float32) + mask_sum = np.zeros_like(img_orig, dtype=np.float32) + + for bbox, mask, m in zip(target_bboxes, target_masks, target_m): + mask_sum += mask + mask_max = np.maximum(mask_sum, mask) + mask_overlap = mask_sum - mask_max + overlap_factor = (mask - mask_overlap * 0.5) + + orig += (1.0 - m) * img_orig * overlap_factor + aug += m * img_mix * overlap_factor + mask_sum = mask_max + + img_oamix = orig + aug + + m = np.random.beta(self.aug_prob_coeff, self.aug_prob_coeff) + + img_oamix += (1.0 - m) * img_orig * (1.0 - mask_sum) + img_oamix += m * img_mix * (1.0 - mask_sum) + img_oamix = np.clip(img_oamix, 0, 255) + + return img_oamix + + @staticmethod + def get_saliency_scores(img: np.ndarray, bboxes: np.ndarray) -> np.ndarray: + saliency_scores = [] + for bbox in np.asarray(bboxes, dtype=np.int32): + bbox_img = img[bbox[1]:bbox[3], bbox[0]:bbox[2]] + + saliency = cv2.saliency.StaticSaliencySpectralResidual_create() + (success, saliency_map) = saliency.computeSaliency(bbox_img) + score = np.mean((saliency_map * 255).astype('uint8')) + saliency_scores.append(score) + return np.asarray(saliency_scores, dtype=np.uint8) + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(version={self.version}, ' \ + f'aug_prob_coeff={self.aug_prob_coeff}, ' \ + f'mixture_width={self.mixture_width}, ' \ + f'mixture_depth={self.mixture_depth}, ' \ + f'box_scale={self.box_scale}, ' \ + f'box_ratio={self.box_ratio}, ' \ + f'sigma_ratio={self.sigma_ratio}, ' \ + f'score_thresh={self.score_thresh})' + return repr_str + + @staticmethod + def _load_random_data() -> Tuple[np.ndarray, np.ndarray]: + img = np.random.randint(0, 256, (427, 640, 3)).astype(np.uint8) + gt_bboxes = np.random.randn(3, 4).astype(np.float32) + return img, gt_bboxes + + def _test_transformations(self) -> None: + img, gt_bboxes = self._load_random_data() + gt_masks = self.get_masks(gt_bboxes, img.shape, use_blur=True) + + transform_type_list = [ + 'AutoContrast', 'Brightness', 'Color', 'Contrast', 'Equalize', + 'Invert4Mix', 'Posterize', 'Sharpness', 'Solarize', 'SolarizeAdd', + 'BgShearX', 'BgShearY', 'BgRotate', 'BgTranslateX', 'BgTranslateY', + 'BBoxShearX', 'BBoxShearY', 'BBoxRotate', 'BBoxTranslateX', + 'BBoxTranslateY' + ] + for type in transform_type_list: + op = Compose(dict(type=type)) + op_kwargs = { + 'img': img, + 'bboxes': gt_bboxes, + 'masks': gt_masks, + 'img_shape': img.shape + } + _ = op(op_kwargs)['img'] + + return + + def _test_multilevel_transformations(self) -> None: + img_orig, gt_bboxes = self._load_random_data() + gt_masks = self.get_masks(gt_bboxes, img_orig.shape, use_blur=True) + + for idx in range(10): + ws = np.float32( + np.random.dirichlet([self.aug_prob_coeff] * + self.mixture_width)) + img_mix = np.zeros_like(img_orig, dtype=np.float32) + for i in range(self.mixture_width): + """Multi-level transformation.""" + depth = self.mixture_depth \ + if self.mixture_depth > 0 else np.random.randint(1, 4) + img_aug = img_orig.copy() + for _ in range(depth): + img_aug = self.multilivel_transform( + img_aug, gt_bboxes, gt_masks) + img_mix += ws[i] * img_aug + + return + + def _test_objectaware_mixing(self) -> None: + img_orig, gt_bboxes = self._load_random_data() + gt_masks = self.get_masks(gt_bboxes, img_orig.shape, use_blur=True) + + ws = np.float32( + np.random.dirichlet([self.aug_prob_coeff] * self.mixture_width)) + img_mix = np.zeros_like(img_orig, dtype=np.float32) + for i in range(self.mixture_width): + """Multi-level transformation.""" + depth = self.mixture_depth if \ + self.mixture_depth > 0 else np.random.randint(1, 4) + img_aug = img_orig.copy() + for _ in range(depth): + img_aug = self.multilivel_transform(img_aug, gt_bboxes, + gt_masks) + img_mix += ws[i] * img_aug + """ Object-aware mixing """ + self.object_aware_mixing(img_orig, img_mix, gt_bboxes, gt_masks) + + return diff --git a/resources/oamix_examples.gif b/resources/oamix_examples.gif new file mode 100644 index 00000000000..a98919c7ee7 Binary files /dev/null and b/resources/oamix_examples.gif differ diff --git a/resources/oamix_examples.png b/resources/oamix_examples.png new file mode 100644 index 00000000000..2cb31b83c5b Binary files /dev/null and b/resources/oamix_examples.png differ