6
6
import torch
7
7
from causal_conv1d import causal_conv1d_fn , causal_conv1d_update
8
8
from mamba_ssm .ops .triton .selective_state_update import selective_state_update
9
- from mamba_ssm .ops .triton .ssd_combined import ( mamba_chunk_scan_combined )
9
+ from mamba_ssm .ops .triton .ssd_combined import mamba_chunk_scan_combined
10
10
from torch import nn
11
11
from torch .nn .parameter import Parameter
12
12
from transformers import MambaConfig
13
13
14
14
from vllm .attention .backends .abstract import AttentionMetadata
15
15
from vllm .config import CacheConfig , LoRAConfig , SchedulerConfig
16
16
from vllm .distributed import (get_tensor_model_parallel_rank ,
17
- get_tensor_model_parallel_world_size )
17
+ get_tensor_model_parallel_world_size ,
18
+ tensor_model_parallel_all_gather )
18
19
from vllm .model_executor .layers .activation import SiluAndMul
19
20
from vllm .model_executor .layers .layernorm import RMSNorm
20
21
from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
@@ -45,12 +46,20 @@ class MambaCacheParams:
45
46
ssm_state : torch .Tensor = torch .Tensor ()
46
47
47
48
49
+ # Load weights that are sharded along axis 0
50
+ def sharded_weight_loader (param : Parameter , loaded_weight : torch .Tensor ):
51
+ tp_rank = get_tensor_model_parallel_rank ()
52
+ param .data .copy_ (
53
+ loaded_weight .data .split (param .data .shape [0 ], dim = 0 )[tp_rank ])
54
+
55
+
48
56
class MambaRMSNormGated (torch .nn .Module ):
49
57
50
58
def __init__ (self , hidden_size , eps = 1e-6 ):
51
59
super ().__init__ ()
52
60
self .weight = nn .Parameter (torch .ones (hidden_size ))
53
61
self .variance_epsilon = eps
62
+ set_weight_attrs (self .weight , {"weight_loader" : sharded_weight_loader })
54
63
55
64
def forward (self , hidden_states , gate = None ):
56
65
input_dtype = hidden_states .dtype
@@ -65,6 +74,7 @@ def forward(self, hidden_states, gate=None):
65
74
66
75
return self .weight * hidden_states .to (input_dtype )
67
76
77
+
68
78
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
69
79
class MambaMixer (nn .Module ):
70
80
"""
@@ -81,6 +91,7 @@ def __init__(self, config: MambaConfig, layer_idx):
81
91
super ().__init__ ()
82
92
self .config = config
83
93
94
+ self .tp_size = get_tensor_model_parallel_world_size ()
84
95
self .num_heads = config .num_heads
85
96
self .layer_idx = layer_idx
86
97
self .hidden_size = config .hidden_size
@@ -114,15 +125,14 @@ def __init__(self, config: MambaConfig, layer_idx):
114
125
# doesn't allow to override it
115
126
self .conv1d .weight .data = self .conv1d .weight .data .unsqueeze (1 )
116
127
117
- projection_size = (self .intermediate_size + self .conv_dim +
118
- self .num_heads )
119
- self .in_proj = ColumnParallelLinear (self .hidden_size ,
120
- projection_size ,
121
- bias = self .use_bias )
128
+ self .in_proj = MergedColumnParallelLinear (
129
+ self .hidden_size ,
130
+ [self .intermediate_size , self .conv_dim , self .num_heads ],
131
+ bias = self .use_bias )
122
132
123
133
# time step projection (discretization)
124
134
# instantiate once and copy inv_dt in init_weights of PretrainedModel
125
- self .dt_bias = nn .Parameter (torch .ones (self .num_heads ))
135
+ self .dt_bias = nn .Parameter (torch .ones (self .num_heads // self . tp_size ))
126
136
127
137
# time step projection (discretization) -
128
138
# In the forward we need to apply dt_proj without the bias,
@@ -132,33 +142,18 @@ def __init__(self, config: MambaConfig, layer_idx):
132
142
bias = True ,
133
143
skip_bias_add = True )
134
144
135
- def weight_loader (param : Parameter , loaded_weight : torch .Tensor ):
136
- tp_rank = get_tensor_model_parallel_rank ()
137
- tp_size = get_tensor_model_parallel_world_size ()
138
- param .data .copy_ (
139
- loaded_weight .data .split (loaded_weight .shape [0 ] // tp_size ,
140
- dim = 0 )[tp_rank ])
141
-
142
145
def A_weight_loader (param : Parameter , loaded_weight : torch .Tensor ):
143
- weight_loader (param , - torch .exp (loaded_weight .float ()))
144
-
145
- # TODO: Figure out tensor parallelism for A and D.
146
- # tp_size = get_tensor_model_parallel_world_size()
147
- # In mamba.py A is
148
- # self.intermediate_size // tp_size by self.ssm_state_size,
149
- # D is a vector of size self.intermediate_size // tp_size
150
- # For mamba2 they are much smaller
151
-
152
- # A = torch.arange(1, self.num_heads + 1)
153
- # self.A_log = nn.Parameter(torch.log(A))
154
- self .A = nn .Parameter (torch .ones (self .num_heads ))
155
- self .norm = MambaRMSNormGated (self .intermediate_size ,
156
- eps = self .layer_norm_epsilon )
146
+ sharded_weight_loader (param , - torch .exp (loaded_weight .float ()))
157
147
158
- self .D = nn .Parameter (torch .ones (self .num_heads ))
148
+ self .A = nn .Parameter (torch .ones (self .num_heads // self .tp_size ))
149
+ self .D = nn .Parameter (torch .ones (self .num_heads // self .tp_size ))
150
+ self .norm = MambaRMSNormGated (self .intermediate_size // self .tp_size ,
151
+ eps = self .layer_norm_epsilon )
159
152
160
- set_weight_attrs (self .D , {"weight_loader" : weight_loader })
153
+ set_weight_attrs (self .D , {"weight_loader" : sharded_weight_loader })
161
154
set_weight_attrs (self .A , {"weight_loader" : A_weight_loader })
155
+ set_weight_attrs (self .dt_bias ,
156
+ {"weight_loader" : sharded_weight_loader })
162
157
163
158
self .out_proj = RowParallelLinear (
164
159
self .intermediate_size ,
@@ -174,17 +169,15 @@ def mamba_forward(self,
174
169
# set up dimensions for reshapes later
175
170
batch_size , seq_len , _ = hidden_states .shape
176
171
groups_time_state_size = self .n_groups * self .ssm_state_size
177
- d_to_remove = (2 * self .intermediate_size +
178
- 2 * self .n_groups * self .ssm_state_size +
179
- self .num_heads )
172
+ d_to_remove = (2 * self .intermediate_size + 2 * self .n_groups *
173
+ self .ssm_state_size + self .num_heads ) // self .tp_size
180
174
181
175
if cache_params is not None and not cache_params .is_prompt :
182
- in_projected_states , _ = self .in_proj (
183
- hidden_states .squeeze (1 )) # (B 2D)
176
+ in_projected_states , _ = self .in_proj (hidden_states .squeeze (1 ))
184
177
d_mlp = (in_projected_states .shape [- 1 ] - d_to_remove ) // 2
185
178
split_projection_dim = [
186
- d_mlp , d_mlp , self .intermediate_size , self .conv_dim ,
187
- self .num_heads
179
+ d_mlp , d_mlp , self .intermediate_size // self .tp_size ,
180
+ self .conv_dim // self . tp_size , self . num_heads // self . tp_size
188
181
]
189
182
_ , _ , gate , hidden_states_B_C , dt = torch .split (
190
183
in_projected_states , split_projection_dim , dim = - 1 )
@@ -197,15 +190,19 @@ def mamba_forward(self,
197
190
self .activation ,
198
191
)
199
192
200
- hidden_states , B , C = torch .split (
193
+ hidden_states , B_C = torch .split (
201
194
hidden_states_B_C ,
202
195
[
203
- self .intermediate_size , groups_time_state_size ,
204
- groups_time_state_size
196
+ self .intermediate_size // self . tp_size ,
197
+ 2 * groups_time_state_size // self . tp_size
205
198
],
206
199
dim = - 1 ,
207
200
)
208
201
202
+ B_C = tensor_model_parallel_all_gather (B_C .contiguous ())
203
+ B , C = torch .split (
204
+ B_C , [groups_time_state_size , groups_time_state_size ], dim = - 1 )
205
+
209
206
A = self .A [:, None , ...][:, :, None ].expand (
210
207
- 1 , self .head_dim , self .ssm_state_size ).to (dtype = torch .float32 )
211
208
dt = dt [:, :, None ].expand (- 1 , - 1 , self .head_dim )
@@ -214,7 +211,7 @@ def mamba_forward(self,
214
211
B = B .view (batch_size , self .n_groups , B .shape [1 ] // self .n_groups )
215
212
C = C .view (batch_size , self .n_groups , C .shape [1 ] // self .n_groups )
216
213
hidden_states_reshaped = hidden_states .view (
217
- batch_size , self .num_heads , self .head_dim )
214
+ batch_size , self .num_heads // self . tp_size , self .head_dim )
218
215
hidden_states = selective_state_update (
219
216
cache_params .ssm_state ,
220
217
hidden_states_reshaped ,
@@ -227,8 +224,8 @@ def mamba_forward(self,
227
224
dt_bias = dt_bias ,
228
225
dt_softplus = True ,
229
226
)
230
- hidden_states = hidden_states .view (batch_size ,
231
- self .num_heads * self .head_dim )
227
+ hidden_states = hidden_states .view (
228
+ batch_size , self . num_heads // self .tp_size * self .head_dim )
232
229
hidden_states = self .norm (hidden_states , gate )
233
230
out = self .out_proj (hidden_states )[0 ][:, None , ...]
234
231
# if no cache is found, calling the kernel
@@ -253,7 +250,10 @@ def mamba_forward(self,
253
250
254
251
gate , hidden_states_B_C , time_step = torch .split (
255
252
projected_states ,
256
- [self .intermediate_size , self .conv_dim , self .num_heads ],
253
+ [
254
+ self .intermediate_size // self .tp_size , self .conv_dim //
255
+ self .tp_size , self .num_heads // self .tp_size
256
+ ],
257
257
dim = - 1 ,
258
258
)
259
259
@@ -273,14 +273,21 @@ def mamba_forward(self,
273
273
bias = self .conv1d .bias ,
274
274
activation = self .activation ,
275
275
).transpose (1 , 2 )[:, :seq_len ]
276
- hidden_states , B , C = torch .split (
276
+
277
+ hidden_states , B_C = torch .split (
277
278
hidden_states_B_C ,
278
279
[
279
- self .intermediate_size , groups_time_state_size ,
280
- groups_time_state_size
280
+ self .intermediate_size // self . tp_size ,
281
+ 2 * groups_time_state_size // self . tp_size
281
282
],
282
283
dim = - 1 ,
283
284
)
285
+
286
+ # Allgather on B and C needed
287
+ B_C = tensor_model_parallel_all_gather (B_C .contiguous ())
288
+ B , C = torch .split (
289
+ B_C , [groups_time_state_size , groups_time_state_size ], dim = - 1 )
290
+
284
291
# if (attention_mask is not None
285
292
# and attention_mask.shape[1] > 1
286
293
# and attention_mask.shape[0] > 1:
0 commit comments