@@ -531,7 +531,10 @@ def _fw_triton_block_sparse_tlmm_kernel(node_mars, element_mars, params, nids, c
531
531
532
532
acc = tl .where (emars_max > acc ,
533
533
tl .log (nmars + tl .exp (acc - emars_max ) + 1e-24 ) + emars_max ,
534
- tl .log (tl .exp (emars_max - acc ) * nmars + 1.0 ) + acc
534
+ tl .where (acc != - float ("inf" ),
535
+ tl .log (tl .exp (emars_max - acc ) * nmars + 1.0 ) + acc ,
536
+ - float ("inf" )
537
+ )
535
538
)
536
539
537
540
# Increment `epars_ptr`
@@ -647,7 +650,10 @@ def _fw_triton_block_sparse_csmm1_kernel(node_mars, element_mars, params, nids,
647
650
648
651
acc = tl .where (emars_max > acc ,
649
652
tl .log (nmars + tl .exp (acc - emars_max ) + 1e-24 ) + emars_max ,
650
- tl .log (tl .exp (emars_max - acc ) * nmars + 1.0 ) + acc
653
+ tl .where (acc != - float ("inf" ),
654
+ tl .log (tl .exp (emars_max - acc ) * nmars + 1.0 ) + acc ,
655
+ - float ("inf" )
656
+ )
651
657
)
652
658
653
659
# Increment `epars_ptr`
@@ -757,7 +763,10 @@ def _fw_triton_block_sparse_csmm2_kernel(node_mars, element_mars, params, nids,
757
763
758
764
acc = tl .where (emars_max [None ,:] > acc ,
759
765
tl .log (nmars + tl .exp (acc - emars_max [None ,:]) + 1e-24 ) + emars_max [None ,:],
760
- tl .log (tl .exp (emars_max [None ,:] - acc ) * nmars + 1.0 ) + acc
766
+ tl .where (acc != - float ("inf" ),
767
+ tl .log (tl .exp (emars_max [None ,:] - acc ) * nmars + 1.0 ) + acc ,
768
+ - float ("inf" )
769
+ )
761
770
)
762
771
763
772
# Increment `epars_ptr`
@@ -1679,7 +1688,7 @@ def _bk_triton_block_sparse_ele_kernel(node_flows, element_flows, node_mars, ele
1679
1688
1680
1689
if logspace_flows :
1681
1690
partial_flows_max = emars + log_n_fdm_max
1682
- acc = tl .where (log_n_fdm_max == - float ("inf" ),
1691
+ acc = tl .where (partial_flows_max == - float ("inf" ),
1683
1692
acc ,
1684
1693
tl .where (partial_flows_max > acc ,
1685
1694
tl .log (partial_flows + tl .exp (acc - partial_flows_max ) + 1e-24 ) + partial_flows_max ,
@@ -1838,7 +1847,7 @@ def _bk_triton_block_sparse_ele_csmm2_kernel(node_flows, element_flows, node_mar
1838
1847
1839
1848
if logspace_flows :
1840
1849
partial_flows_max = emars + log_n_fdm_max [None ,:]
1841
- acc = tl .where (log_n_fdm_max [ None ,:] == - float ("inf" ),
1850
+ acc = tl .where (partial_flows_max == - float ("inf" ),
1842
1851
acc ,
1843
1852
tl .where (partial_flows_max > acc ,
1844
1853
tl .log (partial_flows + tl .exp (acc - partial_flows_max ) + 1e-24 ) + partial_flows_max ,
0 commit comments