Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
checkpoints.zip
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# [CVPR 2020] Instance-aware Image Colorization
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ericsujw/InstColorization/blob/master/InstColorization.ipynb)

### [[Paper](https://arxiv.org/abs/2005.10825)] [[Project Website](https://ericsujw.github.io/InstColorization/)] [[Google Colab](https://colab.research.google.com/github/ericsujw/InstColorization/blob/master/InstColorization.ipynb)]
### [[Paper](https://arxiv.org/abs/2005.10825)] [[Project Website](https://ericsujw.github.io/InstColorization/)] [[Google Colab](https://colab.research.google.com/github/ericsujw/InstColorization/blob/master/InstColorization.ipynb)] [[Run on Replicate](https://replicate.ai/ericsujw/instcolorization)]

<p align='center'>
<img src='imgs/teaser.png' width=1000>
Expand Down
45 changes: 45 additions & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
build:
python_version: "3.8"
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
python_packages:
- "cachetools==4.1.0"
- "chardet==3.0.4"
- "future==0.18.2"
- "fvcore==0.1.dev200506"
- "idna==2.9"
- "importlib-metadata==1.6.0"
- "jsonpatch==1.25"
- "jsonpointer==2.0"
- "markdown==3.2.2"
- "mock==4.0.2"
- "opencv-python==4.3.0.38"
- "portalocker==1.7.0"
- "pyasn1==0.4.8"
- "pyasn1-modules==0.2.8"
- "pydot==1.4.1"
- "requests==2.23.0"
- "requests-oauthlib==1.3.0"
- "rsa==4.0"
- "tabulate==0.8.7"
- "termcolor==1.1.0"
- "urllib3==1.25.8"
- "visdom==0.1.8.9"
- "websocket-client==0.57.0"
- "yacs==0.1.7"
- "zipp==3.1.0"
- "cython==0.29.22"
- "pyyaml==5.1"
- "dominate==2.4.0"
- "detectron2==0.1.2"
- "torch==1.5.0"
- "torchvision==0.6.0"
- "pycocotools==2.0.2"
- "ipython==7.21.0"
- "scikit-image==0.18.1"
python_extra_index_urls:
- "git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI"
python_find_links:
- "https://dl.fbaipublicfiles.com/detectron2/wheels/cpu/index.html"
predict: "predict.py:InstColorizationPredictor"
20 changes: 19 additions & 1 deletion download.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
from argparse import ArgumentParser


def download_file_from_google_drive(id, destination):
URL = "https://docs.google.com/uc?export=download"

Expand Down Expand Up @@ -45,12 +46,29 @@ def save_response_content(response, destination):
destination = 'checkpoints.zip'
download_file_from_google_drive(file_id, destination)


elif args.mode == 'cocostuff':
print('download cocostuff training dataset')
url = "http://images.cocodataset.org/zips/train2017.zip"
response = requests.get(url, stream = True)
if isdir(join(args.dataset_dir, "cocostuff")) is False:
os.makedirs(join(args.dataset_dir, "cocostuff"))
save_response_content(response, join(args.dataset_dir, "cocostuff", "train.zip"))

elif args.mode == 'coco-weights':
os.environ["FVCORE_CACHE"] = "checkpoints/fvcore_cache"

from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor

# Download coco weights
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml")
cfg.MODEL.DEVICE = "cpu"
predictor = DefaultPredictor(cfg)

else:
print('Error Mode!')
print('Error Mode!')
33 changes: 20 additions & 13 deletions models/fusion_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import torch
from torch import nn
from collections import OrderedDict
from util.image_pool import ImagePool
from util import util
Expand Down Expand Up @@ -29,34 +30,37 @@ def initialize(self, opt):

# load/define networks
num_in = opt.input_nc + opt.output_nc + 1

self.netG = networks.define_G(num_in, opt.output_nc, opt.ngf,
'instance', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids,
use_tanh=True, classification=False)
self.netG.eval()

self.netG = nn.DataParallel(self.netG)

self.netGF = networks.define_G(num_in, opt.output_nc, opt.ngf,
'fusion', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids,
use_tanh=True, classification=False)
self.netGF.eval()
self.netGF = nn.DataParallel(self.netGF)

self.netGComp = networks.define_G(num_in, opt.output_nc, opt.ngf,
'siggraph', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids,
use_tanh=True, classification=opt.classification)
self.netGComp.eval()
self.netGComp = nn.DataParallel(self.netGComp)


def set_input(self, input):
AtoB = self.opt.which_direction == 'AtoB'
self.real_A = input['A' if AtoB else 'B'].to(self.device)
self.real_B = input['B' if AtoB else 'A'].to(self.device)
self.hint_B = input['hint_B'].to(self.device)

self.mask_B = input['mask_B'].to(self.device)
self.mask_B_nc = self.mask_B + self.opt.mask_cent

