Skip to content

Commit fafab08

Browse files
committed
add flag to disable buffer init
1 parent 4502f7b commit fafab08

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/pyjuice/model/tensorcircuit.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ def backward(self, inputs: Optional[torch.Tensor] = None,
323323
logspace_flows: bool = False,
324324
negate_pflows: bool = False,
325325
_inner_layers_only: bool = False,
326+
_disable_buffer_init: bool = False,
326327
**kwargs):
327328
"""
328329
Backward evaluation of the PC that computes node flows as well as parameter flows.
@@ -351,8 +352,9 @@ def backward(self, inputs: Optional[torch.Tensor] = None,
351352

352353
## Initialize buffers for backward pass ##
353354

354-
self._init_buffer(name = "node_flows", shape = (self.num_nodes, B), set_value = 0.0 if not logspace_flows else -float("inf"))
355-
self._init_buffer(name = "element_flows", shape = (self.num_elements, B), set_value = 0.0 if not logspace_flows else -float("inf"))
355+
if not _disable_buffer_init:
356+
self._init_buffer(name = "node_flows", shape = (self.num_nodes, B), set_value = 0.0 if not logspace_flows else -float("inf"))
357+
self._init_buffer(name = "element_flows", shape = (self.num_elements, B), set_value = 0.0 if not logspace_flows else -float("inf"))
356358

357359
# Set root node flows
358360
def _set_root_node_flows():

0 commit comments

Comments
 (0)