Skip to content

Commit de024ec

Browse files
kit1980qgallouedec
andauthored
Use weights_only for load (huggingface#1933)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
1 parent 2fbc0f4 commit de024ec

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

examples/scripts/ddpo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def __init__(self, *, dtype, model_id, model_filename):
9393
cached_path = hf_hub_download(model_id, model_filename)
9494
except EntryNotFoundError:
9595
cached_path = os.path.join(model_id, model_filename)
96-
state_dict = torch.load(cached_path, map_location=torch.device("cpu"))
96+
state_dict = torch.load(cached_path, map_location=torch.device("cpu"), weights_only=True)
9797
self.mlp.load_state_dict(state_dict)
9898
self.dtype = dtype
9999
self.eval()

tests/test_peft_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def test_save_pretrained_peft(self):
138138
assert os.path.exists(f"{tmp_dir}/adapter_config.json"), f"{tmp_dir}/adapter_config.json does not exist"
139139
# check also for `pytorch_model.bin` and make sure it only contains `v_head` weights
140140
assert os.path.exists(f"{tmp_dir}/pytorch_model.bin"), f"{tmp_dir}/pytorch_model.bin does not exist"
141-
maybe_v_head = torch.load(f"{tmp_dir}/pytorch_model.bin")
141+
maybe_v_head = torch.load(f"{tmp_dir}/pytorch_model.bin", weights_only=True)
142142
# check that only keys that starts with `v_head` are in the dict
143143
assert all(
144144
k.startswith("v_head") for k in maybe_v_head.keys()

trl/models/auxiliary_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(self, *, dtype, model_id, model_filename):
6060
cached_path = hf_hub_download(model_id, model_filename)
6161
except EntryNotFoundError:
6262
cached_path = os.path.join(model_id, model_filename)
63-
state_dict = torch.load(cached_path, map_location=torch.device("cpu"))
63+
state_dict = torch.load(cached_path, map_location=torch.device("cpu"), weights_only=True)
6464
self.mlp.load_state_dict(state_dict)
6565
self.dtype = dtype
6666
self.eval()

0 commit comments

Comments
 (0)