12
12
import ffmpeg
13
13
import uuid
14
14
import shutil
15
+ import requests
15
16
16
- import torch
17
- import skimage .io as io
18
- import skimage .color as convertor
19
- import torchvision .transforms as transforms
17
+ try :
18
+ import torch
19
+ import skimage .color as convertor
20
+ import torchvision .transforms as transforms
21
+ except ImportError :
22
+ # AI feature disabled
23
+ pass
20
24
21
25
22
26
try :
31
35
from ImageGoNord .utility .quantize import quantize_to_palette
32
36
import ImageGoNord .utility .palette_loader as pl
33
37
from ImageGoNord .utility .ConvertUtility import ConvertUtility
34
- from ImageGoNord .utility .model import FeatureEncoder ,RecoloringDecoder
38
+
39
+ try :
40
+ from ImageGoNord .utility .model import FeatureEncoder ,RecoloringDecoder
41
+ except ImportError :
42
+ # AI feature disabled
43
+ pass
35
44
36
45
37
46
class NordPaletteFile :
@@ -158,6 +167,8 @@ class GoNord(object):
158
167
TRANSPARENCY_TOLERANCE = 190
159
168
MAX_THREADS = 10
160
169
170
+ PALETTE_NET_REPO_FOLDER = 'https://github.com/Schrodinger-Hat/ImageGoNord-pip/raw/master/ImageGoNord/models/PaletteNet/'
171
+
161
172
AVAILABLE_PALETTE = []
162
173
PALETTE_DATA = {}
163
174
@@ -425,6 +436,16 @@ def converted_loop(self, is_rgba, pixels, original_pixels, maxRow, maxCol, minRo
425
436
pixels [row , col ] = tuple (colors_list )
426
437
return pixels
427
438
439
+ def load_and_save_models (self ):
440
+ rd_model = requests .get (self .PALETTE_NET_REPO_FOLDER + 'RD.state_dict.pt' )
441
+ fe_model = requests .get (self .PALETTE_NET_REPO_FOLDER + 'FE.state_dict.pt' )
442
+
443
+ with open (os .path .dirname (palette_net .__file__ ) + '/FE.state_dict.pt' , "wb" ) as f :
444
+ f .write (fe_model .content )
445
+
446
+ with open (os .path .dirname (palette_net .__file__ ) + '/RD.state_dict.pt' , "wb" ) as f :
447
+ f .write (rd_model .content )
448
+
428
449
def convert_image_by_model (self , image , use_model_cpu = False ):
429
450
"""
430
451
Process a Pillow image by using a PyTorch model "PaletteNet" for recoloring the image
@@ -444,8 +465,14 @@ def convert_image_by_model(self, image, use_model_cpu=False):
444
465
FE = FeatureEncoder () # torch.Size([64, 3, 3, 3])
445
466
RD = RecoloringDecoder () # torch.Size([530, 256, 3, 3])
446
467
447
- FE .load_state_dict (torch .load (pkg_resources .open_binary (palette_net , "FE.state_dict.pt" )))
448
- RD .load_state_dict (torch .load (pkg_resources .open_binary (palette_net , "RD.state_dict.pt" )))
468
+ if (
469
+ os .path .exists (os .path .dirname (palette_net .__file__ ) + '/FE.state_dict.pt' )
470
+ and os .path .exists (os .path .dirname (palette_net .__file__ ) + '/RD.state_dict.pt' )
471
+ ):
472
+ FE .load_state_dict (torch .load (pkg_resources .open_binary (palette_net , "FE.state_dict.pt" )))
473
+ RD .load_state_dict (torch .load (pkg_resources .open_binary (palette_net , "RD.state_dict.pt" )))
474
+ else :
475
+ self .load_and_save_models ()
449
476
450
477
if use_model_cpu :
451
478
FE .to ("cpu" )
@@ -472,7 +499,8 @@ def convert_image_by_model(self, image, use_model_cpu=False):
472
499
try :
473
500
pal_np = np .array (palette ).reshape (1 ,6 ,3 )/ 255
474
501
except :
475
- print ("You have too many colors in your palette for the model, this feature is limited to 6 colours, now you have: " , len (palette ), "! I'll take the first 6!" )
502
+ # this feature is limited to 6 colours
503
+ # we're taking the first six
476
504
pal_np = np .array (palette [0 :6 ]).reshape (1 ,6 ,3 )/ 255
477
505
478
506
pal = torch .Tensor ((convertor .rgb2lab (pal_np ) - [50 ,0 ,0 ] ) / [50 ,128 ,128 ]).unsqueeze (0 )
@@ -518,7 +546,10 @@ def convert_image(self, image, save_path='', use_model=False, use_model_cpu=Fals
518
546
is_rgba = (image .mode == 'RGBA' )
519
547
520
548
if use_model :
521
- image = self .convert_image_by_model (image , use_model_cpu )
549
+ if torch != None :
550
+ image = self .convert_image_by_model (image , use_model_cpu )
551
+ else :
552
+ print ('Please install the dependencies required for the AI feature: pip install image-go-nord[AI]' )
522
553
else :
523
554
if not parallel_threading :
524
555
self .converted_loop (is_rgba , pixels , original_pixels , image .size [0 ], image .size [1 ])
0 commit comments