self.real_B_enc = util.encode_ab_ind(self.real_B[:, :, ::4, ::4], self.opt)

def set_fusion_input(self, input, box_info):
AtoB = self.opt.which_direction == 'AtoB'
self.full_real_A = input['A' if AtoB else 'B'].to(self.device)
Expand Down Expand Up @@ -85,29 +89,32 @@ def set_forward_without_box(self, input):
def forward(self):
(_, feature_map) = self.netG(self.real_A, self.hint_B, self.mask_B)
self.fake_B_reg = self.netGF(self.full_real_A, self.full_hint_B, self.full_mask_B, feature_map, self.box_info_list)

def save_current_imgs(self, path):
out_img = torch.clamp(util.lab2rgb(torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), self.fake_B_reg.type(torch.cuda.FloatTensor)), dim=1), self.opt), 0.0, 1.0)

def save_current_imgs(self, path, is_cuda=True):
if is_cuda:
out_img = torch.clamp(util.lab2rgb(torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), self.fake_B_reg.type(torch.cuda.FloatTensor)), dim=1), self.opt), 0.0, 1.0)
else:
out_img = torch.clamp(util.lab2rgb(torch.cat((self.full_real_A.type(torch.FloatTensor), self.fake_B_reg.type(torch.FloatTensor)), dim=1), self.opt), 0.0, 1.0)
out_img = np.transpose(out_img.cpu().data.numpy()[0], (1, 2, 0))
io.imsave(path, img_as_ubyte(out_img))

def setup_to_test(self, fusion_weight_path):
def setup_to_test(self, fusion_weight_path, map_location):
GF_path = 'checkpoints/{0}/latest_net_GF.pth'.format(fusion_weight_path)
print('load Fusion model from %s' % GF_path)
GF_state_dict = torch.load(GF_path)
GF_state_dict = torch.load(GF_path, map_location=map_location)

# G_path = 'checkpoints/coco_finetuned_mask_256/latest_net_G.pth' # fine tuned on cocostuff
G_path = 'checkpoints/{0}/latest_net_G.pth'.format(fusion_weight_path)
G_state_dict = torch.load(G_path)
G_state_dict = torch.load(G_path, map_location=map_location)

# GComp_path = 'checkpoints/siggraph_retrained/latest_net_G.pth' # original net
# GComp_path = 'checkpoints/coco_finetuned_mask_256/latest_net_GComp.pth' # fine tuned on cocostuff
GComp_path = 'checkpoints/{0}/latest_net_GComp.pth'.format(fusion_weight_path)
GComp_state_dict = torch.load(GComp_path)
GComp_state_dict = torch.load(GComp_path, map_location=map_location)

self.netGF.load_state_dict(GF_state_dict, strict=False)
self.netG.module.load_state_dict(G_state_dict, strict=False)
self.netGComp.module.load_state_dict(GComp_state_dict, strict=False)
self.netGF.eval()
self.netG.eval()
self.netGComp.eval()
self.netGComp.eval()
190 changes: 190 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import argparse
import multiprocessing
import os
import shutil
import tempfile
from glob import glob
from os import listdir
from os.path import isfile, join
from pathlib import Path

os.environ["FVCORE_CACHE"] = "checkpoints/fvcore_cache"

import cog
import cv2
import numpy as np
import torch
from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2.utils.logger import setup_logger

from fusion_dataset import Fusion_Testing_Dataset
from models import create_model
from options.train_options import TestOptions
from util import util


class InstColorizationPredictor(cog.Predictor):
def setup(self):
self.has_gpu = torch.cuda.is_available()

setup_logger()

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml")

if self.has_gpu:
self.device = torch.device("cuda")
else:
cfg.MODEL.DEVICE = "cpu"
self.device = torch.device("cpu")

self.predictor = DefaultPredictor(cfg)

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
multiprocessing.set_start_method('spawn', True)

torch.backends.cudnn.benchmark = True

parser = argparse.ArgumentParser()
TestOptions().initialize(parser)
opt = parser.parse_args(["--name", "test_fusion", "--sample_p", "1.0", "--model", "fusion", "--fineSize", "256", "--test_img_dir", "inputs", "--results_img_dir", "results"])
str_ids = opt.gpu_ids.split(',')
opt.gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
opt.gpu_ids.append(id)


if not self.has_gpu:
opt.gpu_ids = []

if len(opt.gpu_ids) > 0:
torch.cuda.set_device(opt.gpu_ids[0])
opt.A = 2 * opt.ab_max / opt.ab_quant + 1
opt.B = opt.A
opt.isTrain = False

os.makedirs(opt.test_img_dir, exist_ok=True)

self.save_img_dir = opt.results_img_dir
if os.path.isdir(self.save_img_dir) is False:
print('Create path: {0}'.format(self.save_img_dir))
os.makedirs(self.save_img_dir)
self.output_npz_dir = "{0}_bbox".format(opt.test_img_dir)
if os.path.isdir(self.output_npz_dir) is False:
os.makedirs(self.output_npz_dir)

