Skip to content

Elastic Expert Parallel Initial Support #20775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
Jul 19, 2025
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
5c78497
eep basic
ruisearch42 Jul 9, 2025
cbab40d
fixes
ruisearch42 Jul 9, 2025
ea87424
clean up placement_group functions
ruisearch42 Jul 10, 2025
cb183c2
cleanup
ruisearch42 Jul 10, 2025
4b543ea
cleanup
ruisearch42 Jul 10, 2025
a9cbefe
cleanup
ruisearch42 Jul 10, 2025
637aca2
cleanup
ruisearch42 Jul 10, 2025
26af1a8
ray cleanup
ruisearch42 Jul 10, 2025
a804a0f
reorg dir
ruisearch42 Jul 10, 2025
f3c0360
minor refactor
ruisearch42 Jul 10, 2025
6507536
fix repeated scale up
ruisearch42 Jul 10, 2025
11feb5c
move nvshmem.patch
ruisearch42 Jul 11, 2025
7799fb2
Merge branch 'main' into eep_m1
ruisearch42 Jul 13, 2025
2ec9ddc
factor out RayDPClient
ruisearch42 Jul 13, 2025
07e6719
use middleware
ruisearch42 Jul 14, 2025
a28f04a
Merge branch 'main' into eep_m1
ruisearch42 Jul 14, 2025
0aec946
rename
ruisearch42 Jul 14, 2025
91799cd
msgspec for SCALE_DP
ruisearch42 Jul 15, 2025
86bc80d
int32
ruisearch42 Jul 15, 2025
cb9e71c
Merge branch 'main' into eep_m1
ruisearch42 Jul 15, 2025
d504cbb
up
ruisearch42 Jul 15, 2025
499fe95
fix CI
ruisearch42 Jul 16, 2025
c544b36
assert ray backend
ruisearch42 Jul 16, 2025
f44775a
MAX_EXPERT_REDUNDANCY
ruisearch42 Jul 16, 2025
ac505d6
SCALE_DP & port alloc
ruisearch42 Jul 16, 2025
a1f13b2
fix
ruisearch42 Jul 17, 2025
329c445
add install files
ruisearch42 Jul 17, 2025
9250d78
update serve.sh
ruisearch42 Jul 17, 2025
81b14bb
update bench.sh
ruisearch42 Jul 17, 2025
37c897f
address comments
ruisearch42 Jul 18, 2025
be969ec
comments
ruisearch42 Jul 18, 2025
32b96be
refactor reinitialize_distributed
ruisearch42 Jul 18, 2025
0ab9675
Merge branch 'main' into eep_m1
ruisearch42 Jul 18, 2025
e59fd3a
single scale API
ruisearch42 Jul 18, 2025
cbd9966
up
ruisearch42 Jul 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions examples/online_serving/elastic_ep/bench.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/bin/bash

MODEL_NAME="deepseek-ai/DeepSeek-V2-Lite"
LOCAL_MODEL_PATH="/models/models--deepseek-ai--DeepSeek-V2-Lite/snapshots/604d5664dddd88a0433dbae533b7fe9472482de0"
HOST="localhost"
PORT=8006
NUM_PROMPTS=20
REQUEST_RATE=5

# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL_NAME="$2"
shift 2
;;
--local-model)
MODEL_NAME=$LOCAL_MODEL_PATH
shift
;;
--host)
HOST="$2"
shift 2
;;
--port)
PORT="$2"
shift 2
;;
--num-prompts)
NUM_PROMPTS="$2"
shift 2
;;
--request-rate)
REQUEST_RATE="$2"
shift 2
;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --model MODEL_NAME Set model name or path (default: deepseek-ai/DeepSeek-V2-Lite)"
echo " --local-model Use local model path (convenience option)"
exit 0
;;
*)
echo "Unknown option: $1"
echo "Use -h or --help for usage information"
exit 1
;;
esac
done

vllm bench serve \
--model $MODEL_NAME \
--host $HOST \
--port $PORT \
--num-prompts $NUM_PROMPTS \
--request-rate $REQUEST_RATE
53 changes: 53 additions & 0 deletions examples/online_serving/elastic_ep/scale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import argparse
import json
import sys

import requests


def scale(host, port, new_dp_size):
url = f"http://{host}:{port}/scale_elastic_ep"
payload = {"new_data_parallel_size": new_dp_size}
headers = {"Content-Type": "application/json"}

