88
99from text_embeddings_server .models .model import Model
1010from text_embeddings_server .models .default_model import DefaultModel
11+ from text_embeddings_server .models .classification_model import ClassificationModel
1112from text_embeddings_server .utils .device import get_device , use_ipex
1213
1314__all__ = ["Model" ]
@@ -43,18 +44,19 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
4344 if config .model_type == "bert" :
4445 config : BertConfig
4546 if (
46- device .type == "cuda"
47+ use_ipex ()
48+ or device .type in ["cuda" , "hpu" ]
4749 and config .position_embedding_type == "absolute"
4850 and datatype in [torch .float16 , torch .bfloat16 ]
4951 and FLASH_ATTENTION
5052 ):
5153 if pool != "cls" :
52- raise ValueError ( "FlashBert only supports cls pooling" )
53- return FlashBert (model_path , device , datatype ) # type: ignore
54- if use_ipex () or device . type == "hpu" :
55- return FlashBert (model_path , device , datatype ) # type: ignore
56-
57- return DefaultModel (model_path , device , datatype )
54+ return DefaultModel ( model_path , device , datatype , pool )
55+ return FlashBert (model_path , device , datatype )
56+ if config . architectures [ 0 ]. endswith ( "Classification" ) :
57+ return ClassificationModel (model_path , device , datatype )
58+ else :
59+ return DefaultModel (model_path , device , datatype , pool )
5860 else :
5961 if device .type == "hpu" :
6062 from habana_frameworks .torch .hpu import wrap_in_hpu_graph
@@ -63,7 +65,14 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
6365 )
6466
6567 adapt_transformers_to_gaudi ()
66- model_handle = DefaultModel (model_path , device , datatype )
68+ if config .architectures [0 ].endswith ("Classification" ):
69+ model_handle = ClassificationModel (model_path , device , datatype )
70+ else :
71+ model_handle = DefaultModel (model_path , device , datatype , pool )
6772 model_handle .model = wrap_in_hpu_graph (model_handle .model )
6873 return model_handle
69- return DefaultModel (model_path , device , datatype )
74+ elif use_ipex ():
75+ if config .architectures [0 ].endswith ("Classification" ):
76+ return ClassificationModel (model_path , device , datatype )
77+ else :
78+ return DefaultModel (model_path , device , datatype , pool )
0 commit comments