diff --git a/Makefile b/Makefile index 1d125d6..13759c8 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ .PHONY: check-style codestyle docker-build clean check-style: - bash ./bin/_check_codestyle.sh -s + bash ./bin/codestyle/check_codestyle.sh -s codestyle: pre-commit run diff --git a/README.md b/README.md index 767d61e..930f427 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,12 @@ elif [[ "$DATASET" == "voc2012" ]]; then tar -xf VOCtrainval_11-May-2012.tar &>/dev/null mkdir -p ./data/origin/images/; mv VOCdevkit/VOC2012/JPEGImages/* $_ mkdir -p ./data/origin/raw_masks; mv VOCdevkit/VOC2012/SegmentationClass/* $_ +elif [[ "$DATASET" == "dsb2018" ]]; then + # instance segmentation + # https://www.kaggle.com/c/data-science-bowl-2018 + download-gdrive 1RCqaQZLziuq1Z4sbMpwD_WHjqR5cdPvh dsb2018_cleared_191109.tar.gz + tar -xf dsb2018_cleared_191109.tar.gz &>/dev/null + mv dsb2018_cleared_191109 ./data/origin fi ``` @@ -102,6 +108,11 @@ fi #### Data structure Make sure, that final folder with data has the required structure: + +
+Data structure for binary segmentation +

+ ```bash /path/to/your_dataset/ images/ @@ -115,6 +126,66 @@ Make sure, that final folder with data has the required structure: ... mask_N ``` +where each `mask` is a binary image + +

+
+ +
+Data structure for semantic segmentation +

+ +```bash +/path/to/your_dataset/ + images/ + image_1 + image_2 + ... + image_N + raw_masks/ + mask_1 + mask_2 + ... + mask_N +``` +where each `mask` is an image with class encoded through colors e.g. [VOC2012](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/) dataset where `bicycle` class is encoded with green color and `bird` with olive + +

+
+ +
+Data structure for instance segmentation +

+ +```bash +/path/to/your_dataset/ + images/ + image_1 + image_2 + ... + image_M + raw_masks/ + mask_1/ + instance_1 + instance_2 + ... + instance_N + mask_2/ + instance_1 + instance_2 + ... + instance_K + ... + mask_M/ + instance_1 + instance_2 + ... + instance_Z +``` +where each `mask` represented as a folder with instances images (one image per instance), and masks may consisting of a different number of instances e.g. [Data Science Bowl 2018](https://www.kaggle.com/c/data-science-bowl-2018) dataset + +

+
#### Data location @@ -161,35 +232,32 @@ We will initialize [Unet](https://arxiv.org/abs/1505.04597) model with a pre-tra CUDA_VISIBLE_DEVICES=0 \ CUDNN_BENCHMARK="True" \ CUDNN_DETERMINISTIC="True" \ -WORKDIR=./logs \ -DATADIR=./data/origin \ -IMAGE_SIZE=256 \ -CONFIG_TEMPLATE=./configs/templates/binary.yml \ -NUM_WORKERS=4 \ -BATCH_SIZE=256 \ -bash ./bin/catalyst-binary-segmentation-pipeline.sh +bash ./bin/catalyst-binary-segmentation-pipeline.sh \ + --workdir ./logs \ + --datadir ./data/origin \ + --max-image-size 256 \ + --config-template ./configs/templates/binary.yml \ + --num-workers 4 \ + --batch-size 256 ``` #### Run in docker: ```bash -export LOGDIR=$(pwd)/logs docker run -it --rm --shm-size 8G --runtime=nvidia \ - -v $(pwd):/workspace/ \ - -v $LOGDIR:/logdir/ \ - -v $(pwd)/data/origin:/data \ - -e "CUDA_VISIBLE_DEVICES=0" \ - -e "USE_WANDB=1" \ - -e "LOGDIR=/logdir" \ - -e "CUDNN_BENCHMARK='True'" \ - -e "CUDNN_DETERMINISTIC='True'" \ - -e "WORKDIR=/logdir" \ - -e "DATADIR=/data" \ - -e "IMAGE_SIZE=256" \ - -e "CONFIG_TEMPLATE=./configs/templates/binary.yml" \ - -e "NUM_WORKERS=4" \ - -e "BATCH_SIZE=256" \ - catalyst-segmentation ./bin/catalyst-binary-segmentation-pipeline.sh + -v $(pwd):/workspace/ \ + -v $(pwd)/logs:/logdir/ \ + -v $(pwd)/data/origin:/data \ + -e "CUDA_VISIBLE_DEVICES=0" \ + -e "CUDNN_BENCHMARK='True'" \ + -e "CUDNN_DETERMINISTIC='True'" \ + catalyst-segmentation ./bin/catalyst-binary-segmentation-pipeline.sh \ + --workdir /logdir \ + --datadir /data \ + --max-image-size 256 \ + --config-template ./configs/templates/binary.yml \ + --num-workers 4 \ + --batch-size 256 ```

@@ -205,54 +273,83 @@ docker run -it --rm --shm-size 8G --runtime=nvidia \ CUDA_VISIBLE_DEVICES=0 \ CUDNN_BENCHMARK="True" \ CUDNN_DETERMINISTIC="True" \ -WORKDIR=./logs \ -DATADIR=./data/origin \ -IMAGE_SIZE=256 \ -CONFIG_TEMPLATE=./configs/templates/semantic.yml \ -NUM_WORKERS=4 \ -BATCH_SIZE=256 \ -bash ./bin/catalyst-semantic-segmentation-pipeline.sh +bash ./bin/catalyst-semantic-segmentation-pipeline.sh \ + --workdir ./logs \ + --datadir ./data/origin \ + --max-image-size 256 \ + --config-template ./configs/templates/semantic.yml \ + --num-workers 4 \ + --batch-size 256 ``` #### Run in docker: ```bash -export LOGDIR=$(pwd)/logs docker run -it --rm --shm-size 8G --runtime=nvidia \ - -v $(pwd):/workspace/ \ - -v $LOGDIR:/logdir/ \ - -v $(pwd)/data/origin:/data \ - -e "CUDA_VISIBLE_DEVICES=0" \ - -e "USE_WANDB=1" \ - -e "LOGDIR=/logdir" \ - -e "CUDNN_BENCHMARK='True'" \ - -e "CUDNN_DETERMINISTIC='True'" \ - -e "WORKDIR=/logdir" \ - -e "DATADIR=/data" \ - -e "IMAGE_SIZE=256" \ - -e "CONFIG_TEMPLATE=./configs/templates/semantic.yml" \ - -e "NUM_WORKERS=4" \ - -e "BATCH_SIZE=256" \ - catalyst-segmentation ./bin/catalyst-semantic-segmentation-pipeline.sh + -v $(pwd):/workspace/ \ + -v $(pwd)/logs:/logdir/ \ + -v $(pwd)/data/origin:/data \ + -e "CUDA_VISIBLE_DEVICES=0" \ + -e "CUDNN_BENCHMARK='True'" \ + -e "CUDNN_DETERMINISTIC='True'" \ + catalyst-segmentation ./bin/catalyst-semantic-segmentation-pipeline.sh \ + --workdir /logdir \ + --datadir /data \ + --max-image-size 256 \ + --config-template ./configs/templates/semantic.yml \ + --num-workers 4 \ + --batch-size 256 ```

-The pipeline is running and you don’t have to do anything else, it remains to wait for the best model! - -#### Visualizations +
+Instance segmentation pipeline +

-You can use [W&B](https://www.wandb.com/) account for visualisation right after `pip install wandb`: +#### Run in local environment: +```bash +CUDA_VISIBLE_DEVICES=0 \ +CUDNN_BENCHMARK="True" \ +CUDNN_DETERMINISTIC="True" \ +bash ./bin/catalyst-semantic-segmentation-pipeline.sh \ + --workdir ./logs \ + --datadir ./data/origin \ + --max-image-size 256 \ + --config-template ./configs/templates/instance.yml \ + --num-workers 4 \ + --batch-size 256 ``` -wandb: (1) Create a W&B account -wandb: (2) Use an existing W&B account -wandb: (3) Don't visualize my results + +#### Run in docker: + +```bash +docker run -it --rm --shm-size 8G --runtime=nvidia \ + -v $(pwd):/workspace/ \ + -v $(pwd)/logs:/logdir/ \ + -v $(pwd)/data/origin:/data \ + -e "CUDA_VISIBLE_DEVICES=0" \ + -e "CUDNN_BENCHMARK='True'" \ + -e "CUDNN_DETERMINISTIC='True'" \ + catalyst-segmentation ./bin/catalyst-instance-segmentation-pipeline.sh \ + --workdir /logdir \ + --datadir /data \ + --max-image-size 256 \ + --config-template ./configs/templates/instance.yml \ + --num-workers 4 \ + --batch-size 256 ``` - -Tensorboard also can be used for visualisation: +

+
+ +The pipeline is running and you don’t have to do anything else, it remains to wait for the best model! + +#### Visualizations + +Tensorboard can be used for visualisation: ```bash tensorboard --logdir=/catalyst.segmentation/logs @@ -286,7 +383,7 @@ For your future experiments framework provides powerful configs allow to optimiz * Common settings of stages of training and model parameters can be found in `catalyst.segmentation/configs/_common.yml`. * `model_params`: detailed configuration of models, including: - * model, for instance `ResnetUnet` + * model, for instance `ResNetUnet` * detailed architecture description * using pretrained model * `stages`: you can configure training or inference in several stages with different hyperparameters. In our example: diff --git a/bin/catalyst-binary-segmentation-pipeline.sh b/bin/catalyst-binary-segmentation-pipeline.sh index cf616df..9499dac 100755 --- a/bin/catalyst-binary-segmentation-pipeline.sh +++ b/bin/catalyst-binary-segmentation-pipeline.sh @@ -4,7 +4,7 @@ #author :Sergey Kolesnikov, Yauheni Kachan #author_email :scitator@gmail.com, yauheni.kachan@gmail.com #date :20191016 -#version :19.10.2 +#version :20.03 #============================================================================== set -e diff --git a/bin/catalyst-instance-segmentation-pipeline.sh b/bin/catalyst-instance-segmentation-pipeline.sh new file mode 100644 index 0000000..d4581d3 --- /dev/null +++ b/bin/catalyst-instance-segmentation-pipeline.sh @@ -0,0 +1,146 @@ +#!/usr/bin/env bash +#title :catalyst-instance-segmentation-pipeline +#description :catalyst.dl script for instance segmentation pipeline run +#author :Sergey Kolesnikov, Yauheni Kachan +#author_email :scitator@gmail.com, yauheni.kachan@gmail.com +#date :20191109 +#version :20.03 +#============================================================================== + +set -e + +usage() +{ + cat << USAGE >&2 +Usage: $(basename "$0") [OPTION...] [catalyst-dl run args...] + + -s, --skipdata Skip data preparation + -j, --num-workers NUM_WORKERS Number of data loading/processing workers + -b, --batch-size BATCH_SIZE Mini-batch size + --max-image-size MAX_IMAGE_SIZE Target size of images e.g. 256 + --config-template CONFIG_TEMPLATE Model config to use + --datadir DATADIR + --workdir WORKDIR + catalyst-dl run args Execute \`catalyst-dl run\` with args + +Example: + CUDA_VISIBLE_DEVICES=0 \\ + CUDNN_BENCHMARK="True" \\ + CUDNN_DETERMINISTIC="True" \\ + ./bin/catalyst-instance-segmentation-pipeline.sh \\ + --workdir ./logs \\ + --datadir ./data/origin \\ + --max-image-size 256 \\ + --config-template ./configs/templates/instance.yml \\ + --num-workers 4 \\ + --batch-size 256 +USAGE + exit 1 +} + + +# ---- environment variables + +NUM_WORKERS=${NUM_WORKERS:=4} +BATCH_SIZE=${BATCH_SIZE:=64} +MAX_IMAGE_SIZE=${MAX_IMAGE_SIZE:=256} +CONFIG_TEMPLATE=${CONFIG_TEMPLATE:="./configs/templates/instance.yml"} +DATADIR=${DATADIR:="./data/origin"} +WORKDIR=${WORKDIR:="./logs"} +SKIPDATA="" +_run_args="" +while (( "$#" )); do + case "$1" in + -j|--num-workers) + NUM_WORKERS=$2 + shift 2 + ;; + -b|--batch-size) + BATCH_SIZE=$2 + shift 2 + ;; + --max-image-size) + MAX_IMAGE_SIZE=$2 + shift 2 + ;; + --config-template) + CONFIG_TEMPLATE=$2 + shift 2 + ;; + --datadir) + DATADIR=$2 + shift 2 + ;; + --workdir) + WORKDIR=$2 + shift 2 + ;; + -s|--skipdata) + SKIPDATA="true" + shift + ;; + -h|--help) + usage + ;; + *) + _run_args="${_run_args} $1" + shift + ;; + esac +done + +date=$(date +%y%m%d-%H%M%S) +postfix=$(openssl rand -hex 4) +logname="${date}-${postfix}" +export DATASET_DIR=${WORKDIR}/dataset +export RAW_MASKS_DIR=${DATASET_DIR}/raw_masks +export CONFIG_DIR=${WORKDIR}/configs-${logname} +export LOGDIR=${WORKDIR}/logdir-${logname} + +for dir in ${WORKDIR} ${DATASET_DIR} ${CONFIG_DIR} ${LOGDIR}; do + mkdir -p ${dir} +done + + +# ---- data preparation + +if [[ -z "${SKIPDATA}" ]]; then + cp -R ${DATADIR}/* ${DATASET_DIR}/ + + mkdir -p ${DATASET_DIR}/masks + python scripts/process_instance_masks.py \ + --in-dir ${RAW_MASKS_DIR} \ + --out-dir ${DATASET_DIR}/masks \ + --num-workers ${NUM_WORKERS} + + python scripts/image2mask.py \ + --in-dir ${DATASET_DIR} \ + --out-dataset ${DATASET_DIR}/dataset_raw.csv + + catalyst-data split-dataframe \ + --in-csv ${DATASET_DIR}/dataset_raw.csv \ + --n-folds=5 --train-folds=0,1,2,3 \ + --out-csv=${DATASET_DIR}/dataset.csv +fi + + +# ---- config preparation + +python ./scripts/prepare_config.py \ + --in-template=${CONFIG_TEMPLATE} \ + --out-config=${CONFIG_DIR}/config.yml \ + --expdir=./src \ + --dataset-path=${DATASET_DIR} \ + --num-classes=2 \ + --num-workers=${NUM_WORKERS} \ + --batch-size=${BATCH_SIZE} \ + --max-image-size=${MAX_IMAGE_SIZE} + +cp -r ./configs/_common.yml ${CONFIG_DIR}/_common.yml + + +# ---- model training + +catalyst-dl run \ + -C ${CONFIG_DIR}/_common.yml ${CONFIG_DIR}/config.yml \ + --logdir ${LOGDIR} ${_run_args} diff --git a/bin/catalyst-semantic-segmentation-pipeline.sh b/bin/catalyst-semantic-segmentation-pipeline.sh index c93b319..2e4f0bc 100755 --- a/bin/catalyst-semantic-segmentation-pipeline.sh +++ b/bin/catalyst-semantic-segmentation-pipeline.sh @@ -4,7 +4,7 @@ #author :Sergey Kolesnikov, Yauheni Kachan #author_email :scitator@gmail.com, yauheni.kachan@gmail.com #date :20191016 -#version :19.10.2 +#version :20.03 #============================================================================== set -e diff --git a/bin/tests/_check_binary.sh b/bin/tests/check_binary.sh similarity index 100% rename from bin/tests/_check_binary.sh rename to bin/tests/check_binary.sh diff --git a/bin/tests/check_instance.sh b/bin/tests/check_instance.sh new file mode 100644 index 0000000..83c0a4a --- /dev/null +++ b/bin/tests/check_instance.sh @@ -0,0 +1,58 @@ +#!/usr/bin/env bash + +# Cause the script to exit if a single command fails +set -eo pipefail -v + + +################################### DATA #################################### +rm -rf ./data + +# load the data +mkdir -p ./data + +download-gdrive 1RCqaQZLziuq1Z4sbMpwD_WHjqR5cdPvh dsb2018_cleared_191109.tar.gz +tar -xf dsb2018_cleared_191109.tar.gz &>/dev/null +mv dsb2018_cleared_191109 ./data/origin + + +################################ pipeline 00 ################################ +rm -rf ./logs + + +################################ pipeline 01 ################################ +CUDA_VISIBLE_DEVICES="" \ +CUDNN_BENCHMARK="True" \ +CUDNN_DETERMINISTIC="True" \ +bash ./bin/catalyst-instance-segmentation-pipeline.sh \ + --config-template ./configs/templates/instance.yml \ + --workdir ./logs \ + --datadir ./data/origin \ + --num-workers 0 \ + --batch-size 2 \ + --max-image-size 256 \ + --check + + +python -c """ +import pathlib +from safitty import Safict + +folder = list(pathlib.Path('./logs/').glob('logdir-*'))[0] +metrics = Safict.load(f'{folder}/checkpoints/_metrics.json') + +aggregated_loss = metrics.get('best', 'loss') +iou_soft = metrics.get('best', 'iou_soft') +iou_hard = metrics.get('best', 'iou_hard') + +print(aggregated_loss) +print(iou_soft) +print(iou_hard) + +assert aggregated_loss < 0.9 +assert iou_soft > 0.04 +assert iou_hard > 0.1 +""" + + +################################ pipeline 99 ################################ +rm -rf ./logs diff --git a/bin/tests/_check_semantic.sh b/bin/tests/check_semantic.sh similarity index 100% rename from bin/tests/_check_semantic.sh rename to bin/tests/check_semantic.sh diff --git a/configs/templates/instance.yml b/configs/templates/instance.yml new file mode 100644 index 0000000..ac6d919 --- /dev/null +++ b/configs/templates/instance.yml @@ -0,0 +1,196 @@ +shared: + image_size: &image_size {{ max_image_size }} + +model_params: + num_classes: {{ num_classes }} + +args: + expdir: {{ expdir }} + +stages: + + state_params: + main_metric: &reduced_metric iou_hard + minimize_metric: False + + data_params: + num_workers: {{ num_workers }} + batch_size: {{ batch_size }} + per_gpu_scaling: True + in_csv_train: {{ dataset_path }}/dataset_train.csv + in_csv_valid: {{ dataset_path }}/dataset_valid.csv + datapath: {{ dataset_path }} + + transform_params: + _key_value: True + + train: + transform: A.Compose + transforms: + - &pre_transforms + transform: A.Compose + transforms: + - transform: A.LongestMaxSize + max_size: *image_size + - transform: A.PadIfNeeded + min_height: *image_size + min_width: *image_size + border_mode: 0 # cv2.BORDER_CONSTANT + value: 0 + - &hard_transforms + transform: A.Compose + transforms: + - transform: A.ShiftScaleRotate + shift_limit: 0.1 + scale_limit: 0.1 + rotate_limit: 15 + border_mode: 2 # cv2.BORDER_REFLECT + - transform: A.OneOf + transforms: + - transform: A.HueSaturationValue + - transform: A.ToGray + - transform: A.RGBShift + - transform: A.ChannelShuffle + - transform: A.RandomBrightnessContrast + brightness_limit: 0.5 + contrast_limit: 0.5 + - transform: A.RandomGamma + - transform: A.CLAHE + - transform: A.ImageCompression + quality_lower: 50 + - &post_transforms + transform: A.Compose + transforms: + - transform: A.Normalize + - transform: C.ToTensor + valid: + transform: A.Compose + transforms: + - *pre_transforms + - *post_transforms + infer: + transform: A.Compose + transforms: + - *pre_transforms + - *post_transforms + + criterion_params: + _key_value: True + + bce: + criterion: BCEWithLogitsLoss + dice: + criterion: DiceLoss + iou: + criterion: IoULoss + + callbacks_params: + loss_bce: + callback: CriterionCallback + input_key: mask + output_key: logits + prefix: loss_bce + criterion_key: bce + multiplier: 1.0 + loss_dice: + callback: CriterionCallback + input_key: mask + output_key: logits + prefix: loss_dice + criterion_key: dice + multiplier: 1.0 + loss_iou: + callback: CriterionCallback + input_key: mask + output_key: logits + prefix: loss_iou + criterion_key: iou + multiplier: 1.0 + + loss_aggregator: + callback: MetricAggregationCallback + prefix: &aggregated_loss loss + metrics: [loss_bce, loss_dice, loss_iou] + mode: "mean" + multiplier: 1.0 + + raw_processor: + callback: RawMaskPostprocessingCallback + instance_extractor: + callback: InstanceMaskPostprocessingCallback + watershed_threshold: 0.9 + mask_threshold: 0.8 + output_key: instance_mask + out_key_semantic: semantic_mask + out_key_border: border_mask + + iou_soft: + callback: IouCallback + input_key: mask + output_key: logits + prefix: iou_soft + iou_hard: + callback: IouCallback + input_key: mask + output_key: logits + prefix: iou_hard + threshold: 0.5 + + optimizer: + callback: OptimizerCallback + loss_key: *aggregated_loss + scheduler: + callback: SchedulerCallback + reduced_metric: *reduced_metric + saver: + callback: CheckpointCallback + + # infer: + # + # data_params: + # num_workers: {{ num_workers }} + # batch_size: {{ batch_size }} + # per_gpu_scaling: True + # in_csv: null + # in_csv_train: null + # in_csv_valid: {{ dataset_path }}/dataset_valid.csv + # in_csv_infer: {{ dataset_path }}/dataset_train.csv + # datapath: {{ dataset_path }} + # + # callbacks_params: + # loader: + # callback: CheckpointCallback + # + # raw_processor: + # callback: RawMaskPostprocessingCallback + # instance_extractor: + # callback: InstanceMaskPostprocessingCallback + # watershed_threshold: 0.9 + # mask_threshold: 0.8 + # output_key: instance_mask + # out_key_semantic: semantic_mask + # out_key_border: border_mask + # + # image_saver: + # callback: OriginalImageSaverCallback + # output_dir: infer + # saver_mask: + # callback: OverlayMaskImageSaverCallback + # output_dir: infer + # filename_suffix: _01_raw_mask + # output_key: mask + # saver_semantic: + # callback: OverlayMaskImageSaverCallback + # output_dir: infer + # filename_suffix: _02_semantic_mask + # output_key: semantic_mask + # saver_border: + # callback: OverlayMaskImageSaverCallback + # output_dir: infer + # filename_suffix: _03_border_mask + # output_key: border_mask + # saver_instance: + # callback: OverlayMaskImageSaverCallback + # output_dir: infer + # filename_suffix: _04_instance_mask + # output_key: instance_mask diff --git a/pics/wandb_metrics.png b/pics/wandb_metrics.png deleted file mode 100644 index 0dfd0a1..0000000 Binary files a/pics/wandb_metrics.png and /dev/null differ diff --git a/requirements/requirements-docker.txt b/requirements/requirements-docker.txt index 7141304..d558d60 100644 --- a/requirements/requirements-docker.txt +++ b/requirements/requirements-docker.txt @@ -1,6 +1,8 @@ albumentations==0.4.5 jinja2 +opencv-python>=4.1.1 pandas>=0.22 safitty>=1.2.3 segmentation-models-pytorch==0.1.0 +shapely[vectorized]==1.7.0 tqdm>=4.33.0 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 2eb9541..c4f19eb 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,5 +1,7 @@ catalyst[cv]==20.5 jinja2 +opencv-python>=4.1.1 pandas>=0.22 safitty>=1.2.3 +shapely[vectorized]==1.7.0 tqdm>=4.33.0 diff --git a/scripts/process_instance_masks.py b/scripts/process_instance_masks.py new file mode 100644 index 0000000..c64c7b6 --- /dev/null +++ b/scripts/process_instance_masks.py @@ -0,0 +1,195 @@ +from typing import List # isort:skip +import argparse +from multiprocessing.pool import Pool +from pathlib import Path + +import numpy as np +from skimage import measure, morphology + +from catalyst.utils import ( + get_pool, + has_image_extension, + imread, + mimwrite_with_meta, + tqdm_parallel_imap, +) + + +def build_args(parser): + parser.add_argument( + "--in-dir", type=Path, required=True, help="Raw masks folder path" + ) + parser.add_argument( + "--out-dir", + type=Path, + required=True, + help="Processed masks folder path", + ) + parser.add_argument("--threshold", type=float, default=0.0) + parser.add_argument( + "--n-channels", + type=int, + choices={2, 3}, + default=2, + help="Number of channels in output masks", + ) + parser.add_argument( + "--num-workers", + default=1, + type=int, + help="Number of workers to parallel the processing", + ) + + return parser + + +def parse_args(): + parser = argparse.ArgumentParser() + build_args(parser) + args = parser.parse_args() + return args + + +def mim_interaction(mim: List[np.ndarray], threshold: float = 0) -> np.ndarray: + result = np.zeros_like(mim[0], dtype=np.uint8) + result[np.stack(mim, axis=-1).max(axis=-1) > threshold] = 255 + return result + + +def mim_color_encode( + mim: List[np.ndarray], threshold: float = 0 +) -> np.ndarray: + result = np.zeros_like(mim[0], dtype=np.uint8) + for index, im in enumerate(mim, start=1): + result[im > threshold] = index + + return result + + +class Preprocessor: + def __init__( + self, + in_dir: Path, + out_dir: Path, + threshold: float = 0.0, + n_channels: int = 2, + ): + """ + Args: + in_dir (Path): raw masks folder path, input folder structure + should be following: + in_path # dir with raw masks + |-- sample_1 + | |-- instance_1 + | |-- instance_2 + | .. + | `-- instance_N + |-- sample_1 + | |-- instance_1 + | |-- instance_2 + | .. + | `-- instance_K + .. + `-- sample_M + |-- instance_1 + |-- instance_2 + .. + `-- instance_Z + out_dir (Path): processed masks folder path, output folder + structure will be following: + out_path + |-- sample_1.tiff # image of shape HxWxN + |-- sample_2.tiff # image of shape HxWxK + .. + `-- sample_M.tiff # image of shape HxWxZ + threshold (float): + n_channels (int): number of channels in output masks, + see https://www.kaggle.com/c/data-science-bowl-2018/discussion/54741 # noqa: E501, W505 + """ + self.in_dir = in_dir + self.out_dir = out_dir + self.threshold = threshold + self.n_channels = n_channels + + def preprocess(self, sample: Path): + masks = [ + imread(filename, grayscale=True, expand_dims=False) + for filename in sample.iterdir() + if has_image_extension(str(filename)) + ] + labels = mim_color_encode(masks, self.threshold) + + scaled_blobs = morphology.dilation(labels > 0, morphology.square(9)) + watersheded_blobs = ( + morphology.watershed( + scaled_blobs, labels, mask=scaled_blobs, watershed_line=True + ) + > 0 + ) + watershed_lines = scaled_blobs ^ (watersheded_blobs) + scaled_watershed_lines = morphology.dilation( + watershed_lines, morphology.square(7) + ) + + props = measure.regionprops(labels) + max_area = max(p.area for p in props) + + mask_without_borders = mim_interaction(masks, self.threshold) + borders = np.zeros_like(labels, dtype=np.uint8) + for y0 in range(labels.shape[0]): + for x0 in range(labels.shape[1]): + if not scaled_watershed_lines[y0, x0]: + continue + + if labels[y0, x0] == 0: + if max_area > 4000: + sz = 6 + else: + sz = 3 + else: + if props[labels[y0, x0] - 1].area < 300: + sz = 1 + elif props[labels[y0, x0] - 1].area < 2000: + sz = 2 + else: + sz = 3 + + uniq = np.unique( + labels[ + max(0, y0 - sz) : min(labels.shape[0], y0 + sz + 1), + max(0, x0 - sz) : min(labels.shape[1], x0 + sz + 1), + ] + ) + if len(uniq[uniq > 0]) > 1: + borders[y0, x0] = 255 + mask_without_borders[y0, x0] = 0 + + if self.n_channels == 2: + mask = [mask_without_borders, borders] + elif self.n_channels == 3: + background = 255 - (mask_without_borders + borders) + mask = [mask_without_borders, borders, background] + else: + raise ValueError() + + mimwrite_with_meta( + self.out_dir / f"{sample.stem}.tiff", mask, {"compress": 9} + ) + + def process_all(self, pool: Pool): + images = list(self.in_dir.iterdir()) + tqdm_parallel_imap(self.preprocess, images, pool) + + +def main(args, _=None): + args = args.__dict__ + args.pop("command", None) + num_workers = args.pop("num_workers") + + with get_pool(num_workers) as p: + Preprocessor(**args).process_all(p) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/setup.cfg b/setup.cfg index 59a3765..f7b3f2d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,7 +5,7 @@ # - python libs (known_third_party) # - dl libs (known_dl) # - catalyst imports -known_third_party = imageio,jinja2,numpy,pandas,safitty,skimage +known_third_party = cv2,imageio,jinja2,numpy,pandas,safitty,shapely,skimage known_dl = albumentations,torch,torchvision known_first_party = catalyst sections=STDLIB,THIRDPARTY,DL,FIRSTPARTY,LOCALFOLDER diff --git a/src/callbacks/__init__.py b/src/callbacks/__init__.py index 62a7a82..a731fac 100644 --- a/src/callbacks/__init__.py +++ b/src/callbacks/__init__.py @@ -1,4 +1,12 @@ # flake8: noqa -from .io import OriginalImageSaverCallback, OverlayMaskImageSaverCallback -from .processing import RawMaskPostprocessingCallback +from .io import ( + InstanceCropSaverCallback, + OriginalImageSaverCallback, + OverlayMaskImageSaverCallback, +) +from .metrics import SegmentationMeanAPCallback +from .processing import ( + InstanceMaskPostprocessingCallback, + RawMaskPostprocessingCallback, +) diff --git a/src/callbacks/io.py b/src/callbacks/io.py index 41eba5b..f6bf962 100644 --- a/src/callbacks/io.py +++ b/src/callbacks/io.py @@ -3,9 +3,9 @@ import imageio import numpy as np -from catalyst.dl import Callback, CallbackOrder, State, utils +from catalyst.dl import Callback, CallbackNode, CallbackOrder, State, utils -from .utils import mask_to_overlay_image +from .utils import crop_by_masks, mask_to_overlay_image class OriginalImageSaverCallback(Callback): @@ -18,7 +18,7 @@ def __init__( input_key: str = "image", outpath_key: str = "name", ): - super().__init__(CallbackOrder.Logging) + super().__init__(order=CallbackOrder.Logging, node=CallbackNode.Master) self.output_dir = Path(output_dir) self.relative = relative self.filename_suffix = filename_suffix @@ -83,4 +83,43 @@ def on_batch_end(self, state: State): imageio.imwrite(fname, image) -__all__ = ["OriginalImageSaverCallback", "OverlayMaskImageSaverCallback"] +class InstanceCropSaverCallback(OriginalImageSaverCallback): + def __init__( + self, + output_dir: str, + relative: bool = True, + filename_extension: str = ".jpg", + input_key: str = "image", + output_key: str = "mask", + outpath_key: str = "name", + ): + super().__init__( + output_dir=output_dir, + relative=relative, + filename_suffix=filename_extension, + input_key=input_key, + outpath_key=outpath_key, + ) + self.output_key = output_key + + def on_batch_end(self, state: State): + names = state.batch_in[self.outpath_key] + images = state.batch_in[self.input_key] + masks = state.batch_out[self.output_key] + + images = utils.tensor_to_ndimage(images.detach().cpu()) + for name, image, masks_ in zip(names, images, masks): + instances = crop_by_masks(image, masks_) + + for index, crop in enumerate(instances): + filename = self.get_image_path( + state, name, suffix=f"_instance{index:02d}" + ) + imageio.imwrite(filename, crop) + + +__all__ = [ + "OriginalImageSaverCallback", + "OverlayMaskImageSaverCallback", + "InstanceCropSaverCallback", +] diff --git a/src/callbacks/metrics.py b/src/callbacks/metrics.py new file mode 100644 index 0000000..aad6d04 --- /dev/null +++ b/src/callbacks/metrics.py @@ -0,0 +1,83 @@ +import numpy as np + +from catalyst.dl import MetricCallback + + +def compute_ious_single_image(predicted_mask, gt_instance_masks): + instance_ids = np.unique(predicted_mask) + n_gt_instaces = gt_instance_masks.shape[0] + + all_ious = [] + + for id_ in instance_ids: + if id_ == 0: + # Skip background + continue + + predicted_instance_mask = predicted_mask == id_ + + sum_ = predicted_instance_mask.reshape( + 1, -1 + ) + gt_instance_masks.reshape(n_gt_instaces, -1) + + intersection = (sum_ == 2).sum(axis=1) + union = (sum_ > 0).sum(axis=1) + + ious = intersection / union + + all_ious.append(ious) + + all_ious = np.array(all_ious).reshape((len(all_ious), n_gt_instaces)) + + return all_ious + + +def map_from_ious(ious: np.ndarray, iou_thresholds: np.ndarray): + """ + Args: + ious (np.ndarray): array of shape n_pred x n_gt + iou_thresholds (np.ndarray): + """ + n_preds = ious.shape[0] + + fn_at_ious = ( + np.max(ious, axis=0, initial=0)[None, :] < iou_thresholds[:, None] + ) + fn_at_iou = np.sum(fn_at_ious, axis=1, initial=0) + + tp_at_ious = ( + np.max(ious, axis=0, initial=0)[None, :] > iou_thresholds[:, None] + ) + tp_at_iou = np.sum(tp_at_ious, axis=1, initial=0) + + metric_at_iou = tp_at_iou / (n_preds + fn_at_iou) + + return metric_at_iou.mean() + + +def mean_average_precision(outputs, targets, iou_thresholds): + batch_metrics = [] + for pred, gt in zip(outputs, targets): + ious = compute_ious_single_image(pred, gt.numpy()) + batch_metrics.append(map_from_ious(ious, iou_thresholds)) + return float(np.mean(batch_metrics)) + + +class SegmentationMeanAPCallback(MetricCallback): + def __init__( + self, + input_key: str = "imasks", + output_key: str = "instance_mask", + prefix: str = "mAP", + iou_thresholds=(0.5, 0.55, 0.6, 0.7, 0.75, 0.8, 0.9, 0.95), + ): + super().__init__( + prefix=prefix, + metric_fn=mean_average_precision, + input_key=input_key, + output_key=output_key, + iou_thresholds=np.array(iou_thresholds), + ) + + +__all__ = ["SegmentationMeanAPCallback"] diff --git a/src/callbacks/processing.py b/src/callbacks/processing.py index 7152850..7d9faed 100644 --- a/src/callbacks/processing.py +++ b/src/callbacks/processing.py @@ -1,8 +1,8 @@ import torch -from catalyst.dl import Callback, CallbackOrder, State +from catalyst.dl import Callback, CallbackNode, CallbackOrder, State -from .utils import encode_mask_with_color +from .utils import encode_mask_with_color, label_instances class RawMaskPostprocessingCallback(Callback): @@ -12,7 +12,7 @@ def __init__( input_key: str = "logits", output_key: str = "mask", ): - super().__init__(CallbackOrder.Internal) + super().__init__(order=CallbackOrder.Internal, node=CallbackNode.All) self.threshold = threshold self.input_key = input_key self.output_key = output_key @@ -22,8 +22,54 @@ def on_batch_end(self, state: State): output = torch.sigmoid(output).detach().cpu().numpy() state.batch_out[self.output_key] = encode_mask_with_color( - output, self.threshold + output, threshold=self.threshold ) -__all__ = ["RawMaskPostprocessingCallback"] +class InstanceMaskPostprocessingCallback(Callback): + def __init__( + self, + watershed_threshold: float = 0.5, + mask_threshold: float = 0.5, + input_key: str = "logits", + output_key: str = "instance_mask", + out_key_semantic: str = None, + out_key_border: str = None, + ): + super().__init__(CallbackOrder.Internal, node=CallbackNode.All) + self.watershed_threshold = watershed_threshold + self.mask_threshold = mask_threshold + self.input_key = input_key + self.output_key = output_key + self.out_key_semantic = out_key_semantic + self.out_key_border = out_key_border + + def on_batch_end(self, state: State): + output = state.batch_out[self.input_key] + + output = torch.sigmoid(output).detach().cpu() + semantic, border = output.chunk(2, -3) + + if self.out_key_semantic is not None: + state.batch_out[self.out_key_semantic] = encode_mask_with_color( + semantic.numpy(), threshold=self.mask_threshold + ) + + if self.out_key_border is not None: + state.batch_out[self.out_key_border] = ( + border.squeeze(-3).numpy() > self.watershed_threshold + ) + + state.batch_out[self.output_key] = label_instances( + semantic, + border, + watershed_threshold=self.watershed_threshold, + instance_mask_threshold=self.mask_threshold, + downscale_factor=1, + ) + + +__all__ = [ + "RawMaskPostprocessingCallback", + "InstanceMaskPostprocessingCallback", +] diff --git a/src/callbacks/utils.py b/src/callbacks/utils.py index 6002c37..782ef60 100644 --- a/src/callbacks/utils.py +++ b/src/callbacks/utils.py @@ -1,22 +1,29 @@ -from typing import List # isort:skip - +from typing import List, Tuple, Union # isort:skip +import cv2 import numpy as np +from shapely.geometry import LinearRing, MultiPoint from skimage.color import label2rgb +from skimage.measure import label, regionprops +from skimage.morphology import watershed import torch +import torch.nn.functional as F + +# types +Point = Tuple[int, int] +Quadrangle = Tuple[Point, Point, Point, Point] def encode_mask_with_color( semantic_masks: torch.Tensor, threshold: float = 0.5 ) -> List[np.ndarray]: """ - Args: semantic_masks (torch.Tensor): semantic mask batch tensor threshold (float): threshold for semantic masks + Returns: List[np.ndarray]: list of semantic masks - """ batch = [] for observation in semantic_masks: @@ -38,3 +45,159 @@ def mask_to_overlay_image( (image_with_overlay * 255).clip(0, 255).round().astype(np.uint8) ) return image_with_overlay + + +def label_instances( + semantic_masks: torch.Tensor, + border_masks: torch.Tensor, + watershed_threshold: float = 0.9, + instance_mask_threshold: float = 0.5, + downscale_factor: float = 4, + interpolation: str = "bilinear", +) -> List[np.ndarray]: + """ + Args: + semantic_masks (torch.Tensor): semantic mask batch tensor + border_masks (torch.Tensor): instance mask batch tensor + watershed_threshold (float): threshold for watershed markers + instance_mask_threshold (float): threshold for final instance masks + downscale_factor (float): mask downscaling factor + (to speed up processing) + interpolation (str): interpolation method + + Returns: + List[np.ndarray]: list of labeled instance masks, one per batch item + """ + bordered_masks = (semantic_masks - border_masks).clamp(min=0) + + scaling = 1 / downscale_factor + semantic_masks, bordered_masks = ( + F.interpolate( + mask.data.cpu(), + scale_factor=scaling, + mode=interpolation, + align_corners=False, + ) + .squeeze(-3) + .numpy() + for mask in (semantic_masks, bordered_masks) + ) + + result: List[np.ndarray] = [] + for semantic, bordered in zip(semantic_masks, bordered_masks): + watershed_marks = label(bordered > watershed_threshold, background=0) + instance_regions = watershed(-bordered, watershed_marks) + + instance_regions[semantic < instance_mask_threshold] = 0 + + result.append(instance_regions) + + return result + + +def _is_ccw(vertices: np.ndarray): + return LinearRing(vertices * [[1, -1]]).is_ccw + + +def get_rects_from_mask( + label_mask: np.ndarray, min_area_fraction=20 +) -> np.ndarray: + props = regionprops(label_mask) + + total_h, total_w = label_mask.shape + total_area = total_h * total_w + + result = [] + for p in props: + + if p.area / total_area < min_area_fraction: + continue + + coords = p.coords + coords = coords[:, ::-1] # row, col -> col, row + + rect = MultiPoint(coords).minimum_rotated_rectangle.exterior.coords + + rect = np.array(rect)[:4].astype(np.int32) + + if _is_ccw(rect): + rect = rect[::-1] + + result.append(rect) + + result = np.stack(result) if result else [] + + return result + + +def perspective_crop( + image: np.ndarray, + crop_coords: Union[Quadrangle, np.ndarray], + output_wh: Tuple[int, int], + border_color: Tuple[int, int, int] = (255, 255, 255), +): + width, height = output_wh + target_coords = ((0, 0), (width, 0), (width, height), (0, height)) + + transform_matrix = cv2.getPerspectiveTransform( + np.array(crop_coords, dtype=np.float32), + np.array(target_coords, dtype=np.float32), + ) + + result = cv2.warpPerspective( + image, + transform_matrix, + (width, height), + borderMode=cv2.BORDER_CONSTANT, + borderValue=(border_color), + ) + + return result + + +def perspective_crop_keep_ratio( + image: np.ndarray, + vertices: np.ndarray, + output_size: int = -1, + border_color: Tuple[int, int, int] = (255, 255, 255), +) -> np.ndarray: + """ + Crop some quadrilateral from image keeping it's aspect ratio + + Args: + image (np.ndarray): image numpy array + vertices (np.ndarray): numpy array with quadrilateral vertices coords + output_size (int): minimal side length of output image + (if -1 will be actual side of image) + border_color (Tuple[int, int, int]): + + Returns: + np.ndarray: image crop + """ + lenghts = np.linalg.norm(vertices - np.roll(vertices, -1, 0), axis=1) + + len_ab, len_bc, len_cd, len_da = lenghts.tolist() + + width = (len_ab + len_cd) / 2 + height = (len_bc + len_da) / 2 + + if output_size > 0: + scale = output_size / max(width, height) + width, height = (dim * scale for dim in (width, height)) + + width, height = round(width), round(height) + + crop = perspective_crop(image, vertices, (width, height), border_color) + + return crop + + +def crop_by_masks( + image: np.ndarray, mask: np.ndarray, image_size: int = 512 +) -> List[np.ndarray]: + crops = [ + perspective_crop_keep_ratio(image, rect, image_size) + for rect in get_rects_from_mask(mask) + ] + + return crops diff --git a/teamcity/binary.sh b/teamcity/binary.sh index ce3db40..8a84e7b 100644 --- a/teamcity/binary.sh +++ b/teamcity/binary.sh @@ -8,4 +8,4 @@ pip install -r requirements/requirements.txt # @TODO: fix server issue pip install torch==1.4.0 torchvision==0.5.0 -bash ./bin/tests/_check_binary.sh +bash ./bin/tests/check_binary.sh diff --git a/teamcity/instance.sh b/teamcity/instance.sh new file mode 100644 index 0000000..c1b2466 --- /dev/null +++ b/teamcity/instance.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +# Cause the script to exit if a single command fails +set -eo pipefail -v + +pip install -r requirements/requirements.txt + +bash ./bin/tests/check_instance.sh diff --git a/teamcity/semantic.sh b/teamcity/semantic.sh index 54b8635..e63a6c5 100644 --- a/teamcity/semantic.sh +++ b/teamcity/semantic.sh @@ -8,4 +8,4 @@ pip install -r requirements/requirements.txt # @TODO: fix server issue pip install torch==1.4.0 torchvision==0.5.0 -bash ./bin/tests/_check_semantic.sh +bash ./bin/tests/check_semantic.sh