print(f"Sending scale request to {url}")
print(f"Payload: {json.dumps(payload, indent=2)}")

try:
response = requests.post(url, json=payload, headers=headers, timeout=300)

print(f"Status Code: {response.status_code}")
print(f"Response: {response.text}")

if response.status_code == 200:
print("Scale up/down request successful!")
return True
else:
print("Scale up/down request failed!")
return False

except requests.exceptions.RequestException as e:
print(f"Request failed: {e}")
return False


def main():
parser = argparse.ArgumentParser(description="Test scale up/down functionality")
parser.add_argument("--host", default="localhost", help="API server host")
parser.add_argument("--port", type=int, default=8006, help="API server port")
parser.add_argument(
"--new-dp-size", type=int, default=2, help="New data parallel size"
)

args = parser.parse_args()

success = scale(args.host, args.port, args.new_dp_size)
sys.exit(0 if success else 1)


if __name__ == "__main__":
main()
72 changes: 72 additions & 0 deletions examples/online_serving/elastic_ep/serve_deepseek_v2.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#!/bin/bash

HOST="0.0.0.0"
PORT=8006
DATA_PARALLEL_SIZE=4
REDUNDANT_EXPERTS=0
LOCAL_MODEL_PATH="/models/models--deepseek-ai--DeepSeek-V2-Lite/snapshots/604d5664dddd88a0433dbae533b7fe9472482de0"
MODEL_NAME="deepseek-ai/DeepSeek-V2-Lite"

while [[ $# -gt 0 ]]; do
case $1 in
--dp)
DATA_PARALLEL_SIZE="$2"
shift 2
;;
--re)
REDUNDANT_EXPERTS="$2"
shift 2
;;
--host)
HOST="$2"
shift 2
;;
--port)
PORT="$2"
shift 2
;;
--model)
MODEL_NAME="$2"
shift 2
;;
--local-model)
MODEL_NAME=$LOCAL_MODEL_PATH
shift
;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --dp SIZE Set data parallel size (default: 4)"
echo " --re SIZE Set redundant experts (default: 0)"
echo " --host HOST Set host address (default: 0.0.0.0)"
echo " --port PORT Set port number (default: 8006)"
echo " --model MODEL_NAME Set model name or path"
echo " -h, --help Show this help message"
exit 0
;;
*)
echo "Unknown option: $1"
echo "Use -h or --help for usage information"
exit 1
;;
esac
done

echo "Starting vLLM server for $MODEL_NAME with data parallel size: $DATA_PARALLEL_SIZE and redundant experts: $REDUNDANT_EXPERTS"

export RAY_DEDUP_LOGS=0
export VLLM_USE_V1=1
export VLLM_ALL2ALL_BACKEND="pplx"
export VLLM_USE_DEEP_GEMM=1

vllm serve $MODEL_NAME \
--data-parallel-size $DATA_PARALLEL_SIZE \
--data-parallel-size-local $DATA_PARALLEL_SIZE \
--data-parallel-backend ray \
--enforce-eager \
--enable-expert-parallel \
--enable-eplb \
--num-redundant-experts $REDUNDANT_EXPERTS \
--trust-remote-code \
--host $HOST \
--port $PORT
92 changes: 92 additions & 0 deletions tools/ep_kernels/elastic_ep/eep_nvshmem.patch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI @mnicely - this PR is important for autoscaling large-scale distributed MoE inference. It would be great to upstream any changes necessary for changing the world_size

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the ping. I'll bring to the team

Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
From 18c0599c2f07ec965132efa25961dc8179c2dda3 Mon Sep 17 00:00:00 2001
From: Yongji Wu <wuyongji317@gmail.com>
Date: Tue, 20 May 2025 13:41:12 -0700
Subject: [PATCH] fix reinit issues due to states not cleaned up

fix double free
---
src/host/init/init.cu | 10 ++++++++++
.../internal/host/nvshmemi_mem_transport.hpp | 15 +++++++++++++++
src/modules/bootstrap/uid/bootstrap_uid.cpp | 5 +++++
3 files changed, 30 insertions(+)

