Skip to content

Commit be1c653

Browse files
committed
[CI] fix
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
1 parent 89a033e commit be1c653

File tree

4 files changed

+60
-39
lines changed

4 files changed

+60
-39
lines changed

vllm_ascend/distributed/parallel_state.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@
44
from vllm.config import ParallelConfig
55
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
66
init_model_parallel_group)
7+
from vllm.utils import logger
78

89
import vllm_ascend.envs as envs_ascend
910
from vllm_ascend.ascend_config import get_ascend_config
1011

11-
from vllm.utils import logger
1212
# Currently, mc2 op need their own group coordinator.
1313
_MC2: Optional[GroupCoordinator] = None
1414
_MLP_TP: Optional[GroupCoordinator] = None
1515

1616
_LMTP: Optional[GroupCoordinator] = None
1717
_EMTP: Optional[GroupCoordinator] = None
1818

19+
1920
def get_mc2_group() -> GroupCoordinator:
2021
assert _MC2 is not None, ("mc2 group is not initialized")
2122
return _MC2
@@ -26,10 +27,12 @@ def get_lmhead_tp_group() -> GroupCoordinator:
2627
"lm head tensor parallel group is not initialized")
2728
return _LMTP
2829

30+
2931
def get_emtp_group() -> GroupCoordinator:
3032
assert _EMTP is not None, ("emtp group is not initialized")
3133
return _EMTP
3234

