Skip to content

Commit ee73fc4

Browse files
committed
[docs] Update guidance on how to implement and register new models in vLLM-Ascend
Signed-off-by: linfeng-yuan <1102311262@qq.com>
1 parent 6003afa commit ee73fc4

File tree

1 file changed

+237
-0
lines changed

1 file changed

+237
-0
lines changed

docs/source/model_registration.md

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
**Integrating New Models into vLLM-Ascend**
2+
This guide demonstrates how to integrate novel or customized models into vLLM-Ascend. For foundational concepts, it is highly recommended to refer to:
3+
[Adding a New Model - vLLM Documentation](https://docs.vllm.ai/en/stable/contributing/model/)
4+
5+
### 1. Implementing Models Using PyTorch and Ascend Extension for PyTorch
6+
7+
This section provides instructions for implementing new models compatible with vLLM and vLLM-Ascend. Before starting:
8+
9+
1. Verify whether your model already exists in [vLLM's Model Executor](https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models) directory
10+
2. Use existing implementations as templates to accelerate development
11+
12+
#### 1.1 Implementing New Models from Scratch
13+
14+
Follow vLLM's OPT model adaptation example for guidance:
15+
[Implementing a Basic Model - vLLM Documentation](https://docs.vllm.ai/en/stable/contributing/model/basic.html)
16+
17+
Key implementation requirements:
18+
19+
1. Place model files in [vllm_ascend/models/](https://github.com/vllm-project/vllm-ascend/tree/main/vllm_ascend/models) directory
20+
2. Standard module structure for decoder-only LLMs (please checkout vllm's implementations for other kinds of model):
21+
22+
- `*ModelForCausalLM` (top-level wrapper)
23+
- `*Model` (main architecture)
24+
- `*DecoderLayer` (transformer block)
25+
- `*Attention` & `*MLP` (specific computation unit)
26+
`*` denotes your model's unique identifier
27+
3. **Critical Implementation Details**:
28+
29+
- All modules **must** include a `prefix` argument in `__init__()`
30+
- Required interfaces:
31+
| Module Type | Required Methods |
32+
|----------------------|-------------------------------------------|
33+
| `*ModelForCausalLM` | `get_input_embeddings`, `compute_logits`, `load_weights` |
34+
| `*Model` | `get_input_embeddings`, `load_weights` |
35+
36+
37+
4. **Attention Backend Integration**:
38+
Import attention via `from vllm.attention import Attention` can automatically leverage vLLM-Ascend's attention backend routing (see: `get_attn_backend_cls()` in [vllm_ascend/platform.py](https://github.com/vllm-project/vllm-ascend/blob/main/vllm_ascend/platform.py))
39+
5. **Tensor Parallelism**:
40+
Use vLLM's parallel layers (`ColumnParallelLinear`, `VocabParallelEmbedding`, etc.), but note Ascend-specific customizations implemented in [vllm_ascend/ops/](https://github.com/vllm-project/vllm-ascend/tree/main/vllm_ascend/ops) directory (RMSNorm, VocabParallelEmbedding, etc.).
41+
42+
**Reference Implementation Template** (assumed path: `vllm_ascend/models/custom_model.py`):
43+
44+
```python
45+
from collections.abc import Iterable
46+
from typing import Optional, Union
47+
48+
import torch
49+
from torch import nn
50+
from vllm.attention import Attention
51+
from vllm.config import VllmConfig
52+
from vllm.sequence import IntermediateTensors
53+
from vllm.model_executor.sampling_metadata import SamplingMetadata
54+
55+
class CustomAttention(nn.Module):
56+
def __init__(self, vllm_config: VllmConfig, prefix: str):
57+
super().__init__()
58+
self.attn = Attention(prefix=f"{prefix}.attn")
59+
60+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
61+
# Implement attention logic
62+
...
63+
64+
class CustomDecoderLayer(nn.Module):
65+
def __init__(self, vllm_config: VllmConfig, prefix: str):
66+
super().__init__()
67+
self.self_attn = CustomAttention(vllm_config, prefix=f"{prefix}.self_attn")
68+
69+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
70+
# Implement decoder layer
71+
...
72+
73+
class CustomModel(nn.Module):
74+
def __init__(self, vllm_config: VllmConfig, prefix: str):
75+
super().__init__()
76+
self.layers = nn.ModuleList([
77+
CustomDecoderLayer(vllm_config, prefix=f"{prefix}.layers.{i}")
78+
for i in range(vllm_config.model_config.hf_config.num_hidden_layers)
79+
])
80+
81+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
82+
...
83+
84+
def forward(
85+
self,
86+
input_ids: torch.Tensor,
87+
positions: torch.Tensor,
88+
intermediate_tensors: Optional[IntermediateTensors] = None,
89+
inputs_embeds: Optional[torch.Tensor] = None,
90+
) -> Union[torch.Tensor, IntermediateTensors]:
91+
...
92+
93+
def load_weights(self,
94+
weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
95+
...
96+
97+
class CustomModelForCausalLM(nn.Module):
98+
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
99+
super().__init__()
100+
self.model = CustomModel(vllm_config, prefix=f"{prefix}.model")
101+
102+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
103+
...
104+
105+
def forward(
106+
self,
107+
input_ids: torch.Tensor,
108+
positions: torch.Tensor,
109+
intermediate_tensors: Optional[IntermediateTensors] = None,
110+
inputs_embeds: Optional[torch.Tensor] = None,
111+
) -> Union[torch.Tensor, IntermediateTensors]:
112+
...
113+
114+
def compute_logits(self,
115+
hidden_states: torch.Tensor,
116+
sampling_metadata: SamplingMetadata) -> torch.Tensor:
117+
...
118+
119+
def load_weights(self,
120+
weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
121+
...
122+
```
123+
124+
#### 1.2 Customizing Existing vLLM Models
125+
126+
For most use cases, extending existing implementations is preferable. Inherit from base classes and override specific methods (assumed path: `vllm_ascend/models/deepseek_v2.py`):
127+
128+
```python
129+
from typing import List, Optional
130+
import torch
131+
from vllm.attention import AttentionMetadata
132+
from vllm.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM
133+
from vllm.sequence import IntermediateTensors
134+
135+
class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
136+
# Define merged weights for quantization/efficiency
137+
packed_modules_mapping = {
138+
"gate_up_proj": ["gate_proj", "up_proj"],
139+
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
140+
}
141+
142+
def forward(
143+
self,
144+
input_ids: torch.Tensor,
145+
positions: torch.Tensor,
146+
kv_caches: Optional[List[torch.Tensor]] = None,
147+
attn_metadata: Optional[AttentionMetadata] = None,
148+
intermediate_tensors: Optional[IntermediateTensors] = None,
149+
inputs_embeds: Optional[torch.Tensor] = None,
150+
) -> Union[torch.Tensor, IntermediateTensors]:
151+
# Custom forward logic
152+
hidden_states = self.model(
153+
input_ids,
154+
positions,
155+
kv_caches,
156+
attn_metadata,
157+
intermediate_tensors,
158+
inputs_embeds
159+
)
160+
return hidden_states
161+
```
162+
163+
For a complete implementation reference, see: [vllm_ascend/models/deepseek_v2.py](https://github.com/vllm-project/vllm-ascend/blob/main/vllm_ascend/models/deepseek_v2.py)
164+
165+
### 2. Registering Custom Models as Out-of-Tree Plugins in vLLM
166+
167+
vLLM provides a plugin mechanism for registering externally implemented models without modifying its codebase. To integrate your custom model from [\`vllm\_ascend/models/\`](https://github.com/vllm-project/vllm-ascend/blob/main/vllm_ascend/models) directory:
168+
169+
1. **Import your model implementation** in [\`vllm\_ascend/models/\_\_init\_\_.py\`](https://github.com/vllm-project/vllm-ascend/blob/main/vllm_ascend/models/__init__.py) using relative imports
170+
2. **Register the model wrapper class** via `vllm.ModelRegistry.register_model()`
171+
172+
**Reference Registration Template** (assumed path: `vllm_ascend/models/__init__.py`)
173+
174+
```python
175+
from vllm import ModelRegistry
176+
177+
def register_model():
178+
from .custom_model import CustomModelForCausalLM # New custom model
179+
from .deepseek_v2 import ModifiedDeepseekV2ForCausalLM # Customized Deepseek
180+
181+
# For NEW architectures: Register with unique name
182+
ModelRegistry.register_model(
183+
"CustomModelForCausalLM", # Must match config.json's 'architectures'
184+
"vllm_ascend.models.custom_model:CustomModelForCausalLM"
185+
)
186+
187+
# For MODIFIED architectures: Use original name
188+
ModelRegistry.register_model(
189+
"DeepseekV2ForCausalLM", # Original architecture identifier in vLLM
190+
"vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM "
191+
)
192+
```
193+
194+
**Key Note**
195+
The architecture identifier (argument 0 for `vllm.ModelRegistry.register_model()`) must match 'architectures' in model's config.json.
196+
197+
```json
198+
{
199+
"architectures": [
200+
"CustomModelForCausalLM"
201+
],
202+
}
203+
```
204+
205+
```json
206+
{
207+
"architectures": [
208+
"DeepseekV2ForCausalLM"
209+
],
210+
}
211+
```
212+
213+
### 3. Verify Model Registration
214+
215+
#### 3.1 Overriding Existing vLLM Model Architecture
216+
217+
If you're registering a customized model architecture based on vLLM's existing implementation (overriding vLLM's original class), when executing vLLM offline/online inference (using any model), you'll observe warning logs similar to the following output from [\`vllm/models\_executor/models/registry.py\`](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/registry.py) :
218+
219+
```
220+
Model architecture DeepseekV2ForCausalLM is already registered, and will be overwritten by the new model class vllm_ascend/models/deepseek_v2:CustomDeepseekV2ForCausalLM.
221+
```
222+
223+
#### 3.2 Registering New Model Architecture
224+
225+
If you're registering a novel model architecture not present in vLLM (creating a completely new class), current logs won't provide explicit confirmation by default. It's recommended to add the following logging statement at the end of the `register_model` method in [\`vllm/models\_executor/models/registry.py\`](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/registry.py) :
226+
227+
```
228+
logger.warning(f"model_arch: {model_arch} has been registered here!")
229+
```
230+
231+
When running vLLM offline/online inference (using any model), you should then see confirmation logs similar to:
232+
233+
```
234+
model_arch: CustomModelForCausalLM has been registered here!
235+
```
236+
237+
This log output confirms your novel model architecture has been successfully registered in vLLM.

0 commit comments

Comments
 (0)