Skip to content

Commit 92577ce

Browse files
committed
small fix
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
1 parent 3cd632a commit 92577ce

File tree

1 file changed

+4
-5
lines changed
  • backends/python/server/text_embeddings_server/models

1 file changed

+4
-5
lines changed

backends/python/server/text_embeddings_server/models/__init__.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,11 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
5454
and FLASH_ATTENTION
5555
):
5656
if pool != "cls":
57-
if config.architectures[0].endswith("ForMaskedLM"):
57+
if config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
5858
return MaskedLanguageModel(
5959
model_path,
6060
device,
6161
datatype,
62-
pool,
6362
trust_remote=TRUST_REMOTE_CODE,
6463
)
6564
return DefaultModel(
@@ -70,9 +69,9 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
7069
return ClassificationModel(
7170
model_path, device, datatype, trust_remote=TRUST_REMOTE_CODE
7271
)
73-
elif config.architectures[0].endswith("ForMaskedLM"):
72+
elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
7473
return MaskedLanguageModel(
75-
model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE
74+
model_path, device, datatype, trust_remote=TRUST_REMOTE_CODE
7675
)
7776
else:
7877
return DefaultModel(
@@ -99,7 +98,7 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
9998
)
10099
elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
101100
model_handle = MaskedLanguageModel(
102-
model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE
101+
model_path, device, datatype, trust_remote=TRUST_REMOTE_CODE
103102
)
104103
else:
105104
model_handle = DefaultModel(

0 commit comments

Comments
 (0)