Skip to content

Commit 9168047

Browse files
committed
[main] flashcomm_v1 & RS & AG optim in Dense Model
Signed-off-by: rjg-lyh <1318825571@qq.com>
1 parent dfc7eb3 commit 9168047

File tree

8 files changed

+453
-2
lines changed

8 files changed

+453
-2
lines changed

vllm_ascend/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,10 @@ def register():
2323

2424

2525
def register_model():
26+
import vllm.envs as envs
27+
import vllm_ascend.envs as envs_ascend
2628
from .models import register_model
29+
if envs.VLLM_USE_V1 and \
30+
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM != 0:
31+
import vllm_ascend.patch.platform.patch_main.patch_decorator
2732
register_model()

vllm_ascend/ascend_forward_context.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,35 @@ def set_ascend_forward_context(
110110
# NOTE: This cannot be set using set_forward_context
111111
# due to multiple warmups before actual capturing
112112
forward_context.capturing = False
113+
114+
# set this for rope forward_oot using
115+
forward_context.is_first_layer = True
116+
117+
# set for flashcomm_v1
118+
flashcomm_v1_enabled = False
119+
matmul_rs_enabled = False
120+
ag_matmal_enabled = False
121+
pad_size = 0
122+
from vllm_ascend.attention.attention_v1 import AscendAttentionState
123+
if envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM == 1 and \
124+
attn_metadata is not None and \
125+
attn_metadata.attn_state != AscendAttentionState.DecodeOnly:
126+
flashcomm_v1_enabled = True
127+
if flashcomm_v1_enabled and \
128+
envs_ascend.VLLM_ASCEND_ENABLE_LCOC_MATMUL_RS == 1:
129+
matmul_rs_enabled = True
130+
if flashcomm_v1_enabled and \
131+
envs_ascend.VLLM_ASCEND_ENABLE_LCOC_AG_MATMUL == 1:
132+
ag_matmal_enabled = True
133+
if flashcomm_v1_enabled:
134+
# num_tokens = hidden_states.size(0)
135+
tp_world_size = get_tensor_model_parallel_world_size()
136+
pad_size = (tp_world_size -
137+
(num_tokens % tp_world_size)) % tp_world_size
138+
forward_context.pad_size = pad_size
139+
forward_context.flashcomm_v1_enabled = flashcomm_v1_enabled
140+
forward_context.matmul_rs_enabled = matmul_rs_enabled
141+
forward_context.ag_matmal_enabled = ag_matmal_enabled
113142

114143
if num_tokens is None and attn_metadata is not None:
115144
num_tokens = attn_metadata.num_actual_tokens

vllm_ascend/envs.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,23 @@
136136
# this feature is supported in A2, and eager mode will get better performance.
137137
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE":
138138
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))),
139+
# FlashComm optimization: Enable v1 and v2 by setting this flag to 1 or 2 respectively
140+
"VLLM_ASCEND_ENABLE_FLASHCOMM":
141+
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0')),
142+
# LcocMatmulReduceScatter optimization
143+
"VLLM_ASCEND_ENABLE_LCOC_MATMUL_RS":
144+
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_LCOC_MATMUL_RS", '0')),
145+
# LcocAllGatherMatmul optimization
146+
"VLLM_ASCEND_ENABLE_LCOC_AG_MATMUL":
147+
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_LCOC_AG_MATMUL", '0')),
139148
# Whether to enable the alltoall_seq flag, this provides a basic framework on the basis of alltoall for easy expansion.
140149
# 0: default, normal init.
141150
# 1: enable moe all2all seq.
142151
"VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ":
143152
lambda: bool(int(os.getenv('VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ', '0'))),
153+
# Whether to enable dense model and general optimizations for better performance.
154+
"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE":
155+
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE", '0'))),
144156
# Whether to enable mlp optimize when tensor parallel is enabled.
145157
# this feature in eager mode will get better performance.
146158
"VLLM_ASCEND_ENABLE_MLP_OPTIMIZE":

vllm_ascend/ops/layernorm.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
from typing import Optional, Tuple, Union
1919

2020
import torch
21+
import torch.nn.functional as F
22+
from vllm.distributed import (get_tensor_model_parallel_rank,
23+
get_tensor_model_parallel_world_size)
24+
from vllm.forward_context import get_forward_context
2125
from vllm.model_executor.layers.layernorm import RMSNorm
2226

2327

