Skip to content

Commit dce748e

Browse files
committed
rescale missing flows for input nodes with unnormalized parameters
1 parent 06f561e commit dce748e

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

src/pyjuice/layer/input_layer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,12 +609,17 @@ def eval_partition_fn(self, node_mars: torch.Tensor, params: Optional[Dict] = No
609609
else:
610610
raise NotImplementedError("CPU minibatch partition fn for input nodes is not implemented.")
611611

612-
def add_missing_flows(self, node_flows: torch.Tensor, logspace_flows: bool = False, scale: float = 1.0):
612+
def add_missing_flows(self, node_flows: torch.Tensor, node_mars: Optional[torch.Tensor] = None, logspace_flows: bool = False,
613+
scale: float = 1.0, pc_is_normalized: bool = True):
613614
"""
614615
Add missing flows specified by `node_flows` to the input node parameters.
615616
node_flows: [num_nodes]
616617
"""
617618

619+
if not pc_is_normalized:
620+
sid, eid = self._output_ind_range
621+
node_flows[sid:eid] /= torch.exp(node_mars[sid:eid])
622+
618623
if "cuda" in self.device.type:
619624
node_offset = self._output_ind_range[0]
620625
layer_num_nodes = self._output_ind_range[1] - self._output_ind_range[0]

0 commit comments

Comments
 (0)