|
| 1 | +""" |
| 2 | +"Ying et al. - GNNExplainer: Generating Explanations for Graph Neural Networks" |
| 3 | +
|
| 4 | +**Changelog** |
| 5 | +
|
| 6 | +??.??.2022 - Initial implementation |
| 7 | +
|
| 8 | +30.01.2023 - Added the class "GnnExplainer" which supports RaggedTensors and can thus generate multiple |
| 9 | +explanations at once, greatly improving time efficiency for explaining large batches of predictions. However |
| 10 | +the new class does not implement visualization of the explanations. This will have to be realized on a |
| 11 | +higher abstraction level. |
| 12 | +""" |
| 13 | +import time |
| 14 | +import typing as t |
| 15 | + |
| 16 | +import numpy as np |
1 | 17 | import tensorflow as tf
|
2 | 18 | ks = tf.keras
|
3 | 19 |
|
| 20 | +from kgcnn.xai.base import ImportanceExplanationMethod |
| 21 | + |
4 | 22 | # Keep track of model version from commit date in literature.
|
5 | 23 | # To be updated if model is changed in a significant way.
|
6 | 24 | __model_version__ = "2022.05.31"
|
7 | 25 |
|
8 | 26 |
|
| 27 | +# == REDUCED, RAGGED TENSOR IMPLEMENTATION == |
| 28 | + |
| 29 | +class GnnExplainer(ImportanceExplanationMethod): |
| 30 | + """ |
| 31 | + Implementation of "ImportanceExplanationMethod", which means that calling an instance of this class |
| 32 | + given a model, a ragged input tensor and output predictions, it should return the corresponding |
| 33 | + node and edge importance tensors, which provide an explanation by assigning each node and edge of the |
| 34 | + input graphs with a 0-1 importance value. |
| 35 | +
|
| 36 | + By the nature of the base idea behind GNNExplainer, the number of explanations produced has to be equal |
| 37 | + to the number of prediction targets that are generated by the model. Each target will receive its own |
| 38 | + explanation. |
| 39 | + """ |
| 40 | + def __init__(self, |
| 41 | + channels: int, |
| 42 | + epochs: int = 100, |
| 43 | + learning_rate: float = 0.01, |
| 44 | + node_sparsity_factor: float = 0.1, |
| 45 | + edge_sparsity_factor: float = 0.1, |
| 46 | + log_step: int = 10, |
| 47 | + verbose: bool = True): |
| 48 | + super(GnnExplainer, self).__init__(channels=channels) |
| 49 | + self.epochs = epochs |
| 50 | + self.learning_rate = learning_rate |
| 51 | + self.log_step = log_step |
| 52 | + self.verbose = verbose |
| 53 | + self.node_sparsity_factor = node_sparsity_factor |
| 54 | + self.edge_sparsity_factor = edge_sparsity_factor |
| 55 | + |
| 56 | + def __call__(self, |
| 57 | + model: ks.models.Model, |
| 58 | + x: t.Tuple[tf.RaggedTensor, tf.RaggedTensor, tf.RaggedTensor], |
| 59 | + y: np.ndarray): |
| 60 | + """ |
| 61 | + Given a model, the input tensor and the output array, this method will return a tuple of two |
| 62 | + ragged tensors, which represent the node importances and the edge importances. |
| 63 | +
|
| 64 | + Beware, that this method executes an entire training process and may take some time. |
| 65 | +
|
| 66 | + Reference of tensor shapes. [Brackets] indicate ragged dimension |
| 67 | + - V: Number of nodes in graph |
| 68 | + - E: Number of edges in graph |
| 69 | + - K: Number of explanation channels given in constructor. This has to be equal to the number of |
| 70 | + prediction targets specified in the constructor. |
| 71 | + - N: Number of node attributes |
| 72 | + - M: Number of edge attributes |
| 73 | + - B: batch size |
| 74 | +
|
| 75 | + Args: |
| 76 | + x: A tuple (node_input, edge_input, edge_indices) of 3 RaggedTensors |
| 77 | + - node_input: Shape ([B], [V], N) |
| 78 | + - edge_input: Shape ([B], [E], M) |
| 79 | + - edge_indices: Shape ([B], [E], 2) |
| 80 | + y: A numpy array of shape (B, K) |
| 81 | + model: Any compatible keras model, which means any model which accepts the previously described |
| 82 | + input tensors and returns output similar to the previously described output tensor. |
| 83 | +
|
| 84 | + Returns: |
| 85 | + A tuple (node_importances, edge_importances) of RaggedTensors. |
| 86 | + - node_importances: Shape ([B], [V], K) |
| 87 | + - edge_importances: Shape ([B], [E], K) |
| 88 | + """ |
| 89 | + # Generally the idea of the implementation is that we use the node_input and edge_input tensors as |
| 90 | + # templates to generate the mask variable tensors, which match the graph dimensions but differ in |
| 91 | + # the final dimension, which instead of the node / edge features we will use to represent the |
| 92 | + # number of importance channels (== number of prediction targets). |
| 93 | + |
| 94 | + node_input, edge_input, edge_indices = x |
| 95 | + |
| 96 | + # Here we reduce away the last dimension of node and edge input to get just the ragged graph sizes |
| 97 | + # But we run into a problem here with multiple channels: We cant actually use the last dimension to |
| 98 | + # represent the number of different explanation channels. Instead, we do a workaround here where for |
| 99 | + # each channel we extend the batch dimension. Aka we assume that all the different channels are just |
| 100 | + # additional graphs to be treated like the others. The reason why we have to do it like that is |
| 101 | + # because later on we need to multiply the masks with the inputs! |
| 102 | + node_mask_single = tf.reduce_mean(tf.ones_like(node_input), axis=-1, keepdims=True) |
| 103 | + node_mask_ragged = tf.concat([node_mask_single for _ in range(self.channels)], axis=0) |
| 104 | + node_mask_variables = tf.Variable(node_mask_ragged.flat_values, trainable=True, dtype=tf.float64) |
| 105 | + |
| 106 | + edge_mask_single = tf.reduce_mean(tf.ones_like(edge_input), axis=-1, keepdims=True) |
| 107 | + edge_mask_ragged = tf.concat([edge_mask_single for _ in range(self.channels)], axis=0) |
| 108 | + edge_mask_variables = tf.Variable(edge_mask_ragged.flat_values, trainable=True, dtype=tf.float64) |
| 109 | + |
| 110 | + optimizer = ks.optimizers.Nadam(learning_rate=self.learning_rate) |
| 111 | + |
| 112 | + # This is a logical extension of what was previously described. Since we treat the different |
| 113 | + # explanation channels as just a batch extension, we have to modify the input values and the output |
| 114 | + # values accordingly so that they have the same batch size so to say. Naturally we simply have to |
| 115 | + # duplicate the values. |
| 116 | + x_extended = ( |
| 117 | + tf.concat([node_input for _ in range(self.channels)], axis=0), |
| 118 | + tf.concat([edge_input for _ in range(self.channels)], axis=0), |
| 119 | + tf.concat([edge_indices for _ in range(self.channels)], axis=0), |
| 120 | + ) |
| 121 | + y_extended = [] |
| 122 | + for c in range(self.channels): |
| 123 | + y_mod = np.zeros_like(y) |
| 124 | + y_mod[:, c] = y[:, c] |
| 125 | + y_extended.append(y_mod) |
| 126 | + |
| 127 | + y_extended = np.concatenate(y_extended) |
| 128 | + |
| 129 | + start_time = time.time() |
| 130 | + for epoch in range(self.epochs): |
| 131 | + |
| 132 | + with tf.GradientTape() as tape: |
| 133 | + node_mask = tf.RaggedTensor.from_nested_row_splits( |
| 134 | + node_mask_variables, |
| 135 | + nested_row_splits=node_mask_ragged.nested_row_splits |
| 136 | + ) |
| 137 | + |
| 138 | + edge_mask = tf.RaggedTensor.from_nested_row_splits( |
| 139 | + edge_mask_variables, |
| 140 | + nested_row_splits=edge_mask_ragged.nested_row_splits |
| 141 | + ) |
| 142 | + |
| 143 | + out = model([ |
| 144 | + x_extended[0] * node_mask, |
| 145 | + x_extended[1] * edge_mask, |
| 146 | + x_extended[2] |
| 147 | + ]) |
| 148 | + |
| 149 | + # The loss can basically be summerized as: We try to find the smallest subset of nodes and |
| 150 | + # edges in the input, which will cause the network to get as close as possible to it's |
| 151 | + # original prediction! |
| 152 | + loss = tf.cast(tf.reduce_mean(tf.square(y_extended - out)), dtype=tf.float64) |
| 153 | + # Important detail: The reduce_sum here reduces over all the nodes / edges and is necessary! |
| 154 | + loss += self.node_sparsity_factor * tf.reduce_mean(tf.reduce_sum(tf.abs(node_mask), axis=1)) |
| 155 | + loss += self.edge_sparsity_factor * tf.reduce_mean(tf.reduce_sum(tf.abs(edge_mask), axis=1)) |
| 156 | + |
| 157 | + trainable_vars = [node_mask_variables, edge_mask_variables] |
| 158 | + gradients = tape.gradient(loss, trainable_vars) |
| 159 | + optimizer.apply_gradients(zip(gradients, trainable_vars)) |
| 160 | + |
| 161 | + if self.verbose and epoch % self.log_step == 0: |
| 162 | + print(f' * epoch ({epoch}/{self.epochs}) ' |
| 163 | + f' - loss: {loss}' |
| 164 | + f' - elapsed time: {time.time()-start_time:.2f} seconds') |
| 165 | + |
| 166 | + # For the training we had to treat the different explanation channels as a batch extension. As per |
| 167 | + # the interface we need to return the importances however such that the different explanation |
| 168 | + # channels are organized into the third dimension of the tensors. |
| 169 | + |
| 170 | + # Sadly this does not work in a more direct fashion. We get the number of elements of nodes and |
| 171 | + # edges that belong to one explanation channel. Iterate in chunks of that size and turn each of |
| 172 | + # those chunks into it's own explanation respectively. At the end we concatenate all of them in |
| 173 | + # the 3rd dimension to produce the desired result. |
| 174 | + num_elements_node = node_mask_single.flat_values.shape[0] |
| 175 | + num_elements_edge = edge_mask_single.flat_values.shape[0] |
| 176 | + node_importances_list = [] |
| 177 | + edge_importances_list = [] |
| 178 | + for c in range(self.channels): |
| 179 | + node_importances_part = tf.RaggedTensor.from_nested_row_splits( |
| 180 | + node_mask_variables[c*num_elements_node:(c+1)*num_elements_node, :], |
| 181 | + node_mask_single.nested_row_splits |
| 182 | + ) |
| 183 | + node_importances_list.append(node_importances_part) |
| 184 | + |
| 185 | + edge_importances_part = tf.RaggedTensor.from_nested_row_splits( |
| 186 | + edge_mask_variables[c*num_elements_edge:(c+1)*num_elements_edge, :], |
| 187 | + edge_mask_single.nested_row_splits |
| 188 | + ) |
| 189 | + edge_importances_list.append(edge_importances_part) |
| 190 | + |
| 191 | + return ( |
| 192 | + tf.concat(node_importances_list, axis=-1), |
| 193 | + tf.concat(edge_importances_list, axis=-1) |
| 194 | + ) |
| 195 | + |
| 196 | + |
| 197 | +# == ORIGINAL IMPLEMENTATION == |
| 198 | + |
9 | 199 | class GNNInterface:
|
10 | 200 | """An interface class which should be implemented by a Graph Neural Network (GNN) model to make it explainable.
|
11 | 201 | This class is just an interface, which is used by the `GNNExplainer` and should be implemented in a subclass.
|
|
0 commit comments