Skip to content

Commit 1745c8f

Browse files
Remove hf_auth_token use
-- This commit removes `--hf_auth_token` uses from vicuna.py. -- It adds llama2 models based on daryl49's HF. Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
1 parent bde63ee commit 1745c8f

File tree

4 files changed

+9
-56
lines changed

4 files changed

+9
-56
lines changed

apps/language_models/scripts/vicuna.py

Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,6 @@
110110
choices=["vicuna", "llama2_7b", "llama2_13b", "llama2_70b"],
111111
help="Specify which model to run.",
112112
)
113-
parser.add_argument(
114-
"--hf_auth_token",
115-
type=str,
116-
default=None,
117-
help="Specify your own huggingface authentication tokens for models like Llama2.",
118-
)
119113
parser.add_argument(
120114
"--cache_vicunas",
121115
default=False,
@@ -460,10 +454,6 @@ def __init__(
460454

461455
def get_tokenizer(self):
462456
kwargs = {}
463-
if self.model_name == "llama2":
464-
kwargs = {
465-
"use_auth_token": "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"
466-
}
467457
tokenizer = AutoTokenizer.from_pretrained(
468458
self.hf_model_path,
469459
use_fast=False,
@@ -1217,7 +1207,6 @@ def __init__(
12171207
self,
12181208
model_name,
12191209
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
1220-
hf_auth_token: str = None,
12211210
max_num_tokens=512,
12221211
device="cpu",
12231212
precision="int8",
@@ -1237,17 +1226,12 @@ def __init__(
12371226
max_num_tokens,
12381227
extra_args_cmd=extra_args_cmd,
12391228
)
1240-
if "llama2" in self.model_name and hf_auth_token == None:
1241-
raise ValueError(
1242-
"HF auth token required. Pass it using --hf_auth_token flag."
1243-
)
1244-
self.hf_auth_token = hf_auth_token
12451229
if self.model_name == "llama2_7b":
1246-
self.hf_model_path = "meta-llama/Llama-2-7b-chat-hf"
1230+
self.hf_model_path = "daryl149/llama-2-7b-chat-hf"
12471231
elif self.model_name == "llama2_13b":
1248-
self.hf_model_path = "meta-llama/Llama-2-13b-chat-hf"
1232+
self.hf_model_path = "daryl149/llama-2-13b-chat-hf"
12491233
elif self.model_name == "llama2_70b":
1250-
self.hf_model_path = "meta-llama/Llama-2-70b-chat-hf"
1234+
self.hf_model_path = "daryl149/llama-2-70b-chat-hf"
12511235
print(f"[DEBUG] hf model name: {self.hf_model_path}")
12521236
self.max_sequence_length = 256
12531237
self.device = device
@@ -1276,18 +1260,15 @@ def get_model_path(self, suffix="mlir"):
12761260
)
12771261

12781262
def get_tokenizer(self):
1279-
kwargs = {"use_auth_token": self.hf_auth_token}
12801263
tokenizer = AutoTokenizer.from_pretrained(
12811264
self.hf_model_path,
12821265
use_fast=False,
1283-
**kwargs,
12841266
)
12851267
return tokenizer
12861268

12871269
def get_src_model(self):
12881270
kwargs = {
12891271
"torch_dtype": torch.float,
1290-
"use_auth_token": self.hf_auth_token,
12911272
}
12921273
vicuna_model = AutoModelForCausalLM.from_pretrained(
12931274
self.hf_model_path,
@@ -1460,8 +1441,6 @@ def compile(self):
14601441
self.hf_model_path,
14611442
self.precision,
14621443
self.weight_group_size,
1463-
self.model_name,
1464-
self.hf_auth_token,
14651444
)
14661445
print(f"[DEBUG] generating torchscript graph")
14671446
is_f16 = self.precision in ["fp16", "int4"]
@@ -1553,24 +1532,18 @@ def compile(self):
15531532
self.hf_model_path,
15541533
self.precision,
15551534
self.weight_group_size,
1556-
self.model_name,
1557-
self.hf_auth_token,
15581535
)
15591536
elif self.model_name == "llama2_70b":
15601537
model = SecondVicuna70B(
15611538
self.hf_model_path,
15621539
self.precision,
15631540
self.weight_group_size,
1564-
self.model_name,
1565-
self.hf_auth_token,
15661541
)
15671542
else:
15681543
model = SecondVicuna7B(
15691544
self.hf_model_path,
15701545
self.precision,
15711546
self.weight_group_size,
1572-
self.model_name,
1573-
self.hf_auth_token,
15741547
)
15751548
print(f"[DEBUG] generating torchscript graph")
15761549
is_f16 = self.precision in ["fp16", "int4"]
@@ -1714,7 +1687,6 @@ def generate(self, prompt, cli):
17141687
logits = generated_token_op["logits"]
17151688
pkv = generated_token_op["past_key_values"]
17161689
detok = generated_token_op["detok"]
1717-
17181690
if token == 2:
17191691
break
17201692
res_tokens.append(token)
@@ -1809,7 +1781,6 @@ def create_prompt(model_name, history):
18091781
)
18101782
vic = UnshardedVicuna(
18111783
model_name=args.model_name,
1812-
hf_auth_token=args.hf_auth_token,
18131784
device=args.device,
18141785
precision=args.precision,
18151786
vicuna_mlir_path=vic_mlir_path,
@@ -1851,9 +1822,9 @@ def create_prompt(model_name, history):
18511822

18521823
model_list = {
18531824
"vicuna": "vicuna=>TheBloke/vicuna-7B-1.1-HF",
1854-
"llama2_7b": "llama2_7b=>meta-llama/Llama-2-7b-chat-hf",
1855-
"llama2_13b": "llama2_13b=>meta-llama/Llama-2-13b-chat-hf",
1856-
"llama2_70b": "llama2_70b=>meta-llama/Llama-2-70b-chat-hf",
1825+
"llama2_7b": "llama2_7b=>daryl149/llama-2-7b-chat-hf",
1826+
"llama2_13b": "llama2_7b=>daryl149/llama-2-13b-chat-hf",
1827+
"llama2_70b": "llama2_7b=>daryl149/llama-2-70b-chat-hf",
18571828
}
18581829
while True:
18591830
# TODO: Add break condition from user input

apps/language_models/src/model_wrappers/vicuna_model.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,9 @@ def __init__(
88
model_path,
99
precision="fp32",
1010
weight_group_size=128,
11-
model_name="vicuna",
12-
hf_auth_token: str = None,
1311
):
1412
super().__init__()
1513
kwargs = {"torch_dtype": torch.float32}
16-
if "llama2" in model_name:
17-
kwargs["use_auth_token"] = hf_auth_token
1814
self.model = AutoModelForCausalLM.from_pretrained(
1915
model_path, low_cpu_mem_usage=True, **kwargs
2016
)
@@ -57,13 +53,9 @@ def __init__(
5753
model_path,
5854
precision="fp32",
5955
weight_group_size=128,
60-
model_name="vicuna",
61-
hf_auth_token: str = None,
6256
):
6357
super().__init__()
6458
kwargs = {"torch_dtype": torch.float32}
65-
if "llama2" in model_name:
66-
kwargs["use_auth_token"] = hf_auth_token
6759
self.model = AutoModelForCausalLM.from_pretrained(
6860
model_path, low_cpu_mem_usage=True, **kwargs
6961
)
@@ -303,13 +295,9 @@ def __init__(
303295
model_path,
304296
precision="int8",
305297
weight_group_size=128,
306-
model_name="vicuna",
307-
hf_auth_token: str = None,
308298
):
309299
super().__init__()
310300
kwargs = {"torch_dtype": torch.float32}
311-
if "llama2" in model_name:
312-
kwargs["use_auth_token"] = hf_auth_token
313301
self.model = AutoModelForCausalLM.from_pretrained(
314302
model_path, low_cpu_mem_usage=True, **kwargs
315303
)
@@ -596,13 +584,9 @@ def __init__(
596584
model_path,
597585
precision="fp32",
598586
weight_group_size=128,
599-
model_name="vicuna",
600-
hf_auth_token: str = None,
601587
):
602588
super().__init__()
603589
kwargs = {"torch_dtype": torch.float32}
604-
if "llama2" in model_name:
605-
kwargs["use_auth_token"] = hf_auth_token
606590
self.model = AutoModelForCausalLM.from_pretrained(
607591
model_path, low_cpu_mem_usage=True, **kwargs
608592
)

apps/stable_diffusion/web/ui/stablelm_ui.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ def user(message, history):
2323
past_key_values = None
2424

2525
model_map = {
26-
"llama2_7b": "meta-llama/Llama-2-7b-chat-hf",
27-
"llama2_13b": "meta-llama/Llama-2-13b-chat-hf",
28-
"llama2_70b": "meta-llama/Llama-2-70b-chat-hf",
26+
"llama2_7b": "daryl149/llama-2-7b-chat-hf",
27+
"llama2_13b": "daryl149/llama-2-13b-chat-hf",
28+
"llama2_70b": "daryl149/llama-2-70b-chat-hf",
2929
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
3030
}
3131

@@ -186,7 +186,6 @@ def chat(
186186
vicuna_model = UnshardedVicuna(
187187
model_name,
188188
hf_model_path=model_path,
189-
hf_auth_token=args.hf_auth_token,
190189
device=device,
191190
precision=precision,
192191
max_num_tokens=max_toks,

shark/iree_utils/compile_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,6 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
356356
def load_vmfb_using_mmap(
357357
flatbuffer_blob_or_path, device: str, device_idx: int = None
358358
):
359-
print(f"Loading module {flatbuffer_blob_or_path}...")
360359
if "rocm" in device:
361360
device = "rocm"
362361
with DetailLogger(timeout=2.5) as dl:

0 commit comments

Comments
 (0)