110
110
choices = ["vicuna" , "llama2_7b" , "llama2_13b" , "llama2_70b" ],
111
111
help = "Specify which model to run." ,
112
112
)
113
- parser .add_argument (
114
- "--hf_auth_token" ,
115
- type = str ,
116
- default = None ,
117
- help = "Specify your own huggingface authentication tokens for models like Llama2." ,
118
- )
119
113
parser .add_argument (
120
114
"--cache_vicunas" ,
121
115
default = False ,
@@ -460,10 +454,6 @@ def __init__(
460
454
461
455
def get_tokenizer (self ):
462
456
kwargs = {}
463
- if self .model_name == "llama2" :
464
- kwargs = {
465
- "use_auth_token" : "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"
466
- }
467
457
tokenizer = AutoTokenizer .from_pretrained (
468
458
self .hf_model_path ,
469
459
use_fast = False ,
@@ -1217,7 +1207,6 @@ def __init__(
1217
1207
self ,
1218
1208
model_name ,
1219
1209
hf_model_path = "TheBloke/vicuna-7B-1.1-HF" ,
1220
- hf_auth_token : str = None ,
1221
1210
max_num_tokens = 512 ,
1222
1211
device = "cpu" ,
1223
1212
precision = "int8" ,
@@ -1237,17 +1226,12 @@ def __init__(
1237
1226
max_num_tokens ,
1238
1227
extra_args_cmd = extra_args_cmd ,
1239
1228
)
1240
- if "llama2" in self .model_name and hf_auth_token == None :
1241
- raise ValueError (
1242
- "HF auth token required. Pass it using --hf_auth_token flag."
1243
- )
1244
- self .hf_auth_token = hf_auth_token
1245
1229
if self .model_name == "llama2_7b" :
1246
- self .hf_model_path = "meta-llama/Llama -2-7b-chat-hf"
1230
+ self .hf_model_path = "daryl149/llama -2-7b-chat-hf"
1247
1231
elif self .model_name == "llama2_13b" :
1248
- self .hf_model_path = "meta-llama/Llama -2-13b-chat-hf"
1232
+ self .hf_model_path = "daryl149/llama -2-13b-chat-hf"
1249
1233
elif self .model_name == "llama2_70b" :
1250
- self .hf_model_path = "meta-llama/Llama -2-70b-chat-hf"
1234
+ self .hf_model_path = "daryl149/llama -2-70b-chat-hf"
1251
1235
print (f"[DEBUG] hf model name: { self .hf_model_path } " )
1252
1236
self .max_sequence_length = 256
1253
1237
self .device = device
@@ -1276,18 +1260,15 @@ def get_model_path(self, suffix="mlir"):
1276
1260
)
1277
1261
1278
1262
def get_tokenizer (self ):
1279
- kwargs = {"use_auth_token" : self .hf_auth_token }
1280
1263
tokenizer = AutoTokenizer .from_pretrained (
1281
1264
self .hf_model_path ,
1282
1265
use_fast = False ,
1283
- ** kwargs ,
1284
1266
)
1285
1267
return tokenizer
1286
1268
1287
1269
def get_src_model (self ):
1288
1270
kwargs = {
1289
1271
"torch_dtype" : torch .float ,
1290
- "use_auth_token" : self .hf_auth_token ,
1291
1272
}
1292
1273
vicuna_model = AutoModelForCausalLM .from_pretrained (
1293
1274
self .hf_model_path ,
@@ -1460,8 +1441,6 @@ def compile(self):
1460
1441
self .hf_model_path ,
1461
1442
self .precision ,
1462
1443
self .weight_group_size ,
1463
- self .model_name ,
1464
- self .hf_auth_token ,
1465
1444
)
1466
1445
print (f"[DEBUG] generating torchscript graph" )
1467
1446
is_f16 = self .precision in ["fp16" , "int4" ]
@@ -1553,24 +1532,18 @@ def compile(self):
1553
1532
self .hf_model_path ,
1554
1533
self .precision ,
1555
1534
self .weight_group_size ,
1556
- self .model_name ,
1557
- self .hf_auth_token ,
1558
1535
)
1559
1536
elif self .model_name == "llama2_70b" :
1560
1537
model = SecondVicuna70B (
1561
1538
self .hf_model_path ,
1562
1539
self .precision ,
1563
1540
self .weight_group_size ,
1564
- self .model_name ,
1565
- self .hf_auth_token ,
1566
1541
)
1567
1542
else :
1568
1543
model = SecondVicuna7B (
1569
1544
self .hf_model_path ,
1570
1545
self .precision ,
1571
1546
self .weight_group_size ,
1572
- self .model_name ,
1573
- self .hf_auth_token ,
1574
1547
)
1575
1548
print (f"[DEBUG] generating torchscript graph" )
1576
1549
is_f16 = self .precision in ["fp16" , "int4" ]
@@ -1714,7 +1687,6 @@ def generate(self, prompt, cli):
1714
1687
logits = generated_token_op ["logits" ]
1715
1688
pkv = generated_token_op ["past_key_values" ]
1716
1689
detok = generated_token_op ["detok" ]
1717
-
1718
1690
if token == 2 :
1719
1691
break
1720
1692
res_tokens .append (token )
@@ -1809,7 +1781,6 @@ def create_prompt(model_name, history):
1809
1781
)
1810
1782
vic = UnshardedVicuna (
1811
1783
model_name = args .model_name ,
1812
- hf_auth_token = args .hf_auth_token ,
1813
1784
device = args .device ,
1814
1785
precision = args .precision ,
1815
1786
vicuna_mlir_path = vic_mlir_path ,
@@ -1851,9 +1822,9 @@ def create_prompt(model_name, history):
1851
1822
1852
1823
model_list = {
1853
1824
"vicuna" : "vicuna=>TheBloke/vicuna-7B-1.1-HF" ,
1854
- "llama2_7b" : "llama2_7b=>meta-llama/Llama -2-7b-chat-hf" ,
1855
- "llama2_13b" : "llama2_13b=>meta-llama/Llama -2-13b-chat-hf" ,
1856
- "llama2_70b" : "llama2_70b=>meta-llama/Llama -2-70b-chat-hf" ,
1825
+ "llama2_7b" : "llama2_7b=>daryl149/llama -2-7b-chat-hf" ,
1826
+ "llama2_13b" : "llama2_7b=>daryl149/llama -2-13b-chat-hf" ,
1827
+ "llama2_70b" : "llama2_7b=>daryl149/llama -2-70b-chat-hf" ,
1857
1828
}
1858
1829
while True :
1859
1830
# TODO: Add break condition from user input
0 commit comments