Skip to content

Commit e40fe7e

Browse files
committed
added optional disjoint to gather.py
1 parent 724bdab commit e40fe7e

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

kgcnn/layers/gather.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(self,
4242
split_axis: Union[int, None] = 2,
4343
split_indices: list = None,
4444
concat_axis: Union[int, None] = 2,
45+
allow_disjoint_implementation: bool = True,
4546
**kwargs):
4647
r"""Initialize layer.
4748
@@ -50,13 +51,15 @@ def __init__(self,
5051
split_axis (int): The axis to split indices to gather embeddings. Default is None.
5152
split_indices (list): List of indices to split from gathered tensor. Default is None.
5253
concat_axis (int): The axis which concatenates embeddings. Default is 2.
54+
allow_disjoint_implementation (bool): Whether to allow (preferred) disjoint implementation.
5355
"""
5456
super(GatherEmbedding, self).__init__(**kwargs)
5557
self.concat_axis = concat_axis
5658
self.axis = axis
5759
self.split_axis = split_axis
5860
self.split_indices = split_indices
5961
self.node_indexing = "sample"
62+
self.allow_disjoint_implementation = allow_disjoint_implementation
6063

6164
if self.concat_axis is not None and self.split_axis is None:
6265
raise ValueError("Can only concat `list` of gathered tensors. Require `split_axis` not None.")
@@ -111,7 +114,7 @@ def call(self, inputs, **kwargs):
111114
tf.RaggedTensor: Gathered node embeddings that match the number of edges of shape `(batch, [M], 2*F)`
112115
"""
113116
# Old disjoint implementation that could be faster.
114-
if self._is_disjoint_possible(inputs, **kwargs):
117+
if self._is_disjoint_possible(inputs, **kwargs) and self.allow_disjoint_implementation:
115118
return self._disjoint_implementation(inputs, **kwargs)
116119

117120
# For arbitrary gather from ragged tensor use tf.gather with batch_dims=1.
@@ -136,7 +139,7 @@ def get_config(self):
136139
config = super(GatherEmbedding, self).get_config()
137140
config.update({
138141
"concat_axis": self.concat_axis, "axis": self.axis, "split_axis": self.split_axis,
139-
"split_indices": self.split_indices
142+
"split_indices": self.split_indices, "allow_disjoint_implementation": self.allow_disjoint_implementation
140143
})
141144
return config
142145

0 commit comments

Comments
 (0)