35+
3336
def get_mlp_tp_group() -> GroupCoordinator:
3437
assert _MLP_TP is not None, ("mlp group is not initialized")
3538
return _MLP_TP
@@ -99,8 +102,8 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
99102
if embedding_tensor_parallel_size is not None:
100103
group_ranks = []
101104
global _EMTP
102-
num_embedding_tensor_parallel_groups: int = (world_size //
103-
embedding_tensor_parallel_size)
105+
num_embedding_tensor_parallel_groups: int = (
106+
world_size // embedding_tensor_parallel_size)
104107
for i in range(num_embedding_tensor_parallel_groups):
105108
ranks = list(
106109
range(i * embedding_tensor_parallel_size,
@@ -110,7 +113,10 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
110113
get_world_group().local_rank,
111114
backend,
112115
group_name="emtp")
113-
logger.info(f"Successfully established embedding communication parallel group with size {embedding_tensor_parallel_size}")
116+
logger.info(
117+
f"Successfully established embedding communication parallel group with size {embedding_tensor_parallel_size}"
118+
)
119+
114120

115121
def get_mlp_tensor_model_parallel_world_size():
116122
"""Return world size for the tensor model parallel group."""

vllm_ascend/ops/vocab_parallel_embedding.py

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,20 @@
2121
from torch import nn
2222
from torch.nn.parameter import Parameter
2323
from vllm.distributed import divide, tensor_model_parallel_all_reduce
24-
from vllm.distributed.parallel_state import get_tp_group
25-
import torch.distributed as dist
24+
from vllm.distributed.parallel_state import get_dp_group, get_tp_group
25+
from vllm.forward_context import get_forward_context
2626
from vllm.model_executor.layers.logits_processor import LogitsProcessor
2727
from vllm.model_executor.layers.quantization.base_config import (
2828
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
2929
from vllm.model_executor.layers.vocab_parallel_embedding import (
3030
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, UnquantizedEmbeddingMethod,
3131
VocabParallelEmbedding, pad_vocab_size)
3232
from vllm.model_executor.utils import set_weight_attrs
33-
from vllm.distributed.parallel_state import get_dp_group
34-
from vllm.forward_context import get_forward_context
3533
from vllm.utils import logger
3634

37-
from vllm_ascend.distributed.parallel_state import get_lmhead_tp_group, get_emtp_group
38-
from vllm_ascend.utils import lmhead_tp_enable, embedding_tp_enable
35+
from vllm_ascend.distributed.parallel_state import (get_emtp_group,
36+
get_lmhead_tp_group)
37+
from vllm_ascend.utils import embedding_tp_enable, lmhead_tp_enable
3938

4039

4140
class AscendVocabParallelEmbedding(VocabParallelEmbedding):
@@ -150,30 +149,41 @@ def _get_masked_input_and_mask(
150149
input_ = vocab_mask * (input_ - valid_offset)
151150
return input_, ~vocab_mask
152151

153-
def _get_local_batch_slice(self, tensor: torch.Tensor,
154-
batch_sizes: list,
155-
local_batch_size: int,
156-
rank: int) -> torch.Tensor:
152+
def _get_local_batch_slice(self, tensor: torch.Tensor, batch_sizes: list,
153+
local_batch_size: int,
154+
rank: int) -> torch.Tensor:
157155
end_idx = batch_sizes[rank]
158156
start_idx = end_idx - local_batch_size
159157
return tensor[start_idx:end_idx]
160-
158+
161159
def forward(self, input_):
162160
if embedding_tp_enable():
163-
logger.info(f"rank:{get_dp_group().rank_in_group} embedding_tp_enable")
161+
logger.info(
162+
f"rank:{get_dp_group().rank_in_group} embedding_tp_enable")
164163
return self._forward_embed_tp(input_)
165164
else:
166165
return self._forward_normal(input_)
167-
166+
168167
def _forward_embed_tp(self, input_):
169-
cu_tokens_across_dp_cpu = get_forward_context().dp_metadata.cu_tokens_across_dp_cpu
170-
global_dp_batch_size = torch.diff(cu_tokens_across_dp_cpu, prepend=cu_tokens_across_dp_cpu.new_zeros(1))
171-
logger.info(f"debug input_: {input_.shape} \n global_dp_batch_size: {global_dp_batch_size}\n ")
172-
lmhead_group_batch_size = [global_dp_batch_size[x] for x in get_lmhead_tp_group().ranks]
168+
cu_tokens_across_dp_cpu = get_forward_context(
169+
).dp_metadata.cu_tokens_across_dp_cpu
170+
global_dp_batch_size = torch.diff(
171+
cu_tokens_across_dp_cpu,
172+
prepend=cu_tokens_across_dp_cpu.new_zeros(1))
173+
logger.info(
174+
f"debug input_: {input_.shape} \n global_dp_batch_size: {global_dp_batch_size}\n "
175+
)
176+
lmhead_group_batch_size = [
177+
global_dp_batch_size[x] for x in get_lmhead_tp_group().ranks
178+
]
173179
local_batch_size = input_.size(0)
174-
gathered_input = [torch.empty(batch_size, dtype=input_.dtype, device='npu') for batch_size in lmhead_group_batch_size]
175-
torch.distributed.all_gather(
176-
gathered_input, input_, group=get_lmhead_tp_group().device_group)
180+
gathered_input = [
181+
torch.empty(batch_size, dtype=input_.dtype, device='npu')
182+
for batch_size in lmhead_group_batch_size
183+
]
184+
torch.distributed.all_gather(gathered_input,
185+
input_,
186+
group=get_lmhead_tp_group().device_group)
177187
complete_input = torch.cat(gathered_input, dim=0)
178188
masked_input, input_mask = self._get_masked_input_and_mask(
179189
complete_input, self.shard_indices.org_vocab_start_index,
@@ -182,20 +192,18 @@ def _forward_embed_tp(self, input_):
182192
self.shard_indices.added_vocab_start_index,
183193
self.shard_indices.added_vocab_end_index)
184194
logger.info(f"all_gather_down complete_input: {complete_input.shape}")
185-
195+
186196
output = self.quant_method.embedding(self, masked_input.long())
187197
output.masked_fill_(input_mask.unsqueeze(-1), 0)
188198
output = tensor_model_parallel_all_reduce(output)
189199
# output = output[lmhead_group_batch_size[get_lmhead_tp_group().rank_in_group]-local_batch_size :lmhead_group_batch_size[get_lmhead_tp_group().rank_in_group]]
190200
# Extract the local batch portion from the gathered output
191201
lmhead_tp_group = get_lmhead_tp_group()
192-
output = self._get_local_batch_slice(
193-
output,
194-
lmhead_group_batch_size,
195-
local_batch_size,
196-
lmhead_tp_group.rank_in_group
197-
)
198-
logger.info(f"rank:{get_dp_group().rank_in_group} output: {output.shape}")
202+
output = self._get_local_batch_slice(output, lmhead_group_batch_size,
203+
local_batch_size,
204+
lmhead_tp_group.rank_in_group)
205+
logger.info(
206+
f"rank:{get_dp_group().rank_in_group} output: {output.shape}")
199207
return output
200208

201209
def _forward_normal(self, input_):
@@ -209,16 +217,23 @@ def _forward_normal(self, input_):
209217
self.shard_indices.added_vocab_end_index)
210218
else:
211219
masked_input = input_
212-
logger.info(f"rank:{get_dp_group().rank_in_group} masked_input:{masked_input.shape}")
220+
logger.info(
221+
f"rank:{get_dp_group().rank_in_group} masked_input:{masked_input.shape}"
222+
)
213223
# Get the embeddings.
214-
output_parallel = self.quant_method.embedding(self, masked_input.long())
215-
logger.info(f"rank:{get_dp_group().rank_in_group} output_parallel:{output_parallel.shape}")
224+
output_parallel = self.quant_method.embedding(self,
225+
masked_input.long())
226+
logger.info(
227+
f"rank:{get_dp_group().rank_in_group} output_parallel:{output_parallel.shape}"
228+
)
216229
# Mask the output embedding.
217230
if self.tp_size > 1:
218231
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
219232
# Reduce across all the model parallel GPUs.
220233
output = tensor_model_parallel_all_reduce(output_parallel)
221-
logger.info(f"rank:{get_dp_group().rank_in_group} forward_normal output:{output.shape}")
234+
logger.info(
235+
f"rank:{get_dp_group().rank_in_group} forward_normal output:{output.shape}"
236+
)
222237
return output
223238

224239

vllm_ascend/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,5 +557,6 @@ def get_ascend_soc_version():
557557
def lmhead_tp_enable() -> bool:
558558
return get_ascend_config().lmhead_tensor_parallel_size is not None
559559

560+
560561
def embedding_tp_enable() -> bool:
561-
return get_ascend_config().embedding_tensor_parallel_size is not None
562+
return get_ascend_config().embedding_tensor_parallel_size is not None

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,7 @@
9090
from vllm_ascend.torchair.torchair_mla import AscendMLATorchairMetadata
9191
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
9292
ProfileExecuteDuration, is_310p,
93-
lmhead_tp_enable, vllm_version_is,
94-
embedding_tp_enable)
93+
lmhead_tp_enable, vllm_version_is)
9594
from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
9695
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
9796
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch

0 commit comments

Comments
 (0)