Skip to content

Commit a4ac1d6

Browse files
authored
Merge pull request #245 from jrzaurin/update-dependencies
Adjusted tests to torch 2.6 and higher
2 parents 790e1ec + 4fc4c5f commit a4ac1d6

File tree

4 files changed

+29
-11
lines changed

4 files changed

+29
-11
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ gensim
66
spacy
77
opencv-contrib-python>=4.9.0.80
88
tqdm
9-
torch >= 2.0.0, <2.6.0
9+
torch >= 2.0.0
1010
torchvision >= 0.15.0
1111
einops
1212
wrapt

tests/test_model_functioning/test_miscellaneous.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,9 @@ def test_save_and_load():
255255
trainer.fit(X_wide=X_wide, X_tab=X_tab, target=target, batch_size=16)
256256
wide_weights = model.wide.wide_linear.weight.data
257257
trainer.save("tests/test_model_functioning/model_dir/")
258-
n_model = torch.load("tests/test_model_functioning/model_dir/wd_model.pt")
258+
n_model = torch.load(
259+
"tests/test_model_functioning/model_dir/wd_model.pt", weights_only=False
260+
)
259261
n_wide_weights = n_model.wide.wide_linear.weight.data
260262
assert torch.allclose(wide_weights, n_wide_weights)
261263

tests/test_model_functioning/test_save_optimizer.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ def test_save_one_optimizer(save_state_dict):
6969
model_filename="model_and_optimizer.pt",
7070
)
7171

72-
checkpoint = torch.load(os.path.join(save_path, "model_and_optimizer.pt"))
72+
checkpoint = torch.load(
73+
os.path.join(save_path, "model_and_optimizer.pt"), weights_only=False
74+
)
7375

7476
if save_state_dict:
7577
new_model = WideDeep(wide=wide, deeptabular=tab_mlp)
@@ -83,7 +85,9 @@ def test_save_one_optimizer(save_state_dict):
8385
else:
8486
# This else statement is mostly testing that it runs, as it does not
8587
# involved loading a state_dict
86-
saved_objects = torch.load(os.path.join(save_path, "model_and_optimizer.pt"))
88+
saved_objects = torch.load(
89+
os.path.join(save_path, "model_and_optimizer.pt"), weights_only=False
90+
)
8791
new_model = saved_objects["model"]
8892
new_optimizer = saved_objects["optimizer"]
8993

@@ -123,7 +127,9 @@ def test_save_multiple_optimizers(save_state_dict):
123127
model_filename="model_and_optimizer.pt",
124128
)
125129

126-
checkpoint = torch.load(os.path.join(save_path, "model_and_optimizer.pt"))
130+
checkpoint = torch.load(
131+
os.path.join(save_path, "model_and_optimizer.pt"), weights_only=False
132+
)
127133

128134
if save_state_dict:
129135
new_model = WideDeep(wide=wide, deeptabular=tab_mlp)
@@ -140,7 +146,9 @@ def test_save_multiple_optimizers(save_state_dict):
140146
else:
141147
# This else statement is mostly testing that it runs, as it does not
142148
# involved loading a state_dict
143-
saved_objects = torch.load(os.path.join(save_path, "model_and_optimizer.pt"))
149+
saved_objects = torch.load(
150+
os.path.join(save_path, "model_and_optimizer.pt"), weights_only=False
151+
)
144152
new_model = saved_objects["model"]
145153
new_optimizers = saved_objects["optimizer"]
146154
new_wide_opt = new_optimizers._optimizers["wide"]

tests/test_self_supervised/test_ss_miscellaneous.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ def test_save_and_load(model_type):
110110
save_state_dict=False,
111111
model_filename="ss_model.pt",
112112
)
113-
new_model = torch.load("tests/test_self_supervised/model_dir/ss_model.pt")
113+
new_model = torch.load(
114+
"tests/test_self_supervised/model_dir/ss_model.pt", weights_only=False
115+
)
114116

115117
if model_type == "mlp":
116118
new_col_embed_module = new_model.encoder.cat_embed.embed_layers.emb_layer_col1
@@ -177,7 +179,8 @@ def test_save_model_and_optimizer(model_type, save_state_dict):
177179
)
178180

179181
checkpoint = torch.load(
180-
os.path.join("tests/test_self_supervised/model_dir/", "model_and_optimizer.pt")
182+
os.path.join("tests/test_self_supervised/model_dir/", "model_and_optimizer.pt"),
183+
weights_only=False,
181184
)
182185

183186
if save_state_dict:
@@ -204,7 +207,8 @@ def test_save_model_and_optimizer(model_type, save_state_dict):
204207
saved_objects = torch.load(
205208
os.path.join(
206209
"tests/test_self_supervised/model_dir/", "model_and_optimizer.pt"
207-
)
210+
),
211+
weights_only=False,
208212
)
209213
new_optimizer = saved_objects["optimizer"]
210214

@@ -275,11 +279,15 @@ def test_save_and_load_dict(model_type): # noqa: C901
275279

276280
if model_type == "mlp":
277281
trainer2.ed_model.load_state_dict(
278-
torch.load("tests/test_self_supervised/model_dir/ss_model.pt")
282+
torch.load(
283+
"tests/test_self_supervised/model_dir/ss_model.pt", weights_only=False
284+
)
279285
)
280286
elif model_type == "transformer":
281287
trainer2.cd_model.load_state_dict(
282-
torch.load("tests/test_self_supervised/model_dir/ss_model.pt")
288+
torch.load(
289+
"tests/test_self_supervised/model_dir/ss_model.pt", weights_only=False
290+
)
283291
)
284292

285293
if model_type == "mlp":

0 commit comments

Comments
 (0)