Skip to content

Commit ed78105

Browse files
authored
Merge pull request #27 from Schrodinger-Hat/feat/25
feat(25): make models optionals
2 parents 3dfa415 + 7041206 commit ed78105

File tree

4 files changed

+51
-19
lines changed

4 files changed

+51
-19
lines changed

ImageGoNord/GoNord.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,15 @@
1212
import ffmpeg
1313
import uuid
1414
import shutil
15+
import requests
1516

16-
import torch
17-
import skimage.io as io
18-
import skimage.color as convertor
19-
import torchvision.transforms as transforms
17+
try:
18+
import torch
19+
import skimage.color as convertor
20+
import torchvision.transforms as transforms
21+
except ImportError:
22+
# AI feature disabled
23+
pass
2024

2125

2226
try:
@@ -31,7 +35,12 @@
3135
from ImageGoNord.utility.quantize import quantize_to_palette
3236
import ImageGoNord.utility.palette_loader as pl
3337
from ImageGoNord.utility.ConvertUtility import ConvertUtility
34-
from ImageGoNord.utility.model import FeatureEncoder,RecoloringDecoder
38+
39+
try:
40+
from ImageGoNord.utility.model import FeatureEncoder,RecoloringDecoder
41+
except ImportError:
42+
# AI feature disabled
43+
pass
3544

3645

3746
class NordPaletteFile:
@@ -158,6 +167,8 @@ class GoNord(object):
158167
TRANSPARENCY_TOLERANCE = 190
159168
MAX_THREADS = 10
160169

170+
PALETTE_NET_REPO_FOLDER = 'https://github.com/Schrodinger-Hat/ImageGoNord-pip/raw/master/ImageGoNord/models/PaletteNet/'
171+
161172
AVAILABLE_PALETTE = []
162173
PALETTE_DATA = {}
163174

