Skip to content

Commit c2cd071

Browse files
committed
TP=2 working
1 parent 12e0a8b commit c2cd071

File tree

1 file changed

+50
-35
lines changed

1 file changed

+50
-35
lines changed

vllm/model_executor/models/mamba2.py

Lines changed: 50 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -59,20 +59,42 @@ def __init__(self, hidden_size, eps=1e-6):
5959
super().__init__()
6060
self.weight = nn.Parameter(torch.ones(hidden_size))
6161
self.variance_epsilon = eps
62+
self.tp_size = get_tensor_model_parallel_world_size()
6263
set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader})
6364

6465
def forward(self, hidden_states, gate=None):
6566
input_dtype = hidden_states.dtype
66-
hidden_states = hidden_states.to(torch.float32)
67+
input_shape = hidden_states.shape
68+
hidden_states = hidden_states.to(torch.float32).view(-1, hidden_states.shape[-1])
6769

6870
if gate is not None:
6971
hidden_states = hidden_states * nn.functional.silu(
7072
gate.to(torch.float32))
71-
variance = hidden_states.pow(2).mean(-1, keepdim=True)
73+
74+
# Use Welford's online algorithm for caculating the variance in the
75+
# tensor parallel setting, as the hidden_states are sharded along the
76+
# same axis as we are calculating the variance along.
77+
if self.tp_size > 1:
78+
# Calculate local sum and squared_sum
79+
local_sums = torch.zeros((hidden_states[0], 3), hidden_state.dtype, hidden_state.device)
80+
local_sums[:,0] = hidden_states.sum(-1, keep_dim=False)
81+
local_sums[:,1] = hidden_states.pow(2).sum(-1, keep_dim=False)
82+
83+
# Get global sum and squared sum
84+
global_sums = tensor_model_parallel_all_reduce(sum_and_squared_sum)
85+
86+
# Calculate the variance
87+
count = hidden_size.shape(-1)
88+
global_mean = global_sums[:,0] / count
89+
variance = (global_sq_sum[:,1] / count) - global_mean.pow(2)
90+
91+
else:
92+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
93+
7294
hidden_states = hidden_states * torch.rsqrt(variance +
7395
self.variance_epsilon)
7496

75-
return self.weight * hidden_states.to(input_dtype)
97+
return (self.weight * hidden_states.to(input_dtype)).view(input_shape)
7698

7799

78100
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
@@ -112,35 +134,33 @@ def __init__(self, config: MambaConfig, layer_idx):
112134

113135
self.use_bias = config.use_bias
114136

137+
groups_time_state_size = self.n_groups * self.ssm_state_size
115138
self.conv_dim = (self.intermediate_size +
116-
2 * self.n_groups * self.ssm_state_size)
117-
self.conv1d = ColumnParallelLinear(
118-
input_size=self.conv_kernel_size,
119-
output_size=self.conv_dim,
120-
bias=self.use_conv_bias,
121-
)
139+
2 * groups_time_state_size)
140+
141+
self.conv1d = MergedColumnParallelLinear(
142+
self.conv_kernel_size,
143+
[self.intermediate_size, groups_time_state_size, groups_time_state_size],
144+
bias=self.use_conv_bias)
145+
122146
# unsqueeze to fit conv1d weights shape into the linear weights shape.
123147
# Can't do this in `weight_loader` since it already exists in
124148
# `ColumnParallelLinear` and `set_weight_attrs`
125149
# doesn't allow to override it
126150
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
127151

152+
# The sharded outputs are gate, hidden_states, B, C, and dt
128153
self.in_proj = MergedColumnParallelLinear(
129154
self.hidden_size,
130-
[self.intermediate_size, self.conv_dim, self.num_heads],
155+
[self.intermediate_size, self.intermediate_size,
156+
groups_time_state_size, groups_time_state_size,
157+
self.num_heads],
131158
bias=self.use_bias)
132159

