Skip to content

Commit aea003a

Browse files
committed
add flag to only run inner layers
1 parent 7574e9b commit aea003a

File tree

1 file changed

+24
-21
lines changed

1 file changed

+24
-21
lines changed

src/pyjuice/model/tensorcircuit.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def set_propagation_alg(self, propagation_alg: str, **kwargs):
180180
def forward(self, inputs: torch.Tensor, input_layer_fn: Optional[Union[str,Callable]] = None,
181181
cache: Optional[dict] = None, return_cache: bool = False, record_cudagraph: bool = False,
182182
apply_cudagraph: bool = True, force_use_bf16: bool = False, force_use_fp32: bool = False,
183-
propagation_alg: Optional[Union[str,Sequence[str]]] = None, **kwargs):
183+
propagation_alg: Optional[Union[str,Sequence[str]]] = None, _inner_layers_only: bool = False, **kwargs):
184184
"""
185185
Forward evaluation of the PC.
186186
@@ -217,19 +217,20 @@ def forward(self, inputs: torch.Tensor, input_layer_fn: Optional[Union[str,Calla
217217

218218
with torch.no_grad():
219219
# Input layers
220-
for idx, layer in enumerate(self.input_layer_group):
221-
if input_layer_fn is None:
222-
layer(inputs, self.node_mars, **kwargs)
220+
if not _inner_layers_only:
221+
for idx, layer in enumerate(self.input_layer_group):
222+
if input_layer_fn is None:
223+
layer(inputs, self.node_mars, **kwargs)
223224

224-
elif isinstance(input_layer_fn, str):
225-
assert hasattr(layer, input_layer_fn), f"Custom input function `{input_layer_fn}` not found for layer type {type(layer)}."
226-
getattr(layer, input_layer_fn)(inputs, self.node_mars, **kwargs)
225+
elif isinstance(input_layer_fn, str):
226+
assert hasattr(layer, input_layer_fn), f"Custom input function `{input_layer_fn}` not found for layer type {type(layer)}."
227+
getattr(layer, input_layer_fn)(inputs, self.node_mars, **kwargs)
227228

228-
elif isinstance(input_layer_fn, Callable):
229-
input_layer_fn(layer, inputs, self.node_mars, **kwargs)
229+
elif isinstance(input_layer_fn, Callable):
230+
input_layer_fn(layer, inputs, self.node_mars, **kwargs)
230231

231-
else:
232-
raise ValueError(f"Custom input function should be either a `str` or a `Callable`. Found {type(input_layer_fn)} instead.")
232+
else:
233+
raise ValueError(f"Custom input function should be either a `str` or a `Callable`. Found {type(input_layer_fn)} instead.")
233234

234235
# Inner layers
235236
def _run_inner_layers():
@@ -319,6 +320,7 @@ def backward(self, inputs: Optional[torch.Tensor] = None,
319320
propagation_alg: Union[str,Sequence[str]] = "LL",
320321
logspace_flows: bool = False,
321322
negate_pflows: bool = False,
323+
_inner_layers_only: bool = False,
322324
**kwargs):
323325
"""
324326
Backward evaluation of the PC that computes node flows as well as parameter flows.
@@ -443,19 +445,20 @@ def _run_inner_layers():
443445
_run_inner_layers()
444446

445447
# Compute backward pass for all input layers
446-
for idx, layer in enumerate(self.input_layer_group):
447-
if input_layer_fn is None:
448-
layer.backward(inputs, self.node_flows, self.node_mars, logspace_flows = logspace_flows, **kwargs)
448+
if not _inner_layers_only:
449+
for idx, layer in enumerate(self.input_layer_group):
450+
if input_layer_fn is None:
451+
layer.backward(inputs, self.node_flows, self.node_mars, logspace_flows = logspace_flows, **kwargs)
449452

450-
elif isinstance(input_layer_fn, str):
451-
assert hasattr(layer, input_layer_fn), f"Custom input function `{input_layer_fn}` not found for layer type {type(layer)}."
452-
getattr(layer, input_layer_fn)(inputs, self.node_flows, self.node_mars, logspace_flows = logspace_flows, **kwargs)
453+
elif isinstance(input_layer_fn, str):
454+
assert hasattr(layer, input_layer_fn), f"Custom input function `{input_layer_fn}` not found for layer type {type(layer)}."
455+
getattr(layer, input_layer_fn)(inputs, self.node_flows, self.node_mars, logspace_flows = logspace_flows, **kwargs)
453456

454-
elif isinstance(input_layer_fn, Callable):
455-
input_layer_fn(layer, inputs, self.node_flows, self.node_mars, logspace_flows = logspace_flows, **kwargs)
457+
elif isinstance(input_layer_fn, Callable):
458+
input_layer_fn(layer, inputs, self.node_flows, self.node_mars, logspace_flows = logspace_flows, **kwargs)
456459

457-
else:
458-
raise ValueError(f"Custom input function should be either a `str` or a `Callable`. Found {type(input_layer_fn)} instead.")
460+
else:
461+
raise ValueError(f"Custom input function should be either a `str` or a `Callable`. Found {type(input_layer_fn)} instead.")
459462

460463
if return_cache:
461464
if cache is None:

0 commit comments

Comments
 (0)