diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e883e81 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +checkpoints.zip diff --git a/README.md b/README.md index 938fa20..e19ec40 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # [CVPR 2020] Instance-aware Image Colorization [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ericsujw/InstColorization/blob/master/InstColorization.ipynb) -### [[Paper](https://arxiv.org/abs/2005.10825)] [[Project Website](https://ericsujw.github.io/InstColorization/)] [[Google Colab](https://colab.research.google.com/github/ericsujw/InstColorization/blob/master/InstColorization.ipynb)] +### [[Paper](https://arxiv.org/abs/2005.10825)] [[Project Website](https://ericsujw.github.io/InstColorization/)] [[Google Colab](https://colab.research.google.com/github/ericsujw/InstColorization/blob/master/InstColorization.ipynb)] [[Run on Replicate](https://replicate.ai/ericsujw/instcolorization)]

diff --git a/cog.yaml b/cog.yaml new file mode 100644 index 0000000..1333f23 --- /dev/null +++ b/cog.yaml @@ -0,0 +1,45 @@ +build: + python_version: "3.8" + system_packages: + - "libgl1-mesa-glx" + - "libglib2.0-0" + python_packages: + - "cachetools==4.1.0" + - "chardet==3.0.4" + - "future==0.18.2" + - "fvcore==0.1.dev200506" + - "idna==2.9" + - "importlib-metadata==1.6.0" + - "jsonpatch==1.25" + - "jsonpointer==2.0" + - "markdown==3.2.2" + - "mock==4.0.2" + - "opencv-python==4.3.0.38" + - "portalocker==1.7.0" + - "pyasn1==0.4.8" + - "pyasn1-modules==0.2.8" + - "pydot==1.4.1" + - "requests==2.23.0" + - "requests-oauthlib==1.3.0" + - "rsa==4.0" + - "tabulate==0.8.7" + - "termcolor==1.1.0" + - "urllib3==1.25.8" + - "visdom==0.1.8.9" + - "websocket-client==0.57.0" + - "yacs==0.1.7" + - "zipp==3.1.0" + - "cython==0.29.22" + - "pyyaml==5.1" + - "dominate==2.4.0" + - "detectron2==0.1.2" + - "torch==1.5.0" + - "torchvision==0.6.0" + - "pycocotools==2.0.2" + - "ipython==7.21.0" + - "scikit-image==0.18.1" + python_extra_index_urls: + - "git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI" + python_find_links: + - "https://dl.fbaipublicfiles.com/detectron2/wheels/cpu/index.html" +predict: "predict.py:InstColorizationPredictor" diff --git a/download.py b/download.py index 85ae6a1..794c7ab 100644 --- a/download.py +++ b/download.py @@ -4,6 +4,7 @@ import os from argparse import ArgumentParser + def download_file_from_google_drive(id, destination): URL = "https://docs.google.com/uc?export=download" @@ -45,6 +46,7 @@ def save_response_content(response, destination): destination = 'checkpoints.zip' download_file_from_google_drive(file_id, destination) + elif args.mode == 'cocostuff': print('download cocostuff training dataset') url = "http://images.cocodataset.org/zips/train2017.zip" @@ -52,5 +54,21 @@ def save_response_content(response, destination): if isdir(join(args.dataset_dir, "cocostuff")) is False: os.makedirs(join(args.dataset_dir, "cocostuff")) save_response_content(response, join(args.dataset_dir, "cocostuff", "train.zip")) + +elif args.mode == 'coco-weights': + os.environ["FVCORE_CACHE"] = "checkpoints/fvcore_cache" + + from detectron2 import model_zoo + from detectron2.config import get_cfg + from detectron2.engine import DefaultPredictor + + # Download coco weights + cfg = get_cfg() + cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml")) + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 + cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml") + cfg.MODEL.DEVICE = "cpu" + predictor = DefaultPredictor(cfg) + else: - print('Error Mode!') \ No newline at end of file + print('Error Mode!') diff --git a/models/fusion_model.py b/models/fusion_model.py index 0f6ef6b..b60d2ed 100644 --- a/models/fusion_model.py +++ b/models/fusion_model.py @@ -1,6 +1,7 @@ import os import torch +from torch import nn from collections import OrderedDict from util.image_pool import ImagePool from util import util @@ -29,21 +30,24 @@ def initialize(self, opt): # load/define networks num_in = opt.input_nc + opt.output_nc + 1 - + self.netG = networks.define_G(num_in, opt.output_nc, opt.ngf, 'instance', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, use_tanh=True, classification=False) self.netG.eval() - + self.netG = nn.DataParallel(self.netG) + self.netGF = networks.define_G(num_in, opt.output_nc, opt.ngf, 'fusion', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, use_tanh=True, classification=False) self.netGF.eval() + self.netGF = nn.DataParallel(self.netGF) self.netGComp = networks.define_G(num_in, opt.output_nc, opt.ngf, 'siggraph', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, use_tanh=True, classification=opt.classification) self.netGComp.eval() + self.netGComp = nn.DataParallel(self.netGComp) def set_input(self, input): @@ -51,12 +55,12 @@ def set_input(self, input): self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) self.hint_B = input['hint_B'].to(self.device) - + self.mask_B = input['mask_B'].to(self.device) self.mask_B_nc = self.mask_B + self.opt.mask_cent self.real_B_enc = util.encode_ab_ind(self.real_B[:, :, ::4, ::4], self.opt) - + def set_fusion_input(self, input, box_info): AtoB = self.opt.which_direction == 'AtoB' self.full_real_A = input['A' if AtoB else 'B'].to(self.device) @@ -85,29 +89,32 @@ def set_forward_without_box(self, input): def forward(self): (_, feature_map) = self.netG(self.real_A, self.hint_B, self.mask_B) self.fake_B_reg = self.netGF(self.full_real_A, self.full_hint_B, self.full_mask_B, feature_map, self.box_info_list) - - def save_current_imgs(self, path): - out_img = torch.clamp(util.lab2rgb(torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), self.fake_B_reg.type(torch.cuda.FloatTensor)), dim=1), self.opt), 0.0, 1.0) + + def save_current_imgs(self, path, is_cuda=True): + if is_cuda: + out_img = torch.clamp(util.lab2rgb(torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), self.fake_B_reg.type(torch.cuda.FloatTensor)), dim=1), self.opt), 0.0, 1.0) + else: + out_img = torch.clamp(util.lab2rgb(torch.cat((self.full_real_A.type(torch.FloatTensor), self.fake_B_reg.type(torch.FloatTensor)), dim=1), self.opt), 0.0, 1.0) out_img = np.transpose(out_img.cpu().data.numpy()[0], (1, 2, 0)) io.imsave(path, img_as_ubyte(out_img)) - def setup_to_test(self, fusion_weight_path): + def setup_to_test(self, fusion_weight_path, map_location): GF_path = 'checkpoints/{0}/latest_net_GF.pth'.format(fusion_weight_path) print('load Fusion model from %s' % GF_path) - GF_state_dict = torch.load(GF_path) - + GF_state_dict = torch.load(GF_path, map_location=map_location) + # G_path = 'checkpoints/coco_finetuned_mask_256/latest_net_G.pth' # fine tuned on cocostuff G_path = 'checkpoints/{0}/latest_net_G.pth'.format(fusion_weight_path) - G_state_dict = torch.load(G_path) + G_state_dict = torch.load(G_path, map_location=map_location) # GComp_path = 'checkpoints/siggraph_retrained/latest_net_G.pth' # original net # GComp_path = 'checkpoints/coco_finetuned_mask_256/latest_net_GComp.pth' # fine tuned on cocostuff GComp_path = 'checkpoints/{0}/latest_net_GComp.pth'.format(fusion_weight_path) - GComp_state_dict = torch.load(GComp_path) + GComp_state_dict = torch.load(GComp_path, map_location=map_location) self.netGF.load_state_dict(GF_state_dict, strict=False) self.netG.module.load_state_dict(G_state_dict, strict=False) self.netGComp.module.load_state_dict(GComp_state_dict, strict=False) self.netGF.eval() self.netG.eval() - self.netGComp.eval() \ No newline at end of file + self.netGComp.eval() diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..ab92218 --- /dev/null +++ b/predict.py @@ -0,0 +1,190 @@ +import argparse +import multiprocessing +import os +import shutil +import tempfile +from glob import glob +from os import listdir +from os.path import isfile, join +from pathlib import Path + +os.environ["FVCORE_CACHE"] = "checkpoints/fvcore_cache" + +import cog +import cv2 +import numpy as np +import torch +from detectron2 import model_zoo +from detectron2.config import get_cfg +from detectron2.engine import DefaultPredictor +from detectron2.utils.logger import setup_logger + +from fusion_dataset import Fusion_Testing_Dataset +from models import create_model +from options.train_options import TestOptions +from util import util + + +class InstColorizationPredictor(cog.Predictor): + def setup(self): + self.has_gpu = torch.cuda.is_available() + + setup_logger() + + cfg = get_cfg() + cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml")) + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 + cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml") + + if self.has_gpu: + self.device = torch.device("cuda") + else: + cfg.MODEL.DEVICE = "cpu" + self.device = torch.device("cpu") + + self.predictor = DefaultPredictor(cfg) + + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + multiprocessing.set_start_method('spawn', True) + + torch.backends.cudnn.benchmark = True + + parser = argparse.ArgumentParser() + TestOptions().initialize(parser) + opt = parser.parse_args(["--name", "test_fusion", "--sample_p", "1.0", "--model", "fusion", "--fineSize", "256", "--test_img_dir", "inputs", "--results_img_dir", "results"]) + str_ids = opt.gpu_ids.split(',') + opt.gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + opt.gpu_ids.append(id) + + + if not self.has_gpu: + opt.gpu_ids = [] + + if len(opt.gpu_ids) > 0: + torch.cuda.set_device(opt.gpu_ids[0]) + opt.A = 2 * opt.ab_max / opt.ab_quant + 1 + opt.B = opt.A + opt.isTrain = False + + os.makedirs(opt.test_img_dir, exist_ok=True) + + self.save_img_dir = opt.results_img_dir + if os.path.isdir(self.save_img_dir) is False: + print('Create path: {0}'.format(self.save_img_dir)) + os.makedirs(self.save_img_dir) + self.output_npz_dir = "{0}_bbox".format(opt.test_img_dir) + if os.path.isdir(self.output_npz_dir) is False: + os.makedirs(self.output_npz_dir) + + opt.batch_size = 1 + #dataset = Fusion_Testing_Dataset(opt) + #dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=2) + + #dataset_size = len(dataset) + #print('#Testing images = %d' % dataset_size) + + self.model = create_model(opt) + # model.setup_to_test('coco_finetuned_mask_256') + self.model.setup_to_test('coco_finetuned_mask_256_ffs', map_location=self.device) + + self.opt = opt + + @cog.input("input", type=Path, help="grayscale input image") + def predict(self, input): + output_dir = tempfile.mkdtemp() + color_output_path = Path(os.path.join(output_dir, str(input).split(".")[0] + ".png")) + + try: + input_dir = self.opt.test_img_dir + input_path = os.path.join(input_dir, os.path.basename(input)) + shutil.copy(str(input), input_path) + image_list = [f for f in listdir(input_dir) if isfile(join(input_dir, f))] + + self.get_bounding_boxes(input_dir, image_list) + + dataset = Fusion_Testing_Dataset(self.opt) + dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=0) + + count_empty = 0 + + # Colorize + with torch.no_grad(): + output_paths = [] + input_paths = [] + for data_raw in dataset_loader: + # if os.path.isfile(join(save_img_dir, data_raw['file_id'][0] + '.png')) is True: + # continue + data_raw['full_img'][0] = data_raw['full_img'][0].to(self.device) + if data_raw['empty_box'][0] == 0: + data_raw['cropped_img'][0] = data_raw['cropped_img'][0].to(self.device) + box_info = data_raw['box_info'][0] + box_info_2x = data_raw['box_info_2x'][0] + box_info_4x = data_raw['box_info_4x'][0] + box_info_8x = data_raw['box_info_8x'][0] + cropped_data = util.get_colorization_data(data_raw['cropped_img'], self.opt, ab_thresh=0, p=self.opt.sample_p) + full_img_data = util.get_colorization_data(data_raw['full_img'], self.opt, ab_thresh=0, p=self.opt.sample_p) + self.model.set_input(cropped_data) + self.model.set_fusion_input(full_img_data, [box_info, box_info_2x, box_info_4x, box_info_8x]) + self.model.forward() + else: + count_empty += 1 + full_img_data = util.get_colorization_data(data_raw['full_img'], self.opt, ab_thresh=0, p=self.opt.sample_p) + self.model.set_forward_without_box(full_img_data) + output_path = join(self.save_img_dir, data_raw['file_id'][0] + '.png') + self.model.save_current_imgs(output_path, is_cuda=self.has_gpu) + output_paths.append(output_path) + + input_path = glob(input_dir + "/" + data_raw["file_id"][0] + ".*")[0] + input_paths.append(input_path) + + # Resize + + if len(input_paths) > 1 or len(output_paths) > 1: + print("WARNING: len(input_paths): {len(input_paths)}, len(output_paths): {len(output_paths)}") + + for input_path, output_path in zip(input_paths, output_paths): + input_img = cv2.imread(input_path) + height, width, _ = input_img.shape + output_img = cv2.imread(output_path) + output_img = cv2.resize(output_img, (width, height)) + input_hls = cv2.cvtColor(input_img, cv2.COLOR_BGR2HLS) + output_hls = cv2.cvtColor(output_img, cv2.COLOR_BGR2HLS) + output_hls[:, :, 1] = input_hls[:, :, 1] + output_bgr = cv2.cvtColor(output_hls, cv2.COLOR_HLS2BGR) + + cv2.imwrite(str(color_output_path), output_bgr, [cv2.IMWRITE_PNG_COMPRESSION, 0]) + finally: + self.cleanup() + + return color_output_path + + def get_bounding_boxes(self, input_dir, image_list): + for image_path in image_list: + img = cv2.imread(join(input_dir, image_path)) + lab_image = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) + l_channel, a_channel, b_channel = cv2.split(lab_image) + l_stack = np.stack([l_channel, l_channel, l_channel], axis=2) + outputs = self.predictor(l_stack) + save_path = join(self.output_npz_dir, image_path.split('.')[0]) + pred_bbox = outputs["instances"].pred_boxes.to(torch.device('cpu')).tensor.numpy() + pred_scores = outputs["instances"].scores.cpu().data.numpy() + np.savez(save_path, bbox = pred_bbox, scores = pred_scores) + + def cleanup(self): + clean_folder(self.opt.test_img_dir) + clean_folder(self.output_npz_dir) + clean_folder(self.save_img_dir) + +def clean_folder(folder): + for filename in os.listdir(folder): + file_path = os.path.join(folder, filename) + try: + if os.path.isfile(file_path) or os.path.islink(file_path): + os.unlink(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + except Exception as e: + print('Failed to delete %s. Reason: %s' % (file_path, e)) diff --git a/scripts/download_model.sh b/scripts/download_model.sh index fc4b2b7..c35f47a 100644 --- a/scripts/download_model.sh +++ b/scripts/download_model.sh @@ -1,4 +1,5 @@ echo "Downloading..." python download.py +python download.py --mode coco-weights echo "Finish download." -unzip checkpoints.zip \ No newline at end of file +unzip checkpoints.zip