@@ -54,12 +54,11 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
5454 and FLASH_ATTENTION
5555 ):
5656 if pool != "cls" :
57- if config .architectures [0 ].endswith ("ForMaskedLM" ):
57+ if config .architectures [0 ].endswith ("ForMaskedLM" ) and pool == "splade" :
5858 return MaskedLanguageModel (
5959 model_path ,
6060 device ,
6161 datatype ,
62- pool ,
6362 trust_remote = TRUST_REMOTE_CODE ,
6463 )
6564 return DefaultModel (
@@ -70,9 +69,9 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
7069 return ClassificationModel (
7170 model_path , device , datatype , trust_remote = TRUST_REMOTE_CODE
7271 )
73- elif config .architectures [0 ].endswith ("ForMaskedLM" ):
72+ elif config .architectures [0 ].endswith ("ForMaskedLM" ) and pool == "splade" :
7473 return MaskedLanguageModel (
75- model_path , device , datatype , pool , trust_remote = TRUST_REMOTE_CODE
74+ model_path , device , datatype , trust_remote = TRUST_REMOTE_CODE
7675 )
7776 else :
7877 return DefaultModel (
@@ -99,7 +98,7 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
9998 )
10099 elif config .architectures [0 ].endswith ("ForMaskedLM" ) and pool == "splade" :
101100 model_handle = MaskedLanguageModel (
102- model_path , device , datatype , pool , trust_remote = TRUST_REMOTE_CODE
101+ model_path , device , datatype , trust_remote = TRUST_REMOTE_CODE
103102 )
104103 else :
105104 model_handle = DefaultModel (
0 commit comments