13
13
from transformers import AutoTokenizer , AutoModelForCausalLM
14
14
15
15
llm_model_map = {
16
- "llama2_7b " : {
16
+ "meta-llama/Llama-2-7b-chat-hf " : {
17
17
"initializer" : stateless_llama .export_transformer_model ,
18
18
"hf_model_name" : "meta-llama/Llama-2-7b-chat-hf" ,
19
19
"compile_flags" : ["--iree-opt-const-expr-hoisting=False" ],
@@ -258,7 +258,8 @@ def format_out(results):
258
258
259
259
history .append (format_out (token ))
260
260
while (
261
- format_out (token ) != llm_model_map ["llama2_7b" ]["stop_token" ]
261
+ format_out (token )
262
+ != llm_model_map ["meta-llama/Llama-2-7b-chat-hf" ]["stop_token" ]
262
263
and len (history ) < self .max_tokens
263
264
):
264
265
dec_time = time .time ()
@@ -272,7 +273,10 @@ def format_out(results):
272
273
273
274
self .prev_token_len = token_len + len (history )
274
275
275
- if format_out (token ) == llm_model_map ["llama2_7b" ]["stop_token" ]:
276
+ if (
277
+ format_out (token )
278
+ == llm_model_map ["meta-llama/Llama-2-7b-chat-hf" ]["stop_token" ]
279
+ ):
276
280
break
277
281
278
282
for i in range (len (history )):
@@ -306,7 +310,7 @@ def chat_hf(self, prompt):
306
310
self .first_input = False
307
311
308
312
history .append (int (token ))
309
- while token != llm_model_map ["llama2_7b " ]["stop_token" ]:
313
+ while token != llm_model_map ["meta-llama/Llama-2-7b-chat-hf " ]["stop_token" ]:
310
314
dec_time = time .time ()
311
315
result = self .hf_mod (token .reshape ([1 , 1 ]), past_key_values = pkv )
312
316
history .append (int (token ))
@@ -317,7 +321,7 @@ def chat_hf(self, prompt):
317
321
318
322
self .prev_token_len = token_len + len (history )
319
323
320
- if token == llm_model_map ["llama2_7b " ]["stop_token" ]:
324
+ if token == llm_model_map ["meta-llama/Llama-2-7b-chat-hf " ]["stop_token" ]:
321
325
break
322
326
for i in range (len (history )):
323
327
if type (history [i ]) != int :
@@ -347,7 +351,11 @@ def llm_chat_api(InputData: dict):
347
351
else :
348
352
print (f"prompt : { InputData ['prompt' ]} " )
349
353
350
- model_name = InputData ["model" ] if "model" in InputData .keys () else "llama2_7b"
354
+ model_name = (
355
+ InputData ["model" ]
356
+ if "model" in InputData .keys ()
357
+ else "meta-llama/Llama-2-7b-chat-hf"
358
+ )
351
359
model_path = llm_model_map [model_name ]
352
360
device = InputData ["device" ] if "device" in InputData .keys () else "cpu"
353
361
precision = "fp16"
0 commit comments