-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy path_layers.py
60 lines (51 loc) · 2.65 KB
/
_layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import keras as ks
from keras import ops
from keras.layers import Dense
class TrafoEdgeNetMessages(ks.layers.Layer):
"""Make message from edges for a linear transformation, i.e. matrix multiplication.
The actual matrix is not a trainable weight of this layer but generated by a dense layer.
This was proposed by `NMPNN <http://arxiv.org/abs/1704.01212>`__ .
"""
def __init__(self, target_shape: tuple,
activation="linear",
use_bias=True,
kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None,
kernel_constraint=None, bias_constraint=None,
kernel_initializer='glorot_uniform', bias_initializer='zeros',
**kwargs):
"""Initialize layer.
Args:
target_shape (tuple): Target shape for message matrix.
"""
super(TrafoEdgeNetMessages, self).__init__(**kwargs)
self.target_shape = target_shape
self._units_out = int(target_shape[0])
self._units_in = int(target_shape[1])
self.lay_dense = Dense(units=self._units_out * self._units_in,
activation=activation, use_bias=use_bias,
kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint, kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer)
def build(self, input_shape):
"""Build layer."""
super(TrafoEdgeNetMessages, self).build(input_shape)
def call(self, inputs, **kwargs):
r"""Forward pass.
Args:
inputs (Tensor): Message embeddings or messages `([M], F)` .
Returns:
Tensor: Messages in matrix for multiplication of shape `([M], F_out, F_in)` .
"""
up_scale = self.lay_dense(inputs, **kwargs)
return ops.reshape(up_scale, (ops.shape(up_scale)[0], self._units_out, self._units_in))
def get_config(self):
"""Update layer config."""
config = super(TrafoEdgeNetMessages, self).get_config()
config.update({"target_shape": self.target_shape})
config_dense = self.lay_dense.get_config()
for x in ["kernel_regularizer", "activity_regularizer", "bias_regularizer", "kernel_constraint",
"bias_constraint", "kernel_initializer", "bias_initializer", "activation", "use_bias"]:
if x in config_dense.keys():
config.update({x: config_dense[x]})
return config