opt.batch_size = 1
#dataset = Fusion_Testing_Dataset(opt)
#dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=2)

#dataset_size = len(dataset)
#print('#Testing images = %d' % dataset_size)

self.model = create_model(opt)
# model.setup_to_test('coco_finetuned_mask_256')
self.model.setup_to_test('coco_finetuned_mask_256_ffs', map_location=self.device)

self.opt = opt

@cog.input("input", type=Path, help="grayscale input image")
def predict(self, input):
output_dir = tempfile.mkdtemp()
color_output_path = Path(os.path.join(output_dir, str(input).split(".")[0] + ".png"))

try:
input_dir = self.opt.test_img_dir
input_path = os.path.join(input_dir, os.path.basename(input))
shutil.copy(str(input), input_path)
image_list = [f for f in listdir(input_dir) if isfile(join(input_dir, f))]

self.get_bounding_boxes(input_dir, image_list)

dataset = Fusion_Testing_Dataset(self.opt)
dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=0)

count_empty = 0

# Colorize
with torch.no_grad():
output_paths = []
input_paths = []
for data_raw in dataset_loader:
# if os.path.isfile(join(save_img_dir, data_raw['file_id'][0] + '.png')) is True:
# continue
data_raw['full_img'][0] = data_raw['full_img'][0].to(self.device)
if data_raw['empty_box'][0] == 0:
data_raw['cropped_img'][0] = data_raw['cropped_img'][0].to(self.device)
box_info = data_raw['box_info'][0]
box_info_2x = data_raw['box_info_2x'][0]
box_info_4x = data_raw['box_info_4x'][0]
box_info_8x = data_raw['box_info_8x'][0]
cropped_data = util.get_colorization_data(data_raw['cropped_img'], self.opt, ab_thresh=0, p=self.opt.sample_p)
full_img_data = util.get_colorization_data(data_raw['full_img'], self.opt, ab_thresh=0, p=self.opt.sample_p)
self.model.set_input(cropped_data)
self.model.set_fusion_input(full_img_data, [box_info, box_info_2x, box_info_4x, box_info_8x])
self.model.forward()
else:
count_empty += 1
full_img_data = util.get_colorization_data(data_raw['full_img'], self.opt, ab_thresh=0, p=self.opt.sample_p)
self.model.set_forward_without_box(full_img_data)
output_path = join(self.save_img_dir, data_raw['file_id'][0] + '.png')
self.model.save_current_imgs(output_path, is_cuda=self.has_gpu)
output_paths.append(output_path)

input_path = glob(input_dir + "/" + data_raw["file_id"][0] + ".*")[0]
input_paths.append(input_path)

# Resize

if len(input_paths) > 1 or len(output_paths) > 1:
print("WARNING: len(input_paths): {len(input_paths)}, len(output_paths): {len(output_paths)}")

for input_path, output_path in zip(input_paths, output_paths):
input_img = cv2.imread(input_path)
height, width, _ = input_img.shape
output_img = cv2.imread(output_path)
output_img = cv2.resize(output_img, (width, height))
input_hls = cv2.cvtColor(input_img, cv2.COLOR_BGR2HLS)
output_hls = cv2.cvtColor(output_img, cv2.COLOR_BGR2HLS)
output_hls[:, :, 1] = input_hls[:, :, 1]
output_bgr = cv2.cvtColor(output_hls, cv2.COLOR_HLS2BGR)

cv2.imwrite(str(color_output_path), output_bgr, [cv2.IMWRITE_PNG_COMPRESSION, 0])
finally:
self.cleanup()

return color_output_path

def get_bounding_boxes(self, input_dir, image_list):
for image_path in image_list:
img = cv2.imread(join(input_dir, image_path))
lab_image = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
l_channel, a_channel, b_channel = cv2.split(lab_image)
l_stack = np.stack([l_channel, l_channel, l_channel], axis=2)
outputs = self.predictor(l_stack)
save_path = join(self.output_npz_dir, image_path.split('.')[0])
pred_bbox = outputs["instances"].pred_boxes.to(torch.device('cpu')).tensor.numpy()
pred_scores = outputs["instances"].scores.cpu().data.numpy()
np.savez(save_path, bbox = pred_bbox, scores = pred_scores)

def cleanup(self):
clean_folder(self.opt.test_img_dir)
clean_folder(self.output_npz_dir)
clean_folder(self.save_img_dir)

def clean_folder(folder):
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print('Failed to delete %s. Reason: %s' % (file_path, e))
3 changes: 2 additions & 1 deletion scripts/download_model.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
echo "Downloading..."
python download.py
python download.py --mode coco-weights
echo "Finish download."
unzip checkpoints.zip
unzip checkpoints.zip