Skip to content

Commit f064eed

Browse files
Add multi-lora support for Triton vLLM backend (#23)
Co-authored-by: Olga Andreeva <124622579+oandreeva-nv@users.noreply.github.com>
1 parent a014751 commit f064eed

File tree

7 files changed

+762
-9
lines changed

7 files changed

+762
-9
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
from huggingface_hub import snapshot_download
28+
29+
if __name__ == "__main__":
30+
# download lora weight alpaca
31+
snapshot_download(
32+
repo_id="swathijn/GemmaDoll-2b-dolly-LORA-Tune",
33+
local_dir="./weights/loras/GemmaDoll",
34+
max_workers=8,
35+
)
36+
# download lora weight GemmaSheep
37+
snapshot_download(
38+
repo_id="eduardo-alvarez/GemmaSheep-2B-LORA-TUNED",
39+
local_dir="./weights/loras/GemmaSheep",
40+
max_workers=8,
41+
)
42+
# download backbone weight google/gemma-2b
43+
snapshot_download(
44+
repo_id="unsloth/gemma-2b",
45+
local_dir="./weights/backbone/gemma-2b",
46+
max_workers=8,
47+
)
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
import os
28+
import sys
29+
import unittest
30+
from functools import partial
31+
from typing import List
32+
33+
import tritonclient.grpc as grpcclient
34+
from tritonclient.utils import *
35+
36+
sys.path.append("../../common")
37+
from test_util import AsyncTestResultCollector, UserData, callback, create_vllm_request
38+
39+
PROMPTS = ["Instruct: What do you think of Computer Science?\nOutput:"]
40+
SAMPLING_PARAMETERS = {"temperature": "0", "top_p": "1"}
41+
42+
server_enable_lora = True
43+
44+
45+
class VLLMTritonLoraTest(AsyncTestResultCollector):
46+
def setUp(self):
47+
self.triton_client = grpcclient.InferenceServerClient(url="localhost:8001")
48+
self.vllm_model_name = "vllm_llama_multi_lora"
49+
50+
def _test_vllm_model(
51+
self,
52+
prompts: List[str],
53+
sampling_parameters,
54+
lora_name: List[str],
55+
server_enable_lora=True,
56+
stream=False,
57+
exclude_input_in_output=None,
58+
expected_output=None,
59+
):
60+
assert len(prompts) == len(
61+
lora_name
62+
), "The number of prompts and lora names should be the same"
63+
user_data = UserData()
64+
number_of_vllm_reqs = len(prompts)
65+
66+
self.triton_client.start_stream(callback=partial(callback, user_data))
67+
for i in range(number_of_vllm_reqs):
68+
lora = lora_name[i] if lora_name else None
69+
sam_para_copy = sampling_parameters.copy()
70+
if lora is not None:
71+
sam_para_copy["lora_name"] = lora
72+
request_data = create_vllm_request(
73+
prompts[i],
74+
i,
75+
stream,
76+
sam_para_copy,
77+
self.vllm_model_name,
78+
exclude_input_in_output=exclude_input_in_output,
79+
)
80+
self.triton_client.async_stream_infer(
81+
model_name=self.vllm_model_name,
82+
request_id=request_data["request_id"],
83+
inputs=request_data["inputs"],
84+
outputs=request_data["outputs"],
85+
parameters=sampling_parameters,
86+
)
87+
88+
for i in range(number_of_vllm_reqs):
89+
result = user_data._completed_requests.get()
90+
if type(result) is InferenceServerException:
91+
print(result.message())
92+
if server_enable_lora:
93+
self.assertEqual(
94+
str(result.message()),
95+
f"LoRA {lora_name[i]} is not supported, we currently support ['doll', 'sheep']",
96+
"InferenceServerException",
97+
)
98+
else:
99+
self.assertEqual(
100+
str(result.message()),
101+
"LoRA feature is not enabled.",
102+
"InferenceServerException",
103+
)
104+
self.triton_client.stop_stream()
105+
return
106+
107+
output = result.as_numpy("text_output")
108+
self.assertIsNotNone(output, "`text_output` should not be None")
109+
if expected_output is not None:
110+
self.assertEqual(
111+
output,
112+
expected_output[i],
113+
'Actual and expected outputs do not match.\n \
114+
Expected "{}" \n Actual:"{}"'.format(
115+
output, expected_output[i]
116+
),
117+
)
118+
119+
self.triton_client.stop_stream()
120+
121+
def test_multi_lora_requests(self):
122+
self.triton_client.load_model(self.vllm_model_name)
123+
sampling_parameters = {"temperature": "0", "top_p": "1"}
124+
# make two requests separately to avoid the different arrival of response answers
125+
prompt_1 = ["Instruct: What do you think of Computer Science?\nOutput:"]
126+
lora_1 = ["doll"]
127+
expected_output = [
128+
b" I think it is a very interesting subject.\n\nInstruct: What do you"
129+
]
130+
self._test_vllm_model(
131+
prompt_1,
132+
sampling_parameters,
133+
lora_name=lora_1,
134+
server_enable_lora=server_enable_lora,
135+
stream=False,
136+
exclude_input_in_output=True,
137+
expected_output=expected_output,
138+
)
139+
140+
prompt_2 = ["Instruct: Tell me more about soccer\nOutput:"]
141+
lora_2 = ["sheep"]
142+
expected_output = [
143+
b" I love soccer. I play soccer every day.\nInstruct: Tell me"
144+
]
145+
self._test_vllm_model(
146+
prompt_2,
147+
sampling_parameters,
148+
lora_name=lora_2,
149+
server_enable_lora=server_enable_lora,
150+
stream=False,
151+
exclude_input_in_output=True,
152+
expected_output=expected_output,
153+
)
154+
self.triton_client.unload_model(self.vllm_model_name)
155+
156+
def test_none_exist_lora(self):
157+
self.triton_client.load_model(self.vllm_model_name)
158+
prompts = [
159+
"Instruct: What is the capital city of France?\nOutput:",
160+
]
161+
loras = ["bactrian"]
162+
sampling_parameters = {"temperature": "0", "top_p": "1"}
163+
self._test_vllm_model(
164+
prompts,
165+
sampling_parameters,
166+
lora_name=loras,
167+
server_enable_lora=server_enable_lora,
168+
stream=False,
169+
exclude_input_in_output=True,
170+
expected_output=None, # this request will lead to lora not supported error, so there is no expected output
171+
)
172+
self.triton_client.unload_model(self.vllm_model_name)
173+
174+
def tearDown(self):
175+
self.triton_client.close()
176+
177+
178+
if __name__ == "__main__":
179+
server_enable_lora = os.environ.get("SERVER_ENABLE_LORA", "false").lower() == "true"
180+
181+
unittest.main()

0 commit comments

Comments
 (0)