119119
120120from colt5_attention import ConditionalRoutedAttention
121121
122- from hyper_connections import HyperConnections
122+ from hyper_connections . hyper_connections_with_multi_input_streams import HyperConnections
123123
124124# other external libs
125125
@@ -995,8 +995,8 @@ def __init__(
995995 @typecheck
996996 def forward (
997997 self ,
998- * ,
999998 pairwise_repr : Float ['b n n d' ],
999+ * ,
10001000 mask : Bool ['b n' ] | None = None ,
10011001 value_residuals : tuple [Tensor , Tensor ] | None = None ,
10021002 return_values = False ,
@@ -1470,8 +1470,8 @@ def __init__(
14701470 single_transition = Transition (dim = dim_single )
14711471
14721472 layers .append (ModuleList ([
1473- pairwise_block ,
1474- init_hyper_conn (dim = dim_single , branch = single_pre_ln (pair_bias_attn )),
1473+ init_hyper_conn ( dim = dim_pairwise , branch = pairwise_block ) ,
1474+ init_hyper_conn (dim = dim_single , additional_input_paths = [( 'pairwise_repr' , dim_pairwise )], branch = single_pre_ln (pair_bias_attn )),
14751475 init_hyper_conn (dim = dim_single , branch = single_pre_ln (single_transition )),
14761476 ]))
14771477
@@ -1508,6 +1508,7 @@ def to_layers(
15081508 ) -> Tuple [Float ['b n ds' ], Float ['b n n dp' ]]:
15091509
15101510 single_repr = self .expand_streams (single_repr )
1511+ pairwise_repr = self .expand_streams (pairwise_repr )
15111512
15121513 for _ in range (self .recurrent_depth ):
15131514
@@ -1520,7 +1521,7 @@ def to_layers(
15201521 single_transition
15211522 ) in self .layers :
15221523
1523- pairwise_repr , pairwise_attn_values = pairwise_block (pairwise_repr = pairwise_repr , mask = mask , value_residuals = pairwise_value_residuals , return_values = True )
1524+ pairwise_repr , pairwise_attn_values = pairwise_block (pairwise_repr , mask = mask , value_residuals = pairwise_value_residuals , return_values = True )
15241525
15251526 single_repr , attn_values = pair_bias_attn (single_repr , pairwise_repr = pairwise_repr , mask = mask , return_values = True , value_residual = value_residual )
15261527
@@ -1531,6 +1532,7 @@ def to_layers(
15311532 single_repr = single_transition (single_repr )
15321533
15331534 single_repr = self .reduce_streams (single_repr )
1535+ pairwise_repr = self .reduce_streams (pairwise_repr )
15341536
15351537 return single_repr , pairwise_repr
15361538
@@ -1548,7 +1550,7 @@ def pairwise_block_wrapper(layer):
15481550 @wraps (layer )
15491551 def inner (inputs , * args , ** kwargs ):
15501552 single_repr , pairwise_repr , mask , maybe_value_residual , maybe_pairwise_value_residuals = inputs
1551- pairwise_repr , pairwise_attn_values = layer (pairwise_repr = pairwise_repr , mask = mask , value_residuals = maybe_pairwise_value_residuals , return_values = True )
1553+ pairwise_repr , pairwise_attn_values = layer (pairwise_repr , mask = mask , value_residuals = maybe_pairwise_value_residuals , return_values = True )
15521554
15531555 if self .add_value_residual :
15541556 maybe_pairwise_value_residuals = default (maybe_pairwise_value_residuals , pairwise_attn_values )
@@ -1589,6 +1591,7 @@ def inner(inputs, *args, **kwargs):
15891591 wrapped_layers .append (single_transition_wrapper (single_transition ))
15901592
15911593 single_repr = self .expand_streams (single_repr )
1594+ pairwise_repr = self .expand_streams (pairwise_repr )
15921595
15931596 for _ in range (self .recurrent_depth ):
15941597 inputs = (single_repr , pairwise_repr , mask , None , None )
@@ -1599,6 +1602,7 @@ def inner(inputs, *args, **kwargs):
15991602 single_repr , pairwise_repr , * _ = inputs
16001603
16011604 single_repr = self .reduce_streams (single_repr )
1605+ pairwise_repr = self .reduce_streams (pairwise_repr )
16021606
16031607 return single_repr , pairwise_repr
16041608
0 commit comments