Skip to content

Commit 092d6a3

Browse files
committed
Fix empty list bug caused by '=' (issue= #160)
1 parent 5641739 commit 092d6a3

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

src/utils/ckpt.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,11 @@ def load_ckpt(model, optimizer, ckpt_path, load_model=False, load_opt=False, loa
7777
def load_StudioGAN_ckpts(ckpt_dir, load_best, Gen, Dis, g_optimizer, d_optimizer, run_name, apply_g_ema, Gen_ema, ema,
7878
is_train, RUN, logger, global_rank, device, cfg_file):
7979
when = "best" if load_best is True else "current"
80-
Gen_ckpt_path = glob.glob(join(ckpt_dir, "model=G-{when}-weights-step*.pth".format(when=when)))[0]
81-
Dis_ckpt_path = glob.glob(join(ckpt_dir, "model=D-{when}-weights-step*.pth".format(when=when)))[0]
80+
x = join(ckpt_dir, "model=G-{when}-weights-step=".format(when=when))
81+
Gen_ckpt_path = glob.glob(glob.escape(x) + '*.pth')[0]
82+
y = join(ckpt_dir, "model=D-{when}-weights-step=".format(when=when))
83+
Dis_ckpt_path = glob.glob(glob.escape(y) + '*.pth')[0]
84+
8285
prev_run_name = torch.load(Dis_ckpt_path, map_location=lambda storage, loc: storage)["run_name"]
8386
is_freezeD = True if RUN.freezeD > -1 else False
8487

@@ -100,7 +103,8 @@ def load_StudioGAN_ckpts(ckpt_dir, load_best, Gen, Dis, g_optimizer, d_optimizer
100103
is_freezeD=is_freezeD)
101104

102105
if apply_g_ema:
103-
Gen_ema_ckpt_path = glob.glob(join(ckpt_dir, "model=G_ema-{when}-weights-step*.pth".format(when=when)))[0]
106+
z = join(ckpt_dir, "model=G_ema-{when}-weights-step=".format(when=when))
107+
Gen_ema_ckpt_path = glob.glob(glob.escape(z) + '*.pth')[0]
104108
load_ckpt(model=Gen_ema,
105109
optimizer=None,
106110
ckpt_path=Gen_ema_ckpt_path,

0 commit comments

Comments
 (0)