Skip to content

Commit 18a96e3

Browse files
fix: Enhance checks around KIND_GPU and tensor parallelism (#42)
Co-authored-by: Olga Andreeva <124622579+oandreeva-nv@users.noreply.github.com>
1 parent 2a1691a commit 18a96e3

File tree

3 files changed

+206
-56
lines changed

3 files changed

+206
-56
lines changed

ci/L0_multi_gpu/vllm_backend/test.sh

Lines changed: 82 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -31,50 +31,104 @@ TRITON_DIR=${TRITON_DIR:="/opt/tritonserver"}
3131
SERVER=${TRITON_DIR}/bin/tritonserver
3232
BACKEND_DIR=${TRITON_DIR}/backends
3333
SERVER_ARGS="--model-repository=`pwd`/models --backend-directory=${BACKEND_DIR} --model-control-mode=explicit --log-verbose=1"
34-
SERVER_LOG="./vllm_multi_gpu_test_server.log"
35-
CLIENT_LOG="./vllm_multi_gpu_test_client.log"
3634
TEST_RESULT_FILE='test_results.txt'
3735
CLIENT_PY="./vllm_multi_gpu_test.py"
3836
SAMPLE_MODELS_REPO="../../../samples/model_repository"
3937
EXPECTED_NUM_TESTS=1
4038

41-
rm -rf models && mkdir -p models
42-
cp -r ${SAMPLE_MODELS_REPO}/vllm_model models/vllm_opt
43-
sed -i '3s/^/ "tensor_parallel_size": 2,\n/' models/vllm_opt/1/model.json
39+
### Helpers
40+
function validate_file_contains() {
41+
local KEY="${1}"
42+
local FILE="${2}"
4443

45-
RET=0
44+
if [ -z "${KEY}" ] || [ -z "${FILE}" ]; then
45+
echo "Error: KEY and FILE must be provided."
46+
return 1
47+
fi
4648

47-
run_server
48-
if [ "$SERVER_PID" == "0" ]; then
49-
cat $SERVER_LOG
50-
echo -e "\n***\n*** Failed to start $SERVER\n***"
51-
exit 1
52-
fi
49+
if [ ! -f "${FILE}" ]; then
50+
echo "Error: File '${FILE}' does not exist."
51+
return 1
52+
fi
5353

54-
set +e
55-
python3 $CLIENT_PY -v > $CLIENT_LOG 2>&1
54+
count=$(grep -o -w "${KEY}" "${FILE}" | wc -l)
55+
56+
if [ "${count}" -ne 1 ]; then
57+
echo "Error: KEY '${KEY}' found ${count} times in '${FILE}'. Expected exactly once."
58+
return 1
59+
fi
60+
}
61+
62+
function run_multi_gpu_test() {
63+
export KIND="${1}"
64+
export TENSOR_PARALLELISM="${2}"
65+
export INSTANCE_COUNT="${3}"
66+
67+
# Setup a clean model repository
68+
export TEST_MODEL="vllm_opt_${KIND}_tp${TENSOR_PARALLELISM}_count${INSTANCE_COUNT}"
69+
local TEST_MODEL_TRITON_CONFIG="models/${TEST_MODEL}/config.pbtxt"
70+
local TEST_MODEL_VLLM_CONFIG="models/${TEST_MODEL}/1/model.json"
71+
72+
rm -rf models && mkdir -p models
73+
cp -r "${SAMPLE_MODELS_REPO}/vllm_model" "models/${TEST_MODEL}"
74+
sed -i "s/KIND_MODEL/${KIND}/" "${TEST_MODEL_TRITON_CONFIG}"
75+
sed -i "3s/^/ \"tensor_parallel_size\": ${TENSOR_PARALLELISM},\n/" "${TEST_MODEL_VLLM_CONFIG}"
76+
# Assert the correct kind is set in case the template config changes in the future
77+
validate_file_contains "${KIND}" "${TEST_MODEL_TRITON_CONFIG}"
78+
79+
# Start server
80+
echo "Running multi-GPU test with kind=${KIND}, tp=${TENSOR_PARALLELISM}, instance_count=${INSTANCE_COUNT}"
81+
SERVER_LOG="./vllm_multi_gpu_test--${KIND}_tp${TENSOR_PARALLELISM}_count${INSTANCE_COUNT}--server.log"
82+
run_server
83+
if [ "$SERVER_PID" == "0" ]; then
84+
cat $SERVER_LOG
85+
echo -e "\n***\n*** Failed to start $SERVER\n***"
86+
exit 1
87+
fi
88+
89+
# Run unit tests
90+
set +e
91+
CLIENT_LOG="./vllm_multi_gpu_test--${KIND}_tp${TENSOR_PARALLELISM}_count${INSTANCE_COUNT}--client.log"
92+
python3 $CLIENT_PY -v > $CLIENT_LOG 2>&1
5693

57-
if [ $? -ne 0 ]; then
58-
cat $CLIENT_LOG
59-
echo -e "\n***\n*** Running $CLIENT_PY FAILED. \n***"
60-
RET=1
61-
else
62-
check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS
6394
if [ $? -ne 0 ]; then
6495
cat $CLIENT_LOG
65-
echo -e "\n***\n*** Test Result Verification FAILED.\n***"
96+
echo -e "\n***\n*** Running $CLIENT_PY FAILED. \n***"
6697
RET=1
98+
else
99+
check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS
100+
if [ $? -ne 0 ]; then
101+
cat $CLIENT_LOG
102+
echo -e "\n***\n*** Test Result Verification FAILED.\n***"
103+
RET=1
104+
fi
67105
fi
68-
fi
69-
set -e
106+
set -e
107+
108+
# Cleanup
109+
kill $SERVER_PID
110+
wait $SERVER_PID
111+
}
112+
113+
### Test
114+
rm -f *.log
115+
RET=0
70116

71-
kill $SERVER_PID
72-
wait $SERVER_PID
73-
rm -rf models/
117+
# Test the various cases of kind, tensor parallelism, and instance count
118+
# for different ways to run multi-GPU models with vLLM on Triton
119+
KINDS="KIND_MODEL KIND_GPU"
120+
TPS="1 2"
121+
INSTANCE_COUNTS="1 2"
122+
for kind in ${KINDS}; do
123+
for tp in ${TPS}; do
124+
for count in ${INSTANCE_COUNTS}; do
125+
run_multi_gpu_test "${kind}" "${tp}" "${count}"
126+
done
127+
done
128+
done
74129

130+
### Results
75131
if [ $RET -eq 1 ]; then
76-
cat $CLIENT_LOG
77-
cat $SERVER_LOG
78132
echo -e "\n***\n*** Multi GPU Utilization test FAILED. \n***"
79133
else
80134
echo -e "\n***\n*** Multi GPU Utilization test PASSED. \n***"

ci/L0_multi_gpu/vllm_backend/vllm_multi_gpu_test.py

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

27+
import os
2728
import sys
2829
import unittest
2930
from functools import partial
@@ -40,7 +41,6 @@ class VLLMMultiGPUTest(TestResultCollector):
4041
def setUp(self):
4142
pynvml.nvmlInit()
4243
self.triton_client = grpcclient.InferenceServerClient(url="localhost:8001")
43-
self.vllm_model_name = "vllm_opt"
4444

4545
def get_gpu_memory_utilization(self, gpu_id):
4646
handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
@@ -56,7 +56,12 @@ def get_available_gpu_ids(self):
5656
available_gpus.append(gpu_id)
5757
return available_gpus
5858

59-
def test_vllm_multi_gpu_utilization(self):
59+
def _test_vllm_multi_gpu_utilization(self, model_name: str):
60+
"""
61+
Test that loading a given vLLM model will increase GPU utilization
62+
across multiple GPUs, and run a sanity check inference to confirm
63+
that the loaded multi-gpu/multi-instance model is working as expected.
64+
"""
6065
gpu_ids = self.get_available_gpu_ids()
6166
self.assertGreaterEqual(len(gpu_ids), 2, "Error: Detected single GPU")
6267

@@ -67,8 +72,8 @@ def test_vllm_multi_gpu_utilization(self):
6772
print(f"GPU {gpu_id} Memory Utilization: {memory_utilization} bytes")
6873
mem_util_before_loading_model[gpu_id] = memory_utilization
6974

70-
self.triton_client.load_model(self.vllm_model_name)
71-
self._test_vllm_model()
75+
self.triton_client.load_model(model_name)
76+
self._test_vllm_model(model_name)
7277

7378
print("=============== After Loading vLLM Model ===============")
7479
vllm_model_used_gpus = 0
@@ -80,7 +85,7 @@ def test_vllm_multi_gpu_utilization(self):
8085

8186
self.assertGreaterEqual(vllm_model_used_gpus, 2)
8287

83-
def _test_vllm_model(self, send_parameters_as_tensor=True):
88+
def _test_vllm_model(self, model_name: str, send_parameters_as_tensor: bool = True):
8489
user_data = UserData()
8590
stream = False
8691
prompts = [
@@ -98,11 +103,11 @@ def _test_vllm_model(self, send_parameters_as_tensor=True):
98103
i,
99104
stream,
100105
sampling_parameters,
101-
self.vllm_model_name,
106+
model_name,
102107
send_parameters_as_tensor,
103108
)
104109
self.triton_client.async_stream_infer(
105-
model_name=self.vllm_model_name,
110+
model_name=model_name,
106111
request_id=request_data["request_id"],
107112
inputs=request_data["inputs"],
108113
outputs=request_data["outputs"],
@@ -118,6 +123,59 @@ def _test_vllm_model(self, send_parameters_as_tensor=True):
118123

119124
self.triton_client.stop_stream()
120125

126+
def test_multi_gpu_model(self):
127+
"""
128+
Tests that a multi-GPU vLLM model loads successfully on multiple GPUs
129+
and can handle a few sanity check inference requests.
130+
131+
Multi-GPU models are currently defined here as either:
132+
- a single model instance with tensor parallelism > 1
133+
- multiple model instances each with tensor parallelism == 1
134+
135+
FIXME: This test currently skips over a few combinations that may
136+
be enhanced in the future, such as:
137+
- tensor parallel models with multiple model instances
138+
- KIND_MODEL models with multiple model instances
139+
"""
140+
model = os.environ.get("TEST_MODEL")
141+
kind = os.environ.get("KIND")
142+
tp = os.environ.get("TENSOR_PARALLELISM")
143+
instance_count = os.environ.get("INSTANCE_COUNT")
144+
for env_var in [model, kind, tp, instance_count]:
145+
self.assertIsNotNone(env_var)
146+
147+
print(f"Test Matrix: {model=}, {kind=}, {tp=}, {instance_count=}")
148+
149+
# Only support tensor parallelism or multiple instances for now, but not both.
150+
# Support for multi-instance tensor parallel models may require more
151+
# special handling in the backend to better handle device assignment.
152+
# NOTE: This eliminates the 1*1=1 and 2*2=4 test cases.
153+
if int(tp) * int(instance_count) != 2:
154+
msg = "TENSOR_PARALLELISM and INSTANCE_COUNT must have a product of 2 for this 2-GPU test"
155+
print("Skipping Test:", msg)
156+
self.skipTest(msg)
157+
158+
# Loading a KIND_GPU model with Tensor Parallelism > 1 should fail and
159+
# recommend using KIND_MODEL instead for multi-gpu model instances.
160+
if kind == "KIND_GPU" and int(tp) > 1:
161+
with self.assertRaisesRegex(
162+
InferenceServerException, "please specify KIND_MODEL"
163+
):
164+
self._test_vllm_multi_gpu_utilization(model)
165+
166+
return
167+
168+
# Loading a KIND_MODEL model with multiple instances can cause
169+
# oversubscription to specific GPUs and cause a CUDA OOM if the
170+
# gpu_memory_utilization settings are high without further handling
171+
# of device assignment in the backend.
172+
if kind == "KIND_MODEL" and int(instance_count) > 1:
173+
msg = "Testing multiple model instances of KIND_MODEL is not implemented at this time"
174+
print("Skipping Test:", msg)
175+
self.skipTest(msg)
176+
177+
self._test_vllm_multi_gpu_utilization(model)
178+
121179
def tearDown(self):
122180
pynvml.nvmlShutdown()
123181
self.triton_client.close()

src/model.py

Lines changed: 59 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from typing import Dict, List
3232

3333
import numpy as np
34+
import torch
3435
import triton_python_backend_utils as pb_utils
3536
from vllm.engine.arg_utils import AsyncEngineArgs
3637
from vllm.engine.async_llm_engine import AsyncLLMEngine
@@ -98,12 +99,31 @@ def auto_complete_config(auto_complete_model_config):
9899
return auto_complete_model_config
99100

100101
def initialize(self, args):
102+
self.args = args
101103
self.logger = pb_utils.Logger
102104
self.model_config = json.loads(args["model_config"])
105+
output_config = pb_utils.get_output_config_by_name(
106+
self.model_config, "text_output"
107+
)
108+
self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
109+
110+
# Prepare vLLM engine
111+
self.init_engine()
112+
113+
# Counter to keep track of ongoing request counts
114+
self.ongoing_request_count = 0
115+
116+
# Starting asyncio event loop to process the received requests asynchronously.
117+
self._loop = asyncio.get_event_loop()
118+
self._loop_thread = threading.Thread(
119+
target=self.engine_loop, args=(self._loop,)
120+
)
121+
self._shutdown_event = asyncio.Event()
122+
self._loop_thread.start()
103123

104-
# assert are in decoupled mode. Currently, Triton needs to use
105-
# decoupled policy for asynchronously forwarding requests to
106-
# vLLM engine.
124+
def init_engine(self):
125+
# Currently, Triton needs to use decoupled policy for asynchronously
126+
# forwarding requests to vLLM engine, so assert it.
107127
self.using_decoupled = pb_utils.using_decoupled_model_transaction_policy(
108128
self.model_config
109129
)
@@ -118,17 +138,25 @@ def initialize(self, args):
118138
engine_args_filepath
119139
), f"'{_VLLM_ENGINE_ARGS_FILENAME}' containing vllm engine args must be provided in '{pb_utils.get_model_dir()}'"
120140
with open(engine_args_filepath) as file:
121-
vllm_engine_config = json.load(file)
141+
self.vllm_engine_config = json.load(file)
142+
143+
# Validate device and multi-processing settings are currently set based on model/configs.
144+
self.validate_device_config()
145+
146+
# Check for LoRA config and set it up if enabled
147+
self.setup_lora()
122148

123149
# Create an AsyncLLMEngine from the config from JSON
124150
self.llm_engine = AsyncLLMEngine.from_engine_args(
125-
AsyncEngineArgs(**vllm_engine_config)
151+
AsyncEngineArgs(**self.vllm_engine_config)
126152
)
153+
154+
def setup_lora(self):
127155
self.enable_lora = False
128156

129157
if (
130-
"enable_lora" in vllm_engine_config.keys()
131-
and vllm_engine_config["enable_lora"].lower() == "true"
158+
"enable_lora" in self.vllm_engine_config.keys()
159+
and self.vllm_engine_config["enable_lora"].lower() == "true"
132160
):
133161
# create Triton LoRA weights repository
134162
multi_lora_args_filepath = os.path.join(
@@ -146,21 +174,31 @@ def initialize(self, args):
146174
f"Triton backend cannot find {multi_lora_args_filepath}."
147175
)
148176

149-
output_config = pb_utils.get_output_config_by_name(
150-
self.model_config, "text_output"
151-
)
152-
self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
153-
154-
# Counter to keep track of ongoing request counts
155-
self.ongoing_request_count = 0
177+
def validate_device_config(self):
178+
triton_kind = self.args["model_instance_kind"]
179+
triton_device_id = int(self.args["model_instance_device_id"])
180+
triton_instance = f"{self.args['model_name']}_{triton_device_id}"
181+
182+
# Triton's current definition of KIND_GPU makes assumptions that
183+
# models only use a single GPU. For multi-GPU models, the recommendation
184+
# is to specify KIND_MODEL to acknowledge that the model will take control
185+
# of the devices made available to it.
186+
# NOTE: Consider other parameters that would indicate multi-GPU in the future.
187+
tp_size = int(self.vllm_engine_config.get("tensor_parallel_size", 1))
188+
if tp_size > 1 and triton_kind == "GPU":
189+
raise ValueError(
190+
"KIND_GPU is currently for single-GPU models, please specify KIND_MODEL "
191+
"in the model's config.pbtxt for multi-GPU models"
192+
)
156193

157-
# Starting asyncio event loop to process the received requests asynchronously.
158-
self._loop = asyncio.get_event_loop()
159-
self._loop_thread = threading.Thread(
160-
target=self.engine_loop, args=(self._loop,)
161-
)
162-
self._shutdown_event = asyncio.Event()
163-
self._loop_thread.start()
194+
# If KIND_GPU is specified, specify the device ID assigned by Triton to ensure that
195+
# multiple model instances do not oversubscribe the same default device.
196+
if triton_kind == "GPU" and triton_device_id >= 0:
197+
self.logger.log_info(
198+
f"Detected KIND_GPU model instance, explicitly setting GPU device={triton_device_id} for {triton_instance}"
199+
)
200+
# vLLM doesn't currently (v0.4.2) expose device selection in the APIs
201+
torch.cuda.set_device(triton_device_id)
164202

165203
def create_task(self, coro):
166204
"""

0 commit comments

Comments
 (0)