Skip to content

[Feature] Enable inference support for Deepseekr1-w8a8-MTP #1834

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions vllm_ascend/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import \
VocabParallelEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.models.deepseek_mtp import (
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
SharedHead)
Expand All @@ -40,6 +40,20 @@
from .deepseek_v2 import CustomDeepseekV2DecoderLayer


class CustomDeepSeekShareHead(SharedHead):

def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
nn.Module.__init__(self)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "head"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why only we need prefix here?The behavior of w8a8 weight is different with others?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefix is used for the function is_layer_skipped in quant_config.py file, the behavior is the same



class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):

def __init__(
Expand All @@ -61,7 +75,10 @@ def __init__(
self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.shared_head = CustomDeepSeekShareHead(config=config,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "shared_head"))
self.mtp_block = CustomDeepseekV2DecoderLayer(config, prefix,
model_config,
cache_config,
Expand Down
4 changes: 3 additions & 1 deletion vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
Expand Down
27 changes: 26 additions & 1 deletion vllm_ascend/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.vocab_parallel_embedding import (
UnquantizedEmbeddingMethod, VocabParallelEmbedding)
from vllm.model_executor.parameter import PerTensorScaleParameter
from vllm.model_executor.utils import set_weight_attrs

Expand All @@ -46,7 +48,7 @@
@register_quantization_config(ASCEND_QUATIZATION_METHOD)
class AscendQuantConfig(QuantizationConfig):
"""Config class for Ascend

This class is a general class that parse quantization configs
that are supported on ascend hardware.
"""
Expand Down Expand Up @@ -107,6 +109,12 @@ def get_quant_method(self, layer: torch.nn.Module,
return AscendUnquantizedFusedMoEMethod()
return AscendFusedMoEMethod(self, prefix,
self.packed_modules_mapping)
elif isinstance(layer, VocabParallelEmbedding):
if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping):
return UnquantizedEmbeddingMethod()
return AscendEmbeddingMethod(self, prefix,
self.packed_modules_mapping)
return None

def is_layer_skipped_ascend(
Expand Down Expand Up @@ -319,3 +327,20 @@ def apply(
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
self.quant_method.process_weights_after_loading(layer)


class AscendEmbeddingMethod(AscendLinearMethod):
"""Embedding method for Ascend quantization.

This class calls AscendQuantizer to search a specific quantization
implementations supported on ascend hardware for Embedding methods.

Args:
quant_config: The Ascend quantization config.
"""

def __init__(self, quant_config: AscendQuantConfig, prefix: str,
packed_modules_mapping: Dict[str, Any]) -> None:
self.quantizer = AscendQuantizer.get_quantizer(
quant_config.quant_description, prefix, packed_modules_mapping)
self.quant_method = self.quantizer.build_linear_method()
Loading