Skip to content

Commit d988cf8

Browse files
author
paulyu
committed
[Bugfix] use tiny_llama_chat to test lora
Signed-off-by: paulyu <paulyu0307@gmail.com>
1 parent 9e538da commit d988cf8

File tree

8 files changed

+191
-118
lines changed

8 files changed

+191
-118
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,24 +111,27 @@ jobs:
111111
VLLM_WORKER_MULTIPROC_METHOD: spawn
112112
run: |
113113
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
114+
pytest -sv tests/singlecard/test_lora_quant.py
114115
pytest -sv tests/singlecard/test_offline_inference.py
115116
pytest -sv tests/ops
116117
pytest -sv tests/compile
117118
else
119+
pytest -sv tests/multicard/test_lora_quant_tp.py
118120
pytest -sv -k "QwQ" tests/multicard/test_offline_inference_distributed.py
119121
pytest -sv tests/ops
120122
pytest -sv tests/compile
121-
pytest -sv tests/lora/test_baichuan_tp.py
122123
fi
123124
124125
- name: Run vllm-project/vllm-ascend test on V0 engine
125126
env:
126127
VLLM_USE_V1: 0
127128
run: |
128129
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
130+
pytest -sv tests/singlecard/test_lora_quant.py
129131
pytest -sv tests/singlecard/test_offline_inference.py
130132
pytest -sv tests/ops
131133
else
134+
pytest -sv tests/multicard/test_lora_quant_tp.py
132135
pytest -sv -k "QwQ" tests/multicard/test_offline_inference_distributed.py
133136
pytest -sv -k "DeepSeek" tests/multicard/test_offline_inference_distributed.py
134137
pytest -sv tests/ops

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,5 +353,5 @@ def prompt_template(request):
353353

354354

355355
@pytest.fixture(scope="session")
356-
def baichuan_lora_files():
357-
return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider")
356+
def tinyllama_lora_files():
357+
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")

tests/lora/__init__.py

Whitespace-only changes.

tests/lora/test_baichuan.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

tests/lora/test_baichuan_tp.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

tests/lora/utils.py

