@@ -65,28 +65,30 @@ def __init__(self, hidden_size, eps=1e-6):
65
65
def forward (self , hidden_states , gate = None ):
66
66
input_dtype = hidden_states .dtype
67
67
input_shape = hidden_states .shape
68
- hidden_states = hidden_states .to (torch .float32 ).view (- 1 , hidden_states .shape [- 1 ])
68
+ hidden_states = hidden_states .to (torch .float32 ).view (
69
+ - 1 , hidden_states .shape [- 1 ])
69
70
70
71
if gate is not None :
71
72
hidden_states = hidden_states * nn .functional .silu (
72
73
gate .to (torch .float32 ))
73
74
74
- # Use Welford's online algorithm for caculating the variance in the
75
+ # Use Welford's online algorithm for caculating the variance in the
75
76
# tensor parallel setting, as the hidden_states are sharded along the
76
- # same axis as we are calculating the variance along.
77
+ # same axis as we are calculating the variance along.
77
78
if self .tp_size > 1 :
78
79
# Calculate local sum and squared_sum
79
- local_sums = torch .zeros ((hidden_states [0 ], 3 ), hidden_state .dtype , hidden_state .device )
80
- local_sums [:,0 ] = hidden_states .sum (- 1 , keep_dim = False )
81
- local_sums [:,1 ] = hidden_states .pow (2 ).sum (- 1 , keep_dim = False )
82
-
80
+ local_sums = torch .zeros ((hidden_states [0 ], 3 ), hidden_state .dtype ,
81
+ hidden_state .device )
82
+ local_sums [:, 0 ] = hidden_states .sum (- 1 , keep_dim = False )
83
+ local_sums [:, 1 ] = hidden_states .pow (2 ).sum (- 1 , keep_dim = False )
84
+
83
85
# Get global sum and squared sum
84
86
global_sums = tensor_model_parallel_all_reduce (sum_and_squared_sum )
85
87
86
88
# Calculate the variance
87
89
count = hidden_size .shape (- 1 )
88
- global_mean = global_sums [:,0 ] / count
89
- variance = (global_sq_sum [:,1 ] / count ) - global_mean .pow (2 )
90
+ global_mean = global_sums [:, 0 ] / count
91
+ variance = (global_sq_sum [:, 1 ] / count ) - global_mean .pow (2 )
90
92
91
93
else :
92
94
variance = hidden_states .pow (2 ).mean (- 1 , keepdim = True )
@@ -135,13 +137,13 @@ def __init__(self, config: MambaConfig, layer_idx):
135
137
self .use_bias = config .use_bias
136
138
137
139
groups_time_state_size = self .n_groups * self .ssm_state_size
138
- self .conv_dim = (self .intermediate_size +
139
- 2 * groups_time_state_size )
140
+ self .conv_dim = (self .intermediate_size + 2 * groups_time_state_size )
140
141
141
- self .conv1d = MergedColumnParallelLinear (
142
- self .conv_kernel_size ,
143
- [self .intermediate_size , groups_time_state_size , groups_time_state_size ],
144
- bias = self .use_conv_bias )
142
+ self .conv1d = MergedColumnParallelLinear (self .conv_kernel_size , [
143
+ self .intermediate_size , groups_time_state_size ,
144
+ groups_time_state_size
145
+ ],
146
+ bias = self .use_conv_bias )
145
147
146
148
# unsqueeze to fit conv1d weights shape into the linear weights shape.
147
149
# Can't do this in `weight_loader` since it already exists in
@@ -150,12 +152,11 @@ def __init__(self, config: MambaConfig, layer_idx):
150
152
self .conv1d .weight .data = self .conv1d .weight .data .unsqueeze (1 )
151
153
152
154
# The sharded outputs are gate, hidden_states, B, C, and dt
153
- self .in_proj = MergedColumnParallelLinear (
154
- self .hidden_size ,
155
- [self .intermediate_size , self .intermediate_size ,
156
- groups_time_state_size , groups_time_state_size ,
157
- self .num_heads ],
158
- bias = self .use_bias )
155
+ self .in_proj = MergedColumnParallelLinear (self .hidden_size , [
156
+ self .intermediate_size , self .intermediate_size ,
157
+ groups_time_state_size , groups_time_state_size , self .num_heads
158
+ ],
159
+ bias = self .use_bias )
159
160
160
161
# time step projection (discretization)
161
162
# instantiate once and copy inv_dt in init_weights of PretrainedModel
0 commit comments