Skip to content

Commit 75041ac

Browse files
committed
Merge remote-tracking branch 'origin/main' into normalize
2 parents cd304a3 + 7111264 commit 75041ac

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/pyjuice/layer/sum_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2637,6 +2637,7 @@ def _bk_triton_large_sparse_ele_kernel(node_flows, element_flows, node_mars, ele
26372637

26382638
elflows_max = tl.max(elflows, axis = 1)
26392639
eflows = tl.log(tl.sum(tl.exp(elflows - elflows_max[:,None,:]), axis = 1)) + elflows_max
2640+
eflows = tl.where((elflows_max == -float("inf")) | (emars == -float("inf")), -float("inf"), eflows)
26402641
else:
26412642
if propagation_alg_id == 0:
26422643
eflows = tl.sum(nflows * epars[:,:,None] * tl.exp(emars[:,None,:] - nmars), axis = 1)
@@ -2741,7 +2742,6 @@ def _backward_sparse_ele_flows(self, node_flows: torch.Tensor, element_flows: to
27412742
)
27422743

27432744
else:
2744-
27452745
for pid_m_start in range(0, grid[1], 32768):
27462746

27472747
pid_m_end = min(pid_m_start + 32768, grid[1])

0 commit comments

Comments
 (0)