Skip to content

Commit 3c21663

Browse files
committed
fix triton bug in sampling
1 parent 0832766 commit 3c21663

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

src/pyjuice/queries/sample.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,9 @@ def count_prod_nch_kernel(nids, cids, element_samples, ind_ch_count, ind_nids, i
197197
tl.store(ind_nid_offs + offs_sample, local_nid_offs, mask = mask_sample)
198198
tl.store(ind_mask + offs_sample, partition_id, mask = mask_sample)
199199

200+
# Handle triton bug.. (otherwise `local_nids` will be wrong)
201+
local_nids = tl.load(ind_nids + offs_sample, mask = mask_sample, other = 0)
202+
200203
# Offset for children
201204
offs_child = tl.arange(0, BLOCK_C)
202205
mask_child = offs_child < num_edges

tests/queries/cond_sample_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,4 +129,5 @@ def test_cond_sample():
129129

130130

131131
if __name__ == "__main__":
132+
torch.manual_seed(2389)
132133
test_cond_sample()

0 commit comments

Comments
 (0)