Skip to content

Commit cd304a3

Browse files
committed
Merge branch 'main' of https://github.com/Tractables/pyjuice into normalize
2 parents 4f268a8 + 0a4915e commit cd304a3

File tree

4 files changed

+32
-9
lines changed

4 files changed

+32
-9
lines changed

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,3 +240,17 @@ sum_ns = juice.summate(prod_ns1, prod_ns2, num_node_blocks = num_nodes // 4, blo
240240
```
241241

242242
The above is equivalent to considering the input nodes to be concatenated into a single vector of nodes, and then define the edges correspondingly.
243+
244+
## Citation
245+
246+
If you find PyJuice useful, please consider citing us:
247+
248+
```
249+
@inproceedings{liu2024scaling,
250+
title = {Scaling Tractable Probabilistic Circuits: A Systems Perspective},
251+
author = {Liu, Anji and Ahmed, Kareem and Van den Broeck, Guy},
252+
booktitle = {Proceedings of the 41th International Conference on Machine Learning (ICML)},
253+
month = {jul},
254+
year = {2024}
255+
}
256+
```

src/pyjuice/layer/compilation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def sum_layer_forward_compilation(nodes, fw_partition_max_chs, n_partition_ids,
268268
use_cuda = False
269269

270270
if use_cuda:
271-
device = torch.device("cuda:0")
271+
device = torch.device(f"cuda:{torch.cuda.current_device()}")
272272
else:
273273
device = torch.device("cpu")
274274

@@ -703,7 +703,7 @@ def sum_layer_backward_compilation(nodes, cs2parns, n_partition_ids, n_id_in_par
703703
use_cuda = False
704704

705705
if use_cuda:
706-
device = torch.device("cuda:0")
706+
device = torch.device(f"cuda:{torch.cuda.current_device()}")
707707
else:
708708
device = torch.device("cpu")
709709

@@ -912,7 +912,7 @@ def prod_layer_forward_compilation(nodes, fw_partition_max_chs, n_partition_ids,
912912
assert block_size == 1
913913

914914
if use_cuda:
915-
device = torch.device("cuda:0")
915+
device = torch.device(f"cuda:{torch.cuda.current_device()}")
916916
else:
917917
device = torch.device("cpu")
918918

src/pyjuice/layer/input_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ def sample(self, samples: torch.Tensor, node_flows: torch.Tensor, missing_mask:
499499
nv_block_size = triton.next_power_of_2(self.num_vars_per_node),
500500
batch_size = batch_size,
501501
BLOCK_SIZE = BLOCK_SIZE,
502-
seed = seed if seed is not None else random.randint(0, 1e8)
502+
seed = seed if seed is not None else random.randint(0, int(1e8))
503503
)
504504

505505
else:

src/pyjuice/layer/sum_layer.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,10 @@ def _fw_triton_block_sparse_tlmm_kernel(node_mars, element_mars, params, nids, c
531531

532532
acc = tl.where(emars_max > acc,
533533
tl.log(nmars + tl.exp(acc - emars_max) + 1e-24) + emars_max,
534-
tl.log(tl.exp(emars_max - acc) * nmars + 1.0) + acc
534+
tl.where(acc != -float("inf"),
535+
tl.log(tl.exp(emars_max - acc) * nmars + 1.0) + acc,
536+
-float("inf")
537+
)
535538
)
536539

537540
# Increment `epars_ptr`
@@ -647,7 +650,10 @@ def _fw_triton_block_sparse_csmm1_kernel(node_mars, element_mars, params, nids,
647650

648651
acc = tl.where(emars_max > acc,
649652
tl.log(nmars + tl.exp(acc - emars_max) + 1e-24) + emars_max,
650-
tl.log(tl.exp(emars_max - acc) * nmars + 1.0) + acc
653+
tl.where(acc != -float("inf"),
654+
tl.log(tl.exp(emars_max - acc) * nmars + 1.0) + acc,
655+
-float("inf")
656+
)
651657
)
652658

653659
# Increment `epars_ptr`
@@ -757,7 +763,10 @@ def _fw_triton_block_sparse_csmm2_kernel(node_mars, element_mars, params, nids,
757763

758764
acc = tl.where(emars_max[None,:] > acc,
759765
tl.log(nmars + tl.exp(acc - emars_max[None,:]) + 1e-24) + emars_max[None,:],
760-
tl.log(tl.exp(emars_max[None,:] - acc) * nmars + 1.0) + acc
766+
tl.where(acc != -float("inf"),
767+
tl.log(tl.exp(emars_max[None,:] - acc) * nmars + 1.0) + acc,
768+
-float("inf")
769+
)
761770
)
762771

763772
# Increment `epars_ptr`
@@ -1679,7 +1688,7 @@ def _bk_triton_block_sparse_ele_kernel(node_flows, element_flows, node_mars, ele
16791688

16801689
if logspace_flows:
16811690
partial_flows_max = emars + log_n_fdm_max
1682-
acc = tl.where(log_n_fdm_max == -float("inf"),
1691+
acc = tl.where(partial_flows_max == -float("inf"),
16831692
acc,
16841693
tl.where(partial_flows_max > acc,
16851694
tl.log(partial_flows + tl.exp(acc - partial_flows_max) + 1e-24) + partial_flows_max,
@@ -1838,7 +1847,7 @@ def _bk_triton_block_sparse_ele_csmm2_kernel(node_flows, element_flows, node_mar
18381847

18391848
if logspace_flows:
18401849
partial_flows_max = emars + log_n_fdm_max[None,:]
1841-
acc = tl.where(log_n_fdm_max[None,:] == -float("inf"),
1850+
acc = tl.where(partial_flows_max == -float("inf"),
18421851
acc,
18431852
tl.where(partial_flows_max > acc,
18441853
tl.log(partial_flows + tl.exp(acc - partial_flows_max) + 1e-24) + partial_flows_max,

0 commit comments

Comments
 (0)