@@ -62,6 +62,7 @@ def __init__(
62
62
prefix : str = "" ,
63
63
* ,
64
64
return_bias : bool = True ,
65
+ disable_tp : bool = False ,
65
66
):
66
67
self .comm_group = None
67
68
if prefix .find ("gate_up_proj" ) != - 1 and mlp_tp_enable ():
@@ -88,7 +89,8 @@ def __init__(
88
89
params_dtype ,
89
90
quant_config ,
90
91
prefix ,
91
- return_bias = return_bias )
92
+ return_bias = return_bias ,
93
+ disable_tp = disable_tp )
92
94
93
95
self .gather_output = gather_output
94
96
@@ -137,6 +139,7 @@ def __init__(
137
139
prefix : str = "" ,
138
140
* ,
139
141
return_bias : bool = True ,
142
+ disable_tp : bool = False ,
140
143
):
141
144
if prefix .find ("down_proj" ) != - 1 and mlp_tp_enable ():
142
145
comm_group = get_mlp_tp_group ()
@@ -156,6 +159,7 @@ def __init__(
156
159
self .forward_type = "normal"
157
160
self .comm_group = comm_group
158
161
162
+ # TODO: check for disable_tp
159
163
self .tp_size = self .comm_group .world_size
160
164
self .tp_rank = self .comm_group .rank_in_group
161
165
@@ -171,7 +175,8 @@ def __init__(
171
175
params_dtype ,
172
176
quant_config ,
173
177
prefix ,
174
- return_bias = return_bias )
178
+ return_bias = return_bias ,
179
+ disable_tp = disable_tp )
175
180
176
181
self .input_is_parallel = input_is_parallel
177
182
self .reduce_results = reduce_results
@@ -392,6 +397,7 @@ def __init__(
392
397
prefix : str = "" ,
393
398
* ,
394
399
return_bias : bool = True ,
400
+ disable_tp : bool = False ,
395
401
):
396
402
if prefix .find ("gate_up_proj" ) != - 1 and mlp_tp_enable ():
397
403
comm_group = get_mlp_tp_group ()
@@ -403,6 +409,7 @@ def __init__(
403
409
comm_group = get_tp_group ()
404
410
self .forward_type = "normal_tp"
405
411
self .comm_group = comm_group
412
+ # TODO: check for disable_tp
406
413
self .tp_rank = comm_group .rank_in_group
407
414
self .tp_size = comm_group .world_size
408
415
@@ -418,7 +425,8 @@ def __init__(
418
425
params_dtype = params_dtype ,
419
426
quant_config = quant_config ,
420
427
prefix = prefix ,
421
- return_bias = return_bias )
428
+ return_bias = return_bias ,
429
+ disable_tp = disable_tp )
422
430
423
431
def forward (
424
432
self ,
@@ -498,6 +506,7 @@ def __init__(
498
506
prefix : str = "" ,
499
507
* ,
500
508
return_bias : bool = True ,
509
+ disable_tp : bool = False ,
501
510
):
502
511
if dense_optim_enable ():
503
512
self .forward_type = "dense_optim"
@@ -511,6 +520,7 @@ def __init__(
511
520
total_num_kv_heads = total_num_heads
512
521
self .total_num_kv_heads = total_num_kv_heads
513
522
# Divide the weight matrix along the last dimension.
523
+ # TODO: check for disable_tp
514
524
tp_size = self .comm_group .world_size
515
525
self .num_heads = divide (self .total_num_heads , tp_size )
516
526
if tp_size >= self .total_num_kv_heads :
@@ -537,7 +547,8 @@ def __init__(
537
547
params_dtype = params_dtype ,
538
548
quant_config = quant_config ,
539
549
prefix = prefix ,
540
- return_bias = return_bias )
550
+ return_bias = return_bias ,
551
+ disable_tp = disable_tp )
541
552
542
553
def forward (
543
554
self ,
@@ -611,4 +622,4 @@ def __init__(
611
622
self .quant_method = quant_config .get_quant_method (self ,
612
623
prefix = prefix )
613
624
self .return_bias = return_bias
614
- self .disable_tp = disable_tp
625
+ self .disable_tp = disable_tp
0 commit comments