@@ -180,7 +180,7 @@ def set_propagation_alg(self, propagation_alg: str, **kwargs):
180
180
def forward (self , inputs : torch .Tensor , input_layer_fn : Optional [Union [str ,Callable ]] = None ,
181
181
cache : Optional [dict ] = None , return_cache : bool = False , record_cudagraph : bool = False ,
182
182
apply_cudagraph : bool = True , force_use_bf16 : bool = False , force_use_fp32 : bool = False ,
183
- propagation_alg : Optional [Union [str ,Sequence [str ]]] = None , ** kwargs ):
183
+ propagation_alg : Optional [Union [str ,Sequence [str ]]] = None , _inner_layers_only : bool = False , ** kwargs ):
184
184
"""
185
185
Forward evaluation of the PC.
186
186
@@ -217,19 +217,20 @@ def forward(self, inputs: torch.Tensor, input_layer_fn: Optional[Union[str,Calla
217
217
218
218
with torch .no_grad ():
219
219
# Input layers
220
- for idx , layer in enumerate (self .input_layer_group ):
221
- if input_layer_fn is None :
222
- layer (inputs , self .node_mars , ** kwargs )
220
+ if not _inner_layers_only :
221
+ for idx , layer in enumerate (self .input_layer_group ):
222
+ if input_layer_fn is None :
223
+ layer (inputs , self .node_mars , ** kwargs )
223
224
224
- elif isinstance (input_layer_fn , str ):
225
- assert hasattr (layer , input_layer_fn ), f"Custom input function `{ input_layer_fn } ` not found for layer type { type (layer )} ."
226
- getattr (layer , input_layer_fn )(inputs , self .node_mars , ** kwargs )
225
+ elif isinstance (input_layer_fn , str ):
226
+ assert hasattr (layer , input_layer_fn ), f"Custom input function `{ input_layer_fn } ` not found for layer type { type (layer )} ."
227
+ getattr (layer , input_layer_fn )(inputs , self .node_mars , ** kwargs )
227
228
228
- elif isinstance (input_layer_fn , Callable ):
229
- input_layer_fn (layer , inputs , self .node_mars , ** kwargs )
229
+ elif isinstance (input_layer_fn , Callable ):
230
+ input_layer_fn (layer , inputs , self .node_mars , ** kwargs )
230
231
231
- else :
232
- raise ValueError (f"Custom input function should be either a `str` or a `Callable`. Found { type (input_layer_fn )} instead." )
232
+ else :
233
+ raise ValueError (f"Custom input function should be either a `str` or a `Callable`. Found { type (input_layer_fn )} instead." )
233
234
234
235
# Inner layers
235
236
def _run_inner_layers ():
@@ -319,6 +320,7 @@ def backward(self, inputs: Optional[torch.Tensor] = None,
319
320
propagation_alg : Union [str ,Sequence [str ]] = "LL" ,
320
321
logspace_flows : bool = False ,
321
322
negate_pflows : bool = False ,
323
+ _inner_layers_only : bool = False ,
322
324
** kwargs ):
323
325
"""
324
326
Backward evaluation of the PC that computes node flows as well as parameter flows.
@@ -443,19 +445,20 @@ def _run_inner_layers():
443
445
_run_inner_layers ()
444
446
445
447
# Compute backward pass for all input layers
446
- for idx , layer in enumerate (self .input_layer_group ):
447
- if input_layer_fn is None :
448
- layer .backward (inputs , self .node_flows , self .node_mars , logspace_flows = logspace_flows , ** kwargs )
448
+ if not _inner_layers_only :
449
+ for idx , layer in enumerate (self .input_layer_group ):
450
+ if input_layer_fn is None :
451
+ layer .backward (inputs , self .node_flows , self .node_mars , logspace_flows = logspace_flows , ** kwargs )
449
452
450
- elif isinstance (input_layer_fn , str ):
451
- assert hasattr (layer , input_layer_fn ), f"Custom input function `{ input_layer_fn } ` not found for layer type { type (layer )} ."
452
- getattr (layer , input_layer_fn )(inputs , self .node_flows , self .node_mars , logspace_flows = logspace_flows , ** kwargs )
453
+ elif isinstance (input_layer_fn , str ):
454
+ assert hasattr (layer , input_layer_fn ), f"Custom input function `{ input_layer_fn } ` not found for layer type { type (layer )} ."
455
+ getattr (layer , input_layer_fn )(inputs , self .node_flows , self .node_mars , logspace_flows = logspace_flows , ** kwargs )
453
456
454
- elif isinstance (input_layer_fn , Callable ):
455
- input_layer_fn (layer , inputs , self .node_flows , self .node_mars , logspace_flows = logspace_flows , ** kwargs )
457
+ elif isinstance (input_layer_fn , Callable ):
458
+ input_layer_fn (layer , inputs , self .node_flows , self .node_mars , logspace_flows = logspace_flows , ** kwargs )
456
459
457
- else :
458
- raise ValueError (f"Custom input function should be either a `str` or a `Callable`. Found { type (input_layer_fn )} instead." )
460
+ else :
461
+ raise ValueError (f"Custom input function should be either a `str` or a `Callable`. Found { type (input_layer_fn )} instead." )
459
462
460
463
if return_cache :
461
464
if cache is None :
0 commit comments