Skip to content

Commit 432a6ff

Browse files
author
Laurent
committed
make refiners.conversion.utils.Hub.expected_sha256 optional
1 parent 1f70f43 commit 432a6ff

File tree

2 files changed

+33
-23
lines changed

2 files changed

+33
-23
lines changed

src/refiners/conversion/utils.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@ def download_file_url(url: str, destination: Path) -> None:
2626
logging.debug(f"Downloading {url} to {destination}")
2727

2828
# get the size of the file
29-
response = requests.get(url, stream=True)
29+
response = requests.get(url=url, stream=True)
3030
response.raise_for_status()
3131
total = int(response.headers.get("content-length", 0))
32+
chunk_size = 1024 * 1000 # 1 MiB
3233

3334
# create a progress bar
3435
bar = tqdm(
@@ -45,7 +46,7 @@ def download_file_url(url: str, destination: Path) -> None:
4546
with destination.open("wb") as f:
4647
with requests.get(url, stream=True) as r:
4748
r.raise_for_status()
48-
for chunk in r.iter_content(chunk_size=1024 * 1000):
49+
for chunk in r.iter_content(chunk_size=chunk_size):
4950
size = f.write(chunk)
5051
bar.update(size)
5152
bar.close()
@@ -63,8 +64,8 @@ def __init__(
6364
self,
6465
repo_id: str,
6566
filename: str,
66-
expected_sha256: str,
6767
revision: str = "main",
68+
expected_sha256: str | None = None,
6869
download_url: str | None = None,
6970
) -> None:
7071
"""Initialize the HubPath.
@@ -73,14 +74,14 @@ def __init__(
7374
repo_id: The repository identifier on the hub.
7475
filename: The filename of the file in the repository.
7576
revision: The revision of the file on the hf hub.
76-
expected_sha256: The sha256 hash of the file.
77+
expected_sha256: The sha256 hash of the file, to optionally check against the local or remote hash.
7778
download_url: The url to download the file from, if not from the huggingface hub.
7879
"""
7980
self.repo_id = repo_id
8081
self.filename = filename
8182
self.revision = revision
82-
self.expected_sha256 = expected_sha256.lower()
83-
self.override_download_url = download_url
83+
self.expected_sha256 = expected_sha256.lower() if expected_sha256 is not None else None
84+
self.download_url = download_url
8485

8586
@staticmethod
8687
def hub_location():
@@ -90,16 +91,22 @@ def hub_location():
9091
@property
9192
def hf_url(self) -> str:
9293
"""Return the url to the file on the hf hub."""
93-
assert self.override_download_url is None, f"{self.repo_id}/{self.filename} is not available on the hub"
94+
assert self.download_url is None, f"{self.repo_id}/{self.filename} is not available on the hub"
9495
return hf_hub_url(
9596
repo_id=self.repo_id,
9697
filename=self.filename,
9798
revision=self.revision,
9899
)
99100

101+
@property
102+
def hf_metadata(self) -> HfFileMetadata:
103+
"""Return the metadata of the file on the hf hub."""
104+
return get_hf_file_metadata(self.hf_url)
105+
100106
@property
101107
def hf_cache_path(self) -> Path:
102108
"""Download the file from the hf hub and return its path in the local hf cache."""
109+
assert self.download_url is None, f"{self.repo_id}/{self.filename} is not available on the hub"
103110
return Path(
104111
hf_hub_download(
105112
repo_id=self.repo_id,
@@ -108,11 +115,6 @@ def hf_cache_path(self) -> Path:
108115
),
109116
)
110117

111-
@property
112-
def hf_metadata(self) -> HfFileMetadata:
113-
"""Return the metadata of the file on the hf hub."""
114-
return get_hf_file_metadata(self.hf_url)
115-
116118
@property
117119
def hf_sha256_hash(self) -> str:
118120
"""Return the sha256 hash of the file on the hf hub."""
@@ -127,24 +129,32 @@ def local_path(self) -> Path:
127129
return self.hub_location() / self.repo_id / self.filename
128130

129131
@property
130-
def local_hash(self) -> str:
132+
def local_sh256_hash(self) -> str:
131133
"""Return the sha256 hash of the file in the local hub."""
132134
assert self.local_path.is_file(), f"{self.local_path} does not exist"
133135
# TODO: use https://docs.python.org/3/library/hashlib.html#hashlib.file_digest when support python >= 3.11
134136
return sha256(self.local_path.read_bytes()).hexdigest().lower()
135137

136138
def check_local_hash(self) -> bool:
137139
"""Check if the sha256 hash of the file in the local hub is correct."""
138-
if self.expected_sha256 != self.local_hash:
139-
logging.warning(f"{self.local_path} local sha256 mismatch, {self.local_hash} != {self.expected_sha256}")
140+
if self.expected_sha256 is None:
141+
logging.warning(f"{self.repo_id}/{self.filename} has no expected sha256 hash, skipping check")
142+
return True
143+
elif self.expected_sha256 != self.local_sh256_hash:
144+
logging.warning(
145+
f"{self.local_path} local sha256 mismatch, {self.local_sh256_hash} != {self.expected_sha256}"
146+
)
140147
return False
141148
else:
142-
logging.debug(f"{self.local_path} local sha256 is correct ({self.local_hash})")
149+
logging.debug(f"{self.local_path} local sha256 is correct ({self.local_sh256_hash})")
143150
return True
144151

145152
def check_remote_hash(self) -> bool:
146153
"""Check if the sha256 hash of the file on the hf hub is correct."""
147-
if self.expected_sha256 != self.hf_sha256_hash:
154+
if self.expected_sha256 is None:
155+
logging.warning(f"{self.repo_id}/{self.filename} has no expected sha256 hash, skipping check")
156+
return True
157+
elif self.expected_sha256 != self.hf_sha256_hash:
148158
logging.warning(
149159
f"{self.local_path} remote sha256 mismatch, {self.hf_sha256_hash} != {self.expected_sha256}"
150160
)
@@ -154,14 +164,14 @@ def check_remote_hash(self) -> bool:
154164
return True
155165

156166
def download(self) -> None:
157-
"""Download the file from the hf hub or from the override download url."""
158-
self.local_path.parent.mkdir(parents=True, exist_ok=True)
167+
"""Download the file from the hf hub or from the override download url, and save it to the local hub."""
159168
if self.local_path.is_file():
160169
logging.warning(f"{self.local_path} already exists")
161-
elif self.override_download_url is not None:
162-
download_file_url(url=self.override_download_url, destination=self.local_path)
170+
elif self.download_url is not None:
171+
self.local_path.parent.mkdir(parents=True, exist_ok=True)
172+
download_file_url(url=self.download_url, destination=self.local_path)
163173
else:
164-
# TODO: pas assez de message de log quand local_path existe pas et que ça vient du hf cache
174+
self.local_path.parent.mkdir(parents=True, exist_ok=True)
165175
self.local_path.symlink_to(self.hf_cache_path)
166176
assert self.check_local_hash()
167177

tests/weight_paths.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def get_path(hub: Hub, use_local_weights: bool) -> Path:
3131
if use_local_weights:
3232
path = hub.local_path
3333
else:
34-
if hub.override_download_url is not None:
34+
if hub.download_url is not None:
3535
pytest.skip(f"{hub.filename} is not available on Hugging Face Hub")
3636

3737
try:

0 commit comments

Comments
 (0)