@@ -44,6 +48,16 @@ def forward(
4448
import torch_npu
4549

4650
if residual is not None:
51+
forward_context = get_forward_context()
52+
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
53+
if x.size(0) != residual.size(0) and \
54+
flashcomm_v1_enabled:
55+
pad_size = forward_context.pad_size
56+
tp_size = get_tensor_model_parallel_world_size()
57+
tp_rank = get_tensor_model_parallel_rank()
58+
from vllm_ascend.utils import maybe_pad_and_chunk_tensor
59+
residual = maybe_pad_and_chunk_tensor(residual, pad_size, tp_size, tp_rank, 0)
60+
assert x.size(0) == residual.size(0)
4761
x, _, residual = torch_npu.npu_add_rms_norm_quant(
4862
x,
4963
residual,
@@ -69,6 +83,16 @@ def forward_oot(
6983

7084
from vllm_ascend.utils import is_310p
7185
if residual is not None:
86+
forward_context = get_forward_context()
87+
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
88+
if x.size(0) != residual.size(0) and \
89+
flashcomm_v1_enabled:
90+
pad_size = forward_context.pad_size
91+
tp_size = get_tensor_model_parallel_world_size()
92+
tp_rank = get_tensor_model_parallel_rank()
93+
from vllm_ascend.utils import maybe_pad_and_chunk_tensor
94+
residual = maybe_pad_and_chunk_tensor(residual, pad_size, tp_size, tp_rank, 0)
95+
assert x.size(0) == residual.size(0)
7296
if is_310p():
7397
orig_dtype = residual.dtype
7498
x = x + residual.to(x.dtype)

vllm_ascend/ops/linear.py

Lines changed: 153 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
from typing import Optional, Union
1919

2020
import torch
21+
import torch.nn.functional as F
2122
from torch.nn.parameter import Parameter
23+
import torch_npu
2224
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
2325
get_tensor_model_parallel_world_size,
2426
split_tensor_along_last_dim,
@@ -28,14 +30,19 @@
2830
ColumnParallelLinear,
2931
LinearBase,
3032
MergedColumnParallelLinear,
31-
RowParallelLinear)
33+
QKVParallelLinear,
34+
RowParallelLinear,
35+
UnquantizedLinearMethod)
3236
from vllm.model_executor.layers.quantization.base_config import \
3337
QuantizationConfig
3438
from vllm.model_executor.utils import set_weight_attrs
39+
from vllm.forward_context import get_forward_context
3540

3641
from vllm_ascend.distributed.parallel_state import (
3742
get_mlp_tensor_model_parallel_rank,
3843
get_mlp_tensor_model_parallel_world_size, get_mlp_tp_group)
44+
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod, quant_per_tensor
45+
from vllm_ascend.utils import all_gather_and_maybe_unpad, maybe_pad_and_reduce_scatter
3946

4047

4148
class AscendMlpColumnParallelLinear(ColumnParallelLinear):
@@ -307,3 +314,148 @@ def forward(
307314
if not self.return_bias:
308315
return output
309316
return output, output_bias
317+
318+
319+
class AscendDenseMergedColumnParallelLinear(MergedColumnParallelLinear):
320+
"""Linear layer with column parallelism.
321+
322+
Implemented multiple optimization projects for dense models, such as FlashComm and
323+
communication-computation fusion.
324+
"""
325+
326+
def forward(
327+
self,
328+
input_: torch.Tensor
329+
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
330+
bias = self.bias if not self.skip_bias_add else None
331+
332+
# Matrix multiply.
333+
assert self.quant_method is not None
334+
forward_context = get_forward_context()
335+
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
336+
ag_matmal_enabled = forward_context.ag_matmal_enabled
337+
pad_size = forward_context.pad_size
338+
if not flashcomm_v1_enabled:
339+
output_parallel = self.quant_method.apply(self, input_, bias)
340+
# fp or bf
341+
elif ag_matmal_enabled and isinstance(self.quant_method, UnquantizedLinearMethod):
342+
raise NotImplementedError("AllGather_MatMul with UnquantizedLinearMethod is not implemented yet.")
343+
# w8a8 quant
344+
elif ag_matmal_enabled and isinstance(self.quant_method.quant_method, AscendW8A8LinearMethod):
345+
raise NotImplementedError("AllGather_MatMul with AscendW8A8LinearMethod is not implemented yet.")
346+
else:
347+
input_ = all_gather_and_maybe_unpad(input_, pad_size, 0)
348+
output_parallel = self.quant_method.apply(self, input_, bias)
349+
350+
if self.gather_output:
351+
# All-gather across the partitions.
352+
output = tensor_model_parallel_all_gather(output_parallel)
353+
else:
354+
output = output_parallel
355+
output_bias = self.bias if self.skip_bias_add else None
356+
if not self.return_bias:
357+
return output
358+
return output, output_bias
359+
360+
361+
class AscendDenseQKVParallelLinear(QKVParallelLinear):
362+
"""Linear layer with column parallelism.
363+
364+
Implemented multiple optimization projects for dense models, such as FlashComm and
365+
communication-computation fusion.
366+
"""
367+
368+
def forward(
369+
self,
370+
input_: torch.Tensor
371+
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
372+
bias = self.bias if not self.skip_bias_add else None
373+
374+
# Matrix multiply.
375+
assert self.quant_method is not None
376+
forward_context = get_forward_context()
377+
layer_num = self.prefix.split('.')[2]
378+
if layer_num == '0':
379+
flashcomm_v1_enabled = False
380+
else:
381+
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
382+
ag_matmal_enabled = forward_context.ag_matmal_enabled
383+
pad_size = forward_context.pad_size
384+
if not flashcomm_v1_enabled:
385+
output_parallel = self.quant_method.apply(self, input_, bias)
386+
# fp or bf
387+
elif ag_matmal_enabled and isinstance(self.quant_method, UnquantizedLinearMethod):
388+
raise NotImplementedError("AllGather_MatMul with UnquantizedLinearMethod is not implemented yet.")
389+
# w8a8 quant
390+
elif ag_matmal_enabled and isinstance(self.quant_method.quant_method, AscendW8A8LinearMethod):
391+
raise NotImplementedError("AllGather_MatMul with AscendW8A8LinearMethod is not implemented yet.")
392+
else:
393+
input_ = all_gather_and_maybe_unpad(input_, pad_size, 0)
394+
output_parallel = self.quant_method.apply(self, input_, bias)
395+
396+
if self.gather_output:
397+
# All-gather across the partitions.
398+
output = tensor_model_parallel_all_gather(output_parallel)
399+
else:
400+
output = output_parallel
401+
output_bias = self.bias if self.skip_bias_add else None
402+
if not self.return_bias:
403+
return output
404+
return output, output_bias
405+
406+
407+
class AscendDenseRowParallelLinear(RowParallelLinear):
408+
"""Linear layer with row parallelism.
409+
410+
Implemented multiple optimization projects for dense models, such as FlashComm and
411+
communication-computation fusion.
412+
"""
413+
414+
def forward(
415+
self,
416+
input_: torch.Tensor
417+
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
418+
tp_rank = get_tensor_model_parallel_rank()
419+
forward_context = get_forward_context()
420+
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
421+
matmul_rs_enabled = forward_context.matmul_rs_enabled
422+
pad_size = forward_context.pad_size
423+
if self.input_is_parallel:
424+
input_parallel = input_
425+
else:
426+
tp_rank = get_tensor_model_parallel_rank()
427+
splitted_input = split_tensor_along_last_dim(
428+
input_, num_partitions=self.tp_size)
429+
input_parallel = splitted_input[tp_rank].contiguous()
430+
431+
# Matrix multiply.
432+
assert self.quant_method is not None
433+
# Only fuse bias add into GEMM for rank 0 (this ensures that
434+
# bias will not get added more than once in TP>1 case)
435+
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
436+
if self.tp_size == 1 or not self.reduce_results:
437+
output = self.quant_method.apply(self,
438+
input_parallel,
439+
bias=bias_)
440+
elif not flashcomm_v1_enabled:
441+
output_parallel = self.quant_method.apply(self,
442+
input_parallel,
443+
bias=bias_)
444+
output = tensor_model_parallel_all_reduce(output_parallel)
445+
# fp or bf
446+
elif matmul_rs_enabled and isinstance(self.quant_method, UnquantizedLinearMethod):
447+
raise NotImplementedError("Matmul_ReduceScatter with UnquantizedLinearMethod is not implemented yet.")
448+
# w8a8 quant
449+
elif matmul_rs_enabled and isinstance(self.quant_method.quant_method, AscendW8A8LinearMethod):
450+
raise NotImplementedError("Matmul_ReduceScatter with AscendW8A8LinearMethod is not implemented yet.")
451+
else:
452+
output_parallel = self.quant_method.apply(self,
453+
input_parallel,
454+
bias=bias_)
455+
output = maybe_pad_and_reduce_scatter(output_parallel, pad_size, 0)
456+
457+
output_bias = self.bias if self.skip_bias_add else None
458+
459+
if not self.return_bias:
460+
return output
461+
return output, output_bias

0 commit comments

Comments
 (0)