7
7
from threading import Thread
8
8
from datetime import datetime as dt
9
9
from safetensors .torch import save_file
10
-
10
+ import hashlib
11
11
COLOR_DARK_GREEN = '#78BA04'
12
12
COLOR_DARK_BLUE = '#4974a5'
13
13
COLOR_RED_ORANGE = '#C13515'
16
16
COLOR_GREEN = '#43CD80'
17
17
COLOR_DARK_GREEN2 = '#78BA04'
18
18
19
+ HASH_START = 0x100000
20
+ HASH_LENGTH = 0x10000
21
+
19
22
_pytorch_file_extensions = {".ckpt" }
20
23
21
24
file_ext = {("Stable Diffusion" , "*.ckpt" )}
22
25
23
26
def main ():
24
- ver = '0.0.1 '
27
+ ver = '0.0.2 '
25
28
sg .theme ('Dark Gray 15' )
26
29
app_title = f"Safe & Stable - Ckpt2Safetensors Conversion Tool-GUI - Ver { ver } "
27
30
pbar_progress_bar_key = 'progress_bar'
@@ -169,16 +172,20 @@ def process_directory(path,idx):
169
172
cpbar .progress_bar_custom (idx - 1 ,len (input_directory_path_list ),start_time ,window ,pbar_progress_bar_key ,"ckpt" )
170
173
convert_button_enable ()
171
174
172
-
173
175
def convert_to_st (checkpoint_path ):
176
+ modelhash = model_hash (checkpoint_path )
174
177
try :
175
178
with torch .no_grad ():
176
- weights = torch .load (checkpoint_path )["state_dict" ]
179
+ weights = torch .load (checkpoint_path , map_location = torch .device ('cpu' ))
180
+ if "state_dict" in weights :
181
+ weights = weights ["state_dict" ]
182
+ if "state_dict" in weights :
183
+ weights .pop ("state_dict" )
177
184
file_name = f"{ os .path .splitext (checkpoint_path )[0 ]} .safetensors"
178
185
179
- print (f'Converting { checkpoint_path } to safetensors.' )
186
+ print (f'Converting { checkpoint_path } [ { modelhash } ] to safetensors.' )
180
187
save_file (weights , file_name )
181
- print (f'Saving { file_name } .' )
188
+ print (f'Saving { file_name } [ { model_hash ( file_name ) } ] .' )
182
189
183
190
except Exception as e :
184
191
if isinstance (e , (RuntimeError , EOFError )):
@@ -191,6 +198,13 @@ def convert_to_st(checkpoint_path):
191
198
print (f'Error: { e } ' )
192
199
193
200
201
+ def model_hash (filename ):
202
+ with open (filename , "rb" ) as file :
203
+ m = hashlib .sha256 ()
204
+ file .seek (HASH_START )
205
+ m .update (file .read (HASH_LENGTH ))
206
+ return m .hexdigest ()[0 :8 ]
207
+
194
208
while True :
195
209
event , values = window .read ()
196
210
0 commit comments