111111
112112from alphafold3_pytorch .utils .model_utils import distance_to_dgram
113113
114+ # personal libraries
115+
114116from frame_averaging_pytorch import FrameAverage
115117
116118from taylor_series_linear_attention import TaylorSeriesLinearAttn
117119
118120from colt5_attention import ConditionalRoutedAttention
119121
120- import einx
121- from einops import rearrange , repeat , reduce , einsum , pack , unpack
122- from einops .layers .torch import Rearrange
122+ from hyper_connections import HyperConnections
123123
124- from tqdm import tqdm
124+ # other external libs
125125
126+ from tqdm import tqdm
126127from loguru import logger
127128
128129from importlib .metadata import version
132133from Bio .PDB .Structure import Structure
133134from Bio .PDB .StructureBuilder import StructureBuilder
134135
136+ # einstein notation related
137+
138+ import einx
139+ from einops import rearrange , repeat , reduce , einsum , pack , unpack
140+ from einops .layers .torch import Rearrange
141+
135142"""
136143global ein notation:
137144
@@ -2008,6 +2015,7 @@ def __init__(
20082015 use_linear_attn = False ,
20092016 checkpoint = False ,
20102017 add_value_residual = False ,
2018+ num_residual_streams = 1 ,
20112019 linear_attn_kwargs = dict (
20122020 heads = 8 ,
20132021 dim_head = 16
@@ -2026,6 +2034,12 @@ def __init__(
20262034
20272035 dim_single_cond = default (dim_single_cond , dim )
20282036
2037+ # hyper connections
2038+
2039+ init_hyper_conn , self .expand_streams , self .reduce_streams = HyperConnections .get_init_and_expand_reduce_stream_functions (num_residual_streams , disable = num_residual_streams == 1 )
2040+
2041+ # layers
2042+
20292043 layers = ModuleList ([])
20302044
20312045 for i in range (depth ):
@@ -2042,6 +2056,8 @@ def __init__(
20422056 ** linear_attn_kwargs
20432057 )
20442058
2059+ linear_attn = init_hyper_conn (dim = dim , branch = linear_attn )
2060+
20452061 colt5_attn = None
20462062
20472063 if use_colt5_attn :
@@ -2051,6 +2067,8 @@ def __init__(
20512067 ** colt5_attn_kwargs
20522068 )
20532069
2070+ colt5_attn = init_hyper_conn (dim = dim , branch = colt5_attn )
2071+
20542072 accept_value_residual = add_value_residual and not is_first
20552073
20562074 pair_bias_attn = AttentionPairBias (
@@ -2083,8 +2101,8 @@ def __init__(
20832101 layers .append (ModuleList ([
20842102 linear_attn ,
20852103 colt5_attn ,
2086- conditionable_pair_bias ,
2087- conditionable_transition
2104+ init_hyper_conn ( dim = dim , branch = conditionable_pair_bias ) ,
2105+ init_hyper_conn ( dim = dim , branch = conditionable_transition )
20882106 ]))
20892107
20902108 self .checkpoint = checkpoint
@@ -2112,24 +2130,21 @@ def to_checkpointed_serial_layers(
21122130 windowed_mask : Bool ['b nw w (w*2)' ] | None = None
21132131 ):
21142132
2115- inputs = (noised_repr , single_repr , pairwise_repr , mask , windowed_mask , None )
2116-
21172133 wrapped_layers = []
21182134
21192135 def efficient_attn_wrapper (fn ):
21202136 @wraps (fn )
21212137 def inner (inputs ):
21222138 noised_repr , single_repr , pairwise_repr , mask , windowed_mask , maybe_value_residual = inputs
2123- noised_repr = fn (noised_repr , mask = mask ) + noised_repr
2139+ noised_repr = fn (noised_repr , mask = mask )
21242140 return noised_repr , single_repr , pairwise_repr , mask , windowed_mask , maybe_value_residual
21252141 return inner
21262142
21272143 def attn_wrapper (fn ):
21282144 @wraps (fn )
21292145 def inner (inputs ):
21302146 noised_repr , single_repr , pairwise_repr , mask , windowed_mask , maybe_value_residual = inputs
2131- attn_out , attn_values = fn (noised_repr , cond = single_repr , pairwise_repr = pairwise_repr , mask = mask , windowed_mask = windowed_mask , value_residual = maybe_value_residual , return_values = True )
2132- noised_repr = attn_out + noised_repr
2147+ noised_repr , attn_values = fn (noised_repr , cond = single_repr , pairwise_repr = pairwise_repr , mask = mask , windowed_mask = windowed_mask , value_residual = maybe_value_residual , return_values = True )
21332148
21342149 if self .add_value_residual :
21352150 maybe_value_residual = default (maybe_value_residual , attn_values )
@@ -2141,10 +2156,12 @@ def transition_wrapper(fn):
21412156 @wraps (fn )
21422157 def inner (inputs ):
21432158 noised_repr , single_repr , pairwise_repr , mask , windowed_mask , maybe_value_residual = inputs
2144- noised_repr = fn (noised_repr , cond = single_repr ) + noised_repr
2159+ noised_repr = fn (noised_repr , cond = single_repr )
21452160 return noised_repr , single_repr , pairwise_repr , mask , windowed_mask , maybe_value_residual
21462161 return inner
21472162
2163+ # wrap layers
2164+
21482165 for linear_attn , colt5_attn , attn , transition in self .layers :
21492166
21502167 if exists (linear_attn ):
@@ -2156,10 +2173,19 @@ def inner(inputs):
21562173 wrapped_layers .append (attn_wrapper (attn ))
21572174 wrapped_layers .append (transition_wrapper (transition ))
21582175
2176+ # forward
2177+
2178+ noised_repr = self .expand_streams (noised_repr )
2179+
2180+ inputs = (noised_repr , single_repr , pairwise_repr , mask , windowed_mask , None )
2181+
21592182 for layer in wrapped_layers :
21602183 inputs = checkpoint (layer , inputs )
21612184
21622185 noised_repr , * _ = inputs
2186+
2187+ noised_repr = self .reduce_streams (noised_repr )
2188+
21632189 return noised_repr
21642190
21652191 @typecheck
@@ -2175,15 +2201,17 @@ def to_serial_layers(
21752201
21762202 value_residual = None
21772203
2204+ noised_repr = self .expand_streams (noised_repr )
2205+
21782206 for linear_attn , colt5_attn , attn , transition in self .layers :
21792207
21802208 if exists (linear_attn ):
2181- noised_repr = linear_attn (noised_repr , mask = mask ) + noised_repr
2209+ noised_repr = linear_attn (noised_repr , mask = mask )
21822210
21832211 if exists (colt5_attn ):
2184- noised_repr = colt5_attn (noised_repr , mask = mask ) + noised_repr
2212+ noised_repr = colt5_attn (noised_repr , mask = mask )
21852213
2186- attn_out , attn_values = attn (
2214+ noised_repr , attn_values = attn (
21872215 noised_repr ,
21882216 cond = single_repr ,
21892217 pairwise_repr = pairwise_repr ,
@@ -2193,15 +2221,15 @@ def to_serial_layers(
21932221 value_residual = value_residual
21942222 )
21952223
2196- noised_repr = noised_repr + attn_out
2197-
21982224 if self .add_value_residual :
21992225 value_residual = default (value_residual , attn_values )
22002226
22012227 noised_repr = transition (
22022228 noised_repr ,
22032229 cond = single_repr
2204- ) + noised_repr
2230+ )
2231+
2232+ noised_repr = self .reduce_streams (noised_repr )
22052233
22062234 return noised_repr
22072235
0 commit comments