@@ -1430,10 +1430,18 @@ def __init__(
14301430 num_register_tokens = 0 ,
14311431 checkpoint = False ,
14321432 add_value_residual = False ,
1433+ num_residual_streams = 1 ,
14331434 pairwise_block_kwargs : dict = dict (),
14341435 pair_bias_attn_kwargs : dict = dict ()
14351436 ):
14361437 super ().__init__ ()
1438+
1439+ # residual / hyper connections
1440+
1441+ init_hyper_conn , self .expand_streams , self .reduce_streams = HyperConnections .get_init_and_expand_reduce_stream_functions (num_residual_streams , disable = num_residual_streams == 1 )
1442+
1443+ # layers
1444+
14371445 layers = ModuleList ([])
14381446
14391447 pair_bias_attn_kwargs = dict (
@@ -1463,8 +1471,8 @@ def __init__(
14631471
14641472 layers .append (ModuleList ([
14651473 pairwise_block ,
1466- single_pre_ln (pair_bias_attn ),
1467- single_pre_ln (single_transition ),
1474+ init_hyper_conn ( dim = dim_single , branch = single_pre_ln (pair_bias_attn ) ),
1475+ init_hyper_conn ( dim = dim_single , branch = single_pre_ln (single_transition ) ),
14681476 ]))
14691477
14701478 self .layers = layers
@@ -1499,6 +1507,8 @@ def to_layers(
14991507
15001508 ) -> Tuple [Float ['b n ds' ], Float ['b n n dp' ]]:
15011509
1510+ single_repr = self .expand_streams (single_repr )
1511+
15021512 for _ in range (self .recurrent_depth ):
15031513
15041514 value_residual = None
@@ -1512,15 +1522,15 @@ def to_layers(
15121522
15131523 pairwise_repr , pairwise_attn_values = pairwise_block (pairwise_repr = pairwise_repr , mask = mask , value_residuals = pairwise_value_residuals , return_values = True )
15141524
1515- attn_out , attn_values = pair_bias_attn (single_repr , pairwise_repr = pairwise_repr , mask = mask , return_values = True , value_residual = value_residual )
1516-
1517- single_repr = single_repr + attn_out
1525+ single_repr , attn_values = pair_bias_attn (single_repr , pairwise_repr = pairwise_repr , mask = mask , return_values = True , value_residual = value_residual )
15181526
15191527 if self .add_value_residual :
15201528 value_residual = default (value_residual , attn_values )
15211529 pairwise_value_residuals = default (pairwise_value_residuals , pairwise_attn_values )
15221530
1523- single_repr = single_transition (single_repr ) + single_repr
1531+ single_repr = single_transition (single_repr )
1532+
1533+ single_repr = self .reduce_streams (single_repr )
15241534
15251535 return single_repr , pairwise_repr
15261536
@@ -1550,8 +1560,7 @@ def pair_bias_attn_wrapper(layer):
15501560 @wraps (layer )
15511561 def inner (inputs , * args , ** kwargs ):
15521562 single_repr , pairwise_repr , mask , maybe_value_residual , maybe_pairwise_value_residuals = inputs
1553- attn_out , attn_values = layer (single_repr , pairwise_repr = pairwise_repr , mask = mask , return_values = True , value_residual = maybe_value_residual )
1554- single_repr = single_repr + attn_out
1563+ single_repr , attn_values = layer (single_repr , pairwise_repr = pairwise_repr , mask = mask , return_values = True , value_residual = maybe_value_residual )
15551564
15561565 if self .add_value_residual :
15571566 maybe_value_residual = default (maybe_value_residual , attn_values )
@@ -1563,7 +1572,7 @@ def single_transition_wrapper(layer):
15631572 @wraps (layer )
15641573 def inner (inputs , * args , ** kwargs ):
15651574 single_repr , pairwise_repr , mask , maybe_value_residual , maybe_pairwise_value_residuals = inputs
1566- single_repr = layer (single_repr ) + single_repr
1575+ single_repr = layer (single_repr )
15671576 return single_repr , pairwise_repr , mask , maybe_value_residual , maybe_pairwise_value_residuals
15681577 return inner
15691578
@@ -1579,6 +1588,8 @@ def inner(inputs, *args, **kwargs):
15791588 wrapped_layers .append (pair_bias_attn_wrapper (pair_bias_attn ))
15801589 wrapped_layers .append (single_transition_wrapper (single_transition ))
15811590
1591+ single_repr = self .expand_streams (single_repr )
1592+
15821593 for _ in range (self .recurrent_depth ):
15831594 inputs = (single_repr , pairwise_repr , mask , None , None )
15841595
@@ -1587,6 +1598,8 @@ def inner(inputs, *args, **kwargs):
15871598
15881599 single_repr , pairwise_repr , * _ = inputs
15891600
1601+ single_repr = self .reduce_streams (single_repr )
1602+
15901603 return single_repr , pairwise_repr
15911604
15921605 @typecheck
0 commit comments