Skip to content

Commit 3c6873d

Browse files
authored
Merge pull request #103 from the16thpythonist/master
ragged tensor gnnexplainer implementation for xai benchmarks
2 parents f64f53a + 2c0c0c8 commit 3c6873d

File tree

11 files changed

+786
-24
lines changed

11 files changed

+786
-24
lines changed

kgcnn/literature/GNNExplain.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,201 @@
1+
"""
2+
"Ying et al. - GNNExplainer: Generating Explanations for Graph Neural Networks"
3+
4+
**Changelog**
5+
6+
??.??.2022 - Initial implementation
7+
8+
30.01.2023 - Added the class "GnnExplainer" which supports RaggedTensors and can thus generate multiple
9+
explanations at once, greatly improving time efficiency for explaining large batches of predictions. However
10+
the new class does not implement visualization of the explanations. This will have to be realized on a
11+
higher abstraction level.
12+
"""
13+
import time
14+
import typing as t
15+
16+
import numpy as np
117
import tensorflow as tf
218
ks = tf.keras
319

20+
from kgcnn.xai.base import ImportanceExplanationMethod
21+
422
# Keep track of model version from commit date in literature.
523
# To be updated if model is changed in a significant way.
624
__model_version__ = "2022.05.31"
725

826

27+
# == REDUCED, RAGGED TENSOR IMPLEMENTATION ==
28+
29+
class GnnExplainer(ImportanceExplanationMethod):
30+
"""
31+
Implementation of "ImportanceExplanationMethod", which means that calling an instance of this class
32+
given a model, a ragged input tensor and output predictions, it should return the corresponding
33+
node and edge importance tensors, which provide an explanation by assigning each node and edge of the
34+
input graphs with a 0-1 importance value.
35+
36+
By the nature of the base idea behind GNNExplainer, the number of explanations produced has to be equal
37+
to the number of prediction targets that are generated by the model. Each target will receive its own
38+
explanation.
39+
"""
40+
def __init__(self,
41+
channels: int,
42+
epochs: int = 100,
43+
learning_rate: float = 0.01,
44+
node_sparsity_factor: float = 0.1,
45+
edge_sparsity_factor: float = 0.1,
46+
log_step: int = 10,
47+
verbose: bool = True):
48+
super(GnnExplainer, self).__init__(channels=channels)
49+
self.epochs = epochs
50+
self.learning_rate = learning_rate
51+
self.log_step = log_step
52+
self.verbose = verbose
53+
self.node_sparsity_factor = node_sparsity_factor
54+
self.edge_sparsity_factor = edge_sparsity_factor
55+
56+
def __call__(self,
57+
model: ks.models.Model,
58+
x: t.Tuple[tf.RaggedTensor, tf.RaggedTensor, tf.RaggedTensor],
59+
y: np.ndarray):
60+
"""
61+
Given a model, the input tensor and the output array, this method will return a tuple of two
62+
ragged tensors, which represent the node importances and the edge importances.
63+
64+
Beware, that this method executes an entire training process and may take some time.
65+
66+
Reference of tensor shapes. [Brackets] indicate ragged dimension
67+
- V: Number of nodes in graph
68+
- E: Number of edges in graph
69+
- K: Number of explanation channels given in constructor. This has to be equal to the number of
70+
prediction targets specified in the constructor.
71+
- N: Number of node attributes
72+
- M: Number of edge attributes
73+
- B: batch size
74+
75+
Args:
76+
x: A tuple (node_input, edge_input, edge_indices) of 3 RaggedTensors
77+
- node_input: Shape ([B], [V], N)
78+
- edge_input: Shape ([B], [E], M)
79+
- edge_indices: Shape ([B], [E], 2)
80+
y: A numpy array of shape (B, K)
81+
model: Any compatible keras model, which means any model which accepts the previously described
82+
input tensors and returns output similar to the previously described output tensor.
83+
84+
Returns:
85+
A tuple (node_importances, edge_importances) of RaggedTensors.
86+
- node_importances: Shape ([B], [V], K)
87+
- edge_importances: Shape ([B], [E], K)
88+
"""
89+
# Generally the idea of the implementation is that we use the node_input and edge_input tensors as
90+
# templates to generate the mask variable tensors, which match the graph dimensions but differ in
91+
# the final dimension, which instead of the node / edge features we will use to represent the
92+
# number of importance channels (== number of prediction targets).
93+
94+
node_input, edge_input, edge_indices = x
95+
96+
# Here we reduce away the last dimension of node and edge input to get just the ragged graph sizes
97+
# But we run into a problem here with multiple channels: We cant actually use the last dimension to
98+
# represent the number of different explanation channels. Instead, we do a workaround here where for
99+
# each channel we extend the batch dimension. Aka we assume that all the different channels are just
100+
# additional graphs to be treated like the others. The reason why we have to do it like that is
101+
# because later on we need to multiply the masks with the inputs!
102+
node_mask_single = tf.reduce_mean(tf.ones_like(node_input), axis=-1, keepdims=True)
103+
node_mask_ragged = tf.concat([node_mask_single for _ in range(self.channels)], axis=0)
104+
node_mask_variables = tf.Variable(node_mask_ragged.flat_values, trainable=True, dtype=tf.float64)
105+
106+
edge_mask_single = tf.reduce_mean(tf.ones_like(edge_input), axis=-1, keepdims=True)
107+
edge_mask_ragged = tf.concat([edge_mask_single for _ in range(self.channels)], axis=0)
108+
edge_mask_variables = tf.Variable(edge_mask_ragged.flat_values, trainable=True, dtype=tf.float64)
109+
110+
optimizer = ks.optimizers.Nadam(learning_rate=self.learning_rate)
111+
112+
# This is a logical extension of what was previously described. Since we treat the different
113+
# explanation channels as just a batch extension, we have to modify the input values and the output
114+
# values accordingly so that they have the same batch size so to say. Naturally we simply have to
115+
# duplicate the values.
116+
x_extended = (
117+
tf.concat([node_input for _ in range(self.channels)], axis=0),
118+
tf.concat([edge_input for _ in range(self.channels)], axis=0),
119+
tf.concat([edge_indices for _ in range(self.channels)], axis=0),
120+
)
121+
y_extended = []
122+
for c in range(self.channels):
123+
y_mod = np.zeros_like(y)
124+
y_mod[:, c] = y[:, c]
125+
y_extended.append(y_mod)
126+
127+
y_extended = np.concatenate(y_extended)
128+
129+
start_time = time.time()
130+
for epoch in range(self.epochs):
131+
132+
with tf.GradientTape() as tape:
133+
node_mask = tf.RaggedTensor.from_nested_row_splits(
134+
node_mask_variables,
135+
nested_row_splits=node_mask_ragged.nested_row_splits
136+
)
137+
138+
edge_mask = tf.RaggedTensor.from_nested_row_splits(
139+
edge_mask_variables,
140+
nested_row_splits=edge_mask_ragged.nested_row_splits
141+
)
142+
143+
out = model([
144+
x_extended[0] * node_mask,
145+
x_extended[1] * edge_mask,
146+
x_extended[2]
147+
])
148+
149+
# The loss can basically be summerized as: We try to find the smallest subset of nodes and
150+
# edges in the input, which will cause the network to get as close as possible to it's
151+
# original prediction!
152+
loss = tf.cast(tf.reduce_mean(tf.square(y_extended - out)), dtype=tf.float64)
153+
# Important detail: The reduce_sum here reduces over all the nodes / edges and is necessary!
154+
loss += self.node_sparsity_factor * tf.reduce_mean(tf.reduce_sum(tf.abs(node_mask), axis=1))
155+
loss += self.edge_sparsity_factor * tf.reduce_mean(tf.reduce_sum(tf.abs(edge_mask), axis=1))
156+
157+
trainable_vars = [node_mask_variables, edge_mask_variables]
158+
gradients = tape.gradient(loss, trainable_vars)
159+
optimizer.apply_gradients(zip(gradients, trainable_vars))
160+
161+
if self.verbose and epoch % self.log_step == 0:
162+
print(f' * epoch ({epoch}/{self.epochs}) '
163+
f' - loss: {loss}'
164+
f' - elapsed time: {time.time()-start_time:.2f} seconds')
165+
166+
# For the training we had to treat the different explanation channels as a batch extension. As per
167+
# the interface we need to return the importances however such that the different explanation
168+
# channels are organized into the third dimension of the tensors.
169+
170+
# Sadly this does not work in a more direct fashion. We get the number of elements of nodes and
171+
# edges that belong to one explanation channel. Iterate in chunks of that size and turn each of
172+
# those chunks into it's own explanation respectively. At the end we concatenate all of them in
173+
# the 3rd dimension to produce the desired result.
174+
num_elements_node = node_mask_single.flat_values.shape[0]
175+
num_elements_edge = edge_mask_single.flat_values.shape[0]
176+
node_importances_list = []
177+
edge_importances_list = []
178+
for c in range(self.channels):
179+
node_importances_part = tf.RaggedTensor.from_nested_row_splits(
180+
node_mask_variables[c*num_elements_node:(c+1)*num_elements_node, :],
181+
node_mask_single.nested_row_splits
182+
)
183+
node_importances_list.append(node_importances_part)
184+
185+
edge_importances_part = tf.RaggedTensor.from_nested_row_splits(
186+
edge_mask_variables[c*num_elements_edge:(c+1)*num_elements_edge, :],
187+
edge_mask_single.nested_row_splits
188+
)
189+
edge_importances_list.append(edge_importances_part)
190+
191+
return (
192+
tf.concat(node_importances_list, axis=-1),
193+
tf.concat(edge_importances_list, axis=-1)
194+
)
195+
196+
197+
# == ORIGINAL IMPLEMENTATION ==
198+
9199
class GNNInterface:
10200
"""An interface class which should be implemented by a Graph Neural Network (GNN) model to make it explainable.
11201
This class is just an interface, which is used by the `GNNExplainer` and should be implemented in a subclass.

kgcnn/xai/base.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import typing as t
22

3+
import numpy as np
34
import tensorflow as tf
5+
import tensorflow.keras as ks
6+
7+
from kgcnn.data.utils import ragged_tensor_from_nested_numpy
48

59

610
class AbstractExplanationMixin:
@@ -20,3 +24,55 @@ def explain_importances(self,
2024
**kwargs
2125
) -> t.Tuple[tf.RaggedTensor, tf.RaggedTensor]:
2226
raise NotImplementedError
27+
28+
29+
class AbstractExplanationMethod:
30+
31+
def __call__(self, model, x, y):
32+
raise NotImplementedError
33+
34+
35+
class ImportanceExplanationMethod(AbstractExplanationMethod):
36+
37+
def __init__(self,
38+
channels: int):
39+
self.channels = channels
40+
41+
def __call__(self,
42+
model: ks.models.Model,
43+
x: tf.Tensor,
44+
y: tf.Tensor
45+
) -> t.Tuple[tf.Tensor, tf.Tensor]:
46+
raise NotImplementedError
47+
48+
49+
class MockImportanceExplanationMethod(ImportanceExplanationMethod):
50+
"""
51+
This is a mock implementation of "ImportanceExplanationMethod". It is purely for testing purposes.
52+
Using this method will result in randomly generated importance values for nodes and edges.
53+
"""
54+
def __init__(self, channels):
55+
super(MockImportanceExplanationMethod, self).__init__(channels=channels)
56+
57+
def __call__(self,
58+
model: ks.models.Model,
59+
x: t.Tuple[tf.Tensor],
60+
y: t.Tuple[tf.Tensor],
61+
) -> t.Tuple[tf.Tensor, tf.Tensor]:
62+
node_input, edge_input, _ = x
63+
64+
# Im sure you could probably do this in tensorflow directly, but I am just going to go the numpy
65+
# route here because that's just easier.
66+
node_input = node_input.numpy()
67+
edge_input = edge_input.numpy()
68+
69+
node_importances = [np.random.uniform(0, 1, size=(v.shape[0], self.channels))
70+
for v in node_input]
71+
edge_importances = [np.random.uniform(0, 1, size=(v.shape[0], self.channels))
72+
for v in edge_input]
73+
74+
return (
75+
ragged_tensor_from_nested_numpy(node_importances),
76+
ragged_tensor_from_nested_numpy(edge_importances)
77+
)
78+

kgcnn/xai/testing.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import random
2+
import typing as t
3+
4+
import numpy as np
5+
import tensorflow as tf
6+
import tensorflow.keras as ks
7+
8+
from kgcnn.layers.conv.gat_conv import AttentionHeadGATV2
9+
from kgcnn.layers.modules import DenseEmbedding
10+
from kgcnn.layers.pooling import PoolingGlobalEdges
11+
from kgcnn.data.utils import ragged_tensor_from_nested_numpy
12+
13+
14+
# This is a very simple mock implementation, because to test the explanation methods we need some sort
15+
# of a model as basis and this model will act as such.
16+
class Model(ks.models.Model):
17+
18+
def __init__(self,
19+
num_targets: int = 1):
20+
super(Model, self).__init__()
21+
self.conv_layers = [
22+
AttentionHeadGATV2(units=64, use_edge_features=True, use_bias=True),
23+
]
24+
self.lay_pooling = PoolingGlobalEdges(pooling_method='sum')
25+
self.lay_dense = DenseEmbedding(units=num_targets, activation='linear')
26+
27+
def call(self, inputs, training=False):
28+
node_input, edge_input, edge_index_input = inputs
29+
x = node_input
30+
for lay in self.conv_layers:
31+
x = lay([x, edge_input, edge_index_input])
32+
33+
pooled = self.lay_pooling(x)
34+
out = self.lay_dense(pooled)
35+
return out
36+
37+
38+
class MockContext:
39+
40+
def __init__(self,
41+
num_elements: int = 10,
42+
num_targets: int = 1,
43+
epochs: int = 10,
44+
batch_size: int = 2):
45+
self.num_elements = num_elements
46+
self.num_targets = num_targets
47+
self.epochs = epochs
48+
self.batch_size = batch_size
49+
50+
self.model = Model(num_targets=num_targets)
51+
self.x = None
52+
self.y = None
53+
54+
def generate_graph(self,
55+
num_nodes: int,
56+
num_node_attributes: int = 3,
57+
num_edge_attributes: int = 1):
58+
remaining = list(range(num_nodes))
59+
random.shuffle(remaining)
60+
inserted = [remaining.pop(0)]
61+
node_attributes = [[random.random() for _ in range(num_node_attributes)] for _ in range(num_nodes)]
62+
edge_indices = []
63+
edge_attributes = []
64+
while len(remaining) != 0:
65+
i = remaining.pop(0)
66+
j = random.choice(inserted)
67+
inserted.append(i)
68+
69+
edge_indices += [[i, j], [j, i]]
70+
edge_attribute = [1 for _ in range(num_edge_attributes)]
71+
edge_attributes += [edge_attribute, edge_attribute]
72+
73+
return (
74+
np.array(node_attributes, dtype=float),
75+
np.array(edge_attributes, dtype=float),
76+
np.array(edge_indices, dtype=int)
77+
)
78+
79+
def generate_data(self):
80+
node_attributes_list = []
81+
edge_attributes_list = []
82+
edge_indices_list = []
83+
targets_list = []
84+
for i in range(self.num_elements):
85+
num_nodes = random.randint(5, 20)
86+
node_attributes, edge_attributes, edge_indices = self.generate_graph(num_nodes)
87+
node_attributes_list.append(node_attributes)
88+
edge_attributes_list.append(edge_attributes)
89+
edge_indices_list.append(edge_indices)
90+
91+
# The target value we will actually determine deterministically here so that our network
92+
# actually has a chance to learn anything
93+
target = np.sum(node_attributes)
94+
targets = [target for _ in range(self.num_targets)]
95+
targets_list.append(targets)
96+
97+
self.x = (
98+
ragged_tensor_from_nested_numpy(node_attributes_list),
99+
ragged_tensor_from_nested_numpy(edge_attributes_list),
100+
ragged_tensor_from_nested_numpy(edge_indices_list)
101+
)
102+
103+
self.y = (
104+
np.array(targets_list, dtype=float)
105+
)
106+
107+
def __enter__(self):
108+
# This method will generate random input and output data and thus populate the internal attributes
109+
# self.x and self.y
110+
self.generate_data()
111+
112+
# Using these we will train our mock model for a few very brief epochs.
113+
self.model.compile(
114+
loss=ks.losses.mean_squared_error,
115+
metrics=ks.metrics.mean_squared_error,
116+
run_eagerly=False,
117+
optimizer=ks.optimizers.Nadam(learning_rate=0.01),
118+
)
119+
hist = self.model.fit(
120+
self.x, self.y,
121+
batch_size=self.batch_size,
122+
epochs=self.epochs,
123+
verbose=0,
124+
)
125+
self.history = hist.history
126+
127+
return self
128+
129+
def __exit__(self, *args, **kwargs):
130+
pass

0 commit comments

Comments
 (0)