Skip to content

Commit 782148d

Browse files
committed
feat: introduce vLLM compatibility layer for improved model execution
- Added a new module `vllm_compat.py` to manage vLLM-specific configurations and group coordination for tensor parallelism. - Integrated vLLM context management into the `ModelRunnerBase` to ensure proper configuration during model loading. - Updated linear layer operations to utilize the new vLLM group coordination, enhancing distributed training capabilities. - Implemented graph capture support in the multi-block model runner, optimizing CUDA stream management for performance.
1 parent 313883e commit 782148d

4 files changed

Lines changed: 131 additions & 43 deletions

File tree

diffulex/engine/model_runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from diffulex.model import AutoModelForDiffusionLM
1919
from diffulex.engine.strategy_registry import DiffulexStrategyRegistry
2020
from diffulex.logger import get_logger
21+
from diffulex.vllm_compat import reset_vllm_compat_state, vllm_current_config
2122

2223

2324
logger = get_logger(__name__)
@@ -100,7 +101,8 @@ def __init__(self, config: Config, rank: int, event: Event | list[Event]):
100101
torch.set_default_dtype(self.default_dtype)
101102
torch.set_default_device(f"cuda:{device_id}")
102103

103-
self.model = self.load_model(config)
104+
with vllm_current_config(config):
105+
self.model = self.load_model(config)
104106
self.sampler = self.load_sampler(config)
105107
self.allocate_kv_cache()
106108
self.warmup_model()
@@ -146,6 +148,7 @@ def exit(self):
146148
dist.destroy_process_group()
147149
except Exception:
148150
logger.debug("Failed to destroy process group on rank %s.", self.rank, exc_info=True)
151+
reset_vllm_compat_state()
149152
reset_parallel_state()
150153

151154
def start_worker_loop(self):

diffulex/layer/linear.py

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,49 +6,11 @@
66
import torch.distributed as dist
77

88
from diffulex.distributed.parallel_state import fetch_parallel_state
9-
10-
_VLLM_TP_GROUP = None
11-
_VLLM_TP_GROUP_FAILED = False
12-
13-
14-
def _tp_group_ranks_for_vllm() -> list[list[int]]:
15-
state = fetch_parallel_state()
16-
tp_size = int(state.tp_size)
17-
if tp_size <= 1:
18-
return [[int(state.global_rank)]]
19-
dp_size = int(state.dp_size)
20-
return [
21-
list(range(dp_rank * tp_size, (dp_rank + 1) * tp_size))
22-
for dp_rank in range(dp_size)
23-
]
24-
25-
26-
def _get_vllm_tp_group():
27-
global _VLLM_TP_GROUP, _VLLM_TP_GROUP_FAILED
28-
if _VLLM_TP_GROUP is not None:
29-
return _VLLM_TP_GROUP
30-
if _VLLM_TP_GROUP_FAILED:
31-
return None
32-
33-
try:
34-
from vllm.distributed.parallel_state import GroupCoordinator, set_custom_all_reduce
35-
36-
set_custom_all_reduce(True)
37-
_VLLM_TP_GROUP = GroupCoordinator(
38-
group_ranks=_tp_group_ranks_for_vllm(),
39-
local_rank=int(torch.cuda.current_device()),
40-
torch_distributed_backend=dist.get_backend(),
41-
use_device_communicator=True,
42-
group_name="tp",
43-
)
44-
return _VLLM_TP_GROUP
45-
except Exception:
46-
_VLLM_TP_GROUP_FAILED = True
47-
return None
9+
from diffulex.vllm_compat import get_vllm_tp_group
4810

4911

5012
def tp_all_reduce(x: torch.Tensor, group) -> torch.Tensor:
51-
vllm_tp_group = _get_vllm_tp_group()
13+
vllm_tp_group = get_vllm_tp_group()
5214
if vllm_tp_group is not None:
5315
return vllm_tp_group.all_reduce(x)
5416
dist.all_reduce(x, group=group)

diffulex/strategy_template/multi_block/engine/model_runner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from diffulex.engine.status import DllmReqStatus
1818
from diffulex.engine.model_runner import ModelRunnerBase
1919
from diffulex.logger import get_logger
20+
from diffulex.vllm_compat import vllm_graph_capture
2021

2122
logger = get_logger(__name__)
2223

@@ -147,8 +148,9 @@ def run_once() -> None:
147148

148149
torch.cuda.synchronize()
149150
self._graph_capture_barrier()
150-
with torch.cuda.graph(graph, pool=pool, stream=stream):
151-
run_once()
151+
with vllm_graph_capture(stream, pool) as capture_stream:
152+
with torch.cuda.graph(graph, pool=pool, stream=capture_stream):
153+
run_once()
152154
stream.synchronize()
153155
return graph
154156

