Skip to content

Commit bc9b5cf

Browse files
committed
spelling
1 parent c2cd071 commit bc9b5cf

File tree

1 file changed

+22
-21
lines changed

1 file changed

+22
-21
lines changed

vllm/model_executor/models/mamba2.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -65,28 +65,30 @@ def __init__(self, hidden_size, eps=1e-6):
6565
def forward(self, hidden_states, gate=None):
6666
input_dtype = hidden_states.dtype
6767
input_shape = hidden_states.shape
68-
hidden_states = hidden_states.to(torch.float32).view(-1, hidden_states.shape[-1])
68+
hidden_states = hidden_states.to(torch.float32).view(
69+
-1, hidden_states.shape[-1])
6970

7071
if gate is not None:
7172
hidden_states = hidden_states * nn.functional.silu(
7273
gate.to(torch.float32))
7374

74-
# Use Welford's online algorithm for caculating the variance in the
75+
# Use Welford's online algorithm for caculating the variance in the
7576
# tensor parallel setting, as the hidden_states are sharded along the
76-
# same axis as we are calculating the variance along.
77+
# same axis as we are calculating the variance along.
7778
if self.tp_size > 1:
7879
# Calculate local sum and squared_sum
79-
local_sums = torch.zeros((hidden_states[0], 3), hidden_state.dtype, hidden_state.device)
80-
local_sums[:,0] = hidden_states.sum(-1, keep_dim=False)
81-
local_sums[:,1] = hidden_states.pow(2).sum(-1, keep_dim=False)
82-
80+
local_sums = torch.zeros((hidden_states[0], 3), hidden_state.dtype,
81+
hidden_state.device)
82+
local_sums[:, 0] = hidden_states.sum(-1, keep_dim=False)
83+
local_sums[:, 1] = hidden_states.pow(2).sum(-1, keep_dim=False)
84+
8385
# Get global sum and squared sum
8486
global_sums = tensor_model_parallel_all_reduce(sum_and_squared_sum)
8587

8688
# Calculate the variance
8789
count = hidden_size.shape(-1)
88-
global_mean = global_sums[:,0] / count
89-
variance = (global_sq_sum[:,1] / count) - global_mean.pow(2)
90+
global_mean = global_sums[:, 0] / count
91+
variance = (global_sq_sum[:, 1] / count) - global_mean.pow(2)
9092

9193
else:
9294
variance = hidden_states.pow(2).mean(-1, keepdim=True)
@@ -135,13 +137,13 @@ def __init__(self, config: MambaConfig, layer_idx):
135137
self.use_bias = config.use_bias
136138

137139
groups_time_state_size = self.n_groups * self.ssm_state_size
138-
self.conv_dim = (self.intermediate_size +
139-
2 * groups_time_state_size)
140+
self.conv_dim = (self.intermediate_size + 2 * groups_time_state_size)
140141

141-
self.conv1d = MergedColumnParallelLinear(
142-
self.conv_kernel_size,
143-
[self.intermediate_size, groups_time_state_size, groups_time_state_size],
144-
bias=self.use_conv_bias)
142+
self.conv1d = MergedColumnParallelLinear(self.conv_kernel_size, [
143+
self.intermediate_size, groups_time_state_size,
144+
groups_time_state_size
145+
],
146+
bias=self.use_conv_bias)
145147

146148
# unsqueeze to fit conv1d weights shape into the linear weights shape.
147149
# Can't do this in `weight_loader` since it already exists in
@@ -150,12 +152,11 @@ def __init__(self, config: MambaConfig, layer_idx):
150152
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
151153

152154
# The sharded outputs are gate, hidden_states, B, C, and dt
153-
self.in_proj = MergedColumnParallelLinear(
154-
self.hidden_size,
155-
[self.intermediate_size, self.intermediate_size,
156-
groups_time_state_size, groups_time_state_size,
157-
self.num_heads],
158-
bias=self.use_bias)
155+
self.in_proj = MergedColumnParallelLinear(self.hidden_size, [
156+
self.intermediate_size, self.intermediate_size,
157+
groups_time_state_size, groups_time_state_size, self.num_heads
158+
],
159+
bias=self.use_bias)
159160

160161
# time step projection (discretization)
161162
# instantiate once and copy inv_dt in init_weights of PretrainedModel

0 commit comments

Comments
 (0)