@@ -42,6 +42,7 @@ def __init__(self,
42
42
split_axis : Union [int , None ] = 2 ,
43
43
split_indices : list = None ,
44
44
concat_axis : Union [int , None ] = 2 ,
45
+ allow_disjoint_implementation : bool = True ,
45
46
** kwargs ):
46
47
r"""Initialize layer.
47
48
@@ -50,13 +51,15 @@ def __init__(self,
50
51
split_axis (int): The axis to split indices to gather embeddings. Default is None.
51
52
split_indices (list): List of indices to split from gathered tensor. Default is None.
52
53
concat_axis (int): The axis which concatenates embeddings. Default is 2.
54
+ allow_disjoint_implementation (bool): Whether to allow (preferred) disjoint implementation.
53
55
"""
54
56
super (GatherEmbedding , self ).__init__ (** kwargs )
55
57
self .concat_axis = concat_axis
56
58
self .axis = axis
57
59
self .split_axis = split_axis
58
60
self .split_indices = split_indices
59
61
self .node_indexing = "sample"
62
+ self .allow_disjoint_implementation = allow_disjoint_implementation
60
63
61
64
if self .concat_axis is not None and self .split_axis is None :
62
65
raise ValueError ("Can only concat `list` of gathered tensors. Require `split_axis` not None." )
@@ -111,7 +114,7 @@ def call(self, inputs, **kwargs):
111
114
tf.RaggedTensor: Gathered node embeddings that match the number of edges of shape `(batch, [M], 2*F)`
112
115
"""
113
116
# Old disjoint implementation that could be faster.
114
- if self ._is_disjoint_possible (inputs , ** kwargs ):
117
+ if self ._is_disjoint_possible (inputs , ** kwargs ) and self . allow_disjoint_implementation :
115
118
return self ._disjoint_implementation (inputs , ** kwargs )
116
119
117
120
# For arbitrary gather from ragged tensor use tf.gather with batch_dims=1.
@@ -136,7 +139,7 @@ def get_config(self):
136
139
config = super (GatherEmbedding , self ).get_config ()
137
140
config .update ({
138
141
"concat_axis" : self .concat_axis , "axis" : self .axis , "split_axis" : self .split_axis ,
139
- "split_indices" : self .split_indices
142
+ "split_indices" : self .split_indices , "allow_disjoint_implementation" : self . allow_disjoint_implementation
140
143
})
141
144
return config
142
145
0 commit comments