Skip to content

Commit 3c62491

Browse files
committed
fix: fix loading of local models
1 parent 8cebf2e commit 3c62491

File tree

3 files changed

+21
-6
lines changed

3 files changed

+21
-6
lines changed

src/wtpsplit_lite/_config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,11 @@ def num_labels(self) -> int:
3838

3939
@classmethod
4040
@cache
41-
def from_pretrained(cls, pretrained_model_name_or_path: str) -> "SubwordXLMConfig":
42-
is_local = Path.is_file(Path(pretrained_model_name_or_path))
41+
def from_pretrained(cls, pretrained_model_name_or_path: str | Path) -> "SubwordXLMConfig":
42+
model_path = Path(pretrained_model_name_or_path)
43+
is_local = model_path.is_dir() and (model_path / "config.json").is_file()
4344
model_config_filepath = Path(
44-
pretrained_model_name_or_path
45+
(model_path / "config.json")
4546
if is_local
4647
else hf_hub_download(pretrained_model_name_or_path, "config.json")
4748
)

src/wtpsplit_lite/_tokenizer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,13 @@ def __call__(
6464

6565
@classmethod
6666
@cache
67-
def from_pretrained(cls, pretrained_model_name_or_path: str) -> "XLMRobertaTokenizerFast":
68-
is_local = Path.is_file(Path(pretrained_model_name_or_path))
67+
def from_pretrained(
68+
cls, pretrained_model_name_or_path: str | Path
69+
) -> "XLMRobertaTokenizerFast":
70+
model_path = Path(pretrained_model_name_or_path)
71+
is_local = model_path.is_dir() and (model_path / "tokenizer.json").is_file()
6972
tokenizer_filepath = Path(
70-
pretrained_model_name_or_path
73+
(model_path / "tokenizer.json")
7174
if is_local
7275
else hf_hub_download(pretrained_model_name_or_path, "tokenizer.json")
7376
)

tests/test_sat.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pathlib import Path
55

66
import pytest
7+
from huggingface_hub import hf_hub_download
78
from wtpsplit import SaT as SaTOriginal
89

910
from wtpsplit_lite import SaT as SaTLite
@@ -103,3 +104,13 @@ def test_weighting(sat: SaTLite, text: str) -> None:
103104
output_lite = sat.split(text, stride=128, block_size=256, weighting="hat")
104105
reconstructed_text_lite = "".join(output_lite)
105106
assert text == reconstructed_text_lite
107+
108+
109+
def test_local() -> None:
110+
"""Test loading a local model."""
111+
model_filepath = hf_hub_download("segment-any-text/sat-3l-sm", filename="model_optimized.onnx")
112+
model_dir = Path(model_filepath).parent
113+
sat_lite = SaTLite(model_dir)
114+
text = "This is a test This is another test."
115+
output_lite = sat_lite.split(text)
116+
assert output_lite == ["This is a test ", "This is another test."]

0 commit comments

Comments
 (0)