Skip to content

Commit 7dc8f38

Browse files
committed
WIP TP==2
1 parent 06fc43e commit 7dc8f38

File tree

1 file changed

+55
-48
lines changed

1 file changed

+55
-48
lines changed

vllm/model_executor/models/mamba2.py

Lines changed: 55 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@
66
import torch
77
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
88
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
9-
from mamba_ssm.ops.triton.ssd_combined import (mamba_chunk_scan_combined)
9+
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
1010
from torch import nn
1111
from torch.nn.parameter import Parameter
1212
from transformers import MambaConfig
1313

1414
from vllm.attention.backends.abstract import AttentionMetadata
1515
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
1616
from vllm.distributed import (get_tensor_model_parallel_rank,
17-
get_tensor_model_parallel_world_size)
17+
get_tensor_model_parallel_world_size,
18+
tensor_model_parallel_all_gather)
1819
from vllm.model_executor.layers.activation import SiluAndMul
1920
from vllm.model_executor.layers.layernorm import RMSNorm
2021
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -45,12 +46,20 @@ class MambaCacheParams:
4546
ssm_state: torch.Tensor = torch.Tensor()
4647

4748

49+
# Load weights that are sharded along axis 0
50+
def sharded_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
51+
tp_rank = get_tensor_model_parallel_rank()
52+
param.data.copy_(
53+
loaded_weight.data.split(param.data.shape[0], dim=0)[tp_rank])
54+
55+
4856
class MambaRMSNormGated(torch.nn.Module):
4957

5058
def __init__(self, hidden_size, eps=1e-6):
5159
super().__init__()
5260
self.weight = nn.Parameter(torch.ones(hidden_size))
5361
self.variance_epsilon = eps
62+
set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader})
5463

5564
def forward(self, hidden_states, gate=None):
5665
input_dtype = hidden_states.dtype
@@ -65,6 +74,7 @@ def forward(self, hidden_states, gate=None):
6574

6675
return self.weight * hidden_states.to(input_dtype)
6776

