9
9
MLAAttentionImpl )
10
10
from vllm .attention .backends .utils import PAD_SLOT_ID
11
11
from vllm .config import get_current_vllm_config
12
- from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
13
- LinearBase , RowParallelLinear ,
12
+ from vllm .model_executor .layers .linear import (LinearBase ,
14
13
UnquantizedLinearMethod )
15
- from vllm .model_executor .layers .rotary_embedding import RotaryEmbedding
16
14
17
15
from vllm_ascend .attention .attention_v1 import AscendAttentionState
18
16
from vllm_ascend .ops .attention import vanilla_chunked_prefill_mla
@@ -422,20 +420,7 @@ def __init__(
422
420
blocksparse_params : Optional [dict [str , Any ]],
423
421
logits_soft_cap : Optional [float ],
424
422
attn_type : str ,
425
- # MLA Specific Arguments
426
- q_lora_rank : Optional [int ],
427
- kv_lora_rank : int ,
428
- qk_nope_head_dim : int ,
429
- qk_rope_head_dim : int ,
430
- qk_head_dim : int ,
431
- v_head_dim : int ,
432
- rotary_emb : RotaryEmbedding ,
433
- # q_proj should be q_b_proj if q_lora_rank is not None, but from an
434
- # attention backend perspective we rely on the layer to pass in the
435
- # correct matrix
436
- q_proj : ColumnParallelLinear ,
437
- kv_b_proj : ColumnParallelLinear ,
438
- o_proj : RowParallelLinear ,
423
+ kv_sharing_target_layer_name : Optional [str ] = None ,
439
424
** kwargs ,
440
425
) -> None :
441
426
self .num_heads = num_heads
@@ -444,25 +429,20 @@ def __init__(
444
429
self .num_kv_heads = num_kv_heads
445
430
self .kv_cache_dtype = kv_cache_dtype
446
431
447
- self .q_lora_rank = q_lora_rank
448
- self .kv_lora_rank = kv_lora_rank
449
- self .qk_nope_head_dim = qk_nope_head_dim
450
- self .qk_rope_head_dim = qk_rope_head_dim
451
- self .qk_head_dim = qk_head_dim
452
- self .v_head_dim = v_head_dim
453
-
454
- # Hack for V1 for now to avoid torch library overhead (since we are
455
- # already inside an attention custom op), pull out the forward
456
- # method from the rotary embedding and call it directly
457
- # TODO(lucas): we should probably find a cleaner way to do this
458
- self .rotary_emb = rotary_emb
459
-
460
- self .q_proj = q_proj
461
- self .kv_b_proj = kv_b_proj
462
- self .o_proj = o_proj
463
-
432
+ # MLA Args
433
+ self .q_lora_rank = kwargs ['q_lora_rank' ]
434
+ self .kv_lora_rank = kwargs ['kv_lora_rank' ]
435
+ self .qk_nope_head_dim = kwargs ['qk_nope_head_dim' ]
436
+ self .qk_rope_head_dim = kwargs ['qk_rope_head_dim' ]
437
+ self .qk_head_dim = kwargs ['qk_head_dim' ]
438
+ self .v_head_dim = kwargs ['v_head_dim' ]
439
+ self .rotary_emb = kwargs ['rotary_emb' ]
440
+ self .q_proj = kwargs ['q_proj' ]
441
+ self .kv_b_proj = kwargs ['kv_b_proj' ]
442
+ self .o_proj = kwargs ['o_proj' ]
464
443
self .kv_a_proj_with_mqa = kwargs .get ('kv_a_proj_with_mqa' , None )
465
444
self .kv_a_layernorm = kwargs .get ('kv_a_layernorm' , None )
445
+
466
446
# Handle the differences between the flash_attn_varlen from flash_attn
467
447
# and the one from vllm_flash_attn. The former is used on RoCM and the
468
448
# latter has an additional parameter to control FA2 vs FA3
0 commit comments