Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 32 additions & 22 deletions src/refiners/conversion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ def download_file_url(url: str, destination: Path) -> None:
logging.debug(f"Downloading {url} to {destination}")

# get the size of the file
response = requests.get(url, stream=True)
response = requests.get(url=url, stream=True)
response.raise_for_status()
total = int(response.headers.get("content-length", 0))
chunk_size = 1024 * 1000 # 1 MiB

# create a progress bar
bar = tqdm(
Expand All @@ -45,7 +46,7 @@ def download_file_url(url: str, destination: Path) -> None:
with destination.open("wb") as f:
with requests.get(url, stream=True) as r:
r.raise_for_status()
for chunk in r.iter_content(chunk_size=1024 * 1000):
for chunk in r.iter_content(chunk_size=chunk_size):
size = f.write(chunk)
bar.update(size)
bar.close()
Expand All @@ -63,8 +64,8 @@ def __init__(
self,
repo_id: str,
filename: str,
expected_sha256: str,
revision: str = "main",
expected_sha256: str | None = None,
download_url: str | None = None,
) -> None:
"""Initialize the HubPath.
Expand All @@ -73,14 +74,14 @@ def __init__(
repo_id: The repository identifier on the hub.
filename: The filename of the file in the repository.
revision: The revision of the file on the hf hub.
expected_sha256: The sha256 hash of the file.
expected_sha256: The sha256 hash of the file, to optionally (but strongly recommended) check against the local or remote hash.
download_url: The url to download the file from, if not from the huggingface hub.
"""
self.repo_id = repo_id
self.filename = filename
self.revision = revision
self.expected_sha256 = expected_sha256.lower()
self.override_download_url = download_url
self.expected_sha256 = expected_sha256.lower() if expected_sha256 is not None else None
self.download_url = download_url

@staticmethod
def hub_location():
Expand All @@ -90,16 +91,22 @@ def hub_location():
@property
def hf_url(self) -> str:
"""Return the url to the file on the hf hub."""
assert self.override_download_url is None, f"{self.repo_id}/{self.filename} is not available on the hub"
assert self.download_url is None, f"{self.repo_id}/{self.filename} is not available on the hub"
return hf_hub_url(
repo_id=self.repo_id,
filename=self.filename,
revision=self.revision,
)

@property
def hf_metadata(self) -> HfFileMetadata:
"""Return the metadata of the file on the hf hub."""
return get_hf_file_metadata(self.hf_url)

@property
def hf_cache_path(self) -> Path:
"""Download the file from the hf hub and return its path in the local hf cache."""
assert self.download_url is None, f"{self.repo_id}/{self.filename} is not available on the hub"
return Path(
hf_hub_download(
repo_id=self.repo_id,
Expand All @@ -108,11 +115,6 @@ def hf_cache_path(self) -> Path:
),
)

@property
def hf_metadata(self) -> HfFileMetadata:
"""Return the metadata of the file on the hf hub."""
return get_hf_file_metadata(self.hf_url)

@property
def hf_sha256_hash(self) -> str:
"""Return the sha256 hash of the file on the hf hub."""
Expand All @@ -127,24 +129,32 @@ def local_path(self) -> Path:
return self.hub_location() / self.repo_id / self.filename

@property
def local_hash(self) -> str:
def local_sha256_hash(self) -> str:
"""Return the sha256 hash of the file in the local hub."""
assert self.local_path.is_file(), f"{self.local_path} does not exist"
# TODO: use https://docs.python.org/3/library/hashlib.html#hashlib.file_digest when support python >= 3.11
return sha256(self.local_path.read_bytes()).hexdigest().lower()

def check_local_hash(self) -> bool:
"""Check if the sha256 hash of the file in the local hub is correct."""
if self.expected_sha256 != self.local_hash:
logging.warning(f"{self.local_path} local sha256 mismatch, {self.local_hash} != {self.expected_sha256}")
if self.expected_sha256 is None:
logging.warning(f"{self.repo_id}/{self.filename} has no expected sha256 hash, skipping check")
return True
elif self.expected_sha256 != self.local_sha256_hash:
logging.warning(
f"{self.local_path} local sha256 mismatch, {self.local_sha256_hash} != {self.expected_sha256}"
)
return False
else:
logging.debug(f"{self.local_path} local sha256 is correct ({self.local_hash})")
logging.debug(f"{self.local_path} local sha256 is correct ({self.local_sha256_hash})")
return True

def check_remote_hash(self) -> bool:
"""Check if the sha256 hash of the file on the hf hub is correct."""
if self.expected_sha256 != self.hf_sha256_hash:
if self.expected_sha256 is None:
logging.warning(f"{self.repo_id}/{self.filename} has no expected sha256 hash, skipping check")
return True
elif self.expected_sha256 != self.hf_sha256_hash:
logging.warning(
f"{self.local_path} remote sha256 mismatch, {self.hf_sha256_hash} != {self.expected_sha256}"
)
Expand All @@ -154,14 +164,14 @@ def check_remote_hash(self) -> bool:
return True

def download(self) -> None:
"""Download the file from the hf hub or from the override download url."""
self.local_path.parent.mkdir(parents=True, exist_ok=True)
"""Download the file from the hf hub or from the override download url, and save it to the local hub."""
if self.local_path.is_file():
logging.warning(f"{self.local_path} already exists")
elif self.override_download_url is not None:
download_file_url(url=self.override_download_url, destination=self.local_path)
elif self.download_url is not None:
self.local_path.parent.mkdir(parents=True, exist_ok=True)
download_file_url(url=self.download_url, destination=self.local_path)
else:
# TODO: pas assez de message de log quand local_path existe pas et que ça vient du hf cache
self.local_path.parent.mkdir(parents=True, exist_ok=True)
self.local_path.symlink_to(self.hf_cache_path)
assert self.check_local_hash()

Expand Down
2 changes: 1 addition & 1 deletion src/refiners/foundationals/segment_anything/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def mask_decoder(self) -> MaskDecoder:

@no_grad()
def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
"""Compute the emmbedding of an image.
"""Compute the embedding of an image.

Args:
image: The image to compute the embedding of.
Expand Down
2 changes: 1 addition & 1 deletion src/refiners/foundationals/swin/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def forward(self, x: Tensor):

class WindowAttention(fl.Chain):
"""
Window-based Multi-head Self-Attenion (W-MSA), optionally shifted (SW-MSA).
Window-based Multi-head Self-Attention (W-MSA), optionally shifted (SW-MSA).

It has a trainable relative position bias (RelativePositionBias).

Expand Down
2 changes: 1 addition & 1 deletion tests/weight_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_path(hub: Hub, use_local_weights: bool) -> Path:
if use_local_weights:
path = hub.local_path
else:
if hub.override_download_url is not None:
if hub.download_url is not None:
pytest.skip(f"{hub.filename} is not available on Hugging Face Hub")

try:
Expand Down