diffulex/vllm_compat.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from __future__ import annotations
2+
3+
from contextlib import contextmanager, nullcontext
4+
from typing import Iterator
5+
6+
import torch
7+
import torch.distributed as dist
8+
9+
from diffulex.config import Config
10+
from diffulex.distributed.parallel_state import fetch_parallel_state
11+
from diffulex.logger import get_logger
12+
13+
14+
logger = get_logger(__name__)
15+
16+
_VLLM_TP_GROUP = None
17+
_VLLM_TP_GROUP_FAILED = False
18+
19+
20+
def _tp_group_ranks_for_vllm() -> list[list[int]]:
21+
state = fetch_parallel_state()
22+
tp_size = int(state.tp_size)
23+
if tp_size <= 1:
24+
return [[int(state.global_rank)]]
25+
dp_size = int(state.dp_size)
26+
return [
27+
list(range(dp_rank * tp_size, (dp_rank + 1) * tp_size))
28+
for dp_rank in range(dp_size)
29+
]
30+
31+
32+
def get_vllm_tp_group():
33+
"""Return a vLLM GroupCoordinator matching Diffulex TP ranks.
34+
35+
dInfer routes custom all-reduce and CUDA graph capture through the vLLM /
36+
sglang group coordinator. Diffulex owns process-group initialization, so we
37+
build the coordinator on top of the already initialized torch distributed
38+
group and fall back silently when vLLM is unavailable.
39+
"""
40+
global _VLLM_TP_GROUP, _VLLM_TP_GROUP_FAILED
41+
if _VLLM_TP_GROUP is not None:
42+
return _VLLM_TP_GROUP
43+
if _VLLM_TP_GROUP_FAILED:
44+
return None
45+
46+
try:
47+
from vllm.distributed.parallel_state import GroupCoordinator, set_custom_all_reduce
48+
49+
set_custom_all_reduce(True)
50+
_VLLM_TP_GROUP = GroupCoordinator(
51+
group_ranks=_tp_group_ranks_for_vllm(),
52+
local_rank=int(torch.cuda.current_device()),
53+
torch_distributed_backend=dist.get_backend(),
54+
use_device_communicator=True,
55+
group_name="tp",
56+
)
57+
return _VLLM_TP_GROUP
58+
except Exception:
59+
_VLLM_TP_GROUP_FAILED = True
60+
logger.debug("Failed to initialize vLLM TP coordinator.", exc_info=True)
61+
return None
62+
63+
64+
@contextmanager
65+
def vllm_current_config(config: Config) -> Iterator[None]:
66+
"""Temporarily install a minimal vLLM config during module construction."""
67+
try:
68+
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
69+
70+
parallel_config = ParallelConfig(
71+
tensor_parallel_size=int(config.tensor_parallel_size),
72+
pipeline_parallel_size=1,
73+
data_parallel_size=int(config.data_parallel_size),
74+
enable_expert_parallel=bool(int(config.expert_parallel_size) > 1),
75+
disable_custom_all_reduce=False,
76+
distributed_timeout_seconds=int(config.distributed_timeout_seconds),
77+
)
78+
with set_current_vllm_config(VllmConfig(parallel_config=parallel_config)):
79+
yield
80+
except Exception:
81+
logger.debug("Using Diffulex model init without vLLM current config.", exc_info=True)
82+
yield
83+
84+
85+
@contextmanager
86+
def vllm_graph_capture(stream: torch.cuda.Stream, pool) -> Iterator[torch.cuda.Stream]:
87+
"""Enter vLLM graph-capture side contexts when available."""
88+
try:
89+
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
90+
91+
set_graph_pool_id(pool)
92+
except Exception:
93+
logger.debug("Failed to set vLLM graph pool id.", exc_info=True)
94+
95+
group = get_vllm_tp_group()
96+
if group is None or not hasattr(group, "graph_capture"):
97+
with torch.cuda.stream(stream):
98+
yield stream
99+
return
100+
101+
try:
102+
context = getattr(group, "graph_capture")()
103+
with context as graph_context:
104+
yield getattr(graph_context, "stream", stream)
105+
except Exception:
106+
logger.debug("vLLM graph_capture context failed; using raw CUDA stream.", exc_info=True)
107+
with torch.cuda.stream(stream):
108+
yield stream
109+
110+
111+
def reset_vllm_compat_state() -> None:
112+
global _VLLM_TP_GROUP, _VLLM_TP_GROUP_FAILED
113+
group = _VLLM_TP_GROUP
114+
_VLLM_TP_GROUP = None
115+
_VLLM_TP_GROUP_FAILED = False
116+
if group is not None and hasattr(group, "destroy"):
117+
try:
118+
group.destroy()
119+
except Exception:
120+
logger.debug("Failed to destroy vLLM TP coordinator.", exc_info=True)
121+

0 commit comments

Comments
 (0)