Explainability Techniques for Graph Convolutional Networks
Federico Baldassarre, Hossein Azizpour
Introduction
Many concepts from chemistry, life science, and physics are naturally represented in the graph domain. Most Machine Learning (ML) methods, however, were devised to operate on Euclidean space. One of the most successful ML techniques, Deep Learning, has been recently generalized to operate on a graph domain. These Graph Networks (GN) have achieved remarkable performances in various applications thanks to their consistency to the native data representation, with examples from biochemistry (Duvenaud et al., 2015; Kearnes et al., 2016; Fout et al., 2017; Zitnik et al., 2018), physics (Battaglia et al., 2016; Chang et al., 2017; Gilmer et al., 2017; Watters et al., 2017; Sanchez-Gonzalez et al., 2018), visual recognition (Qi et al., 2017, 2018; Narasimhan et al., 2018), and natural language processing (Bastings et al., 2017; Beck et al., 2018).
ML algorithms become increasingly trustworthy to humans when the basis for their decisions can be explained in human terms. Interpretability is also useful for diagnosing biases, designing datasets and gaining insight on governing laws. While several methods have been developed for standard deep networks, there is a lack of study for applicability of such methods on GNs; that is the focus of this work.
In this work, we assume the general form of GNs as defined in (Battaglia et al., 2018). Regarding explanation algorithms, we consider two main classes: a) gradient-based such as Sensitivity Analysis (Baehrens et al., 2010; Simonyan et al., 2014) and Guided Backpropagation (Springenberg et al., 2015), b) decomposition-based such as Layer-wise Relevance Propagation (Bach et al., 2015) and Taylor decomposition (Montavon et al., 2017). We base the discussions on a toy dataset and a chemistry task. We hope this work will set the ground for this important topic and suggest future directions and applications.
The contributions of this work can be summarized as:
to the best of our knowledge, this is the first work to focus on explainability techniques for GN
we highlight and identify the challenges and future directions for explaining GN
we compare two main classes of explainability methods on graph-based tasks predictions
Our PyTorch (Paszke et al., 2017) implementation of GNs equipped with different explanation algorithms is available at github.com/baldassarreFe/graph-network-explainability.
Related works
This work is closely related to Graph Networks and explanation techniques developed for standard networks.
Graph Networks Graphs can be embedded in Euclidean spaces using Neural Networks (Perozzi et al., 2014; Hamilton et al., 2017; Kipf & Welling, 2016). This representation preserves the relational structure of the graph while enjoying the properties of Euclidean space. The embedding can then be processed further, e.g. with more interpretable linear models (Duvenaud et al., 2015). On the other hand, DL algorithms that operate end-to-end in the graph domain can leverage the structure to predict at vertex (Hu et al., 2018), edge (Qi et al., 2018), or global (Zitnik et al., 2018) levels. Several variants of GN have been proposed starting from the early work of (Scarselli et al., 2009). It was further extended, by gating (Li et al., 2016), convolutions in spectral (Bruna et al., 2014) and spatial (Kipf & Welling, 2017) domains, skip connections (Rahimi et al., 2018) and recently attention (Velickovic et al., 2018). In this work, we focus on the GN as described in (Battaglia et al., 2018), which tries to generalize all variants while remaining conceptually simple.
Model interpretation and explanation of predictions ML methods are ubiquitous in several industries, with deep networks achieving impressive performances. Where humans are involved, the expectation of high levels of safety, reliability, and fairness arises. Thanks to this demand, several techniques have been developed for increasing the transparency of the most successful models. At high level these techniques can be divided into those that attempt to interpret the model as a whole (Simonyan et al., 2014; Nguyen et al., 2016) and those that try to explain individual predictions made by a model (Sung, 1998; Bach et al., 2015)For insight on such grouping refer to (Gilpin et al., 2018). This paper focuses on the latter, for which several techniques have been developed on standard deep networks. These include works that attempt at variation-based analysis (Baehrens et al., 2010; Simonyan et al., 2014; Springenberg et al., 2015; Ribeiro et al., 2016; Smilkov et al., 2017; Bordes et al., 2018) and output decomposition (Montavon et al., 2017; Shrikumar et al., 2017; Kindermans et al., 2018). (Sundararajan et al., 2017) combines the two principles and invert the decisions (Zeiler & Fergus, 2014; Mahendran & Vedaldi, 2015; Carlsson et al., 2017). Furthermore, (Dhurandhar et al., 2018) uses contrastive explanations, (Zhang et al., 2018) identifies minimal changes in the input to get the desired prediction. In this work, we evaluate variation- and decomposition-based techniques in the context of GNs.
Explanation for GNs To the best of our knowledge, exploring explanation techniques for GNs has not been the focus of any prior work. In (Duvenaud et al., 2015), GNs are used to learn molecular fingerprints and predict their chemical properties. Chemically grounded insights are then extracted from the model by heuristic inspection of their resultsas described in the authors’ rebuttal: (NIPS reviews, 2015). In our experiments, we show that GNs trained end-to-end on those problems achieve similar performance while making it possible to explain individual predictions.
Method
GN as described in (Battaglia et al., 2018) use a message-passing algorithm that aggregates local information similar to convolutions in CNNs. Graphs can contain features on edges , nodes and graph-level . At every layer of computation, the graph is updated using three update functions and three aggregation functions :
where and represent the sender and receiver nodes of the -th edge, and the sets , , represent the edges incident to node , all edges updated by and all nodes updated by . Each processing layer leaves the structure of the graph unchanged, updating only the features of the graph and not its topology. The mapping , can represent a single quantity of interest (e.g. the solubility of a molecule) or a graph with individual predictions for nodes and edges. In this work, all are linear transformations followed by ReLU activations, and all are sum/mean/max pooling operations.
2 Explainability
produces local explanations for the prediction of a differentiable function using the squared norm of its gradient w.r.t. the inputs (Gevrey et al., 2003): . The saliency map produced with this method describes the extent to which variations in the input would produce a change in the output.
also constructs a saliency map using gradients (Springenberg et al., 2015). Different from SA, negative gradients are clipped during backpropagation, which concentrates the explanation on the features that have an excitatory effect on the output.
All three methods rely on a backward pass through the network to propagate gradients/relevance from the output and accumulating it to the input. Since the computational graph of a GN can become complex and non-sequential, we take advantage of PyTorch’s capability to track operations and implement these algorithms on top of its autograd module.
Experiments
To evaluate different explainability methods on GNs, we consider a toy graph problem and a chemistry regression task. Task-specific comments can be found in this section, followed by a more general discussion in Section 5.
In this toy problem, the input graph represents a group of individuals who are either sick or healthy, as well as immune to a certain disease. Between people are directed edges, representing the relationships they have, characterized as virtual or not. The disease spreads according to a simple rule: a sick node infects the neighbors to which it is connected through a non-virtual edge, unless the target node is immune. The objective is to predict the state of every node of the graph after one step of the spread, and then evaluate the correctness of the explanations produced with Sensitivity Analysis, Guided Backpropagation and Layer-wise Relevance Propagation against the logical infection dynamics. Details about the dataset, the network and the training procedure are in Appendix B.1.
When tasked to explain the prediction for a single node, all three techniques identify the relevant nodes/edges of the input graph (Fig. 2). We note however that the explanations produced by variation-based methods tend to diverge from how a human would intuitively describe the process in terms of cause and effect, while LRP results are more natural. Appendix C.1 provides a detailed case-based visualization of the explanations, down to the individual features.
2 Solubility
We train a GN to predict the aqueous solubility of organic compounds from their molecular graph as in (Duvenaud et al., 2015). Our multi-layer GN matches their performances while remaining simple. Details about the dataset, the network and the training procedure are in Appendix B.1.
When explaining the predictions of the network, LRP attributes positive and negative relevance to features that are known to correlate with solubility, such as the presence of R-OH groups on the outside of the molecule and features that typically indicate low solubility such as repeated non-polar aromatic rings (Fig. 3). Similar observations are made in (Duvenaud et al., 2015), although by manual inspection of high-scoring predictions.
Note that LRP was originally introduced to explain classification predictions, but here it is adopted for a regression task. A discussion on how to interpret these explanations can be found in section C.2 of the appendix.
Discussion
Recent methods for explainability have been developed within image or text domains (Bach et al., 2015; Springenberg et al., 2015; Ribeiro et al., 2016). With the experiments presented in this work, we intend to highlight some key differences of the graph domain that require special consideration to produce meaningful explanations.
Images can be seen as graphs with a regular grid topology and whose features are attributed only to nodes. In this context, an explanation can take the form of a heatmap over the image, highlighting relevant pixels and, implicitly, their local connections. For graphs with irregular connectivity, edges acquire a more preeminent role that can be missed when using image-based explanation techniques. For example, in a graph where edge features are not present or are all identical (not informative), neither gradients nor relevance would be propagated back to these connections, even though the presence itself of a connection between two nodes is a source of information that should be taken into account for explanations. We propose to take advantage of the structure-preserving property of graph convolution and aggregate explanations at multiple steps of message-passing, arguing that the importance of connections should emerge from the intermediate steps (Fig. 4).
2 Pooling
Architectural choice In standard NN, pooling operations are commonly used to aggregate features. In message-passing GNs, pooling is used to aggregate edge and node features at a local and global level, while not modifying the topology of the network (Eq. 1). The choice of pooling function in GN is closely related to the learning problem, e.g. sum pooling is best for counting at a global level, while max pooling can be used for identifying local properties.
Explanations The choice of aggregation also influences the explanations obtained for a prediction. Sum and mean pooling propagate gradients/relevance to all their inputs, possibly identifying all sources of signal. Max pooling, instead, considers only one of its inputs and disregards others, regardless of their magnitude, which can lead to incomplete explanations (e.g. multiple neighboring sick nodes could explain an infection). To counter this, LRP proposes to approximate max pooling with -norm during relevance propagation, but this approach can over-disperse relevance to unimportant inputs. We suggest that the backward pass through max pooling should be approached as a search that only redistributes relevance to those inputs that result in a similar prediction if chosen as maxima (Fig. 5).
3 Heterogeneous Graph Features
Images are usually represented as a matrix of continuous RGB pixel values, while graphs are often employed for domains that require a mixed encoding of continuous, binary and categorical features that are semantically meaningful (Fout et al., 2017; Kearnes et al., 2016; Sanchez-Gonzalez et al., 2018). Thus, rather than aggregating the explanation at the node/edge level, it can be of higher interest to evaluate the importance of individual features. For this reason, visualizations based on graph heatmaps might be insufficient. We suggest a more detailed visualization in Appendix C.
4 Perturbation-based evaluation
Images and graphs can be considered as points in very high-dimensional spaces, belonging to complex and structured manifolds (Tenenbaum et al., 2000). The commonly used representation of an image introduces an elevated degree of redundancy, so that changing the value of a single pixel minimally affects the content and meaning associated with the image. Under this observation, an explanation can be quantitatively evaluated by progressively ”graying out” pixels in order of importance and measuring how it affects the prediction (Bach et al., 2015). On the other hand, graph representations tend to be less redundant and the structure of the graph is a constituent part of its identity, therefore small alterations of nodes/edges can drastically alter the meaning of the graph. In our chemistry problem, for example, replacing an atom or bond would fundamentally change a molecule or invalidate it. As a viable strategy, one could rely on domain-specific knowledge to perform such changes while remaining semantically close to the original. Alternatively, one can learn a bijective grounding of graphs onto a manifold with meaningful neighborhood to conduct such an evaluation. In Appendix C.2 we present a hand-crafted example of progressively eliminating atoms from a molecule in order of importance using domain knowledge.
Conclusion
As a expository paper, in this work, we introduced and focused on an important problem with impactful applications: we analyzed the major existing explanation techniques in the context of Graph Networks. We further conducted a case-based analysis on two simple but complementary tasks as well as some important high-level discussions on design choices when explaining a GNs decision. Finally, we provided an implementation of five different explanation techniques using PyTorch autograd which can be readily used for any definition of GN. In tandem with the high-level technical novelty, we hope these contributions open up fruitful discussions at the workshop and pave the road for future development of specific techniques for GN explanation for real-world applications.
Acknowledgement. Federico Baldassarre is partially funded by Swedish Research Council project 2017-04609
References
Appendix A Explainability techniques
Explanations for image-based tasks usually aggregate the importance of the input features at the pixel level, e.g. by taking an average over the individual color channels. This is done under the reasonable assumption that spatial locations are the smallest unit of input that can still be interpreted by humans. The tasks considered in this work make use of node/edge features that are heterogeneous and individually interpretable. Therefore, we choose to present the explanations at the feature level, rather than aggregating at node or edge level. Furthermore, we observed that the sign of the gradients produced with Sensitivity Analysis can provide additional context to the explanation. For this reason, the visualizations in this appendix will make use of the gradient ”as is” and not of its squared norm. Overall, we observe that explanations produced by variation-based methods tend to diverge from how a human would intuitively describe the process in terms of causes and effects. Decomposition-based methods result instead in more natural explanations. We posit that the decomposition of the output signal makes LRP more suitable for the categorical distribution of the relevant features on both nodes and edges.
A.2 Layer-wise Relevance Propagation rules
Layer-wise Relevance Propagation (LRP) is a signal decomposition method introduced in (Bach et al., 2015), where the authors mainly propose two rules. The former, known as -rule:
where is a small number to avoid division by zero.
We found the former with to be quite unstable in the presence of zeros in the input or in the weights, a situation that occurs often when using one-hot encoding of categorical features and L1 regularization for the weights. Therefore, despite the -rule should allow for more flexibility in tuning the ratio of positive and negative relevance, we chose the simpler -stabilized rule with .
A.3 LRP for regression
Layer-wise Relevance Propagation was initially developed as an explanation technique for classification tasks. In the context of our Solubility experiment, we extend its application to a regression task. Since the prediction target is now a continuous variable, the explanations produced by LRP can be interpreted as ”How much does this feature of this atom/bond contribute, positively or negatively, to the final predicted value?”. Also note, that due to the use of bias terms in our networks, the conservation property of LRP does not hold in full. Some relevance, in fact, will inevitably be attributed to the biases, that are internal parameters of the model and therefore not interpretable.
Appendix B Experiment details
The feature vectors and encode edge and node features respectively. Both include uninformative features that the network should learn to ignore and which should be attributed no importance by the explanation techniques (Fig. 6). Notably, binary features are encoded as rather than , while this does not affect variation-based models (SA and GPB), it facilitates the propagation of relevance to the input when LRP is used. The synthetic dataset used for training contains with 30 or fewer nodes generated with the Barabási-Albert algorithm. The datasets used for validation and testing contain graphs of up to 60 nodes and different percentages of sick and immune nodes.
Architecture and training
The network used for the Infection task makes use of a single layer of graph processing as in Eq. 1, without graph-level features. The update functions for the edges and the nodes are shallow multi-layer perceptrons, with ReLU activations and we use sum/max pooling to aggregate the edges incident to a node. We use the Adam optimizer (Kingma & Ba, 2014) to minimize the binary cross-entropy between the per-node predictions and the ground truth. Multiple choices of hyperparameters such as learning rate, number of neurons in the hidden layers and L1 regularization yield similar outcomes. Both sum and max pooling perform well, but the former fails in some corner cases (Fig. 17).
B.2 Solubility
The Solubility dataset is the same as (Duvenaud et al., 2015), consisting of around 1000 organic molecules represented as SMILES strings and their measured solubility in water. The molecules are represented as graphs having atoms as nodes (with degree, number of hydrogens, implicit valence and type as features) and bonds as edges (with their type, whether they are conjugate and whether they are in a ring as features).
Architecture and training
As optimization objective we use the mean squared error between the measured log-solubility and the global features of the output graph of a multi-layer GN, where each layer performs updates the graph as in Eq. 1. Using multiple layers of graph convolution allows the network to aggregate information at progressively larger scales, starting from the local neighborhood and extending to wider groups of atoms. Dropout is applied at the output of every linear transformation, as a technique to counteract overfitting. We tested multiple combinations of hyperparameters and obtain results comparable to (Duvenaud et al., 2015) using 3-5 hidden graph layers with a dimensionality of either 64, 128 or 256 and sum/mean pooling for all aggregation operations. Max pooling performed much worse, probably due to the nature of the task.
Appendix C Additional results
In the following pages we present an in-depth comparison between the three explainability methods we experimented with. We consider a graph with multiple sources of infection and immune nodes. The network, that uses max pooling to aggregate information from the incoming edges, correctly predicts the state of every node after one step of infection propagation (Fig. 8). In the figures that follow, we present a visualization of the explanations produced for three nodes of the graph: one that becomes infected (Fig. 9), one that receives no infection from its neighbors (Fig. 10) and one that is immune (Fig. 11). Refer to the captions for observations specific to every example.
Aggregation: max vs. sum comparison
It then follows an overview of explanations obtained for smaller graphs. For every input graph, we show two predictions: one made by a GN that uses max pooling and one made by a GN that uses sum pooling. The predictions are followed by a visualization of the explanations produced with Sensitivity Analysis, Guided Backpropagation and Layer-wise Relevance Propagation (in this order). For each explainability method, we represent the values of the gradient/relevance as a heatmap over the individual features of every node/edge, as well as with a graphical representation of the graph. Refer to the captions for observations specific to every example.
C.2 Solubility
As mentioned in the discussion (Sec. 5), choosing to model molecules as graphs yields a non-redundant and highly structured representation. As a consequence, the ability to slightly alter a molecule by performing small steps in Euclidean space is lost. This makes it hard to verify that explanations correspond to how the trained network predicts solubility. In fact, it is not possible to automatically alter a molecule according to the importance of its atoms/bonds and still obtain a valid molecule. In this case, it is necessary to apply domain-specific knowledge and identify which changes are viable in the space of valid molecules. In Figure 7 we show a trivial example where we use LRP to identify important atoms/bonds of a molecule and progressively remove them to reduce the predicted solubility.
Solubility explanations
In figures 18 and 19 we present a visualization of the explanation for the predicted solubility of glucose (moderately soluble) and 4-hexylresorcinol (moderately insoluble). The explanations are produced by applying LRP and propagating positive and negative relevance to the individual features of the atoms/bonds.