@@ -425,6 +436,16 @@ def converted_loop(self, is_rgba, pixels, original_pixels, maxRow, maxCol, minRo
425436
pixels[row, col] = tuple(colors_list)
426437
return pixels
427438

439+
def load_and_save_models(self):
440+
rd_model = requests.get(self.PALETTE_NET_REPO_FOLDER + 'RD.state_dict.pt')
441+
fe_model = requests.get(self.PALETTE_NET_REPO_FOLDER + 'FE.state_dict.pt')
442+
443+
with open(os.path.dirname(palette_net.__file__) + '/FE.state_dict.pt', "wb") as f:
444+
f.write(fe_model.content)
445+
446+
with open(os.path.dirname(palette_net.__file__) + '/RD.state_dict.pt', "wb") as f:
447+
f.write(rd_model.content)
448+
428449
def convert_image_by_model(self, image, use_model_cpu=False):
429450
"""
430451
Process a Pillow image by using a PyTorch model "PaletteNet" for recoloring the image
@@ -444,8 +465,14 @@ def convert_image_by_model(self, image, use_model_cpu=False):
444465
FE = FeatureEncoder() # torch.Size([64, 3, 3, 3])
445466
RD = RecoloringDecoder() # torch.Size([530, 256, 3, 3])
446467

447-
FE.load_state_dict(torch.load(pkg_resources.open_binary(palette_net, "FE.state_dict.pt")))
448-
RD.load_state_dict(torch.load(pkg_resources.open_binary(palette_net, "RD.state_dict.pt")))
468+
if (
469+
os.path.exists(os.path.dirname(palette_net.__file__) + '/FE.state_dict.pt')
470+
and os.path.exists(os.path.dirname(palette_net.__file__) + '/RD.state_dict.pt')
471+
):
472+
FE.load_state_dict(torch.load(pkg_resources.open_binary(palette_net, "FE.state_dict.pt")))
473+
RD.load_state_dict(torch.load(pkg_resources.open_binary(palette_net, "RD.state_dict.pt")))
474+
else:
475+
self.load_and_save_models()
449476

450477
if use_model_cpu:
451478
FE.to("cpu")
@@ -472,7 +499,8 @@ def convert_image_by_model(self, image, use_model_cpu=False):
472499
try:
473500
pal_np = np.array(palette).reshape(1,6,3)/255
474501
except:
475-
print("You have too many colors in your palette for the model, this feature is limited to 6 colours, now you have: ", len(palette), "! I'll take the first 6!")
502+
# this feature is limited to 6 colours
503+
# we're taking the first six
476504
pal_np = np.array(palette[0:6]).reshape(1,6,3)/255
477505

478506
pal = torch.Tensor((convertor.rgb2lab(pal_np) - [50,0,0] ) / [50,128,128]).unsqueeze(0)
@@ -518,7 +546,10 @@ def convert_image(self, image, save_path='', use_model=False, use_model_cpu=Fals
518546
is_rgba = (image.mode == 'RGBA')
519547

520548
if use_model:
521-
image = self.convert_image_by_model(image, use_model_cpu)
549+
if torch != None:
550+
image = self.convert_image_by_model(image, use_model_cpu)
551+
else:
552+
print('Please install the dependencies required for the AI feature: pip install image-go-nord[AI]')
522553
else:
523554
if not parallel_threading:
524555
self.converted_loop(is_rgba, pixels, original_pixels, image.size[0], image.size[1])

ImageGoNord/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# gonord version
2-
__version__ = "1.0.2"
2+
__version__ = "1.1.0"
33

44
from ImageGoNord.GoNord import *

index.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,10 @@
2222
# go_nord.add_file_to_palette(NordPaletteFile.AURORA)
2323
# go_nord.add_file_to_palette(NordPaletteFile.FROST)
2424

25-
# image = go_nord.open_image("images/valley.jpg")
26-
# go_nord.convert_image(image, save_path="images/test-valley-ai.jpg", use_model=True)
27-
28-
output_path = go_nord.convert_video('videos/SampleVideo_720x480.mp4', 'custom_palette', save_path='videos/SampleVideo_converted.mp4')
29-
print(output_path)
25+
image = go_nord.open_image("images/valley.jpg")
26+
go_nord.convert_image(image, save_path="images/test-valley-ai.jpg", use_model=True)
3027
exit()
28+
# output_path = go_nord.convert_video('videos/SampleVideo_720x480.mp4', 'custom_palette', save_path='videos/SampleVideo_converted.mp4')
3129

3230
image = go_nord.open_image("images/test.jpg")
3331
resized_img = go_nord.resize_image(image)

setup.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
setup(
88
name="image-go-nord",
9-
version="1.0.2",
9+
version="1.1.0",
1010
description="A tool to convert any RGB image or video to any theme or color palette input by the user",
1111
long_description=README,
1212
long_description_content_type="text/markdown",
@@ -17,7 +17,7 @@
1717
author_email="schrodinger.hat.show@gmail.com",
1818
license="AGPL-3.0",
1919
classifiers=[
20-
'Development Status :: 5 - Production/Stable', # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
20+
'Development Status :: 5 - Production/Stable',
2121
'Intended Audience :: Developers',
2222
'Topic :: Software Development :: Build Tools',
2323
"License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)",
@@ -30,8 +30,11 @@
3030
"Bug Reports": "https://github.com/Schrodinger-Hat/ImageGoNord-pip/issues",
3131
},
3232
packages=find_packages(),
33-
package_data={'': ['*.txt', 'palettes/*.txt', 'models/*.pt', '*.pt', '*.state_dict.*']},
33+
package_data={'': ['*.txt', 'palettes/*.txt']},
3434
include_package_data=True,
35-
install_requires=["Pillow", "ffmpeg-python", "numpy", "torch", "scikit-image", "torchvision"],
35+
install_requires=["Pillow", "ffmpeg-python", "numpy", "requests"],
36+
extras_require = {
37+
'AI': ["torch", "scikit-image", "torchvision"]
38+
},
3639
python_requires=">=3.5"
3740
)

0 commit comments

Comments
 (0)