diff --git a/Makefile b/Makefile
index 1d125d6..13759c8 100644
--- a/Makefile
+++ b/Makefile
@@ -1,7 +1,7 @@
 .PHONY: check-style codestyle docker-build clean
 
 check-style:
-	bash ./bin/_check_codestyle.sh -s
+	bash ./bin/codestyle/check_codestyle.sh -s
 
 codestyle:
 	pre-commit run
diff --git a/README.md b/README.md
index 767d61e..930f427 100644
--- a/README.md
+++ b/README.md
@@ -87,6 +87,12 @@ elif [[ "$DATASET" == "voc2012" ]]; then
     tar -xf VOCtrainval_11-May-2012.tar &>/dev/null
     mkdir -p ./data/origin/images/; mv VOCdevkit/VOC2012/JPEGImages/* $_
     mkdir -p ./data/origin/raw_masks; mv VOCdevkit/VOC2012/SegmentationClass/* $_
+elif [[ "$DATASET" == "dsb2018" ]]; then
+    # instance segmentation
+    # https://www.kaggle.com/c/data-science-bowl-2018
+    download-gdrive 1RCqaQZLziuq1Z4sbMpwD_WHjqR5cdPvh dsb2018_cleared_191109.tar.gz
+    tar -xf dsb2018_cleared_191109.tar.gz &>/dev/null
+    mv dsb2018_cleared_191109 ./data/origin
 fi
 ```
 
@@ -102,6 +108,11 @@ fi
 #### Data structure
 
 Make sure, that final folder with data has the required structure:
+
+
+Data structure for binary segmentation
+
+
 ```bash
 /path/to/your_dataset/
         images/
@@ -115,6 +126,66 @@ Make sure, that final folder with data has the required structure:
             ...
             mask_N
 ```
+where each `mask` is a binary image
+
+
+ 
+
+
+Data structure for semantic segmentation
+
+
+```bash
+/path/to/your_dataset/
+        images/
+            image_1
+            image_2
+            ...
+            image_N
+        raw_masks/
+            mask_1
+            mask_2
+            ...
+            mask_N
+```
+where each `mask` is an image with class encoded through colors e.g. [VOC2012](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/) dataset where `bicycle` class is encoded with green color and `bird` with olive
+
+
+ 
+
+
+Data structure for instance segmentation
+
+
+```bash
+/path/to/your_dataset/
+        images/
+            image_1
+            image_2
+            ...
+            image_M
+        raw_masks/
+            mask_1/
+                instance_1
+                instance_2
+                ...
+                instance_N
+            mask_2/
+                instance_1
+                instance_2
+                ...
+                instance_K
+            ...
+            mask_M/
+                instance_1
+                instance_2
+                ...
+                instance_Z
+```
+where each `mask` represented as a folder with instances images (one image per instance), and masks may consisting of a different number of instances e.g. [Data Science Bowl 2018](https://www.kaggle.com/c/data-science-bowl-2018) dataset
+
+
+ 
 
 #### Data location
 
@@ -161,35 +232,32 @@ We will initialize [Unet](https://arxiv.org/abs/1505.04597) model with a pre-tra
 CUDA_VISIBLE_DEVICES=0 \
 CUDNN_BENCHMARK="True" \
 CUDNN_DETERMINISTIC="True" \
-WORKDIR=./logs \
-DATADIR=./data/origin \
-IMAGE_SIZE=256 \
-CONFIG_TEMPLATE=./configs/templates/binary.yml \
-NUM_WORKERS=4 \
-BATCH_SIZE=256 \
-bash ./bin/catalyst-binary-segmentation-pipeline.sh
+bash ./bin/catalyst-binary-segmentation-pipeline.sh \
+    --workdir ./logs \
+    --datadir ./data/origin \
+    --max-image-size 256 \
+    --config-template ./configs/templates/binary.yml \
+    --num-workers 4 \
+    --batch-size 256
 ```
 
 #### Run in docker:
 
 ```bash
-export LOGDIR=$(pwd)/logs
 docker run -it --rm --shm-size 8G --runtime=nvidia \
-   -v $(pwd):/workspace/ \
-   -v $LOGDIR:/logdir/ \
-   -v $(pwd)/data/origin:/data \
-   -e "CUDA_VISIBLE_DEVICES=0" \
-   -e "USE_WANDB=1" \
-   -e "LOGDIR=/logdir" \
-   -e "CUDNN_BENCHMARK='True'" \
-   -e "CUDNN_DETERMINISTIC='True'" \
-   -e "WORKDIR=/logdir" \
-   -e "DATADIR=/data" \
-   -e "IMAGE_SIZE=256" \
-   -e "CONFIG_TEMPLATE=./configs/templates/binary.yml" \
-   -e "NUM_WORKERS=4" \
-   -e "BATCH_SIZE=256" \
-   catalyst-segmentation ./bin/catalyst-binary-segmentation-pipeline.sh
+    -v $(pwd):/workspace/ \
+    -v $(pwd)/logs:/logdir/ \
+    -v $(pwd)/data/origin:/data \
+    -e "CUDA_VISIBLE_DEVICES=0" \
+    -e "CUDNN_BENCHMARK='True'" \
+    -e "CUDNN_DETERMINISTIC='True'" \
+    catalyst-segmentation ./bin/catalyst-binary-segmentation-pipeline.sh \
+        --workdir /logdir \
+        --datadir /data \
+        --max-image-size 256 \
+        --config-template ./configs/templates/binary.yml \
+        --num-workers 4 \
+        --batch-size 256
 ```
 
 
@@ -205,54 +273,83 @@ docker run -it --rm --shm-size 8G --runtime=nvidia \
 CUDA_VISIBLE_DEVICES=0 \
 CUDNN_BENCHMARK="True" \
 CUDNN_DETERMINISTIC="True" \
-WORKDIR=./logs \
-DATADIR=./data/origin \
-IMAGE_SIZE=256 \
-CONFIG_TEMPLATE=./configs/templates/semantic.yml \
-NUM_WORKERS=4 \
-BATCH_SIZE=256 \
-bash ./bin/catalyst-semantic-segmentation-pipeline.sh
+bash ./bin/catalyst-semantic-segmentation-pipeline.sh \
+    --workdir ./logs \
+    --datadir ./data/origin \
+    --max-image-size 256 \
+    --config-template ./configs/templates/semantic.yml \
+    --num-workers 4 \
+    --batch-size 256
 ```
 
 #### Run in docker:
 
 ```bash
-export LOGDIR=$(pwd)/logs
 docker run -it --rm --shm-size 8G --runtime=nvidia \
-   -v $(pwd):/workspace/ \
-   -v $LOGDIR:/logdir/ \
-   -v $(pwd)/data/origin:/data \
-   -e "CUDA_VISIBLE_DEVICES=0" \
-   -e "USE_WANDB=1" \
-   -e "LOGDIR=/logdir" \
-   -e "CUDNN_BENCHMARK='True'" \
-   -e "CUDNN_DETERMINISTIC='True'" \
-   -e "WORKDIR=/logdir" \
-   -e "DATADIR=/data" \
-   -e "IMAGE_SIZE=256" \
-   -e "CONFIG_TEMPLATE=./configs/templates/semantic.yml" \
-   -e "NUM_WORKERS=4" \
-   -e "BATCH_SIZE=256" \
-   catalyst-segmentation ./bin/catalyst-semantic-segmentation-pipeline.sh
+    -v $(pwd):/workspace/ \
+    -v $(pwd)/logs:/logdir/ \
+    -v $(pwd)/data/origin:/data \
+    -e "CUDA_VISIBLE_DEVICES=0" \
+    -e "CUDNN_BENCHMARK='True'" \
+    -e "CUDNN_DETERMINISTIC='True'" \
+    catalyst-segmentation ./bin/catalyst-semantic-segmentation-pipeline.sh \
+        --workdir /logdir \
+        --datadir /data \
+        --max-image-size 256 \
+        --config-template ./configs/templates/semantic.yml \
+        --num-workers 4 \
+        --batch-size 256
 ```
 
 
 
 
-The pipeline is running and you don’t have to do anything else, it remains to wait for the best model!
-
-#### Visualizations
+
+Instance segmentation pipeline
+
 
-You can use [W&B](https://www.wandb.com/) account for visualisation right after `pip install wandb`:
+#### Run in local environment:
 
+```bash
+CUDA_VISIBLE_DEVICES=0 \
+CUDNN_BENCHMARK="True" \
+CUDNN_DETERMINISTIC="True" \
+bash ./bin/catalyst-semantic-segmentation-pipeline.sh \
+    --workdir ./logs \
+    --datadir ./data/origin \
+    --max-image-size 256 \
+    --config-template ./configs/templates/instance.yml \
+    --num-workers 4 \
+    --batch-size 256
 ```
-wandb: (1) Create a W&B account
-wandb: (2) Use an existing W&B account
-wandb: (3) Don't visualize my results
+
+#### Run in docker:
+
+```bash
+docker run -it --rm --shm-size 8G --runtime=nvidia \
+    -v $(pwd):/workspace/ \
+    -v $(pwd)/logs:/logdir/ \
+    -v $(pwd)/data/origin:/data \
+    -e "CUDA_VISIBLE_DEVICES=0" \
+    -e "CUDNN_BENCHMARK='True'" \
+    -e "CUDNN_DETERMINISTIC='True'" \
+    catalyst-segmentation ./bin/catalyst-instance-segmentation-pipeline.sh \
+        --workdir /logdir \
+        --datadir /data \
+        --max-image-size 256 \
+        --config-template ./configs/templates/instance.yml \
+        --num-workers 4 \
+        --batch-size 256
 ```
- -Tensorboard also can be used for visualisation:
+
 
-Tensorboard also can be used for visualisation:
+
+ 
+
+The pipeline is running and you don’t have to do anything else, it remains to wait for the best model!
+
+#### Visualizations
+
+Tensorboard can be used for visualisation:
 
 ```bash
 tensorboard --logdir=/catalyst.segmentation/logs
@@ -286,7 +383,7 @@ For your future experiments framework provides powerful configs allow to optimiz
 
 * Common settings of stages of training and model parameters can be found in `catalyst.segmentation/configs/_common.yml`.
     * `model_params`: detailed configuration of models, including:
-        * model, for instance `ResnetUnet`
+        * model, for instance `ResNetUnet`
         * detailed architecture description
         * using pretrained model
     * `stages`: you can configure training or inference in several stages with different hyperparameters. In our example:
diff --git a/bin/catalyst-binary-segmentation-pipeline.sh b/bin/catalyst-binary-segmentation-pipeline.sh
index cf616df..9499dac 100755
--- a/bin/catalyst-binary-segmentation-pipeline.sh
+++ b/bin/catalyst-binary-segmentation-pipeline.sh
@@ -4,7 +4,7 @@
 #author          :Sergey Kolesnikov, Yauheni Kachan
 #author_email    :scitator@gmail.com, yauheni.kachan@gmail.com
 #date            :20191016
-#version         :19.10.2
+#version         :20.03
 #==============================================================================
 
 set -e
diff --git a/bin/catalyst-instance-segmentation-pipeline.sh b/bin/catalyst-instance-segmentation-pipeline.sh
new file mode 100644
index 0000000..d4581d3
--- /dev/null
+++ b/bin/catalyst-instance-segmentation-pipeline.sh
@@ -0,0 +1,146 @@
+#!/usr/bin/env bash
+#title           :catalyst-instance-segmentation-pipeline
+#description     :catalyst.dl script for instance segmentation pipeline run
+#author          :Sergey Kolesnikov, Yauheni Kachan
+#author_email    :scitator@gmail.com, yauheni.kachan@gmail.com
+#date            :20191109
+#version         :20.03
+#==============================================================================
+
+set -e
+
+usage()
+{
+  cat << USAGE >&2
+Usage: $(basename "$0") [OPTION...] [catalyst-dl run args...]
+
+  -s, --skipdata                       Skip data preparation
+  -j, --num-workers NUM_WORKERS        Number of data loading/processing workers
+  -b, --batch-size BATCH_SIZE          Mini-batch size
+  --max-image-size MAX_IMAGE_SIZE      Target size of images e.g. 256
+  --config-template CONFIG_TEMPLATE    Model config to use
+  --datadir DATADIR
+  --workdir WORKDIR
+  catalyst-dl run args                 Execute \`catalyst-dl run\` with args
+
+Example:
+  CUDA_VISIBLE_DEVICES=0 \\
+  CUDNN_BENCHMARK="True" \\
+  CUDNN_DETERMINISTIC="True" \\
+  ./bin/catalyst-instance-segmentation-pipeline.sh \\
+    --workdir ./logs \\
+    --datadir ./data/origin \\
+    --max-image-size 256 \\
+    --config-template ./configs/templates/instance.yml \\
+    --num-workers 4 \\
+    --batch-size 256
+USAGE
+  exit 1
+}
+
+
+# ---- environment variables
+
+NUM_WORKERS=${NUM_WORKERS:=4}
+BATCH_SIZE=${BATCH_SIZE:=64}
+MAX_IMAGE_SIZE=${MAX_IMAGE_SIZE:=256}
+CONFIG_TEMPLATE=${CONFIG_TEMPLATE:="./configs/templates/instance.yml"}
+DATADIR=${DATADIR:="./data/origin"}
+WORKDIR=${WORKDIR:="./logs"}
+SKIPDATA=""
+_run_args=""
+while (( "$#" )); do
+  case "$1" in
+    -j|--num-workers)
+      NUM_WORKERS=$2
+      shift 2
+      ;;
+    -b|--batch-size)
+      BATCH_SIZE=$2
+      shift 2
+      ;;
+    --max-image-size)
+      MAX_IMAGE_SIZE=$2
+      shift 2
+      ;;
+    --config-template)
+      CONFIG_TEMPLATE=$2
+      shift 2
+      ;;
+    --datadir)
+      DATADIR=$2
+      shift 2
+      ;;
+    --workdir)
+      WORKDIR=$2
+      shift 2
+      ;;
+    -s|--skipdata)
+      SKIPDATA="true"
+      shift
+      ;;
+    -h|--help)
+      usage
+      ;;
+    *)
+      _run_args="${_run_args} $1"
+      shift
+      ;;
+  esac
+done
+
+date=$(date +%y%m%d-%H%M%S)
+postfix=$(openssl rand -hex 4)
+logname="${date}-${postfix}"
+export DATASET_DIR=${WORKDIR}/dataset
+export RAW_MASKS_DIR=${DATASET_DIR}/raw_masks
+export CONFIG_DIR=${WORKDIR}/configs-${logname}
+export LOGDIR=${WORKDIR}/logdir-${logname}
+
+for dir in ${WORKDIR} ${DATASET_DIR} ${CONFIG_DIR} ${LOGDIR}; do
+  mkdir -p ${dir}
+done
+
+
+# ---- data preparation
+
+if [[ -z "${SKIPDATA}" ]]; then
+  cp -R ${DATADIR}/* ${DATASET_DIR}/
+
+  mkdir -p ${DATASET_DIR}/masks
+  python scripts/process_instance_masks.py \
+    --in-dir ${RAW_MASKS_DIR} \
+    --out-dir ${DATASET_DIR}/masks \
+    --num-workers ${NUM_WORKERS}
+
+  python scripts/image2mask.py \
+    --in-dir ${DATASET_DIR} \
+    --out-dataset ${DATASET_DIR}/dataset_raw.csv
+
+  catalyst-data split-dataframe \
+    --in-csv ${DATASET_DIR}/dataset_raw.csv \
+    --n-folds=5 --train-folds=0,1,2,3 \
+    --out-csv=${DATASET_DIR}/dataset.csv
+fi
+
+
+# ---- config preparation
+
+python ./scripts/prepare_config.py \
+  --in-template=${CONFIG_TEMPLATE} \
+  --out-config=${CONFIG_DIR}/config.yml \
+  --expdir=./src \
+  --dataset-path=${DATASET_DIR} \
+  --num-classes=2 \
+  --num-workers=${NUM_WORKERS} \
+  --batch-size=${BATCH_SIZE} \
+  --max-image-size=${MAX_IMAGE_SIZE}
+
+cp -r ./configs/_common.yml ${CONFIG_DIR}/_common.yml
+
+
+# ---- model training
+
+catalyst-dl run \
+    -C ${CONFIG_DIR}/_common.yml ${CONFIG_DIR}/config.yml \
+    --logdir ${LOGDIR} ${_run_args}
diff --git a/bin/catalyst-semantic-segmentation-pipeline.sh b/bin/catalyst-semantic-segmentation-pipeline.sh
index c93b319..2e4f0bc 100755
--- a/bin/catalyst-semantic-segmentation-pipeline.sh
+++ b/bin/catalyst-semantic-segmentation-pipeline.sh
@@ -4,7 +4,7 @@
 #author          :Sergey Kolesnikov, Yauheni Kachan
 #author_email    :scitator@gmail.com, yauheni.kachan@gmail.com
 #date            :20191016
-#version         :19.10.2
+#version         :20.03
 #==============================================================================
 
 set -e
diff --git a/bin/tests/_check_binary.sh b/bin/tests/check_binary.sh
similarity index 100%
rename from bin/tests/_check_binary.sh
rename to bin/tests/check_binary.sh
diff --git a/bin/tests/check_instance.sh b/bin/tests/check_instance.sh
new file mode 100644
index 0000000..83c0a4a
--- /dev/null
+++ b/bin/tests/check_instance.sh
@@ -0,0 +1,58 @@
+#!/usr/bin/env bash
+
+# Cause the script to exit if a single command fails
+set -eo pipefail -v
+
+
+###################################  DATA  ####################################
+rm -rf ./data
+
+# load the data
+mkdir -p ./data
+
+download-gdrive 1RCqaQZLziuq1Z4sbMpwD_WHjqR5cdPvh dsb2018_cleared_191109.tar.gz
+tar -xf dsb2018_cleared_191109.tar.gz &>/dev/null
+mv dsb2018_cleared_191109 ./data/origin
+
+
+################################  pipeline 00  ################################
+rm -rf ./logs
+
+
+################################  pipeline 01  ################################
+CUDA_VISIBLE_DEVICES="" \
+CUDNN_BENCHMARK="True" \
+CUDNN_DETERMINISTIC="True" \
+bash ./bin/catalyst-instance-segmentation-pipeline.sh \
+  --config-template ./configs/templates/instance.yml \
+  --workdir ./logs \
+  --datadir ./data/origin \
+  --num-workers 0 \
+  --batch-size 2 \
+  --max-image-size 256 \
+  --check
+
+
+python -c """
+import pathlib
+from safitty import Safict
+
+folder = list(pathlib.Path('./logs/').glob('logdir-*'))[0]
+metrics = Safict.load(f'{folder}/checkpoints/_metrics.json')
+
+aggregated_loss = metrics.get('best', 'loss')
+iou_soft = metrics.get('best', 'iou_soft')
+iou_hard = metrics.get('best', 'iou_hard')
+
+print(aggregated_loss)
+print(iou_soft)
+print(iou_hard)
+
+assert aggregated_loss < 0.9
+assert iou_soft > 0.04
+assert iou_hard > 0.1
+"""
+
+
+################################  pipeline 99  ################################
+rm -rf ./logs
diff --git a/bin/tests/_check_semantic.sh b/bin/tests/check_semantic.sh
similarity index 100%
rename from bin/tests/_check_semantic.sh
rename to bin/tests/check_semantic.sh
diff --git a/configs/templates/instance.yml b/configs/templates/instance.yml
new file mode 100644
index 0000000..ac6d919
--- /dev/null
+++ b/configs/templates/instance.yml
@@ -0,0 +1,196 @@
+shared:
+  image_size: &image_size {{ max_image_size }}
+
+model_params:
+  num_classes: {{ num_classes }}
+
+args:
+  expdir: {{ expdir }}
+
+stages:
+
+  state_params:
+    main_metric: &reduced_metric iou_hard
+    minimize_metric: False
+
+  data_params:
+    num_workers: {{ num_workers }}
+    batch_size: {{ batch_size }}
+    per_gpu_scaling: True
+    in_csv_train: {{ dataset_path }}/dataset_train.csv
+    in_csv_valid: {{ dataset_path }}/dataset_valid.csv
+    datapath: {{ dataset_path }}
+
+  transform_params:
+    _key_value: True
+
+    train:
+      transform: A.Compose
+      transforms:
+        - &pre_transforms
+          transform: A.Compose
+          transforms:
+            - transform: A.LongestMaxSize
+              max_size: *image_size
+            - transform: A.PadIfNeeded
+              min_height: *image_size
+              min_width: *image_size
+              border_mode: 0  # cv2.BORDER_CONSTANT
+              value: 0
+        - &hard_transforms
+          transform: A.Compose
+          transforms:
+            - transform: A.ShiftScaleRotate
+              shift_limit: 0.1
+              scale_limit: 0.1
+              rotate_limit: 15
+              border_mode: 2  # cv2.BORDER_REFLECT
+            - transform: A.OneOf
+              transforms:
+                - transform: A.HueSaturationValue
+                - transform: A.ToGray
+                - transform: A.RGBShift
+                - transform: A.ChannelShuffle
+            - transform: A.RandomBrightnessContrast
+              brightness_limit: 0.5
+              contrast_limit: 0.5
+            - transform: A.RandomGamma
+            - transform: A.CLAHE
+            - transform: A.ImageCompression
+              quality_lower: 50
+        - &post_transforms
+          transform: A.Compose
+          transforms:
+            - transform: A.Normalize
+            - transform: C.ToTensor
+    valid:
+      transform: A.Compose
+      transforms:
+        - *pre_transforms
+        - *post_transforms
+    infer:
+      transform: A.Compose
+      transforms:
+        - *pre_transforms
+        - *post_transforms
+
+  criterion_params:
+    _key_value: True
+
+    bce:
+      criterion: BCEWithLogitsLoss
+    dice:
+      criterion: DiceLoss
+    iou:
+      criterion: IoULoss
+
+  callbacks_params:
+    loss_bce:
+      callback: CriterionCallback
+      input_key: mask
+      output_key: logits
+      prefix: loss_bce
+      criterion_key: bce
+      multiplier: 1.0
+    loss_dice:
+      callback: CriterionCallback
+      input_key: mask
+      output_key: logits
+      prefix: loss_dice
+      criterion_key: dice
+      multiplier: 1.0
+    loss_iou:
+      callback: CriterionCallback
+      input_key: mask
+      output_key: logits
+      prefix: loss_iou
+      criterion_key: iou
+      multiplier: 1.0
+
+    loss_aggregator:
+      callback: MetricAggregationCallback
+      prefix: &aggregated_loss loss
+      metrics: [loss_bce, loss_dice, loss_iou]
+      mode: "mean"
+      multiplier: 1.0
+
+    raw_processor:
+      callback: RawMaskPostprocessingCallback
+    instance_extractor:
+      callback: InstanceMaskPostprocessingCallback
+      watershed_threshold: 0.9
+      mask_threshold: 0.8
+      output_key: instance_mask
+      out_key_semantic: semantic_mask
+      out_key_border: border_mask
+
+    iou_soft:
+      callback: IouCallback
+      input_key: mask
+      output_key: logits
+      prefix: iou_soft
+    iou_hard:
+      callback: IouCallback
+      input_key: mask
+      output_key: logits
+      prefix: iou_hard
+      threshold: 0.5
+
+    optimizer:
+      callback: OptimizerCallback
+      loss_key: *aggregated_loss
+    scheduler:
+      callback: SchedulerCallback
+      reduced_metric: *reduced_metric
+    saver:
+      callback: CheckpointCallback
+
+  # infer:
+  #
+  #   data_params:
+  #     num_workers: {{ num_workers }}
+  #     batch_size: {{ batch_size }}
+  #     per_gpu_scaling: True
+  #     in_csv: null
+  #     in_csv_train: null
+  #     in_csv_valid: {{ dataset_path }}/dataset_valid.csv
+  #     in_csv_infer: {{ dataset_path }}/dataset_train.csv
+  #     datapath: {{ dataset_path }}
+  #
+  #   callbacks_params:
+  #     loader:
+  #       callback: CheckpointCallback
+  #
+  #     raw_processor:
+  #       callback: RawMaskPostprocessingCallback
+  #     instance_extractor:
+  #       callback: InstanceMaskPostprocessingCallback
+  #       watershed_threshold: 0.9
+  #       mask_threshold: 0.8
+  #       output_key: instance_mask
+  #       out_key_semantic: semantic_mask
+  #       out_key_border: border_mask
+  #
+  #     image_saver:
+  #       callback: OriginalImageSaverCallback
+  #       output_dir: infer
+  #     saver_mask:
+  #       callback: OverlayMaskImageSaverCallback
+  #       output_dir: infer
+  #       filename_suffix: _01_raw_mask
+  #       output_key: mask
+  #     saver_semantic:
+  #       callback: OverlayMaskImageSaverCallback
+  #       output_dir: infer
+  #       filename_suffix: _02_semantic_mask
+  #       output_key: semantic_mask
+  #     saver_border:
+  #       callback: OverlayMaskImageSaverCallback
+  #       output_dir: infer
+  #       filename_suffix: _03_border_mask
+  #       output_key: border_mask
+  #     saver_instance:
+  #       callback: OverlayMaskImageSaverCallback
+  #       output_dir: infer
+  #       filename_suffix: _04_instance_mask
+  #       output_key: instance_mask
diff --git a/pics/wandb_metrics.png b/pics/wandb_metrics.png
deleted file mode 100644
index 0dfd0a1..0000000
Binary files a/pics/wandb_metrics.png and /dev/null differ
diff --git a/requirements/requirements-docker.txt b/requirements/requirements-docker.txt
index 7141304..d558d60 100644
--- a/requirements/requirements-docker.txt
+++ b/requirements/requirements-docker.txt
@@ -1,6 +1,8 @@
 albumentations==0.4.5
 jinja2
+opencv-python>=4.1.1
 pandas>=0.22
 safitty>=1.2.3
 segmentation-models-pytorch==0.1.0
+shapely[vectorized]==1.7.0
 tqdm>=4.33.0
diff --git a/requirements/requirements.txt b/requirements/requirements.txt
index 2eb9541..c4f19eb 100644
--- a/requirements/requirements.txt
+++ b/requirements/requirements.txt
@@ -1,5 +1,7 @@
 catalyst[cv]==20.5
 jinja2
+opencv-python>=4.1.1
 pandas>=0.22
 safitty>=1.2.3
+shapely[vectorized]==1.7.0
 tqdm>=4.33.0
diff --git a/scripts/process_instance_masks.py b/scripts/process_instance_masks.py
new file mode 100644
index 0000000..c64c7b6
--- /dev/null
+++ b/scripts/process_instance_masks.py
@@ -0,0 +1,195 @@
+from typing import List  # isort:skip
+import argparse
+from multiprocessing.pool import Pool
+from pathlib import Path
+
+import numpy as np
+from skimage import measure, morphology
+
+from catalyst.utils import (
+    get_pool,
+    has_image_extension,
+    imread,
+    mimwrite_with_meta,
+    tqdm_parallel_imap,
+)
+
+
+def build_args(parser):
+    parser.add_argument(
+        "--in-dir", type=Path, required=True, help="Raw masks folder path"
+    )
+    parser.add_argument(
+        "--out-dir",
+        type=Path,
+        required=True,
+        help="Processed masks folder path",
+    )
+    parser.add_argument("--threshold", type=float, default=0.0)
+    parser.add_argument(
+        "--n-channels",
+        type=int,
+        choices={2, 3},
+        default=2,
+        help="Number of channels in output masks",
+    )
+    parser.add_argument(
+        "--num-workers",
+        default=1,
+        type=int,
+        help="Number of workers to parallel the processing",
+    )
+
+    return parser
+
+
+def parse_args():
+    parser = argparse.ArgumentParser()
+    build_args(parser)
+    args = parser.parse_args()
+    return args
+
+
+def mim_interaction(mim: List[np.ndarray], threshold: float = 0) -> np.ndarray:
+    result = np.zeros_like(mim[0], dtype=np.uint8)
+    result[np.stack(mim, axis=-1).max(axis=-1) > threshold] = 255
+    return result
+
+
+def mim_color_encode(
+    mim: List[np.ndarray], threshold: float = 0
+) -> np.ndarray:
+    result = np.zeros_like(mim[0], dtype=np.uint8)
+    for index, im in enumerate(mim, start=1):
+        result[im > threshold] = index
+
+    return result
+
+
+class Preprocessor:
+    def __init__(
+        self,
+        in_dir: Path,
+        out_dir: Path,
+        threshold: float = 0.0,
+        n_channels: int = 2,
+    ):
+        """
+        Args:
+            in_dir (Path): raw masks folder path, input folder structure
+                should be following:
+                    in_path  # dir with raw masks
+                    |-- sample_1
+                    |   |-- instance_1
+                    |   |-- instance_2
+                    |   ..
+                    |   `-- instance_N
+                    |-- sample_1
+                    |   |-- instance_1
+                    |   |-- instance_2
+                    |   ..
+                    |   `-- instance_K
+                    ..
+                    `-- sample_M
+                        |-- instance_1
+                        |-- instance_2
+                        ..
+                        `-- instance_Z
+            out_dir (Path): processed masks folder path, output folder
+                structure will be following:
+                    out_path
+                        |-- sample_1.tiff  # image of shape HxWxN
+                        |-- sample_2.tiff  # image of shape HxWxK
+                        ..
+                        `-- sample_M.tiff  # image of shape HxWxZ
+            threshold (float):
+            n_channels (int): number of channels in output masks,
+                see https://www.kaggle.com/c/data-science-bowl-2018/discussion/54741  # noqa: E501, W505
+        """
+        self.in_dir = in_dir
+        self.out_dir = out_dir
+        self.threshold = threshold
+        self.n_channels = n_channels
+
+    def preprocess(self, sample: Path):
+        masks = [
+            imread(filename, grayscale=True, expand_dims=False)
+            for filename in sample.iterdir()
+            if has_image_extension(str(filename))
+        ]
+        labels = mim_color_encode(masks, self.threshold)
+
+        scaled_blobs = morphology.dilation(labels > 0, morphology.square(9))
+        watersheded_blobs = (
+            morphology.watershed(
+                scaled_blobs, labels, mask=scaled_blobs, watershed_line=True
+            )
+            > 0
+        )
+        watershed_lines = scaled_blobs ^ (watersheded_blobs)
+        scaled_watershed_lines = morphology.dilation(
+            watershed_lines, morphology.square(7)
+        )
+
+        props = measure.regionprops(labels)
+        max_area = max(p.area for p in props)
+
+        mask_without_borders = mim_interaction(masks, self.threshold)
+        borders = np.zeros_like(labels, dtype=np.uint8)
+        for y0 in range(labels.shape[0]):
+            for x0 in range(labels.shape[1]):
+                if not scaled_watershed_lines[y0, x0]:
+                    continue
+
+                if labels[y0, x0] == 0:
+                    if max_area > 4000:
+                        sz = 6
+                    else:
+                        sz = 3
+                else:
+                    if props[labels[y0, x0] - 1].area < 300:
+                        sz = 1
+                    elif props[labels[y0, x0] - 1].area < 2000:
+                        sz = 2
+                    else:
+                        sz = 3
+
+                uniq = np.unique(
+                    labels[
+                        max(0, y0 - sz) : min(labels.shape[0], y0 + sz + 1),
+                        max(0, x0 - sz) : min(labels.shape[1], x0 + sz + 1),
+                    ]
+                )
+                if len(uniq[uniq > 0]) > 1:
+                    borders[y0, x0] = 255
+                    mask_without_borders[y0, x0] = 0
+
+        if self.n_channels == 2:
+            mask = [mask_without_borders, borders]
+        elif self.n_channels == 3:
+            background = 255 - (mask_without_borders + borders)
+            mask = [mask_without_borders, borders, background]
+        else:
+            raise ValueError()
+
+        mimwrite_with_meta(
+            self.out_dir / f"{sample.stem}.tiff", mask, {"compress": 9}
+        )
+
+    def process_all(self, pool: Pool):
+        images = list(self.in_dir.iterdir())
+        tqdm_parallel_imap(self.preprocess, images, pool)
+
+
+def main(args, _=None):
+    args = args.__dict__
+    args.pop("command", None)
+    num_workers = args.pop("num_workers")
+
+    with get_pool(num_workers) as p:
+        Preprocessor(**args).process_all(p)
+
+
+if __name__ == "__main__":
+    args = parse_args()
+    main(args)
diff --git a/setup.cfg b/setup.cfg
index 59a3765..f7b3f2d 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -5,7 +5,7 @@
 #  - python libs (known_third_party)
 #  - dl libs (known_dl)
 #  - catalyst imports
-known_third_party = imageio,jinja2,numpy,pandas,safitty,skimage
+known_third_party = cv2,imageio,jinja2,numpy,pandas,safitty,shapely,skimage
 known_dl = albumentations,torch,torchvision
 known_first_party = catalyst
 sections=STDLIB,THIRDPARTY,DL,FIRSTPARTY,LOCALFOLDER
diff --git a/src/callbacks/__init__.py b/src/callbacks/__init__.py
index 62a7a82..a731fac 100644
--- a/src/callbacks/__init__.py
+++ b/src/callbacks/__init__.py
@@ -1,4 +1,12 @@
 # flake8: noqa
 
-from .io import OriginalImageSaverCallback, OverlayMaskImageSaverCallback
-from .processing import RawMaskPostprocessingCallback
+from .io import (
+    InstanceCropSaverCallback,
+    OriginalImageSaverCallback,
+    OverlayMaskImageSaverCallback,
+)
+from .metrics import SegmentationMeanAPCallback
+from .processing import (
+    InstanceMaskPostprocessingCallback,
+    RawMaskPostprocessingCallback,
+)
diff --git a/src/callbacks/io.py b/src/callbacks/io.py
index 41eba5b..f6bf962 100644
--- a/src/callbacks/io.py
+++ b/src/callbacks/io.py
@@ -3,9 +3,9 @@
 import imageio
 import numpy as np
 
-from catalyst.dl import Callback, CallbackOrder, State, utils
+from catalyst.dl import Callback, CallbackNode, CallbackOrder, State, utils
 
-from .utils import mask_to_overlay_image
+from .utils import crop_by_masks, mask_to_overlay_image
 
 
 class OriginalImageSaverCallback(Callback):
@@ -18,7 +18,7 @@ def __init__(
         input_key: str = "image",
         outpath_key: str = "name",
     ):
-        super().__init__(CallbackOrder.Logging)
+        super().__init__(order=CallbackOrder.Logging, node=CallbackNode.Master)
         self.output_dir = Path(output_dir)
         self.relative = relative
         self.filename_suffix = filename_suffix
@@ -83,4 +83,43 @@ def on_batch_end(self, state: State):
             imageio.imwrite(fname, image)
 
 
-__all__ = ["OriginalImageSaverCallback", "OverlayMaskImageSaverCallback"]
+class InstanceCropSaverCallback(OriginalImageSaverCallback):
+    def __init__(
+        self,
+        output_dir: str,
+        relative: bool = True,
+        filename_extension: str = ".jpg",
+        input_key: str = "image",
+        output_key: str = "mask",
+        outpath_key: str = "name",
+    ):
+        super().__init__(
+            output_dir=output_dir,
+            relative=relative,
+            filename_suffix=filename_extension,
+            input_key=input_key,
+            outpath_key=outpath_key,
+        )
+        self.output_key = output_key
+
+    def on_batch_end(self, state: State):
+        names = state.batch_in[self.outpath_key]
+        images = state.batch_in[self.input_key]
+        masks = state.batch_out[self.output_key]
+
+        images = utils.tensor_to_ndimage(images.detach().cpu())
+        for name, image, masks_ in zip(names, images, masks):
+            instances = crop_by_masks(image, masks_)
+
+            for index, crop in enumerate(instances):
+                filename = self.get_image_path(
+                    state, name, suffix=f"_instance{index:02d}"
+                )
+                imageio.imwrite(filename, crop)
+
+
+__all__ = [
+    "OriginalImageSaverCallback",
+    "OverlayMaskImageSaverCallback",
+    "InstanceCropSaverCallback",
+]
diff --git a/src/callbacks/metrics.py b/src/callbacks/metrics.py
new file mode 100644
index 0000000..aad6d04
--- /dev/null
+++ b/src/callbacks/metrics.py
@@ -0,0 +1,83 @@
+import numpy as np
+
+from catalyst.dl import MetricCallback
+
+
+def compute_ious_single_image(predicted_mask, gt_instance_masks):
+    instance_ids = np.unique(predicted_mask)
+    n_gt_instaces = gt_instance_masks.shape[0]
+
+    all_ious = []
+
+    for id_ in instance_ids:
+        if id_ == 0:
+            # Skip background
+            continue
+
+        predicted_instance_mask = predicted_mask == id_
+
+        sum_ = predicted_instance_mask.reshape(
+            1, -1
+        ) + gt_instance_masks.reshape(n_gt_instaces, -1)
+
+        intersection = (sum_ == 2).sum(axis=1)
+        union = (sum_ > 0).sum(axis=1)
+
+        ious = intersection / union
+
+        all_ious.append(ious)
+
+    all_ious = np.array(all_ious).reshape((len(all_ious), n_gt_instaces))
+
+    return all_ious
+
+
+def map_from_ious(ious: np.ndarray, iou_thresholds: np.ndarray):
+    """
+    Args:
+        ious (np.ndarray): array of shape n_pred x n_gt
+        iou_thresholds (np.ndarray):
+    """
+    n_preds = ious.shape[0]
+
+    fn_at_ious = (
+        np.max(ious, axis=0, initial=0)[None, :] < iou_thresholds[:, None]
+    )
+    fn_at_iou = np.sum(fn_at_ious, axis=1, initial=0)
+
+    tp_at_ious = (
+        np.max(ious, axis=0, initial=0)[None, :] > iou_thresholds[:, None]
+    )
+    tp_at_iou = np.sum(tp_at_ious, axis=1, initial=0)
+
+    metric_at_iou = tp_at_iou / (n_preds + fn_at_iou)
+
+    return metric_at_iou.mean()
+
+
+def mean_average_precision(outputs, targets, iou_thresholds):
+    batch_metrics = []
+    for pred, gt in zip(outputs, targets):
+        ious = compute_ious_single_image(pred, gt.numpy())
+        batch_metrics.append(map_from_ious(ious, iou_thresholds))
+    return float(np.mean(batch_metrics))
+
+
+class SegmentationMeanAPCallback(MetricCallback):
+    def __init__(
+        self,
+        input_key: str = "imasks",
+        output_key: str = "instance_mask",
+        prefix: str = "mAP",
+        iou_thresholds=(0.5, 0.55, 0.6, 0.7, 0.75, 0.8, 0.9, 0.95),
+    ):
+        super().__init__(
+            prefix=prefix,
+            metric_fn=mean_average_precision,
+            input_key=input_key,
+            output_key=output_key,
+            iou_thresholds=np.array(iou_thresholds),
+        )
+
+
+__all__ = ["SegmentationMeanAPCallback"]
diff --git a/src/callbacks/processing.py b/src/callbacks/processing.py
index 7152850..7d9faed 100644
--- a/src/callbacks/processing.py
+++ b/src/callbacks/processing.py
@@ -1,8 +1,8 @@
 import torch
 
-from catalyst.dl import Callback, CallbackOrder, State
+from catalyst.dl import Callback, CallbackNode, CallbackOrder, State
 
-from .utils import encode_mask_with_color
+from .utils import encode_mask_with_color, label_instances
 
 
 class RawMaskPostprocessingCallback(Callback):
@@ -12,7 +12,7 @@ def __init__(
         input_key: str = "logits",
         output_key: str = "mask",
     ):
-        super().__init__(CallbackOrder.Internal)
+        super().__init__(order=CallbackOrder.Internal, node=CallbackNode.All)
         self.threshold = threshold
         self.input_key = input_key
         self.output_key = output_key
@@ -22,8 +22,54 @@ def on_batch_end(self, state: State):
 
         output = torch.sigmoid(output).detach().cpu().numpy()
         state.batch_out[self.output_key] = encode_mask_with_color(
-            output, self.threshold
+            output, threshold=self.threshold
         )
 
 
-__all__ = ["RawMaskPostprocessingCallback"]
+class InstanceMaskPostprocessingCallback(Callback):
+    def __init__(
+        self,
+        watershed_threshold: float = 0.5,
+        mask_threshold: float = 0.5,
+        input_key: str = "logits",
+        output_key: str = "instance_mask",
+        out_key_semantic: str = None,
+        out_key_border: str = None,
+    ):
+        super().__init__(CallbackOrder.Internal, node=CallbackNode.All)
+        self.watershed_threshold = watershed_threshold
+        self.mask_threshold = mask_threshold
+        self.input_key = input_key
+        self.output_key = output_key
+        self.out_key_semantic = out_key_semantic
+        self.out_key_border = out_key_border
+
+    def on_batch_end(self, state: State):
+        output = state.batch_out[self.input_key]
+
+        output = torch.sigmoid(output).detach().cpu()
+        semantic, border = output.chunk(2, -3)
+
+        if self.out_key_semantic is not None:
+            state.batch_out[self.out_key_semantic] = encode_mask_with_color(
+                semantic.numpy(), threshold=self.mask_threshold
+            )
+
+        if self.out_key_border is not None:
+            state.batch_out[self.out_key_border] = (
+                border.squeeze(-3).numpy() > self.watershed_threshold
+            )
+
+        state.batch_out[self.output_key] = label_instances(
+            semantic,
+            border,
+            watershed_threshold=self.watershed_threshold,
+            instance_mask_threshold=self.mask_threshold,
+            downscale_factor=1,
+        )
+
+
+__all__ = [
+    "RawMaskPostprocessingCallback",
+    "InstanceMaskPostprocessingCallback",
+]
diff --git a/src/callbacks/utils.py b/src/callbacks/utils.py
index 6002c37..782ef60 100644
--- a/src/callbacks/utils.py
+++ b/src/callbacks/utils.py
@@ -1,22 +1,29 @@
-from typing import List  # isort:skip
-
+from typing import List, Tuple, Union  # isort:skip
+import cv2
 import numpy as np
+from shapely.geometry import LinearRing, MultiPoint
 from skimage.color import label2rgb
+from skimage.measure import label, regionprops
+from skimage.morphology import watershed
 
 import torch
+import torch.nn.functional as F
+
+# types
+Point = Tuple[int, int]
+Quadrangle = Tuple[Point, Point, Point, Point]
 
 
 def encode_mask_with_color(
     semantic_masks: torch.Tensor, threshold: float = 0.5
 ) -> List[np.ndarray]:
     """
-
     Args:
         semantic_masks (torch.Tensor): semantic mask batch tensor
         threshold (float): threshold for semantic masks
+
     Returns:
         List[np.ndarray]: list of semantic masks
-
     """
     batch = []
     for observation in semantic_masks:
@@ -38,3 +45,159 @@ def mask_to_overlay_image(
         (image_with_overlay * 255).clip(0, 255).round().astype(np.uint8)
     )
     return image_with_overlay
+
+
+def label_instances(
+    semantic_masks: torch.Tensor,
+    border_masks: torch.Tensor,
+    watershed_threshold: float = 0.9,
+    instance_mask_threshold: float = 0.5,
+    downscale_factor: float = 4,
+    interpolation: str = "bilinear",
+) -> List[np.ndarray]:
+    """
+    Args:
+        semantic_masks (torch.Tensor): semantic mask batch tensor
+        border_masks (torch.Tensor):  instance mask batch tensor
+        watershed_threshold (float): threshold for watershed markers
+        instance_mask_threshold (float): threshold for final instance masks
+        downscale_factor (float): mask downscaling factor
+            (to speed up processing)
+        interpolation (str): interpolation method
+
+    Returns:
+        List[np.ndarray]: list of labeled instance masks, one per batch item
+    """
+    bordered_masks = (semantic_masks - border_masks).clamp(min=0)
+
+    scaling = 1 / downscale_factor
+    semantic_masks, bordered_masks = (
+        F.interpolate(
+            mask.data.cpu(),
+            scale_factor=scaling,
+            mode=interpolation,
+            align_corners=False,
+        )
+        .squeeze(-3)
+        .numpy()
+        for mask in (semantic_masks, bordered_masks)
+    )
+
+    result: List[np.ndarray] = []
+    for semantic, bordered in zip(semantic_masks, bordered_masks):
+        watershed_marks = label(bordered > watershed_threshold, background=0)
+        instance_regions = watershed(-bordered, watershed_marks)
+
+        instance_regions[semantic < instance_mask_threshold] = 0
+
+        result.append(instance_regions)
+
+    return result
+
+
+def _is_ccw(vertices: np.ndarray):
+    return LinearRing(vertices * [[1, -1]]).is_ccw
+
+
+def get_rects_from_mask(
+    label_mask: np.ndarray, min_area_fraction=20
+) -> np.ndarray:
+    props = regionprops(label_mask)
+
+    total_h, total_w = label_mask.shape
+    total_area = total_h * total_w
+
+    result = []
+    for p in props:
+
+        if p.area / total_area < min_area_fraction:
+            continue
+
+        coords = p.coords
+        coords = coords[:, ::-1]  # row, col -> col, row
+
+        rect = MultiPoint(coords).minimum_rotated_rectangle.exterior.coords
+
+        rect = np.array(rect)[:4].astype(np.int32)
+
+        if _is_ccw(rect):
+            rect = rect[::-1]
+
+        result.append(rect)
+
+    result = np.stack(result) if result else []
+
+    return result
+
+
+def perspective_crop(
+    image: np.ndarray,
+    crop_coords: Union[Quadrangle, np.ndarray],
+    output_wh: Tuple[int, int],
+    border_color: Tuple[int, int, int] = (255, 255, 255),
+):
+    width, height = output_wh
+    target_coords = ((0, 0), (width, 0), (width, height), (0, height))
+
+    transform_matrix = cv2.getPerspectiveTransform(
+        np.array(crop_coords, dtype=np.float32),
+        np.array(target_coords, dtype=np.float32),
+    )
+
+    result = cv2.warpPerspective(
+        image,
+        transform_matrix,
+        (width, height),
+        borderMode=cv2.BORDER_CONSTANT,
+        borderValue=(border_color),
+    )
+
+    return result
+
+
+def perspective_crop_keep_ratio(
+    image: np.ndarray,
+    vertices: np.ndarray,
+    output_size: int = -1,
+    border_color: Tuple[int, int, int] = (255, 255, 255),
+) -> np.ndarray:
+    """
+    Crop some quadrilateral from image keeping it's aspect ratio
+
+    Args:
+        image (np.ndarray): image numpy array
+        vertices (np.ndarray): numpy array with quadrilateral vertices coords
+        output_size (int): minimal side length of output image
+            (if -1 will be actual side of image)
+        border_color (Tuple[int, int, int]):
+
+    Returns:
+        np.ndarray: image crop
+    """
+    lenghts = np.linalg.norm(vertices - np.roll(vertices, -1, 0), axis=1)
+
+    len_ab, len_bc, len_cd, len_da = lenghts.tolist()
+
+    width = (len_ab + len_cd) / 2
+    height = (len_bc + len_da) / 2
+
+    if output_size > 0:
+        scale = output_size / max(width, height)
+        width, height = (dim * scale for dim in (width, height))
+
+    width, height = round(width), round(height)
+
+    crop = perspective_crop(image, vertices, (width, height), border_color)
+
+    return crop
+
+
+def crop_by_masks(
+    image: np.ndarray, mask: np.ndarray, image_size: int = 512
+) -> List[np.ndarray]:
+    crops = [
+        perspective_crop_keep_ratio(image, rect, image_size)
+        for rect in get_rects_from_mask(mask)
+    ]
+
+    return crops
diff --git a/teamcity/binary.sh b/teamcity/binary.sh
index ce3db40..8a84e7b 100644
--- a/teamcity/binary.sh
+++ b/teamcity/binary.sh
@@ -8,4 +8,4 @@ pip install -r requirements/requirements.txt
 # @TODO: fix server issue
 pip install torch==1.4.0 torchvision==0.5.0
 
-bash ./bin/tests/_check_binary.sh
+bash ./bin/tests/check_binary.sh
diff --git a/teamcity/instance.sh b/teamcity/instance.sh
new file mode 100644
index 0000000..c1b2466
--- /dev/null
+++ b/teamcity/instance.sh
@@ -0,0 +1,8 @@
+#!/usr/bin/env bash
+
+# Cause the script to exit if a single command fails
+set -eo pipefail -v
+
+pip install -r requirements/requirements.txt
+
+bash ./bin/tests/check_instance.sh
diff --git a/teamcity/semantic.sh b/teamcity/semantic.sh
index 54b8635..e63a6c5 100644
--- a/teamcity/semantic.sh
+++ b/teamcity/semantic.sh
@@ -8,4 +8,4 @@ pip install -r requirements/requirements.txt
 # @TODO: fix server issue
 pip install torch==1.4.0 torchvision==0.5.0
 
-bash ./bin/tests/_check_semantic.sh
+bash ./bin/tests/check_semantic.sh