Skip to content

Commit c4086d7

Browse files
committed
rebase + fix some tests
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 484fc83 commit c4086d7

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def make_tensors(config: BatchedMMConfig):
3131
A = torch.randn(
3232
(config.num_experts, config.max_tokens_per_expert, config.K),
3333
device="cuda",
34-
dtype=config.dtype)
34+
dtype=config.dtype) / 10
3535
B = torch.randn((config.num_experts, config.N, config.K),
3636
device="cuda",
3737
dtype=config.dtype)

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def flatten_tp_across_dp(dp_rank: int):
155155
and vllm_parallel_config.enable_expert_parallel)
156156

157157
dp_size = dp_size_
158-
dp_rank = get_dp_group().rank_in_group
158+
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
159159
tp_size, tp_rank = flatten_tp_across_dp(dp_rank)
160160

161161
if not use_ep:
@@ -299,6 +299,7 @@ def get_or_create(self, **kwargs):
299299
# TODO (varun): Add support to switch to intranode
300300
# when all communications are within the same
301301
# node.
302+
logger.debug("Create AllToAll %s", kwargs)
302303
instance = pplx.AllToAll.internode(**kwargs)
303304
self._cache[key] = instance
304305
return instance

0 commit comments

Comments
 (0)