@@ -3183,6 +3183,87 @@ def initialize_kv_cache_tensors_deepseek_mla(
31833183
31843184 return kv_caches
31853185
3186+ def _initialize_kv_cache_tensors_310p (
3187+ self , kv_cache_config : KVCacheConfig ) -> dict [str , torch .Tensor ]:
3188+ kv_cache_sizes = {}
3189+ for kv_cache_tensor in kv_cache_config .kv_cache_tensors :
3190+ assert len (kv_cache_tensor .shared_by ) == 1 , (
3191+ "KV cache tensor shared by multiple layers is not supported in "
3192+ "310p NPU." )
3193+ kv_cache_sizes [kv_cache_tensor .shared_by [0 ]] = kv_cache_tensor .size
3194+
3195+ kv_caches : Dict [str , torch .Tensor ] = {}
3196+ for group in self ._kv_cache_spec_attn_group_iterator ():
3197+ kv_cache_spec = group .kv_cache_spec
3198+ attn_backend = group .backend
3199+ for layer_name in group .layer_names :
3200+ if layer_name in self .runner_only_attn_layers :
3201+ continue
3202+ tensor_size = kv_cache_sizes [layer_name ]
3203+ assert tensor_size % kv_cache_spec .page_size_bytes == 0
3204+ num_blocks = tensor_size // kv_cache_spec .page_size_bytes
3205+
3206+ # `num_blocks` is the number of blocks the model runner can use.
3207+ # `kv_cache_config.num_blocks` is the number of blocks that
3208+ # KVCacheManager may allocate.
3209+ # Since different GPUs may have different number of layers and
3210+ # different memory capacities, `num_blocks` can be different on
3211+ # different GPUs, and `kv_cache_config.num_blocks` is set to
3212+ # the min of all `num_blocks`. Verify it here.
3213+ assert num_blocks >= kv_cache_config .num_blocks
3214+
3215+ # TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
3216+ # encounter OOM issue
3217+ if isinstance (kv_cache_spec , FullAttentionSpec ):
3218+ if self .vllm_config .additional_config .get (
3219+ "kv_cache_dtype" , None ) == 'int8' :
3220+ kv_cache_shape = attn_backend .get_bsh_kv_cache_shape (
3221+ num_blocks , kv_cache_spec .block_size ,
3222+ kv_cache_spec .num_kv_heads ,
3223+ kv_cache_spec .head_size )
3224+ elif hasattr (attn_backend , "get_supported_block_size"
3225+ ) and self .use_hybrid_blocks :
3226+ block_size = attn_backend .get_supported_block_size ()[0 ]
3227+
3228+ block_size_chunk = kv_cache_spec .block_size // block_size
3229+ kv_cache_shape = attn_backend .get_kv_cache_shape (
3230+ num_blocks * block_size_chunk , block_size ,
3231+ kv_cache_spec .num_kv_heads ,
3232+ kv_cache_spec .head_size )
3233+ else :
3234+ kv_cache_shape = self .attn_backend .get_kv_cache_shape (
3235+ num_blocks , kv_cache_spec .block_size ,
3236+ kv_cache_spec .num_kv_heads ,
3237+ kv_cache_spec .head_size )
3238+ dtype = kv_cache_spec .dtype
3239+
3240+ if "attn" in layer_name :
3241+ # for self_attn, sliding window attn
3242+ if self .vllm_config .kv_transfer_config is None :
3243+ k_tensor = torch .zeros (kv_cache_shape [1 :],
3244+ dtype = dtype ,
3245+ device = self .device )
3246+ v_tensor = torch .zeros (kv_cache_shape [1 :],
3247+ dtype = dtype ,
3248+ device = self .device )
3249+ k_cache = torch_npu .npu_format_cast (
3250+ k_tensor , ACL_FORMAT )
3251+ v_cache = torch_npu .npu_format_cast (
3252+ v_tensor , ACL_FORMAT )
3253+
3254+ kv_caches [layer_name ] = (k_cache , v_cache )
3255+ else :
3256+ raise ValueError (
3257+ "KV cache transfer is not supported for 310p." )
3258+ else :
3259+ raise ValueError ("Unknown KV cache spec type." )
3260+
3261+ bind_kv_cache (kv_caches ,
3262+ self .compilation_config .static_forward_context ,
3263+ self .kv_caches )
3264+
3265+ return kv_caches
3266+
31863267 def initialize_kv_cache_tensors (
31873268 self , kv_cache_config : KVCacheConfig ) -> dict [str , torch .Tensor ]:
31883269 """
@@ -3194,6 +3275,10 @@ def initialize_kv_cache_tensors(
31943275 Dict[str, torch.Tensor]: A map between layer names to their
31953276 corresponding memory buffer for KV cache.
31963277 """
3278+
3279+ if is_310p ():
3280+ return self ._initialize_kv_cache_tensors_310p (kv_cache_config )
3281+
31973282 # init kv cache tensors
31983283 kv_cache_raw_tensors : dict [str , Union [torch .Tensor ,
31993284 Optional [torch .Tensor ]]] = {}
0 commit comments