|
| 1 | +**Integrating New Models into vLLM-Ascend** |
| 2 | +This guide demonstrates how to integrate novel or customized models into vLLM-Ascend. For foundational concepts, it is highly recommended to refer to: |
| 3 | +[Adding a New Model - vLLM Documentation](https://docs.vllm.ai/en/stable/contributing/model/) |
| 4 | + |
| 5 | +### 1. Implementing Models Using PyTorch and Ascend Extension for PyTorch |
| 6 | + |
| 7 | +This section provides instructions for implementing new models compatible with vLLM and vLLM-Ascend. Before starting: |
| 8 | + |
| 9 | +1. Verify whether your model already exists in [vLLM's Model Executor](https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models) directory |
| 10 | +2. Use existing implementations as templates to accelerate development |
| 11 | + |
| 12 | +#### 1.1 Implementing New Models from Scratch |
| 13 | + |
| 14 | +Follow vLLM's OPT model adaptation example for guidance: |
| 15 | +[Implementing a Basic Model - vLLM Documentation](https://docs.vllm.ai/en/stable/contributing/model/basic.html) |
| 16 | + |
| 17 | +Key implementation requirements: |
| 18 | + |
| 19 | +1. Place model files in [vllm_ascend/models/](https://github.com/vllm-project/vllm-ascend/tree/main/vllm_ascend/models) directory |
| 20 | +2. Standard module structure for decoder-only LLMs (please checkout vllm's implementations for other kinds of model): |
| 21 | + |
| 22 | + - `*ModelForCausalLM` (top-level wrapper) |
| 23 | + - `*Model` (main architecture) |
| 24 | + - `*DecoderLayer` (transformer block) |
| 25 | + - `*Attention` & `*MLP` (specific computation unit) |
| 26 | + `*` denotes your model's unique identifier |
| 27 | +3. **Critical Implementation Details**: |
| 28 | + |
| 29 | + - All modules **must** include a `prefix` argument in `__init__()` |
| 30 | + - Required interfaces: |
| 31 | + | Module Type | Required Methods | |
| 32 | +|----------------------|-------------------------------------------| |
| 33 | +| `*ModelForCausalLM` | `get_input_embeddings`, `compute_logits`, `load_weights` | |
| 34 | +| `*Model` | `get_input_embeddings`, `load_weights` | |
| 35 | + |
| 36 | + |
| 37 | +4. **Attention Backend Integration**: |
| 38 | + Import attention via `from vllm.attention import Attention` can automatically leverage vLLM-Ascend's attention backend routing (see: `get_attn_backend_cls()` in [vllm_ascend/platform.py](https://github.com/vllm-project/vllm-ascend/blob/main/vllm_ascend/platform.py)) |
| 39 | +5. **Tensor Parallelism**: |
| 40 | + Use vLLM's parallel layers (`ColumnParallelLinear`, `VocabParallelEmbedding`, etc.), but note Ascend-specific customizations implemented in [vllm_ascend/ops/](https://github.com/vllm-project/vllm-ascend/tree/main/vllm_ascend/ops) directory (RMSNorm, VocabParallelEmbedding, etc.). |
| 41 | + |
| 42 | +**Reference Implementation Template** (assumed path: `vllm_ascend/models/custom_model.py`): |
| 43 | + |
| 44 | +```python |
| 45 | +from collections.abc import Iterable |
| 46 | +from typing import Optional, Union |
| 47 | + |
| 48 | +import torch |
| 49 | +from torch import nn |
| 50 | +from vllm.attention import Attention |
| 51 | +from vllm.config import VllmConfig |
| 52 | +from vllm.sequence import IntermediateTensors |
| 53 | +from vllm.model_executor.sampling_metadata import SamplingMetadata |
| 54 | + |
| 55 | +class CustomAttention(nn.Module): |
| 56 | + def __init__(self, vllm_config: VllmConfig, prefix: str): |
| 57 | + super().__init__() |
| 58 | + self.attn = Attention(prefix=f"{prefix}.attn") |
| 59 | + |
| 60 | + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| 61 | + # Implement attention logic |
| 62 | + ... |
| 63 | + |
| 64 | +class CustomDecoderLayer(nn.Module): |
| 65 | + def __init__(self, vllm_config: VllmConfig, prefix: str): |
| 66 | + super().__init__() |
| 67 | + self.self_attn = CustomAttention(vllm_config, prefix=f"{prefix}.self_attn") |
| 68 | + |
| 69 | + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| 70 | + # Implement decoder layer |
| 71 | + ... |
| 72 | + |
| 73 | +class CustomModel(nn.Module): |
| 74 | + def __init__(self, vllm_config: VllmConfig, prefix: str): |
| 75 | + super().__init__() |
| 76 | + self.layers = nn.ModuleList([ |
| 77 | + CustomDecoderLayer(vllm_config, prefix=f"{prefix}.layers.{i}") |
| 78 | + for i in range(vllm_config.model_config.hf_config.num_hidden_layers) |
| 79 | + ]) |
| 80 | + |
| 81 | + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: |
| 82 | + ... |
| 83 | + |
| 84 | + def forward( |
| 85 | + self, |
| 86 | + input_ids: torch.Tensor, |
| 87 | + positions: torch.Tensor, |
| 88 | + intermediate_tensors: Optional[IntermediateTensors] = None, |
| 89 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 90 | + ) -> Union[torch.Tensor, IntermediateTensors]: |
| 91 | + ... |
| 92 | + |
| 93 | + def load_weights(self, |
| 94 | + weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: |
| 95 | + ... |
| 96 | + |
| 97 | +class CustomModelForCausalLM(nn.Module): |
| 98 | + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): |
| 99 | + super().__init__() |
| 100 | + self.model = CustomModel(vllm_config, prefix=f"{prefix}.model") |
| 101 | + |
| 102 | + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: |
| 103 | + ... |
| 104 | + |
| 105 | + def forward( |
| 106 | + self, |
| 107 | + input_ids: torch.Tensor, |
| 108 | + positions: torch.Tensor, |
| 109 | + intermediate_tensors: Optional[IntermediateTensors] = None, |
| 110 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 111 | + ) -> Union[torch.Tensor, IntermediateTensors]: |
| 112 | + ... |
| 113 | + |
| 114 | + def compute_logits(self, |
| 115 | + hidden_states: torch.Tensor, |
| 116 | + sampling_metadata: SamplingMetadata) -> torch.Tensor: |
| 117 | + ... |
| 118 | + |
| 119 | + def load_weights(self, |
| 120 | + weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: |
| 121 | + ... |
| 122 | +``` |
| 123 | + |
| 124 | +#### 1.2 Customizing Existing vLLM Models |
| 125 | + |
| 126 | +For most use cases, extending existing implementations is preferable. Inherit from base classes and override specific methods (assumed path: `vllm_ascend/models/deepseek_v2.py`): |
| 127 | + |
| 128 | +```python |
| 129 | +from typing import List, Optional |
| 130 | +import torch |
| 131 | +from vllm.attention import AttentionMetadata |
| 132 | +from vllm.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM |
| 133 | +from vllm.sequence import IntermediateTensors |
| 134 | + |
| 135 | +class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): |
| 136 | + # Define merged weights for quantization/efficiency |
| 137 | + packed_modules_mapping = { |
| 138 | + "gate_up_proj": ["gate_proj", "up_proj"], |
| 139 | + "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] |
| 140 | + } |
| 141 | + |
| 142 | + def forward( |
| 143 | + self, |
| 144 | + input_ids: torch.Tensor, |
| 145 | + positions: torch.Tensor, |
| 146 | + kv_caches: Optional[List[torch.Tensor]] = None, |
| 147 | + attn_metadata: Optional[AttentionMetadata] = None, |
| 148 | + intermediate_tensors: Optional[IntermediateTensors] = None, |
| 149 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 150 | + ) -> Union[torch.Tensor, IntermediateTensors]: |
| 151 | + # Custom forward logic |
| 152 | + hidden_states = self.model( |
| 153 | + input_ids, |
| 154 | + positions, |
| 155 | + kv_caches, |
| 156 | + attn_metadata, |
| 157 | + intermediate_tensors, |
| 158 | + inputs_embeds |
| 159 | + ) |
| 160 | + return hidden_states |
| 161 | +``` |
| 162 | + |
| 163 | +For a complete implementation reference, see: [vllm_ascend/models/deepseek_v2.py](https://github.com/vllm-project/vllm-ascend/blob/main/vllm_ascend/models/deepseek_v2.py) |
| 164 | + |
| 165 | +### 2. Registering Custom Models as Out-of-Tree Plugins in vLLM |
| 166 | + |
| 167 | +vLLM provides a plugin mechanism for registering externally implemented models without modifying its codebase. To integrate your custom model from [\`vllm\_ascend/models/\`](https://github.com/vllm-project/vllm-ascend/blob/main/vllm_ascend/models) directory: |
| 168 | + |
| 169 | +1. **Import your model implementation** in [\`vllm\_ascend/models/\_\_init\_\_.py\`](https://github.com/vllm-project/vllm-ascend/blob/main/vllm_ascend/models/__init__.py) using relative imports |
| 170 | +2. **Register the model wrapper class** via `vllm.ModelRegistry.register_model()` |
| 171 | + |
| 172 | +**Reference Registration Template** (assumed path: `vllm_ascend/models/__init__.py`) |
| 173 | + |
| 174 | +```python |
| 175 | +from vllm import ModelRegistry |
| 176 | + |
| 177 | +def register_model(): |
| 178 | + from .custom_model import CustomModelForCausalLM # New custom model |
| 179 | + from .deepseek_v2 import ModifiedDeepseekV2ForCausalLM # Customized Deepseek |
| 180 | + |
| 181 | + # For NEW architectures: Register with unique name |
| 182 | + ModelRegistry.register_model( |
| 183 | + "CustomModelForCausalLM", # Must match config.json's 'architectures' |
| 184 | + "vllm_ascend.models.custom_model:CustomModelForCausalLM" |
| 185 | + ) |
| 186 | + |
| 187 | + # For MODIFIED architectures: Use original name |
| 188 | + ModelRegistry.register_model( |
| 189 | + "DeepseekV2ForCausalLM", # Original architecture identifier in vLLM |
| 190 | + "vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM " |
| 191 | + ) |
| 192 | +``` |
| 193 | + |
| 194 | +**Key Note** |
| 195 | +The architecture identifier (argument 0 for `vllm.ModelRegistry.register_model()`) must match 'architectures' in model's config.json. |
| 196 | + |
| 197 | +```json |
| 198 | +{ |
| 199 | + "architectures": [ |
| 200 | + "CustomModelForCausalLM" |
| 201 | + ], |
| 202 | +} |
| 203 | +``` |
| 204 | + |
| 205 | +```json |
| 206 | +{ |
| 207 | + "architectures": [ |
| 208 | + "DeepseekV2ForCausalLM" |
| 209 | + ], |
| 210 | +} |
| 211 | +``` |
| 212 | + |
| 213 | +### 3. Verify Model Registration |
| 214 | + |
| 215 | +#### 3.1 Overriding Existing vLLM Model Architecture |
| 216 | + |
| 217 | +If you're registering a customized model architecture based on vLLM's existing implementation (overriding vLLM's original class), when executing vLLM offline/online inference (using any model), you'll observe warning logs similar to the following output from [\`vllm/models\_executor/models/registry.py\`](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/registry.py) : |
| 218 | + |
| 219 | +``` |
| 220 | +Model architecture DeepseekV2ForCausalLM is already registered, and will be overwritten by the new model class vllm_ascend/models/deepseek_v2:CustomDeepseekV2ForCausalLM. |
| 221 | +``` |
| 222 | + |
| 223 | +#### 3.2 Registering New Model Architecture |
| 224 | + |
| 225 | +If you're registering a novel model architecture not present in vLLM (creating a completely new class), current logs won't provide explicit confirmation by default. It's recommended to add the following logging statement at the end of the `register_model` method in [\`vllm/models\_executor/models/registry.py\`](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/registry.py) : |
| 226 | + |
| 227 | +``` |
| 228 | +logger.warning(f"model_arch: {model_arch} has been registered here!") |
| 229 | +``` |
| 230 | + |
| 231 | +When running vLLM offline/online inference (using any model), you should then see confirmation logs similar to: |
| 232 | + |
| 233 | +``` |
| 234 | +model_arch: CustomModelForCausalLM has been registered here! |
| 235 | +``` |
| 236 | + |
| 237 | +This log output confirms your novel model architecture has been successfully registered in vLLM. |
0 commit comments