diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e883e81 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +checkpoints.zip diff --git a/README.md b/README.md index 938fa20..e19ec40 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # [CVPR 2020] Instance-aware Image Colorization [](https://colab.research.google.com/github/ericsujw/InstColorization/blob/master/InstColorization.ipynb) -### [[Paper](https://arxiv.org/abs/2005.10825)] [[Project Website](https://ericsujw.github.io/InstColorization/)] [[Google Colab](https://colab.research.google.com/github/ericsujw/InstColorization/blob/master/InstColorization.ipynb)] +### [[Paper](https://arxiv.org/abs/2005.10825)] [[Project Website](https://ericsujw.github.io/InstColorization/)] [[Google Colab](https://colab.research.google.com/github/ericsujw/InstColorization/blob/master/InstColorization.ipynb)] [[Run on Replicate](https://replicate.ai/ericsujw/instcolorization)]
diff --git a/cog.yaml b/cog.yaml
new file mode 100644
index 0000000..1333f23
--- /dev/null
+++ b/cog.yaml
@@ -0,0 +1,45 @@
+build:
+ python_version: "3.8"
+ system_packages:
+ - "libgl1-mesa-glx"
+ - "libglib2.0-0"
+ python_packages:
+ - "cachetools==4.1.0"
+ - "chardet==3.0.4"
+ - "future==0.18.2"
+ - "fvcore==0.1.dev200506"
+ - "idna==2.9"
+ - "importlib-metadata==1.6.0"
+ - "jsonpatch==1.25"
+ - "jsonpointer==2.0"
+ - "markdown==3.2.2"
+ - "mock==4.0.2"
+ - "opencv-python==4.3.0.38"
+ - "portalocker==1.7.0"
+ - "pyasn1==0.4.8"
+ - "pyasn1-modules==0.2.8"
+ - "pydot==1.4.1"
+ - "requests==2.23.0"
+ - "requests-oauthlib==1.3.0"
+ - "rsa==4.0"
+ - "tabulate==0.8.7"
+ - "termcolor==1.1.0"
+ - "urllib3==1.25.8"
+ - "visdom==0.1.8.9"
+ - "websocket-client==0.57.0"
+ - "yacs==0.1.7"
+ - "zipp==3.1.0"
+ - "cython==0.29.22"
+ - "pyyaml==5.1"
+ - "dominate==2.4.0"
+ - "detectron2==0.1.2"
+ - "torch==1.5.0"
+ - "torchvision==0.6.0"
+ - "pycocotools==2.0.2"
+ - "ipython==7.21.0"
+ - "scikit-image==0.18.1"
+ python_extra_index_urls:
+ - "git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI"
+ python_find_links:
+ - "https://dl.fbaipublicfiles.com/detectron2/wheels/cpu/index.html"
+predict: "predict.py:InstColorizationPredictor"
diff --git a/download.py b/download.py
index 85ae6a1..794c7ab 100644
--- a/download.py
+++ b/download.py
@@ -4,6 +4,7 @@
import os
from argparse import ArgumentParser
+
def download_file_from_google_drive(id, destination):
URL = "https://docs.google.com/uc?export=download"
@@ -45,6 +46,7 @@ def save_response_content(response, destination):
destination = 'checkpoints.zip'
download_file_from_google_drive(file_id, destination)
+
elif args.mode == 'cocostuff':
print('download cocostuff training dataset')
url = "http://images.cocodataset.org/zips/train2017.zip"
@@ -52,5 +54,21 @@ def save_response_content(response, destination):
if isdir(join(args.dataset_dir, "cocostuff")) is False:
os.makedirs(join(args.dataset_dir, "cocostuff"))
save_response_content(response, join(args.dataset_dir, "cocostuff", "train.zip"))
+
+elif args.mode == 'coco-weights':
+ os.environ["FVCORE_CACHE"] = "checkpoints/fvcore_cache"
+
+ from detectron2 import model_zoo
+ from detectron2.config import get_cfg
+ from detectron2.engine import DefaultPredictor
+
+ # Download coco weights
+ cfg = get_cfg()
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml"))
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7
+ cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml")
+ cfg.MODEL.DEVICE = "cpu"
+ predictor = DefaultPredictor(cfg)
+
else:
- print('Error Mode!')
\ No newline at end of file
+ print('Error Mode!')
diff --git a/models/fusion_model.py b/models/fusion_model.py
index 0f6ef6b..b60d2ed 100644
--- a/models/fusion_model.py
+++ b/models/fusion_model.py
@@ -1,6 +1,7 @@
import os
import torch
+from torch import nn
from collections import OrderedDict
from util.image_pool import ImagePool
from util import util
@@ -29,21 +30,24 @@ def initialize(self, opt):
# load/define networks
num_in = opt.input_nc + opt.output_nc + 1
-
+
self.netG = networks.define_G(num_in, opt.output_nc, opt.ngf,
'instance', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids,
use_tanh=True, classification=False)
self.netG.eval()
-
+ self.netG = nn.DataParallel(self.netG)
+
self.netGF = networks.define_G(num_in, opt.output_nc, opt.ngf,
'fusion', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids,
use_tanh=True, classification=False)
self.netGF.eval()
+ self.netGF = nn.DataParallel(self.netGF)
self.netGComp = networks.define_G(num_in, opt.output_nc, opt.ngf,
'siggraph', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids,
use_tanh=True, classification=opt.classification)
self.netGComp.eval()
+ self.netGComp = nn.DataParallel(self.netGComp)
def set_input(self, input):
@@ -51,12 +55,12 @@ def set_input(self, input):
self.real_A = input['A' if AtoB else 'B'].to(self.device)
self.real_B = input['B' if AtoB else 'A'].to(self.device)
self.hint_B = input['hint_B'].to(self.device)
-
+
self.mask_B = input['mask_B'].to(self.device)
self.mask_B_nc = self.mask_B + self.opt.mask_cent
self.real_B_enc = util.encode_ab_ind(self.real_B[:, :, ::4, ::4], self.opt)
-
+
def set_fusion_input(self, input, box_info):
AtoB = self.opt.which_direction == 'AtoB'
self.full_real_A = input['A' if AtoB else 'B'].to(self.device)
@@ -85,29 +89,32 @@ def set_forward_without_box(self, input):
def forward(self):
(_, feature_map) = self.netG(self.real_A, self.hint_B, self.mask_B)
self.fake_B_reg = self.netGF(self.full_real_A, self.full_hint_B, self.full_mask_B, feature_map, self.box_info_list)
-
- def save_current_imgs(self, path):
- out_img = torch.clamp(util.lab2rgb(torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), self.fake_B_reg.type(torch.cuda.FloatTensor)), dim=1), self.opt), 0.0, 1.0)
+
+ def save_current_imgs(self, path, is_cuda=True):
+ if is_cuda:
+ out_img = torch.clamp(util.lab2rgb(torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), self.fake_B_reg.type(torch.cuda.FloatTensor)), dim=1), self.opt), 0.0, 1.0)
+ else:
+ out_img = torch.clamp(util.lab2rgb(torch.cat((self.full_real_A.type(torch.FloatTensor), self.fake_B_reg.type(torch.FloatTensor)), dim=1), self.opt), 0.0, 1.0)
out_img = np.transpose(out_img.cpu().data.numpy()[0], (1, 2, 0))
io.imsave(path, img_as_ubyte(out_img))
- def setup_to_test(self, fusion_weight_path):
+ def setup_to_test(self, fusion_weight_path, map_location):
GF_path = 'checkpoints/{0}/latest_net_GF.pth'.format(fusion_weight_path)
print('load Fusion model from %s' % GF_path)
- GF_state_dict = torch.load(GF_path)
-
+ GF_state_dict = torch.load(GF_path, map_location=map_location)
+
# G_path = 'checkpoints/coco_finetuned_mask_256/latest_net_G.pth' # fine tuned on cocostuff
G_path = 'checkpoints/{0}/latest_net_G.pth'.format(fusion_weight_path)
- G_state_dict = torch.load(G_path)
+ G_state_dict = torch.load(G_path, map_location=map_location)
# GComp_path = 'checkpoints/siggraph_retrained/latest_net_G.pth' # original net
# GComp_path = 'checkpoints/coco_finetuned_mask_256/latest_net_GComp.pth' # fine tuned on cocostuff
GComp_path = 'checkpoints/{0}/latest_net_GComp.pth'.format(fusion_weight_path)
- GComp_state_dict = torch.load(GComp_path)
+ GComp_state_dict = torch.load(GComp_path, map_location=map_location)
self.netGF.load_state_dict(GF_state_dict, strict=False)
self.netG.module.load_state_dict(G_state_dict, strict=False)
self.netGComp.module.load_state_dict(GComp_state_dict, strict=False)
self.netGF.eval()
self.netG.eval()
- self.netGComp.eval()
\ No newline at end of file
+ self.netGComp.eval()
diff --git a/predict.py b/predict.py
new file mode 100644
index 0000000..ab92218
--- /dev/null
+++ b/predict.py
@@ -0,0 +1,190 @@
+import argparse
+import multiprocessing
+import os
+import shutil
+import tempfile
+from glob import glob
+from os import listdir
+from os.path import isfile, join
+from pathlib import Path
+
+os.environ["FVCORE_CACHE"] = "checkpoints/fvcore_cache"
+
+import cog
+import cv2
+import numpy as np
+import torch
+from detectron2 import model_zoo
+from detectron2.config import get_cfg
+from detectron2.engine import DefaultPredictor
+from detectron2.utils.logger import setup_logger
+
+from fusion_dataset import Fusion_Testing_Dataset
+from models import create_model
+from options.train_options import TestOptions
+from util import util
+
+
+class InstColorizationPredictor(cog.Predictor):
+ def setup(self):
+ self.has_gpu = torch.cuda.is_available()
+
+ setup_logger()
+
+ cfg = get_cfg()
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml"))
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7
+ cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml")
+
+ if self.has_gpu:
+ self.device = torch.device("cuda")
+ else:
+ cfg.MODEL.DEVICE = "cpu"
+ self.device = torch.device("cpu")
+
+ self.predictor = DefaultPredictor(cfg)
+
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
+ multiprocessing.set_start_method('spawn', True)
+
+ torch.backends.cudnn.benchmark = True
+
+ parser = argparse.ArgumentParser()
+ TestOptions().initialize(parser)
+ opt = parser.parse_args(["--name", "test_fusion", "--sample_p", "1.0", "--model", "fusion", "--fineSize", "256", "--test_img_dir", "inputs", "--results_img_dir", "results"])
+ str_ids = opt.gpu_ids.split(',')
+ opt.gpu_ids = []
+ for str_id in str_ids:
+ id = int(str_id)
+ if id >= 0:
+ opt.gpu_ids.append(id)
+
+
+ if not self.has_gpu:
+ opt.gpu_ids = []
+
+ if len(opt.gpu_ids) > 0:
+ torch.cuda.set_device(opt.gpu_ids[0])
+ opt.A = 2 * opt.ab_max / opt.ab_quant + 1
+ opt.B = opt.A
+ opt.isTrain = False
+
+ os.makedirs(opt.test_img_dir, exist_ok=True)
+
+ self.save_img_dir = opt.results_img_dir
+ if os.path.isdir(self.save_img_dir) is False:
+ print('Create path: {0}'.format(self.save_img_dir))
+ os.makedirs(self.save_img_dir)
+ self.output_npz_dir = "{0}_bbox".format(opt.test_img_dir)
+ if os.path.isdir(self.output_npz_dir) is False:
+ os.makedirs(self.output_npz_dir)
+
+ opt.batch_size = 1
+ #dataset = Fusion_Testing_Dataset(opt)
+ #dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=2)
+
+ #dataset_size = len(dataset)
+ #print('#Testing images = %d' % dataset_size)
+
+ self.model = create_model(opt)
+ # model.setup_to_test('coco_finetuned_mask_256')
+ self.model.setup_to_test('coco_finetuned_mask_256_ffs', map_location=self.device)
+
+ self.opt = opt
+
+ @cog.input("input", type=Path, help="grayscale input image")
+ def predict(self, input):
+ output_dir = tempfile.mkdtemp()
+ color_output_path = Path(os.path.join(output_dir, str(input).split(".")[0] + ".png"))
+
+ try:
+ input_dir = self.opt.test_img_dir
+ input_path = os.path.join(input_dir, os.path.basename(input))
+ shutil.copy(str(input), input_path)
+ image_list = [f for f in listdir(input_dir) if isfile(join(input_dir, f))]
+
+ self.get_bounding_boxes(input_dir, image_list)
+
+ dataset = Fusion_Testing_Dataset(self.opt)
+ dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=0)
+
+ count_empty = 0
+
+ # Colorize
+ with torch.no_grad():
+ output_paths = []
+ input_paths = []
+ for data_raw in dataset_loader:
+ # if os.path.isfile(join(save_img_dir, data_raw['file_id'][0] + '.png')) is True:
+ # continue
+ data_raw['full_img'][0] = data_raw['full_img'][0].to(self.device)
+ if data_raw['empty_box'][0] == 0:
+ data_raw['cropped_img'][0] = data_raw['cropped_img'][0].to(self.device)
+ box_info = data_raw['box_info'][0]
+ box_info_2x = data_raw['box_info_2x'][0]
+ box_info_4x = data_raw['box_info_4x'][0]
+ box_info_8x = data_raw['box_info_8x'][0]
+ cropped_data = util.get_colorization_data(data_raw['cropped_img'], self.opt, ab_thresh=0, p=self.opt.sample_p)
+ full_img_data = util.get_colorization_data(data_raw['full_img'], self.opt, ab_thresh=0, p=self.opt.sample_p)
+ self.model.set_input(cropped_data)
+ self.model.set_fusion_input(full_img_data, [box_info, box_info_2x, box_info_4x, box_info_8x])
+ self.model.forward()
+ else:
+ count_empty += 1
+ full_img_data = util.get_colorization_data(data_raw['full_img'], self.opt, ab_thresh=0, p=self.opt.sample_p)
+ self.model.set_forward_without_box(full_img_data)
+ output_path = join(self.save_img_dir, data_raw['file_id'][0] + '.png')
+ self.model.save_current_imgs(output_path, is_cuda=self.has_gpu)
+ output_paths.append(output_path)
+
+ input_path = glob(input_dir + "/" + data_raw["file_id"][0] + ".*")[0]
+ input_paths.append(input_path)
+
+ # Resize
+
+ if len(input_paths) > 1 or len(output_paths) > 1:
+ print("WARNING: len(input_paths): {len(input_paths)}, len(output_paths): {len(output_paths)}")
+
+ for input_path, output_path in zip(input_paths, output_paths):
+ input_img = cv2.imread(input_path)
+ height, width, _ = input_img.shape
+ output_img = cv2.imread(output_path)
+ output_img = cv2.resize(output_img, (width, height))
+ input_hls = cv2.cvtColor(input_img, cv2.COLOR_BGR2HLS)
+ output_hls = cv2.cvtColor(output_img, cv2.COLOR_BGR2HLS)
+ output_hls[:, :, 1] = input_hls[:, :, 1]
+ output_bgr = cv2.cvtColor(output_hls, cv2.COLOR_HLS2BGR)
+
+ cv2.imwrite(str(color_output_path), output_bgr, [cv2.IMWRITE_PNG_COMPRESSION, 0])
+ finally:
+ self.cleanup()
+
+ return color_output_path
+
+ def get_bounding_boxes(self, input_dir, image_list):
+ for image_path in image_list:
+ img = cv2.imread(join(input_dir, image_path))
+ lab_image = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
+ l_channel, a_channel, b_channel = cv2.split(lab_image)
+ l_stack = np.stack([l_channel, l_channel, l_channel], axis=2)
+ outputs = self.predictor(l_stack)
+ save_path = join(self.output_npz_dir, image_path.split('.')[0])
+ pred_bbox = outputs["instances"].pred_boxes.to(torch.device('cpu')).tensor.numpy()
+ pred_scores = outputs["instances"].scores.cpu().data.numpy()
+ np.savez(save_path, bbox = pred_bbox, scores = pred_scores)
+
+ def cleanup(self):
+ clean_folder(self.opt.test_img_dir)
+ clean_folder(self.output_npz_dir)
+ clean_folder(self.save_img_dir)
+
+def clean_folder(folder):
+ for filename in os.listdir(folder):
+ file_path = os.path.join(folder, filename)
+ try:
+ if os.path.isfile(file_path) or os.path.islink(file_path):
+ os.unlink(file_path)
+ elif os.path.isdir(file_path):
+ shutil.rmtree(file_path)
+ except Exception as e:
+ print('Failed to delete %s. Reason: %s' % (file_path, e))
diff --git a/scripts/download_model.sh b/scripts/download_model.sh
index fc4b2b7..c35f47a 100644
--- a/scripts/download_model.sh
+++ b/scripts/download_model.sh
@@ -1,4 +1,5 @@
echo "Downloading..."
python download.py
+python download.py --mode coco-weights
echo "Finish download."
-unzip checkpoints.zip
\ No newline at end of file
+unzip checkpoints.zip