133160
# time step projection (discretization)
134161
# instantiate once and copy inv_dt in init_weights of PretrainedModel
135-
self.dt_bias = nn.Parameter(torch.ones(self.num_heads // self.tp_size))
136162

137-
# time step projection (discretization) -
138-
# In the forward we need to apply dt_proj without the bias,
139-
# as the bias is added in the selective scan kernel.
140-
self.dt_proj = ColumnParallelLinear(self.time_step_rank,
141-
self.intermediate_size,
142-
bias=True,
143-
skip_bias_add=True)
163+
self.dt_bias = nn.Parameter(torch.ones(self.num_heads // self.tp_size))
144164

145165
def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
146166
sharded_weight_loader(param, -torch.exp(loaded_weight.float()))
@@ -190,26 +210,23 @@ def mamba_forward(self,
190210
self.activation,
191211
)
192212

193-
hidden_states, B_C = torch.split(
213+
hidden_states, B, C = torch.split(
194214
hidden_states_B_C,
195215
[
196216
self.intermediate_size // self.tp_size,
197-
2 * groups_time_state_size // self.tp_size
217+
groups_time_state_size // self.tp_size,
218+
groups_time_state_size // self.tp_size,
198219
],
199220
dim=-1,
200221
)
201222

202-
B_C = tensor_model_parallel_all_gather(B_C.contiguous())
203-
B, C = torch.split(
204-
B_C, [groups_time_state_size, groups_time_state_size], dim=-1)
205-
206223
A = self.A[:, None, ...][:, :, None].expand(
207224
-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
208225
dt = dt[:, :, None].expand(-1, -1, self.head_dim)
209226
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
210227
D = self.D[:, None, ...].expand(-1, self.head_dim)
211-
B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
212-
C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
228+
B = B.view(batch_size, self.n_groups // self.tp_size, -1)
229+
C = C.view(batch_size, self.n_groups // self.tp_size, -1)
213230
hidden_states_reshaped = hidden_states.view(
214231
batch_size, self.num_heads // self.tp_size, self.head_dim)
215232
hidden_states = selective_state_update(
@@ -226,6 +243,7 @@ def mamba_forward(self,
226243
)
227244
hidden_states = hidden_states.view(
228245
batch_size, self.num_heads // self.tp_size * self.head_dim)
246+
229247
hidden_states = self.norm(hidden_states, gate)
230248
out = self.out_proj(hidden_states)[0][:, None, ...]
231249
# if no cache is found, calling the kernel
@@ -258,6 +276,7 @@ def mamba_forward(self,
258276
)
259277

260278
time_step = nn.functional.softplus(time_step + self.dt_bias)
279+
261280
# 1D Convolution
262281
if causal_conv1d_fn is None or self.activation not in [
263282
"silu", "swish"
@@ -274,20 +293,16 @@ def mamba_forward(self,
274293
activation=self.activation,
275294
).transpose(1, 2)[:, :seq_len]
276295

277-
hidden_states, B_C = torch.split(
296+
hidden_states, B, C = torch.split(
278297
hidden_states_B_C,
279298
[
280299
self.intermediate_size // self.tp_size,
281-
2 * groups_time_state_size // self.tp_size
300+
groups_time_state_size // self.tp_size,
301+
groups_time_state_size // self.tp_size,
282302
],
283303
dim=-1,
284304
)
285305

286-
# Allgather on B and C needed
287-
B_C = tensor_model_parallel_all_gather(B_C.contiguous())
288-
B, C = torch.split(
289-
B_C, [groups_time_state_size, groups_time_state_size], dim=-1)
290-
291306
# if (attention_mask is not None
292307
# and attention_mask.shape[1] > 1
293308
# and attention_mask.shape[0] > 1:
@@ -301,8 +316,8 @@ def mamba_forward(self,
301316
hidden_states.view(batch_size, seq_len, -1, self.head_dim),
302317
time_step,
303318
self.A,
304-
B.view(batch_size, seq_len, self.n_groups, -1),
305-
C.view(batch_size, seq_len, self.n_groups, -1),
319+
B.view(batch_size, seq_len, self.n_groups // self.tp_size, -1),
320+
C.view(batch_size, seq_len, self.n_groups // self.tp_size, -1),
306321
chunk_size=self.chunk_size,
307322
D=self.D,
308323
z=None,

0 commit comments

Comments
 (0)