|
21 | 21 | def inputs(var: Union[int,Sequence[int]], num_node_groups: int = 0, dist: Distribution = Distribution(),
|
22 | 22 | params: Optional[Tensor] = None, num_nodes: int = 0, group_size: int = 0, **kwargs):
|
23 | 23 |
|
| 24 | + assert group_size == 0 or group_size & (group_size - 1) == 0, "`group_size` must be a power of 2." |
| 25 | + |
24 | 26 | if num_nodes > 0:
|
25 | 27 | assert num_node_groups == 0, "Only one of `num_nodes` and `num_node_groups` can be set at the same time."
|
26 | 28 | if group_size == 0:
|
@@ -72,6 +74,8 @@ def multiply(nodes1: ProdNodesChs, *args, edge_ids: Optional[Tensor] = None, spa
|
72 | 74 | def summate(nodes1: SumNodesChs, *args, num_nodes: int = 0, num_node_groups: int = 0,
|
73 | 75 | edge_ids: Optional[Tensor] = None, group_size: int = 0, **kwargs):
|
74 | 76 |
|
| 77 | + assert group_size == 0 or group_size & (group_size - 1) == 0, "`group_size` must be a power of 2." |
| 78 | + |
75 | 79 | if num_nodes > 0:
|
76 | 80 | assert num_node_groups == 0, "Only one of `num_nodes` and `num_node_groups` can be set at the same time."
|
77 | 81 | if group_size == 0:
|
@@ -101,6 +105,8 @@ def summate(nodes1: SumNodesChs, *args, num_nodes: int = 0, num_node_groups: int
|
101 | 105 | class set_group_size(_DecoratorContextManager):
|
102 | 106 | def __init__(self, group_size: int = 1):
|
103 | 107 |
|
| 108 | + assert group_size & (group_size - 1) == 0, "`group_size` must be a power of 2." |
| 109 | + |
104 | 110 | self.group_size = group_size
|
105 | 111 |
|
106 | 112 | self.original_group_size = None
|
|
0 commit comments