Lines changed: 0 additions & 33 deletions
This file was deleted.
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import pytest
2+
from vllm.distributed import cleanup_dist_env_and_memory
3+
4+
from tests.conftest import VllmRunner
5+
from tests.singlecard.test_lora_quant import MODELS, do_sample
6+
7+
8+
@pytest.mark.parametrize("model", MODELS)
9+
def test_quant_model_tp_equality(tinyllama_lora_files,
10+
model):
11+
if model.quantization == "GPTQ":
12+
pytest.skip("GPTQ lora outputs are just incredibly unstable")
13+
with VllmRunner(model=model.model_path,
14+
quantization=model.quantization,
15+
enable_lora=True,
16+
max_loras=4,
17+
gpu_memory_utilization=0.7,
18+
max_num_seqs=16) as vllm_model_tp1:
19+
output_tp1 = do_sample(vllm_model_tp1, tinyllama_lora_files, lora_id=1)
20+
21+
del vllm_model_tp1
22+
cleanup_dist_env_and_memory()
23+
24+
with VllmRunner(model=model.model_path,
25+
quantization=model.quantization,
26+
enable_lora=True,
27+
max_loras=4,
28+
tensor_parallel_size=2,
29+
gpu_memory_utilization=0.7,
30+
max_num_seqs=16) as vllm_model_tp2:
31+
output_tp2 = do_sample(vllm_model_tp2, tinyllama_lora_files, lora_id=1)
32+
33+
del vllm_model_tp2
34+
cleanup_dist_env_and_memory()
35+
36+
assert output_tp1 == output_tp2
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Adapted from
2+
# https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/tests/lora/test_llama.py
3+
from dataclasses import dataclass
4+
5+
import pytest
6+
import vllm
7+
from vllm.distributed import cleanup_dist_env_and_memory
8+
from vllm.lora.request import LoRARequest
9+
10+
from tests.conftest import VllmRunner
11+
12+
13+
@dataclass
14+
class ModelWithQuantization:
15+
model_path: str
16+
quantization: str
17+
18+
19+
MODELS: list[ModelWithQuantization]
20+
MODELS = [
21+
ModelWithQuantization(
22+
model_path="TinyLlama/TinyLlama-1.1B-Chat-v0.3",
23+
quantization=None),
24+
# ModelWithQuantization(
25+
# model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
26+
# quantization="AWQ"), #AWQ quantization is currently not supported in ROCm. (Ref: https://github.com/vllm-project/vllm/blob/f6518b2b487724b3aa20c8b8224faba5622c4e44/tests/lora/test_quant_model.py#L23)
27+
# ModelWithQuantization(
28+
# model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
29+
# quantization="GPTQ"),
30+
]
31+
32+
33+
def do_sample(llm: vllm.LLM,
34+
lora_path: str,
35+
lora_id: int,
36+
max_tokens: int = 256) -> list[str]:
37+
raw_prompts = [
38+
"Give me an orange-ish brown color",
39+
"Give me a neon pink color",
40+
]
41+
42+
def format_prompt_tuples(prompt):
43+
return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
44+
45+
prompts = [format_prompt_tuples(p) for p in raw_prompts]
46+
47+
sampling_params = vllm.SamplingParams(temperature=0,
48+
max_tokens=max_tokens,
49+
stop=["<|im_end|>"])
50+
outputs = llm.generate(
51+
prompts,
52+
sampling_params,
53+
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
54+
if lora_id else None)
55+
# Print the outputs.
56+
generated_texts: list[str] = []
57+
for output in outputs:
58+
prompt = output.prompt
59+
generated_text = output.outputs[0].text
60+
generated_texts.append(generated_text)
61+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
62+
return generated_texts
63+
64+
65+
@pytest.mark.parametrize("model", MODELS)
66+
def test_quant_model_lora(tinyllama_lora_files, model):
67+
68+
if model.quantization is None:
69+
expected_no_lora_output = [
70+
"Here are some examples of orange-brown colors",
71+
"I'm sorry, I don't have"
72+
]
73+
expected_lora_output = [
74+
"#ff8050",
75+
"#ff8080",
76+
]
77+
elif model.quantization == "AWQ":
78+
expected_no_lora_output = [
79+
"I'm sorry, I don't understand",
80+
"I'm sorry, I don't understand",
81+
]
82+
expected_lora_output = [
83+
"#f07700: A v",
84+
"#f00000: A v",
85+
]
86+
elif model.quantization == "GPTQ":
87+
expected_no_lora_output = [
88+
"I'm sorry, I don't have",
89+
"I'm sorry, I don't have",
90+
]
91+
expected_lora_output = [
92+
"#f08800: This is",
93+
"#f07788 \n#",
94+
]
95+
96+
def expect_match(output, expected_output):
97+
# HACK: GPTQ lora outputs are just incredibly unstable.
98+
# Assert that the outputs changed.
99+
if (model.quantization == "GPTQ"
100+
and expected_output is expected_lora_output):
101+
assert output != expected_no_lora_output
102+
for i, o in enumerate(output):
103+
assert o.startswith(
104+
'#'), f"Expected example {i} to start with # but got {o}"
105+
return
106+
assert output == expected_output
107+
108+
max_tokens = 10
109+
110+
print("creating lora adapter")
111+
with VllmRunner(model=model.model_path,
112+
quantization=model.quantization,
113+
enable_lora=True,
114+
max_loras=4,
115+
max_model_len=400,
116+
gpu_memory_utilization=0.7,
117+
max_num_seqs=16) as vllm_model:
118+
print("no lora")
119+
output = do_sample(vllm_model,
120+
tinyllama_lora_files,
121+
lora_id=0,
122+
max_tokens=max_tokens)
123+
expect_match(output, expected_no_lora_output)
124+
125+
print("lora 1")
126+
output = do_sample(vllm_model,
127+
tinyllama_lora_files,
128+
lora_id=1,
129+
max_tokens=max_tokens)
130+
expect_match(output, expected_lora_output)
131+
132+
print("no lora")
133+
output = do_sample(vllm_model,
134+
tinyllama_lora_files,
135+
lora_id=0,
136+
max_tokens=max_tokens)
137+
expect_match(output, expected_no_lora_output)
138+
139+
print("lora 2")
140+
output = do_sample(vllm_model,
141+
tinyllama_lora_files,
142+
lora_id=2,
143+
max_tokens=max_tokens)
144+
expect_match(output, expected_lora_output)
145+
146+
print("removing lora")
147+
148+
del vllm_model
149+
cleanup_dist_env_and_memory()

0 commit comments

Comments
 (0)