@@ -26,9 +26,10 @@ def download_file_url(url: str, destination: Path) -> None:
26
26
logging .debug (f"Downloading { url } to { destination } " )
27
27
28
28
# get the size of the file
29
- response = requests .get (url , stream = True )
29
+ response = requests .get (url = url , stream = True )
30
30
response .raise_for_status ()
31
31
total = int (response .headers .get ("content-length" , 0 ))
32
+ chunk_size = 1024 * 1000 # 1 MiB
32
33
33
34
# create a progress bar
34
35
bar = tqdm (
@@ -45,7 +46,7 @@ def download_file_url(url: str, destination: Path) -> None:
45
46
with destination .open ("wb" ) as f :
46
47
with requests .get (url , stream = True ) as r :
47
48
r .raise_for_status ()
48
- for chunk in r .iter_content (chunk_size = 1024 * 1000 ):
49
+ for chunk in r .iter_content (chunk_size = chunk_size ):
49
50
size = f .write (chunk )
50
51
bar .update (size )
51
52
bar .close ()
@@ -63,8 +64,8 @@ def __init__(
63
64
self ,
64
65
repo_id : str ,
65
66
filename : str ,
66
- expected_sha256 : str ,
67
67
revision : str = "main" ,
68
+ expected_sha256 : str | None = None ,
68
69
download_url : str | None = None ,
69
70
) -> None :
70
71
"""Initialize the HubPath.
@@ -73,14 +74,14 @@ def __init__(
73
74
repo_id: The repository identifier on the hub.
74
75
filename: The filename of the file in the repository.
75
76
revision: The revision of the file on the hf hub.
76
- expected_sha256: The sha256 hash of the file.
77
+ expected_sha256: The sha256 hash of the file, to optionally check against the local or remote hash .
77
78
download_url: The url to download the file from, if not from the huggingface hub.
78
79
"""
79
80
self .repo_id = repo_id
80
81
self .filename = filename
81
82
self .revision = revision
82
- self .expected_sha256 = expected_sha256 .lower ()
83
- self .override_download_url = download_url
83
+ self .expected_sha256 = expected_sha256 .lower () if expected_sha256 is not None else None
84
+ self .download_url = download_url
84
85
85
86
@staticmethod
86
87
def hub_location ():
@@ -90,16 +91,22 @@ def hub_location():
90
91
@property
91
92
def hf_url (self ) -> str :
92
93
"""Return the url to the file on the hf hub."""
93
- assert self .override_download_url is None , f"{ self .repo_id } /{ self .filename } is not available on the hub"
94
+ assert self .download_url is None , f"{ self .repo_id } /{ self .filename } is not available on the hub"
94
95
return hf_hub_url (
95
96
repo_id = self .repo_id ,
96
97
filename = self .filename ,
97
98
revision = self .revision ,
98
99
)
99
100
101
+ @property
102
+ def hf_metadata (self ) -> HfFileMetadata :
103
+ """Return the metadata of the file on the hf hub."""
104
+ return get_hf_file_metadata (self .hf_url )
105
+
100
106
@property
101
107
def hf_cache_path (self ) -> Path :
102
108
"""Download the file from the hf hub and return its path in the local hf cache."""
109
+ assert self .download_url is None , f"{ self .repo_id } /{ self .filename } is not available on the hub"
103
110
return Path (
104
111
hf_hub_download (
105
112
repo_id = self .repo_id ,
@@ -108,11 +115,6 @@ def hf_cache_path(self) -> Path:
108
115
),
109
116
)
110
117
111
- @property
112
- def hf_metadata (self ) -> HfFileMetadata :
113
- """Return the metadata of the file on the hf hub."""
114
- return get_hf_file_metadata (self .hf_url )
115
-
116
118
@property
117
119
def hf_sha256_hash (self ) -> str :
118
120
"""Return the sha256 hash of the file on the hf hub."""
@@ -127,24 +129,32 @@ def local_path(self) -> Path:
127
129
return self .hub_location () / self .repo_id / self .filename
128
130
129
131
@property
130
- def local_hash (self ) -> str :
132
+ def local_sh256_hash (self ) -> str :
131
133
"""Return the sha256 hash of the file in the local hub."""
132
134
assert self .local_path .is_file (), f"{ self .local_path } does not exist"
133
135
# TODO: use https://docs.python.org/3/library/hashlib.html#hashlib.file_digest when support python >= 3.11
134
136
return sha256 (self .local_path .read_bytes ()).hexdigest ().lower ()
135
137
136
138
def check_local_hash (self ) -> bool :
137
139
"""Check if the sha256 hash of the file in the local hub is correct."""
138
- if self .expected_sha256 != self .local_hash :
139
- logging .warning (f"{ self .local_path } local sha256 mismatch, { self .local_hash } != { self .expected_sha256 } " )
140
+ if self .expected_sha256 is None :
141
+ logging .warning (f"{ self .repo_id } /{ self .filename } has no expected sha256 hash, skipping check" )
142
+ return True
143
+ elif self .expected_sha256 != self .local_sh256_hash :
144
+ logging .warning (
145
+ f"{ self .local_path } local sha256 mismatch, { self .local_sh256_hash } != { self .expected_sha256 } "
146
+ )
140
147
return False
141
148
else :
142
- logging .debug (f"{ self .local_path } local sha256 is correct ({ self .local_hash } )" )
149
+ logging .debug (f"{ self .local_path } local sha256 is correct ({ self .local_sh256_hash } )" )
143
150
return True
144
151
145
152
def check_remote_hash (self ) -> bool :
146
153
"""Check if the sha256 hash of the file on the hf hub is correct."""
147
- if self .expected_sha256 != self .hf_sha256_hash :
154
+ if self .expected_sha256 is None :
155
+ logging .warning (f"{ self .repo_id } /{ self .filename } has no expected sha256 hash, skipping check" )
156
+ return True
157
+ elif self .expected_sha256 != self .hf_sha256_hash :
148
158
logging .warning (
149
159
f"{ self .local_path } remote sha256 mismatch, { self .hf_sha256_hash } != { self .expected_sha256 } "
150
160
)
@@ -154,14 +164,14 @@ def check_remote_hash(self) -> bool:
154
164
return True
155
165
156
166
def download (self ) -> None :
157
- """Download the file from the hf hub or from the override download url."""
158
- self .local_path .parent .mkdir (parents = True , exist_ok = True )
167
+ """Download the file from the hf hub or from the override download url, and save it to the local hub."""
159
168
if self .local_path .is_file ():
160
169
logging .warning (f"{ self .local_path } already exists" )
161
- elif self .override_download_url is not None :
162
- download_file_url (url = self .override_download_url , destination = self .local_path )
170
+ elif self .download_url is not None :
171
+ self .local_path .parent .mkdir (parents = True , exist_ok = True )
172
+ download_file_url (url = self .download_url , destination = self .local_path )
163
173
else :
164
- # TODO: pas assez de message de log quand local_path existe pas et que ça vient du hf cache
174
+ self . local_path . parent . mkdir ( parents = True , exist_ok = True )
165
175
self .local_path .symlink_to (self .hf_cache_path )
166
176
assert self .check_local_hash ()
167
177
0 commit comments