Skip to content

Commit b933107

Browse files
committed
0.0.2
fix #2 added to pytorch_lightning requirements file added support for nai files
1 parent 493855e commit b933107

File tree

3 files changed

+22
-7
lines changed

3 files changed

+22
-7
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ share/python-wheels/
2828
MANIFEST
2929
media/ori
3030
media/re
31-
31+
tests
3232
# PyInstaller
3333
# Usually these files are written by a python script from a template
3434
# before PyInstaller builds the exe, so as to inject date/other infos into it.

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
torch
2+
pytorch-lightning==1.8.3.post1
23
torchsde==0.2.5
34
safetensors==0.2.5
45
pysimplegui==4.60.4

run_app_gui.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from threading import Thread
88
from datetime import datetime as dt
99
from safetensors.torch import save_file
10-
10+
import hashlib
1111
COLOR_DARK_GREEN = '#78BA04'
1212
COLOR_DARK_BLUE = '#4974a5'
1313
COLOR_RED_ORANGE = '#C13515'
@@ -16,12 +16,15 @@
1616
COLOR_GREEN = '#43CD80'
1717
COLOR_DARK_GREEN2 = '#78BA04'
1818

19+
HASH_START = 0x100000
20+
HASH_LENGTH = 0x10000
21+
1922
_pytorch_file_extensions = {".ckpt"}
2023

2124
file_ext = {("Stable Diffusion", "*.ckpt")}
2225

2326
def main():
24-
ver = '0.0.1'
27+
ver = '0.0.2'
2528
sg.theme('Dark Gray 15')
2629
app_title = f"Safe & Stable - Ckpt2Safetensors Conversion Tool-GUI - Ver {ver}"
2730
pbar_progress_bar_key = 'progress_bar'
@@ -169,16 +172,20 @@ def process_directory(path,idx):
169172
cpbar.progress_bar_custom(idx-1,len(input_directory_path_list),start_time,window,pbar_progress_bar_key,"ckpt")
170173
convert_button_enable()
171174

172-
173175
def convert_to_st(checkpoint_path):
176+
modelhash = model_hash(checkpoint_path)
174177
try:
175178
with torch.no_grad():
176-
weights = torch.load(checkpoint_path)["state_dict"]
179+
weights = torch.load(checkpoint_path, map_location=torch.device('cpu'))
180+
if "state_dict" in weights:
181+
weights = weights["state_dict"]
182+
if "state_dict" in weights:
183+
weights.pop("state_dict")
177184
file_name = f"{os.path.splitext(checkpoint_path)[0]}.safetensors"
178185

179-
print(f'Converting {checkpoint_path} to safetensors.')
186+
print(f'Converting {checkpoint_path} [{modelhash}] to safetensors.')
180187
save_file(weights, file_name)
181-
print(f'Saving {file_name}.')
188+
print(f'Saving {file_name} [{model_hash(file_name)}].')
182189

183190
except Exception as e:
184191
if isinstance(e, (RuntimeError, EOFError)):
@@ -191,6 +198,13 @@ def convert_to_st(checkpoint_path):
191198
print(f'Error: {e}')
192199

193200

201+
def model_hash(filename):
202+
with open(filename, "rb") as file:
203+
m = hashlib.sha256()
204+
file.seek(HASH_START)
205+
m.update(file.read(HASH_LENGTH))
206+
return m.hexdigest()[0:8]
207+
194208
while True:
195209
event, values = window.read()
196210

0 commit comments

Comments
 (0)