Skip to content

Commit d4e1ebe

Browse files
authored
Merge pull request #325 from anth-volk/fix/huggingface-download-error
Set local directory when downloading datasets from Hugging Face
2 parents cd0f149 + 0b8dc71 commit d4e1ebe

File tree

4 files changed

+19
-4
lines changed

4 files changed

+19
-4
lines changed

changelog_entry.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- bump: patch
2+
changes:
3+
changed:
4+
- Explicitly set local directory when downloading datasets from Hugging Face

policyengine_core/data/dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,4 +501,5 @@ def download_from_huggingface(
501501
repo=f"{owner_name}/{model_name}",
502502
repo_filename=file_name,
503503
version=version,
504+
local_dir=self.file_path.parent,
504505
)

policyengine_core/tools/hugging_face.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313

1414

1515
def download_huggingface_dataset(
16-
repo: str, repo_filename: str, version: str = None
16+
repo: str,
17+
repo_filename: str,
18+
version: str = None,
19+
local_dir: str | None = None,
1720
):
1821
"""
1922
Download a dataset from the Hugging Face Hub.
@@ -22,6 +25,7 @@ def download_huggingface_dataset(
2225
repo (str): The Hugging Face repo name, in format "{org}/{repo}".
2326
repo_filename (str): The filename of the dataset.
2427
version (str, optional): The version of the dataset. Defaults to None.
28+
local_dir (str, optional): The local directory to save the dataset to. Defaults to None.
2529
"""
2630
# Attempt connection to Hugging Face model_info endpoint
2731
# (https://huggingface.co/docs/huggingface_hub/v0.26.5/en/package_reference/hf_api#huggingface_hub.HfApi.model_info)
@@ -52,6 +56,7 @@ def download_huggingface_dataset(
5256
filename=repo_filename,
5357
revision=version,
5458
token=authentication_token,
59+
local_dir=local_dir,
5560
)
5661

5762

tests/core/tools/test_hugging_face.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def test_download_public_repo(self):
1515
test_repo = "test_repo"
1616
test_filename = "test_filename"
1717
test_version = "test_version"
18+
test_dir = "test_dir"
1819

1920
with patch(
2021
"policyengine_core.tools.hugging_face.hf_hub_download"
@@ -29,14 +30,15 @@ def test_download_public_repo(self):
2930
)
3031

3132
download_huggingface_dataset(
32-
test_repo, test_filename, test_version
33+
test_repo, test_filename, test_version, test_dir
3334
)
3435

3536
mock_download.assert_called_with(
3637
repo_id=test_repo,
3738
repo_type="model",
3839
filename=test_filename,
3940
revision=test_version,
41+
local_dir=test_dir,
4042
token=None,
4143
)
4244

@@ -45,6 +47,7 @@ def test_download_private_repo(self):
4547
test_repo = "test_repo"
4648
test_filename = "test_filename"
4749
test_version = "test_version"
50+
test_dir = "test_dir"
4851

4952
with patch(
5053
"policyengine_core.tools.hugging_face.hf_hub_download"
@@ -61,21 +64,23 @@ def test_download_private_repo(self):
6164
mock_token.return_value = "test_token"
6265

6366
download_huggingface_dataset(
64-
test_repo, test_filename, test_version
67+
test_repo, test_filename, test_version, test_dir
6568
)
6669
mock_download.assert_called_with(
6770
repo_id=test_repo,
6871
repo_type="model",
6972
filename=test_filename,
7073
revision=test_version,
7174
token=mock_token.return_value,
75+
local_dir=test_dir,
7276
)
7377

7478
def test_download_private_repo_no_token(self):
7579
"""Test handling of private repo with no token"""
7680
test_repo = "test_repo"
7781
test_filename = "test_filename"
7882
test_version = "test_version"
83+
test_dir = "test_dir"
7984

8085
with patch(
8186
"policyengine_core.tools.hugging_face.hf_hub_download"
@@ -93,7 +98,7 @@ def test_download_private_repo_no_token(self):
9398

9499
with pytest.raises(Exception):
95100
download_huggingface_dataset(
96-
test_repo, test_filename, test_version
101+
test_repo, test_filename, test_version, test_dir
97102
)
98103
mock_download.assert_not_called()
99104

0 commit comments

Comments
 (0)