@@ -77,8 +77,11 @@ def load_ckpt(model, optimizer, ckpt_path, load_model=False, load_opt=False, loa
77
77
def load_StudioGAN_ckpts (ckpt_dir , load_best , Gen , Dis , g_optimizer , d_optimizer , run_name , apply_g_ema , Gen_ema , ema ,
78
78
is_train , RUN , logger , global_rank , device , cfg_file ):
79
79
when = "best" if load_best is True else "current"
80
- Gen_ckpt_path = glob .glob (join (ckpt_dir , "model=G-{when}-weights-step*.pth" .format (when = when )))[0 ]
81
- Dis_ckpt_path = glob .glob (join (ckpt_dir , "model=D-{when}-weights-step*.pth" .format (when = when )))[0 ]
80
+ x = join (ckpt_dir , "model=G-{when}-weights-step=" .format (when = when ))
81
+ Gen_ckpt_path = glob .glob (glob .escape (x ) + '*.pth' )[0 ]
82
+ y = join (ckpt_dir , "model=D-{when}-weights-step=" .format (when = when ))
83
+ Dis_ckpt_path = glob .glob (glob .escape (y ) + '*.pth' )[0 ]
84
+
82
85
prev_run_name = torch .load (Dis_ckpt_path , map_location = lambda storage , loc : storage )["run_name" ]
83
86
is_freezeD = True if RUN .freezeD > - 1 else False
84
87
@@ -100,7 +103,8 @@ def load_StudioGAN_ckpts(ckpt_dir, load_best, Gen, Dis, g_optimizer, d_optimizer
100
103
is_freezeD = is_freezeD )
101
104
102
105
if apply_g_ema :
103
- Gen_ema_ckpt_path = glob .glob (join (ckpt_dir , "model=G_ema-{when}-weights-step*.pth" .format (when = when )))[0 ]
106
+ z = join (ckpt_dir , "model=G_ema-{when}-weights-step=" .format (when = when ))
107
+ Gen_ema_ckpt_path = glob .glob (glob .escape (z ) + '*.pth' )[0 ]
104
108
load_ckpt (model = Gen_ema ,
105
109
optimizer = None ,
106
110
ckpt_path = Gen_ema_ckpt_path ,
0 commit comments