diff --git a/src/host/init/init.cu b/src/host/init/init.cu
index b1c5dbf..1fecb4b 100644
--- a/src/host/init/init.cu
+++ b/src/host/init/init.cu
@@ -43,6 +43,8 @@
#include "internal/host/nvshmemi_types.h"
#include "internal/host/shared_memory.h"
#include "internal/host/nvshmemi_symmetric_heap.hpp"
+// eep-dev
+#include "internal/host/nvshmemi_mem_transport.hpp"

extern __constant__ nvshmemi_device_host_state_t nvshmemi_device_state_d;
static std::map<void *, int> registered_device_states;
@@ -1293,6 +1295,14 @@ void nvshmemid_hostlib_finalize(void *device_ctx, void *transport_device_ctx) {
/* Multi-init Multi-fini*/
nvshmemi_state = NULL;
nvshmemi_device_state.nvshmemi_is_nvshmem_initialized = 0;
+
+ // eep-dev
+ nvshmemi_mem_p2p_transport::destroy_instance();
+ nvshmemi_mem_remote_transport::destroy_instance();
+ free(nvshmemi_default_session);
+ nvshmemi_default_session = nullptr;
+ nvshmemi_device_state.nvshmemi_is_nvshmem_bootstrapped = false;
+
nvshmemi_is_device_state_ready = false;
} else
nvshmemi_boot_handle.barrier(&nvshmemi_boot_handle);
diff --git a/src/include/internal/host/nvshmemi_mem_transport.hpp b/src/include/internal/host/nvshmemi_mem_transport.hpp
index 2495844..e4f408a 100644
--- a/src/include/internal/host/nvshmemi_mem_transport.hpp
+++ b/src/include/internal/host/nvshmemi_mem_transport.hpp
@@ -36,6 +36,13 @@ class nvshmemi_mem_p2p_transport final {
return p2p_objref_;
}
}
+ // eep-dev
+ static void destroy_instance(void) {
+ if (p2p_objref_ != nullptr) {
+ delete p2p_objref_;
+ p2p_objref_ = nullptr;
+ }
+ }

void print_mem_handle(int pe_id, int transport_idx, nvshmemi_symmetric_heap &obj);

@@ -87,6 +94,14 @@ class nvshmemi_mem_remote_transport final {
}
}

+ // eep-dev
+ static void destroy_instance(void) {
+ if (remote_objref_ != nullptr) {
+ delete remote_objref_;
+ remote_objref_ = nullptr;
+ }
+ }
+
int gather_mem_handles(nvshmemi_symmetric_heap &obj, uint64_t heap_offset, size_t size);
/* On-demand registration and release of memory */
int register_mem_handle(nvshmem_mem_handle_t *local_handles, int transport_idx,
diff --git a/src/modules/bootstrap/uid/bootstrap_uid.cpp b/src/modules/bootstrap/uid/bootstrap_uid.cpp
index a1fa748..788fa96 100644
--- a/src/modules/bootstrap/uid/bootstrap_uid.cpp
+++ b/src/modules/bootstrap/uid/bootstrap_uid.cpp
@@ -630,6 +630,11 @@ int nvshmemi_bootstrap_plugin_pre_init(bootstrap_handle_t* handle, const int abi
// Discover the network for bootstrap, if not done previously.
// This code needs to be stateful to be able to be called multiple times by the caller
BOOTSTRAP_CHECK(bootstrap_net_init());
+ // eep-dev
+ if (handle->pre_init_ops != nullptr) {
+ BOOTSTRAP_PTR_FREE(handle->pre_init_ops);
+ handle->pre_init_ops = nullptr;
+ }
if (handle->pre_init_ops == nullptr) {
BOOTSTRAP_CALLOC(&handle->pre_init_ops, 1);
handle->pre_init_ops->get_unique_id = bootstrap_get_unique_id;
--
2.43.0

86 changes: 86 additions & 0 deletions tools/ep_kernels/elastic_ep/install_eep_libraries.sh
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to get nvshmem + deepep built in the vLLM image

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, can we do it as a follow up?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that was just a sidenote, not something for this PR

Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#!/bin/bash

set -ex

# Default workspace directory
WORKSPACE=$(pwd)/eep_kernels_workspace
INSTALL_NVSHMEM=true

# Parse command line arguments
while getopts "w:n" opt; do
case $opt in
w)
WORKSPACE="$OPTARG"
;;
n)
INSTALL_NVSHMEM=false
;;
\?)
echo "Invalid option: -$OPTARG" >&2
exit 1
;;
esac
done

