Skip to content

[PERF] Symmetric memory allreduce #20759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/design/v1/multiprocessing.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ The `multiproc_xpu_executor` forces the use of `spawn`.

There are other miscellaneous places hard-coding the use of `spawn`:

- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L135>
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/distributed/device_communicators/all_reduce_utils.py#L135>
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/entrypoints/openai/api_server.py#L184>

Related PRs:
Expand Down
104 changes: 104 additions & 0 deletions tests/distributed/test_symm_mem_allreduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import random
import typing

import pytest
import ray
import torch
import torch.distributed as dist

from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.device_communicators.cuda_communicator import (
CudaCommunicator)
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
get_tp_group)
from vllm.platforms import current_platform

from ..utils import (ensure_model_parallel_initialized,
init_test_distributed_environment, multi_process_parallel)

torch.manual_seed(42)
random.seed(44)

test_size_elements = 4 * 1024 * 1024


@ray.remote(num_gpus=1, max_calls=1)
def symm_mem_allreduce(
monkeypatch: pytest.MonkeyPatch,
tp_size,
pp_size,
rank,
distributed_init_port,
):
with monkeypatch.context() as m:
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)

init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)
ensure_model_parallel_initialized(tp_size, pp_size)

dtype = torch.bfloat16

cuda_communicator = typing.cast(CudaCommunicator,
get_tp_group().device_communicator)
symm_mem_comm = cuda_communicator.symm_mem_comm
if symm_mem_comm is None or symm_mem_comm.disabled:
pytest.skip("SymmMemCommunicator is not available or disabled.")

inp_direct_symm_mem = torch.randint(1,
23, (test_size_elements, ),
dtype=dtype,
device=device)
if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem):
pytest.skip(
"SymmMemCommunicator isn't used for this world and input size."
)

original_inp_direct_symm_mem = inp_direct_symm_mem.clone()
out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem)
assert out_direct_symm_mem is not None

group = get_tensor_model_parallel_group().device_group
dist.all_reduce(original_inp_direct_symm_mem, group=group)
torch.testing.assert_close(out_direct_symm_mem,
original_inp_direct_symm_mem,
atol=2.5,
rtol=0.1)

# Test tensor_model_parallel_all_reduce which should use symm_mem
inp_tensor_parallel = torch.randint(-23,
1, (test_size_elements, ),
dtype=dtype,
device=device)
original_inp_tensor_parallel = inp_tensor_parallel.clone()
out_tensor_parallel = tensor_model_parallel_all_reduce(
inp_tensor_parallel)
dist.all_reduce(original_inp_tensor_parallel, group=group)
torch.testing.assert_close(out_tensor_parallel,
original_inp_tensor_parallel,
atol=2.5,
rtol=0.1)


@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="SymmMemAllreduce is only available for CUDA platforms.")
@pytest.mark.parametrize("tp_size", [2, 4])
@pytest.mark.parametrize("pipeline_parallel_size", [1])
@pytest.mark.parametrize("test_target", [symm_mem_allreduce])
def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
pipeline_parallel_size, test_target):
world_size = tp_size * pipeline_parallel_size
if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.")

# Enable SymmMemCommunicator
monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")

multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size,
test_target)
2 changes: 1 addition & 1 deletion tools/check_pickle_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
'vllm/distributed/utils.py',
'vllm/distributed/parallel_state.py',
'vllm/engine/multiprocessing/client.py',
'vllm/distributed/device_communicators/custom_all_reduce_utils.py',
'vllm/distributed/device_communicators/all_reduce_utils.py',
'vllm/distributed/device_communicators/shm_broadcast.py',
'vllm/engine/multiprocessing/engine.py',
'benchmarks/kernels/graph_machete_bench.py',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,40 @@

logger = init_logger(__name__)

MiB = 1024 * 1024

# Max size for each world size in case symmetric memory is available
# For different SM architectures
CUSTOM_ALL_REDUCE_MAX_SIZES = {
"9.0": {
2: 64 * MiB, # 64 MB
4: 1 * MiB, # 1 MB
6: MiB // 2, # 512 KB
8: MiB // 4, # 256 KB
},
"10.0": {
2: 2 * MiB, # 2 MB
4: 2 * MiB, # 2 MB
6: 2 * MiB, # 2 MB
8: MiB, # 1 MB
}
}

SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
"9.0": {
2: 64 * MiB, # 64 MB
4: 32 * MiB, # 32 MB
6: 64 * MiB, # 64 MB
8: 64 * MiB, # 64 MB
},
"10.0": {
2: 8 * MiB, # 8 MB
4: 32 * MiB, # 32 MB
6: 128 * MiB, # 128 MB
8: 128 * MiB, # 128 MB
}
}


