Skip to content

Commit be6b554

Browse files
committed
fix blk sparse backward with -inf chs
1 parent 1b5eef7 commit be6b554

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/pyjuice/layer/sum_layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1688,7 +1688,7 @@ def _bk_triton_block_sparse_ele_kernel(node_flows, element_flows, node_mars, ele
16881688

16891689
if logspace_flows:
16901690
partial_flows_max = emars + log_n_fdm_max
1691-
acc = tl.where(log_n_fdm_max == -float("inf"),
1691+
acc = tl.where(partial_flows_max == -float("inf"),
16921692
acc,
16931693
tl.where(partial_flows_max > acc,
16941694
tl.log(partial_flows + tl.exp(acc - partial_flows_max) + 1e-24) + partial_flows_max,
@@ -1847,7 +1847,7 @@ def _bk_triton_block_sparse_ele_csmm2_kernel(node_flows, element_flows, node_mar
18471847

18481848
if logspace_flows:
18491849
partial_flows_max = emars + log_n_fdm_max[None,:]
1850-
acc = tl.where(log_n_fdm_max[None,:] == -float("inf"),
1850+
acc = tl.where(partial_flows_max == -float("inf"),
18511851
acc,
18521852
tl.where(partial_flows_max > acc,
18531853
tl.log(partial_flows + tl.exp(acc - partial_flows_max) + 1e-24) + partial_flows_max,

0 commit comments

Comments
 (0)