77+
6878
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
6979
class MambaMixer(nn.Module):
7080
"""
@@ -81,6 +91,7 @@ def __init__(self, config: MambaConfig, layer_idx):
8191
super().__init__()
8292
self.config = config
8393

94+
self.tp_size = get_tensor_model_parallel_world_size()
8495
self.num_heads = config.num_heads
8596
self.layer_idx = layer_idx
8697
self.hidden_size = config.hidden_size
@@ -114,15 +125,14 @@ def __init__(self, config: MambaConfig, layer_idx):
114125
# doesn't allow to override it
115126
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
116127

117-
projection_size = (self.intermediate_size + self.conv_dim +
118-
self.num_heads)
119-
self.in_proj = ColumnParallelLinear(self.hidden_size,
120-
projection_size,
121-
bias=self.use_bias)
128+
self.in_proj = MergedColumnParallelLinear(
129+
self.hidden_size,
130+
[self.intermediate_size, self.conv_dim, self.num_heads],
131+
bias=self.use_bias)
122132

123133
# time step projection (discretization)
124134
# instantiate once and copy inv_dt in init_weights of PretrainedModel
125-
self.dt_bias = nn.Parameter(torch.ones(self.num_heads))
135+
self.dt_bias = nn.Parameter(torch.ones(self.num_heads // self.tp_size))
126136

127137
# time step projection (discretization) -
128138
# In the forward we need to apply dt_proj without the bias,
@@ -132,33 +142,18 @@ def __init__(self, config: MambaConfig, layer_idx):
132142
bias=True,
133143
skip_bias_add=True)
134144

135-
def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
136-
tp_rank = get_tensor_model_parallel_rank()
137-
tp_size = get_tensor_model_parallel_world_size()
138-
param.data.copy_(
139-
loaded_weight.data.split(loaded_weight.shape[0] // tp_size,
140-
dim=0)[tp_rank])
141-
142145
def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
143-
weight_loader(param, -torch.exp(loaded_weight.float()))
144-
145-
# TODO: Figure out tensor parallelism for A and D.
146-
# tp_size = get_tensor_model_parallel_world_size()
147-
# In mamba.py A is
148-
# self.intermediate_size // tp_size by self.ssm_state_size,
149-
# D is a vector of size self.intermediate_size // tp_size
150-
# For mamba2 they are much smaller
151-
152-
# A = torch.arange(1, self.num_heads + 1)
153-
# self.A_log = nn.Parameter(torch.log(A))
154-
self.A = nn.Parameter(torch.ones(self.num_heads))
155-
self.norm = MambaRMSNormGated(self.intermediate_size,
156-
eps=self.layer_norm_epsilon)
146+
sharded_weight_loader(param, -torch.exp(loaded_weight.float()))
157147

158-
self.D = nn.Parameter(torch.ones(self.num_heads))
148+
self.A = nn.Parameter(torch.ones(self.num_heads // self.tp_size))
149+
self.D = nn.Parameter(torch.ones(self.num_heads // self.tp_size))
150+
self.norm = MambaRMSNormGated(self.intermediate_size // self.tp_size,
151+
eps=self.layer_norm_epsilon)
159152

160-
set_weight_attrs(self.D, {"weight_loader": weight_loader})
153+
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader})
161154
set_weight_attrs(self.A, {"weight_loader": A_weight_loader})
155+
set_weight_attrs(self.dt_bias,
156+
{"weight_loader": sharded_weight_loader})
162157

163158
self.out_proj = RowParallelLinear(
164159
self.intermediate_size,
@@ -174,17 +169,15 @@ def mamba_forward(self,
174169
# set up dimensions for reshapes later
175170
batch_size, seq_len, _ = hidden_states.shape
176171
groups_time_state_size = self.n_groups * self.ssm_state_size
177-
d_to_remove = (2 * self.intermediate_size +
178-
2 * self.n_groups * self.ssm_state_size +
179-
self.num_heads)
172+
d_to_remove = (2 * self.intermediate_size + 2 * self.n_groups *
173+
self.ssm_state_size + self.num_heads) // self.tp_size
180174

181175
if cache_params is not None and not cache_params.is_prompt:
182-
in_projected_states, _ = self.in_proj(
183-
hidden_states.squeeze(1)) # (B 2D)
176+
in_projected_states, _ = self.in_proj(hidden_states.squeeze(1))
184177
d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2
185178
split_projection_dim = [
186-
d_mlp, d_mlp, self.intermediate_size, self.conv_dim,
187-
self.num_heads
179+
d_mlp, d_mlp, self.intermediate_size // self.tp_size,
180+
self.conv_dim // self.tp_size, self.num_heads // self.tp_size
188181
]
189182
_, _, gate, hidden_states_B_C, dt = torch.split(
190183
in_projected_states, split_projection_dim, dim=-1)
@@ -197,15 +190,19 @@ def mamba_forward(self,
197190
self.activation,
198191
)
199192

200-
hidden_states, B, C = torch.split(
193+
hidden_states, B_C = torch.split(
201194
hidden_states_B_C,
202195
[
203-
self.intermediate_size, groups_time_state_size,
204-
groups_time_state_size
196+
self.intermediate_size // self.tp_size,
197+
2 * groups_time_state_size // self.tp_size
205198
],
206199
dim=-1,
207200
)
208201

202+
B_C = tensor_model_parallel_all_gather(B_C.contiguous())
203+
B, C = torch.split(
204+
B_C, [groups_time_state_size, groups_time_state_size], dim=-1)
205+
209206
A = self.A[:, None, ...][:, :, None].expand(
210207
-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
211208
dt = dt[:, :, None].expand(-1, -1, self.head_dim)
@@ -214,7 +211,7 @@ def mamba_forward(self,
214211
B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
215212
C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
216213
hidden_states_reshaped = hidden_states.view(
217-
batch_size, self.num_heads, self.head_dim)
214+
batch_size, self.num_heads // self.tp_size, self.head_dim)
218215
hidden_states = selective_state_update(
219216
cache_params.ssm_state,
220217
hidden_states_reshaped,
@@ -227,8 +224,8 @@ def mamba_forward(self,
227224
dt_bias=dt_bias,
228225
dt_softplus=True,
229226
)
230-
hidden_states = hidden_states.view(batch_size,
231-
self.num_heads * self.head_dim)
227+
hidden_states = hidden_states.view(
228+
batch_size, self.num_heads // self.tp_size * self.head_dim)
232229
hidden_states = self.norm(hidden_states, gate)
233230
out = self.out_proj(hidden_states)[0][:, None, ...]
234231
# if no cache is found, calling the kernel
@@ -253,7 +250,10 @@ def mamba_forward(self,
253250

254251
gate, hidden_states_B_C, time_step = torch.split(
255252
projected_states,
256-
[self.intermediate_size, self.conv_dim, self.num_heads],
253+
[
254+
self.intermediate_size // self.tp_size, self.conv_dim //
255+
self.tp_size, self.num_heads // self.tp_size
256+
],
257257
dim=-1,
258258
)
259259

@@ -273,14 +273,21 @@ def mamba_forward(self,
273273
bias=self.conv1d.bias,
274274
activation=self.activation,
275275
).transpose(1, 2)[:, :seq_len]
276-
hidden_states, B, C = torch.split(
276+
277+
hidden_states, B_C = torch.split(
277278
hidden_states_B_C,
278279
[
279-
self.intermediate_size, groups_time_state_size,
280-
groups_time_state_size
280+
self.intermediate_size // self.tp_size,
281+
2 * groups_time_state_size // self.tp_size
281282
],
282283
dim=-1,
283284
)
285+
286+
# Allgather on B and C needed
287+
B_C = tensor_model_parallel_all_gather(B_C.contiguous())
288+
B, C = torch.split(
289+
B_C, [groups_time_state_size, groups_time_state_size], dim=-1)
290+
284291
# if (attention_mask is not None
285292
# and attention_mask.shape[1] > 1
286293
# and attention_mask.shape[0] > 1:

0 commit comments

Comments
 (0)