if [ ! -d "$WORKSPACE" ]; then
mkdir -p $WORKSPACE
fi


# install dependencies if not installed
pip3 install cmake torch ninja

# build nvshmem
pushd $WORKSPACE
# Reset NVSHMEM build if requested
if [ "$INSTALL_NVSHMEM" = true ]; then
mkdir -p nvshmem_src
wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.2.5/source/nvshmem_src_3.2.5-1.txz
tar -xvf nvshmem_src_3.2.5-1.txz -C nvshmem_src --strip-components=1
pushd nvshmem_src
wget https://github.com/deepseek-ai/DeepEP/raw/main/third-party/nvshmem.patch
git init
git apply -vvv nvshmem.patch
git apply --reject --whitespace=fix ../../eep_nvshmem.patch
else
Comment on lines +38 to +45
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you upgrade to 3.3.9, since it has the performance improvements from the DeepEP patch? (BTW please double check performance as well, if you have the bandwidth to do so)

deepseek-ai/DeepEP#267 (comment)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks. Can we do it as a follow up?

Right now in this initial PR we only support PPLX. And the version 3.2.5-1 is consistent with current DeepEP installation script.

The DeepEP nvshmem.patch is applied now for a few reasons: 1) we will support DeepEP eventually; 2) it is consistent with current DeepEP installation script; 3) it removes the need for GDRCOPY, without the patch the nvshmem compilation fails

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, we only need our nvshmem patch that clears out all global communication states during nvshmem_finalize so we can create a new communication group with a new set of participant GPUs.

pushd nvshmem_src
fi

# assume CUDA_HOME is set correctly
if [ -z "$CUDA_HOME" ]; then
echo "CUDA_HOME is not set, please set it to your CUDA installation directory."
exit 1
fi

# disable all features except IBGDA
export NVSHMEM_IBGDA_SUPPORT=1

export NVSHMEM_SHMEM_SUPPORT=0
export NVSHMEM_UCX_SUPPORT=0
export NVSHMEM_USE_NCCL=0
export NVSHMEM_PMIX_SUPPORT=0
export NVSHMEM_TIMEOUT_DEVICE_POLLING=0
export NVSHMEM_USE_GDRCOPY=0
export NVSHMEM_IBRC_SUPPORT=0
export NVSHMEM_BUILD_TESTS=0
export NVSHMEM_BUILD_EXAMPLES=0
export NVSHMEM_MPI_SUPPORT=0
export NVSHMEM_BUILD_HYDRA_LAUNCHER=0
export NVSHMEM_BUILD_TXZ_PACKAGE=0
export NVSHMEM_TIMEOUT_DEVICE_POLLING=0

cmake -G Ninja -S . -B $WORKSPACE/nvshmem_build/ -DCMAKE_INSTALL_PREFIX=$WORKSPACE/nvshmem_install
cmake --build $WORKSPACE/nvshmem_build/ --target install

popd

export CMAKE_PREFIX_PATH=$WORKSPACE/nvshmem_install:$CMAKE_PREFIX_PATH

# build and install pplx, require pytorch installed
pushd $WORKSPACE
git clone https://github.com/ppl-ai/pplx-kernels
cd pplx-kernels
# see https://github.com/pypa/pip/issues/9955#issuecomment-838065925
# PIP_NO_BUILD_ISOLATION=0 disables build isolation
PIP_NO_BUILD_ISOLATION=0 TORCH_CUDA_ARCH_LIST=9.0a+PTX pip install . --no-deps -v

13 changes: 13 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2004,6 +2004,19 @@ def has_unfinished_dp(dp_group: "ProcessGroup",
aggregated_has_unfinished = bool(tensor.item())
return aggregated_has_unfinished

@staticmethod
def sync_kv_cache_memory_size(dp_group: "ProcessGroup",
kv_cache_memory: int) -> int:
if kv_cache_memory == -1:
kv_cache_memory = torch.iinfo(torch.int64).max
tensor = torch.tensor([kv_cache_memory],
dtype=torch.int64,
device="cpu")
# we cannot use broadcast for stateless dp group since it depends
# on global rank
torch.distributed.all_reduce(tensor, op=ReduceOp.MIN, group=dp_group)
return tensor.item()

def compute_hash(self):
"""
Provide a hash that uniquely identifies all the configs
Expand Down
Loading