1+ import os
12import torch
23
34from loguru import logger
1314
1415__all__ = ["Model" ]
1516
17+ TRUST_REMOTE_CODE = os .getenv ("TRUST_REMOTE_CODE" , "false" ).lower () in ["true" , "1" ]
1618# Disable gradients
1719torch .set_grad_enabled (False )
1820
@@ -40,7 +42,7 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
4042 device = get_device ()
4143 logger .info (f"backend device: { device } " )
4244
43- config = AutoConfig .from_pretrained (model_path )
45+ config = AutoConfig .from_pretrained (model_path , trust_remote_code = TRUST_REMOTE_CODE )
4446 if config .model_type == "bert" :
4547 config : BertConfig
4648 if (
@@ -51,12 +53,22 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
5153 and FLASH_ATTENTION
5254 ):
5355 if pool != "cls" :
54- return DefaultModel (model_path , device , datatype , pool )
56+ return DefaultModel (
57+ model_path , device , datatype , pool , trust_remote = TRUST_REMOTE_CODE
58+ )
5559 return FlashBert (model_path , device , datatype )
5660 if config .architectures [0 ].endswith ("Classification" ):
57- return ClassificationModel (model_path , device , datatype )
61+ return ClassificationModel (
62+ model_path , device , datatype , trust_remote = TRUST_REMOTE_CODE
63+ )
5864 else :
59- return DefaultModel (model_path , device , datatype , pool )
65+ return DefaultModel (
66+ model_path ,
67+ device ,
68+ datatype ,
69+ pool ,
70+ trust_remote = TRUST_REMOTE_CODE ,
71+ )
6072 else :
6173 if device .type == "hpu" :
6274 from habana_frameworks .torch .hpu import wrap_in_hpu_graph
@@ -66,13 +78,35 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
6678
6779 adapt_transformers_to_gaudi ()
6880 if config .architectures [0 ].endswith ("Classification" ):
69- model_handle = ClassificationModel (model_path , device , datatype )
81+ model_handle = ClassificationModel (
82+ model_path ,
83+ device ,
84+ datatype ,
85+ trust_remote = TRUST_REMOTE_CODE ,
86+ )
7087 else :
71- model_handle = DefaultModel (model_path , device , datatype , pool )
88+ model_handle = DefaultModel (
89+ model_path ,
90+ device ,
91+ datatype ,
92+ pool ,
93+ trust_remote = TRUST_REMOTE_CODE ,
94+ )
7295 model_handle .model = wrap_in_hpu_graph (model_handle .model )
7396 return model_handle
7497 elif use_ipex ():
7598 if config .architectures [0 ].endswith ("Classification" ):
76- return ClassificationModel (model_path , device , datatype )
99+ return ClassificationModel (
100+ model_path ,
101+ device ,
102+ datatype ,
103+ trust_remote = TRUST_REMOTE_CODE ,
104+ )
77105 else :
78- return DefaultModel (model_path , device , datatype , pool )
106+ return DefaultModel (
107+ model_path ,
108+ device ,
109+ datatype ,
110+ pool ,
111+ trust_remote = TRUST_REMOTE_CODE ,
112+ )
0 commit comments