1
+ # Adapted from
2
+ # https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/tests/lora/test_llama.py
3
+ from dataclasses import dataclass
4
+
5
+ import pytest
6
+ import vllm
7
+ from vllm .distributed import cleanup_dist_env_and_memory
8
+ from vllm .lora .request import LoRARequest
9
+
10
+ from tests .conftest import VllmRunner
11
+
12
+
13
+ @dataclass
14
+ class ModelWithQuantization :
15
+ model_path : str
16
+ quantization : str
17
+
18
+
19
+ MODELS : list [ModelWithQuantization ]
20
+ MODELS = [
21
+ ModelWithQuantization (
22
+ model_path = "TinyLlama/TinyLlama-1.1B-Chat-v0.3" ,
23
+ quantization = None ),
24
+ # ModelWithQuantization(
25
+ # model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
26
+ # quantization="AWQ"), #AWQ quantization is currently not supported in ROCm. (Ref: https://github.com/vllm-project/vllm/blob/f6518b2b487724b3aa20c8b8224faba5622c4e44/tests/lora/test_quant_model.py#L23)
27
+ # ModelWithQuantization(
28
+ # model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
29
+ # quantization="GPTQ"),
30
+ ]
31
+
32
+
33
+ def do_sample (llm : vllm .LLM ,
34
+ lora_path : str ,
35
+ lora_id : int ,
36
+ max_tokens : int = 256 ) -> list [str ]:
37
+ raw_prompts = [
38
+ "Give me an orange-ish brown color" ,
39
+ "Give me a neon pink color" ,
40
+ ]
41
+
42
+ def format_prompt_tuples (prompt ):
43
+ return f"<|im_start|>user\n { prompt } <|im_end|>\n <|im_start|>assistant\n "
44
+
45
+ prompts = [format_prompt_tuples (p ) for p in raw_prompts ]
46
+
47
+ sampling_params = vllm .SamplingParams (temperature = 0 ,
48
+ max_tokens = max_tokens ,
49
+ stop = ["<|im_end|>" ])
50
+ outputs = llm .generate (
51
+ prompts ,
52
+ sampling_params ,
53
+ lora_request = LoRARequest (str (lora_id ), lora_id , lora_path )
54
+ if lora_id else None )
55
+ # Print the outputs.
56
+ generated_texts : list [str ] = []
57
+ for output in outputs :
58
+ prompt = output .prompt
59
+ generated_text = output .outputs [0 ].text
60
+ generated_texts .append (generated_text )
61
+ print (f"Prompt: { prompt !r} , Generated text: { generated_text !r} " )
62
+ return generated_texts
63
+
64
+
65
+ @pytest .mark .parametrize ("model" , MODELS )
66
+ def test_quant_model_lora (tinyllama_lora_files , model ):
67
+
68
+ if model .quantization is None :
69
+ expected_no_lora_output = [
70
+ "Here are some examples of orange-brown colors" ,
71
+ "I'm sorry, I don't have"
72
+ ]
73
+ expected_lora_output = [
74
+ "#ff8050" ,
75
+ "#ff8080" ,
76
+ ]
77
+ elif model .quantization == "AWQ" :
78
+ expected_no_lora_output = [
79
+ "I'm sorry, I don't understand" ,
80
+ "I'm sorry, I don't understand" ,
81
+ ]
82
+ expected_lora_output = [
83
+ "#f07700: A v" ,
84
+ "#f00000: A v" ,
85
+ ]
86
+ elif model .quantization == "GPTQ" :
87
+ expected_no_lora_output = [
88
+ "I'm sorry, I don't have" ,
89
+ "I'm sorry, I don't have" ,
90
+ ]
91
+ expected_lora_output = [
92
+ "#f08800: This is" ,
93
+ "#f07788 \n #" ,
94
+ ]
95
+
96
+ def expect_match (output , expected_output ):
97
+ # HACK: GPTQ lora outputs are just incredibly unstable.
98
+ # Assert that the outputs changed.
99
+ if (model .quantization == "GPTQ"
100
+ and expected_output is expected_lora_output ):
101
+ assert output != expected_no_lora_output
102
+ for i , o in enumerate (output ):
103
+ assert o .startswith (
104
+ '#' ), f"Expected example { i } to start with # but got { o } "
105
+ return
106
+ assert output == expected_output
107
+
108
+ max_tokens = 10
109
+
110
+ print ("creating lora adapter" )
111
+ with VllmRunner (model = model .model_path ,
112
+ quantization = model .quantization ,
113
+ enable_lora = True ,
114
+ max_loras = 4 ,
115
+ max_model_len = 400 ,
116
+ gpu_memory_utilization = 0.7 ,
117
+ max_num_seqs = 16 ) as vllm_model :
118
+ print ("no lora" )
119
+ output = do_sample (vllm_model ,
120
+ tinyllama_lora_files ,
121
+ lora_id = 0 ,
122
+ max_tokens = max_tokens )
123
+ expect_match (output , expected_no_lora_output )
124
+
125
+ print ("lora 1" )
126
+ output = do_sample (vllm_model ,
127
+ tinyllama_lora_files ,
128
+ lora_id = 1 ,
129
+ max_tokens = max_tokens )
130
+ expect_match (output , expected_lora_output )
131
+
132
+ print ("no lora" )
133
+ output = do_sample (vllm_model ,
134
+ tinyllama_lora_files ,
135
+ lora_id = 0 ,
136
+ max_tokens = max_tokens )
137
+ expect_match (output , expected_no_lora_output )
138
+
139
+ print ("lora 2" )
140
+ output = do_sample (vllm_model ,
141
+ tinyllama_lora_files ,
142
+ lora_id = 2 ,
143
+ max_tokens = max_tokens )
144
+ expect_match (output , expected_lora_output )
145
+
146
+ print ("removing lora" )
147
+
148
+ del vllm_model
149
+ cleanup_dist_env_and_memory ()
0 commit comments