def producer(batch_src: Sequence[int],
producer_queue,
Expand Down
15 changes: 15 additions & 0 deletions vllm/distributed/device_communicators/cuda_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def __init__(self,
PyNcclCommunicator)
from vllm.distributed.device_communicators.quick_all_reduce import (
QuickAllReduce)
from vllm.distributed.device_communicators.symm_mem import (
SymmMemCommunicator)

self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1:
Expand All @@ -54,6 +56,7 @@ def __init__(self,

self.ca_comm: Optional[CustomAllreduce] = None
self.qr_comm: Optional[QuickAllReduce] = None
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce(
Expand All @@ -69,6 +72,12 @@ def __init__(self,
# currently be an MI300 series.
self.qr_comm = QuickAllReduce(group=self.cpu_group,
device=self.device)
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
self.symm_mem_comm = SymmMemCommunicator(
group=self.cpu_group,
device=self.device,
)

if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
Expand Down Expand Up @@ -105,6 +114,12 @@ def all_reduce(self, input_):
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out
symm_mem_comm = self.symm_mem_comm
if symm_mem_comm is not None and not symm_mem_comm.disabled and \
symm_mem_comm.should_use_symm_mem(input_):
Comment on lines +118 to +119
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check not symm_mem_comm.disabled is redundant because should_use_symm_mem already performs this check. Removing the redundant check will make the code more concise.

Suggested change
if symm_mem_comm is not None and not symm_mem_comm.disabled and \
symm_mem_comm.should_use_symm_mem(input_):
if symm_mem_comm is not None and \
symm_mem_comm.should_use_symm_mem(input_):

out = symm_mem_comm.all_reduce(input_)
assert out is not None
return out
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)
Expand Down
11 changes: 9 additions & 2 deletions vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check)
from vllm.distributed.device_communicators.all_reduce_utils import (
CUSTOM_ALL_REDUCE_MAX_SIZES, gpu_p2p_access_check)
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger
from vllm.platforms import current_platform
Expand Down Expand Up @@ -109,6 +109,13 @@ def __init__(self,
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
device_capability = current_platform.get_device_capability(
).as_version_str()
if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES):
max_size = min(
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size],
max_size)

cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices:
Expand Down
115 changes: 115 additions & 0 deletions vllm/distributed/device_communicators/symm_mem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup

from vllm.distributed.device_communicators.all_reduce_utils import (
SYMM_MEM_ALL_REDUCE_MAX_SIZES)
from vllm.logger import init_logger
from vllm.platforms import current_platform

try:
import torch.distributed._symmetric_memory as torch_symm_mem

symm_mem_available = True
except ImportError:
symm_mem_available = False

logger = init_logger(__name__)


class SymmMemCommunicator:

# World sizes where multi-mem all-reduce performs the best
_WORLD_SIZES_MULTIMEM = {
"9.0": [4, 6, 8],
"10.0": [6, 8],
}

def __init__(self, group: ProcessGroup, device: Union[int, str,
torch.device]):
self.disabled = True

if not symm_mem_available:
return

if not current_platform.is_cuda():
logger.warning("SymmMemCommunicator: symmetric "
"memory is not available.")
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
torch.cuda.set_device(device)
self.dtype = torch.bfloat16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The SymmMemCommunicator is hardcoded to use torch.bfloat16, limiting its use with models using other dtypes. Consider initializing buffers based on the input tensor's dtype during the first all_reduce call to increase flexibility.

self.device = device
self.group = group
self.world_size = dist.get_world_size(self.group)
self.device_capability = current_platform.get_device_capability(
).as_version_str()

if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES \
or self.device_capability not in self._WORLD_SIZES_MULTIMEM:
logger.warning(
"SymmMemCommunicator: Device capability %s not supported, "
"communicator is not available.",
self.device_capability,
)
return
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[
self.device_capability]:
logger.warning(
"SymmMemCommunicator: World size %d not supported, "
"communicator is not available.",
self.world_size,
)
return
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
self.world_size]
self.buffer = torch_symm_mem.empty(
self.max_size // self.dtype.itemsize,
device=self.device,
dtype=self.dtype,
)
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
if handle.multicast_ptr == 0:
logger.warning("SymmMemCommunicator: symmetric memory "
"multicast operations are not supported.")
return
self.disabled = False

def should_use_symm_mem(self, inp: torch.Tensor):
if self.disabled:
return False
if inp.dtype != self.dtype:
return False
inp_size = inp.numel() * inp.element_size()
if inp_size % 4 != 0:
return False
return inp_size < self.max_size

def all_reduce(
self,
inp: torch.Tensor,
*,
out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
if not self.should_use_symm_mem(inp):
return None
if out is None:
out = torch.empty_like(inp)
self.buffer[:inp.numel()].copy_(inp.view(-1))
if self.world_size in self._WORLD_SIZES_MULTIMEM[
self.device_capability]:
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
"sum",
self.group.group_name)
else:
torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()],
"sum",
self.group.group_name)
out.copy_(self.buffer[:inp.numel()].view(out.shape))
return out
6 changes: 5 additions & 1 deletion vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -620,7 +621,7 @@ def get_vllm_port() -> Optional[int]:
("1", "true")),

# By default, vLLM will check the peer-to-peer capability itself,
# in case of broken drivers. See https://github.com/vllm-project/vllm/blob/a9b15c606fea67a072416ea0ea115261a2756058/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L101-L108 for details. # noqa
# in case of broken drivers. See https://github.com/vllm-project/vllm/blob/a9b15c606fea67a072416ea0ea115261a2756058/vllm/distributed/device_communicators/all_reduce_utils.py#L101-L108 for details. # noqa
# If this env var is set to 1, vLLM will skip the peer-to-peer check,
# and trust the driver's peer-to-peer capability report.
"VLLM_SKIP_P2P_CHECK":
Expand Down Expand Up @@ -964,6 +965,9 @@ def get_vllm_port() -> Optional[int]:
# If set to 1, use the TRTLLM Decode Attention backend in flashinfer.
"VLLM_USE_TRTLLM_DECODE_ATTENTION":
lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None),
# Whether to use pytorch symmetric memory for allreduce
"VLLM_ALLREDUCE_USE_SYMM_MEM":
lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))),
}

# --8<-- [end:env-vars-definition]
Expand Down