Skip to content

Commit b7b13f9

Browse files
committed
add group size assertions
1 parent 26cd5f3 commit b7b13f9

File tree

3 files changed

+8
-2
lines changed

3 files changed

+8
-2
lines changed

src/pyjuice/layer/prod_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ def _forward_backward(self, node_vals: torch.Tensor, element_vals: torch.Tensor,
412412
accum = 1 if accum else 0
413413
partial_eval = 1 if local_ids is not None else 0
414414

415-
assert num_edges & (num_edges - 1) == 0, "`num_edges` must be power of 2."
415+
assert num_edges & (num_edges - 1) == 0, "`num_edges` must be a power of 2."
416416

417417
# Fall back to the `torch.compile` kernel in the case where we cannot store child edges within a single block
418418
if num_edges > 1024:

src/pyjuice/model/backend/par_update.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _record_par_blks(par_start_ids, pflow_start_ids, blk_sizes, blk_intervals, g
4747
@torch.no_grad()
4848
def compile_par_update_fn(root_ns: CircuitNodes, BLOCK_SIZE: int = 32, buffer_inc_interval: int = 10000, use_numba: bool = True):
4949

50-
assert BLOCK_SIZE & (BLOCK_SIZE - 1) == 0, "`BLOCK_SIZE` must be power of 2."
50+
assert BLOCK_SIZE & (BLOCK_SIZE - 1) == 0, "`BLOCK_SIZE` must be a power of 2."
5151

5252
par_start_ids = np.zeros([buffer_inc_interval], dtype = np.int64)
5353
pflow_start_ids = np.zeros([buffer_inc_interval], dtype = np.int64)

src/pyjuice/nodes/construction.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
def inputs(var: Union[int,Sequence[int]], num_node_groups: int = 0, dist: Distribution = Distribution(),
2222
params: Optional[Tensor] = None, num_nodes: int = 0, group_size: int = 0, **kwargs):
2323

24+
assert group_size == 0 or group_size & (group_size - 1) == 0, "`group_size` must be a power of 2."
25+
2426
if num_nodes > 0:
2527
assert num_node_groups == 0, "Only one of `num_nodes` and `num_node_groups` can be set at the same time."
2628
if group_size == 0:
@@ -72,6 +74,8 @@ def multiply(nodes1: ProdNodesChs, *args, edge_ids: Optional[Tensor] = None, spa
7274
def summate(nodes1: SumNodesChs, *args, num_nodes: int = 0, num_node_groups: int = 0,
7375
edge_ids: Optional[Tensor] = None, group_size: int = 0, **kwargs):
7476

77+
assert group_size == 0 or group_size & (group_size - 1) == 0, "`group_size` must be a power of 2."
78+
7579
if num_nodes > 0:
7680
assert num_node_groups == 0, "Only one of `num_nodes` and `num_node_groups` can be set at the same time."
7781
if group_size == 0:
@@ -101,6 +105,8 @@ def summate(nodes1: SumNodesChs, *args, num_nodes: int = 0, num_node_groups: int
101105
class set_group_size(_DecoratorContextManager):
102106
def __init__(self, group_size: int = 1):
103107

108+
assert group_size & (group_size - 1) == 0, "`group_size` must be a power of 2."
109+
104110
self.group_size = group_size
105111

106112
self.original_group_size = None

0 commit comments

Comments
 (0)