@@ -59,20 +59,42 @@ def __init__(self, hidden_size, eps=1e-6):
59
59
super ().__init__ ()
60
60
self .weight = nn .Parameter (torch .ones (hidden_size ))
61
61
self .variance_epsilon = eps
62
+ self .tp_size = get_tensor_model_parallel_world_size ()
62
63
set_weight_attrs (self .weight , {"weight_loader" : sharded_weight_loader })
63
64
64
65
def forward (self , hidden_states , gate = None ):
65
66
input_dtype = hidden_states .dtype
66
- hidden_states = hidden_states .to (torch .float32 )
67
+ input_shape = hidden_states .shape
68
+ hidden_states = hidden_states .to (torch .float32 ).view (- 1 , hidden_states .shape [- 1 ])
67
69
68
70
if gate is not None :
69
71
hidden_states = hidden_states * nn .functional .silu (
70
72
gate .to (torch .float32 ))
71
- variance = hidden_states .pow (2 ).mean (- 1 , keepdim = True )
73
+
74
+ # Use Welford's online algorithm for caculating the variance in the
75
+ # tensor parallel setting, as the hidden_states are sharded along the
76
+ # same axis as we are calculating the variance along.
77
+ if self .tp_size > 1 :
78
+ # Calculate local sum and squared_sum
79
+ local_sums = torch .zeros ((hidden_states [0 ], 3 ), hidden_state .dtype , hidden_state .device )
80
+ local_sums [:,0 ] = hidden_states .sum (- 1 , keep_dim = False )
81
+ local_sums [:,1 ] = hidden_states .pow (2 ).sum (- 1 , keep_dim = False )
82
+
83
+ # Get global sum and squared sum
84
+ global_sums = tensor_model_parallel_all_reduce (sum_and_squared_sum )
85
+
86
+ # Calculate the variance
87
+ count = hidden_size .shape (- 1 )
88
+ global_mean = global_sums [:,0 ] / count
89
+ variance = (global_sq_sum [:,1 ] / count ) - global_mean .pow (2 )
90
+
91
+ else :
92
+ variance = hidden_states .pow (2 ).mean (- 1 , keepdim = True )
93
+
72
94
hidden_states = hidden_states * torch .rsqrt (variance +
73
95
self .variance_epsilon )
74
96
75
- return self .weight * hidden_states .to (input_dtype )
97
+ return ( self .weight * hidden_states .to (input_dtype )). view ( input_shape )
76
98
77
99
78
100
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
@@ -112,35 +134,33 @@ def __init__(self, config: MambaConfig, layer_idx):
112
134
113
135
self .use_bias = config .use_bias
114
136
137
+ groups_time_state_size = self .n_groups * self .ssm_state_size
115
138
self .conv_dim = (self .intermediate_size +
116
- 2 * self .n_groups * self .ssm_state_size )
117
- self .conv1d = ColumnParallelLinear (
118
- input_size = self .conv_kernel_size ,
119
- output_size = self .conv_dim ,
120
- bias = self .use_conv_bias ,
121
- )
139
+ 2 * groups_time_state_size )
140
+
141
+ self .conv1d = MergedColumnParallelLinear (
142
+ self .conv_kernel_size ,
143
+ [self .intermediate_size , groups_time_state_size , groups_time_state_size ],
144
+ bias = self .use_conv_bias )
145
+
122
146
# unsqueeze to fit conv1d weights shape into the linear weights shape.
123
147
# Can't do this in `weight_loader` since it already exists in
124
148
# `ColumnParallelLinear` and `set_weight_attrs`
125
149
# doesn't allow to override it
126
150
self .conv1d .weight .data = self .conv1d .weight .data .unsqueeze (1 )
127
151
152
+ # The sharded outputs are gate, hidden_states, B, C, and dt
128
153
self .in_proj = MergedColumnParallelLinear (
129
154
self .hidden_size ,
130
- [self .intermediate_size , self .conv_dim , self .num_heads ],
155
+ [self .intermediate_size , self .intermediate_size ,
156
+ groups_time_state_size , groups_time_state_size ,
157
+ self .num_heads ],
131
158
bias = self .use_bias )
132
159
133
160
# time step projection (discretization)
134
161
# instantiate once and copy inv_dt in init_weights of PretrainedModel
135
- self .dt_bias = nn .Parameter (torch .ones (self .num_heads // self .tp_size ))
136
162
137
- # time step projection (discretization) -
138
- # In the forward we need to apply dt_proj without the bias,
139
- # as the bias is added in the selective scan kernel.
140
- self .dt_proj = ColumnParallelLinear (self .time_step_rank ,
141
- self .intermediate_size ,
142
- bias = True ,
143
- skip_bias_add = True )
163
+ self .dt_bias = nn .Parameter (torch .ones (self .num_heads // self .tp_size ))
144
164
145
165
def A_weight_loader (param : Parameter , loaded_weight : torch .Tensor ):
146
166
sharded_weight_loader (param , - torch .exp (loaded_weight .float ()))
@@ -190,26 +210,23 @@ def mamba_forward(self,
190
210
self .activation ,
191
211
)
192
212
193
- hidden_states , B_C = torch .split (
213
+ hidden_states , B , C = torch .split (
194
214
hidden_states_B_C ,
195
215
[
196
216
self .intermediate_size // self .tp_size ,
197
- 2 * groups_time_state_size // self .tp_size
217
+ groups_time_state_size // self .tp_size ,
218
+ groups_time_state_size // self .tp_size ,
198
219
],
199
220
dim = - 1 ,
200
221
)
201
222
202
- B_C = tensor_model_parallel_all_gather (B_C .contiguous ())
203
- B , C = torch .split (
204
- B_C , [groups_time_state_size , groups_time_state_size ], dim = - 1 )
205
-
206
223
A = self .A [:, None , ...][:, :, None ].expand (
207
224
- 1 , self .head_dim , self .ssm_state_size ).to (dtype = torch .float32 )
208
225
dt = dt [:, :, None ].expand (- 1 , - 1 , self .head_dim )
209
226
dt_bias = self .dt_bias [:, None , ...].expand (- 1 , self .head_dim )
210
227
D = self .D [:, None , ...].expand (- 1 , self .head_dim )
211
- B = B .view (batch_size , self .n_groups , B . shape [ 1 ] // self .n_groups )
212
- C = C .view (batch_size , self .n_groups , C . shape [ 1 ] // self .n_groups )
228
+ B = B .view (batch_size , self .n_groups // self .tp_size , - 1 )
229
+ C = C .view (batch_size , self .n_groups // self .tp_size , - 1 )
213
230
hidden_states_reshaped = hidden_states .view (
214
231
batch_size , self .num_heads // self .tp_size , self .head_dim )
215
232
hidden_states = selective_state_update (
@@ -226,6 +243,7 @@ def mamba_forward(self,
226
243
)
227
244
hidden_states = hidden_states .view (
228
245
batch_size , self .num_heads // self .tp_size * self .head_dim )
246
+
229
247
hidden_states = self .norm (hidden_states , gate )
230
248
out = self .out_proj (hidden_states )[0 ][:, None , ...]
231
249
# if no cache is found, calling the kernel
@@ -258,6 +276,7 @@ def mamba_forward(self,
258
276
)
259
277
260
278
time_step = nn .functional .softplus (time_step + self .dt_bias )
279
+
261
280
# 1D Convolution
262
281
if causal_conv1d_fn is None or self .activation not in [
263
282
"silu" , "swish"
@@ -274,20 +293,16 @@ def mamba_forward(self,
274
293
activation = self .activation ,
275
294
).transpose (1 , 2 )[:, :seq_len ]
276
295
277
- hidden_states , B_C = torch .split (
296
+ hidden_states , B , C = torch .split (
278
297
hidden_states_B_C ,
279
298
[
280
299
self .intermediate_size // self .tp_size ,
281
- 2 * groups_time_state_size // self .tp_size
300
+ groups_time_state_size // self .tp_size ,
301
+ groups_time_state_size // self .tp_size ,
282
302
],
283
303
dim = - 1 ,
284
304
)
285
305
286
- # Allgather on B and C needed
287
- B_C = tensor_model_parallel_all_gather (B_C .contiguous ())
288
- B , C = torch .split (
289
- B_C , [groups_time_state_size , groups_time_state_size ], dim = - 1 )
290
-
291
306
# if (attention_mask is not None
292
307
# and attention_mask.shape[1] > 1
293
308
# and attention_mask.shape[0] > 1:
@@ -301,8 +316,8 @@ def mamba_forward(self,
301
316
hidden_states .view (batch_size , seq_len , - 1 , self .head_dim ),
302
317
time_step ,
303
318
self .A ,
304
- B .view (batch_size , seq_len , self .n_groups , - 1 ),
305
- C .view (batch_size , seq_len , self .n_groups , - 1 ),
319
+ B .view (batch_size , seq_len , self .n_groups // self . tp_size , - 1 ),
320
+ C .view (batch_size , seq_len , self .n_groups // self . tp_size , - 1 ),
306
321
chunk_size = self .chunk_size ,
307
322
D = self .D ,
308
323
z = None ,
0 commit comments