Hierarchical Graph Representation Learning with Differentiable Pooling

Rex Ying, Jiaxuan You, Christopher Morris, Xiang Ren, William L. Hamilton, Jure Leskovec

Introduction

In recent years there has been a surge of interest in developing graph neural networks (GNNs)—general deep learning architectures that can operate over graph structured data, such as social network data or graph-based representations of molecules . The general approach with GNNs is to view the underlying graph as a computation graph and learn neural network primitives that generate individual node embeddings by passing, transforming, and aggregating node feature information across the graph . The generated node embeddings can then be used as input to any differentiable prediction layer, e.g., for node classification or link prediction , and the whole model can be trained in an end-to-end fashion.

However, a major limitation of current GNN architectures is that they are inherently flat as they only propagate information across the edges of the graph and are unable to infer and aggregate the information in a hierarchical way. For example, in order to successfully encode the graph structure of organic molecules, one would ideally want to encode the local molecular structure (e.g., individual atoms and their direct bonds) as well as the coarse-grained structure of the molecular graph (e.g., groups of atoms and bonds representing functional units in a molecule). This lack of hierarchical structure is especially problematic for the task of graph classification, where the goal is to predict the label associated with an entire graph. When applying GNNs to graph classification, the standard approach is to generate embeddings for all the nodes in the graph and then to globally pool all these node embeddings together, e.g., using a simple summation or neural network that operates over sets . This global pooling approach ignores any hierarchical structure that might be present in the graph, and it prevents researchers from building effective GNN models for predictive tasks over entire graphs.

Here we propose DiffPool, a differentiable graph pooling module that can be adapted to various graph neural network architectures in an hierarchical and end-to-end fashion (Figure 1). DiffPool allows for developing deeper GNN models that can learn to operate on hierarchical representations of a graph. We develop a graph analogue of the spatial pooling operation in CNNs , which allows for deep CNN architectures to iteratively operate on coarser and coarser representations of an image. The challenge in the GNN setting—compared to standard CNNs—is that graphs contain no natural notion of spatial locality, i.e., one cannot simply pool together all nodes in a “m×mm\times m patch” on a graph, because the complex topological structure of graphs precludes any straightforward, deterministic definition of a “patch”. Moreover, unlike image data, graph data sets often contain graphs with varying numbers of nodes and edges, which makes defining a general graph pooling operator even more challenging.

In order to solve the above challenges, we require a model that learns how to cluster together nodes to build a hierarchical multi-layer scaffold on top of the underlying graph. Our approach DiffPool learns a differentiable soft assignment at each layer of a deep GNN, mapping nodes to a set of clusters based on their learned embeddings. In this framework, we generate deep GNNs by “stacking” GNN layers in a hierarchical fashion (Figure 1): the input nodes at the layer ll GNN module correspond to the clusters learned at the layer l1l-1 GNN module. Thus, each layer of DiffPool coarsens the input graph more and more, and DiffPool is able to generate a hierarchical representation of any input graph after training. We show that DiffPool can be combined with various GNN approaches, resulting in an average 7% gain in accuracy and a new state of the art on four out of five benchmark graph classification tasks. Finally, we show that DiffPool can learn interpretable hierarchical clusters that correspond to well-defined communities in the input graphs.

Related Work

Our work builds upon a rich line of recent research on graph neural networks and graph classification.

General graph neural networks. A wide variety of graph neural network (GNN) models have been proposed in recent years, including methods inspired by convolutional neural networks , recurrent neural networks , recursive neural networks and loopy belief propagation . Most of these approaches fit within the framework of “neural message passing” proposed by Gilmer et al. . In the message passing framework, a GNN is viewed as a message passing algorithm where node representations are iteratively computed from the features of their neighbor nodes using a differentiable aggregation function. Hamilton et al. provides a conceptual review of recent advancements in this area, and Bronstein et al. outlines connections to spectral graph convolutions.

Graph classification with graph neural networks. GNNs have been applied to a wide variety of tasks, including node classification , link prediction , graph classification , and chemoinformatics . In the context of graph classification—the task that we study here—a major challenge in applying GNNs is going from node embeddings, which are the output of GNNs, to a representation of the entire graph. Common approaches to this problem include simply summing up or averaging all the node embeddings in a final layer , introducing a “virtual node” that is connected to all the nodes in the graph , or aggregating the node embeddings using a deep learning architecture that operates over sets . However, all of these methods have the limitation that they do not learn hierarchical representations (i.e., all the node embeddings are globally pooled together in a single layer), and thus are unable to capture the natural structures of many real-world graphs. Some recent approaches have also proposed applying CNN architectures to the concatenation of all the node embeddings , but this requires a specifying (or learning) a canonical ordering over nodes, which is in general very difficult and equivalent to solving graph isomorphism.

Lastly, there are some recent works that learn hierarchical graph representations by combining GNNs with deterministic graph clustering algorithms , following a two-stage approach. However, unlike these previous approaches, we seek to learn the hierarchical structure in an end-to-end fashion, rather than relying on a deterministic graph clustering subroutine.

Proposed Method

The key idea of DiffPool is that it enables the construction of deep, multi-layer GNN models by providing a differentiable module to hierarchically pool graph nodes. In this section, we outline the DiffPool module and show how it is applied in a deep GNN architecture.

Graph neural networks. In this work, we build upon graph neural networks in order to learn useful representations for graph classification in an end-to-end fashion. In particular, we consider GNNs that employ the following general “message-passing” architecture:

There are many possible implementations of the propagation function MM . For example, one popular variant of GNNs—Kipf’s et al. Graph Convolutional Networks (GCNs)—implements MM using a combination of linear transformations and ReLU non-linearities:

2 Differentiable Pooling via Learned Assignments

Our proposed approach, DiffPool, addresses the above challenges by learning a cluster assignment matrix over the nodes using the output of a GNN model. The key intuition is that we stack LL GNN modules and learn to assign nodes to clusters at layer ll in an end-to-end fashion, using embeddings generated from a GNN at layer l1l-1. Thus, we are using GNNs to both extract node embeddings that are useful for graph classification, as well to extract node embeddings that are useful for hierarchical pooling. Using this construction, the GNNs in DiffPool learn to encode a general pooling strategy that is useful for a large set of training graphs. We first describe how the DiffPool module pools nodes at each layer given an assignment matrix; following this, we discuss how we generate the assignment matrix using a GNN architecture.

Suppose that S(l)S^{(l)} has already been computed, i.e., that we have computed the assignment matrix at the ll-th layer of our model. We denote the input adjacency matrix at this layer as A(l)A^{(l)} and denote the input node embedding matrix at this layer as Z(l)Z^{(l)}. Given these inputs, the DiffPool layer (A(l+1),X(l+1))=\textscDiffPool(A(l),Z(l))(A^{(l+1)},X^{(l+1)})=\textsc{DiffPool}(A^{(l)},Z^{(l)}) coarsens the input graph, generating a new coarsened adjacency matrix A(l+1)A^{(l+1)} and a new matrix of embeddings X(l+1)X^{(l+1)} for each of the nodes/clusters in this coarsened graph. In particular, we apply the two following equations:

Equation (3) takes the node embeddings Z(l)Z^{(l)} and aggregates these embeddings according to the cluster assignments S(l)S^{(l)}, generating embeddings for each of the nl+1n_{l+1} clusters. Similarly, Equation (4) takes the adjacency matrix A(l)A^{(l)} and generates a coarsened adjacency matrix denoting the connectivity strength between each pair of clusters.

Through Equations (3) and (4), the DiffPool layer coarsens the graph: the next layer adjacency matrix A(l+1)A^{(l+1)} represents a coarsened graph with nl+1n_{l+1} nodes or cluster nodes, where each individual cluster node in the new coarsened graph corresponds to a cluster of nodes in the graph at layer ll. Note that A(l+1)A^{(l+1)} is a real matrix and represents a fully connected edge-weighted graph; each entry Aij(l+1)A^{(l+1)}_{ij} can be viewed as the connectivity strength between cluster ii and cluster jj. Similarly, the ii-th row of X(l+1)X^{(l+1)} corresponds to the embedding of cluster ii. Together, the coarsened adjacency matrix A(l+1)A^{(l+1)} and cluster embeddings X(l+1)X^{(l+1)} can be used as input to another GNN layer, a process which we describe in detail below.

Learning the assignment matrix. In the following we describe the architecture of DiffPool, i.e., how DiffPool generates the assignment matrix S(l)S^{(l)} and embedding matrices Z(l)Z^{(l)} that are used in Equations (3) and (4). We generate these two matrices using two separate GNNs that are both applied to the input cluster node features X(l)X^{(l)} and coarsened adjacency matrix A(l)A^{(l)}. The embedding GNN at layer ll is a standard GNN module applied to these inputs:

i.e., we take the adjacency matrix between the cluster nodes at layer ll (from Equation 4) and the pooled features for the clusters (from Equation 3) and pass these matrices through a standard GNN to get new embeddings Z(l)Z^{(l)} for the cluster nodes. In contrast, the pooling GNN at layer ll, uses the input cluster features X(l)X^{(l)} and cluster adjacency matrix A(l)A^{(l)} to generate an assignment matrix:

where the softmax function is applied in a row-wise fashion. The output dimension of GNNl,pool\textrm{GNN}_{l,\text{pool}} corresponds to a pre-defined maximum number of clusters in layer ll, and is a hyperparameter of the model.

Note that these two GNNs consume the same input data but have distinct parameterizations and play separate roles: The embedding GNN generates new embeddings for the input nodes at this layer, while the pooling GNN generates a probabilistic assignment of the input nodes to nl+1n_{l+1} clusters.

In the base case, the inputs to Equations (5) and Equations (6) at layer l=0l=0 are simply the input adjacency matrix AA and the node features FF for the original graph. At the penultimate layer L1L-1 of a deep GNN model using DiffPool, we set the assignment matrix S(L1)S^{(L-1)} be a vector of 11’s, i.e., all nodes at the final layer LL are assigned to a single cluster, generating a final embedding vector corresponding to the entire graph. This final output embedding can then be used as feature input to a differentiable classifier (e.g., a softmax layer), and the entire system can be trained end-to-end using stochastic gradient descent.

Permutation invariance. Note that in order to be useful for graph classification, the pooling layer should be invariant under node permutations. For DiffPool we get the following positive result, which shows that any deep GNN model based on DiffPool is permutation invariant, as long as the component GNNs are permutation invariant.

Let P{0,1}n×nP\in\{0,1\}^{n\times n} be any permutation matrix, then \scDiffPool(A,Z)=\scDiffPool(PAPT,PX)\text{\sc{DiffPool}}(A,Z)=\text{\sc{DiffPool}}(PAP^{T},PX) as long as GNN(A,X)=GNN(PAPT,X)\textrm{GNN}(A,X)=\textrm{GNN}(PAP^{T},X) (i.e., as long as the GNN method used is permutation invariant).

Equations (5) and (6) are permutation invariant by the assumption that the GNN module is permutation invariant. And since any permutation matrix is orthogonal, applying PTP=IP^{T}P=I to Equation (3) and (4) finishes the proof. ∎

3 Auxiliary Link Prediction Objective and Entropy Regularization

In practice, it can be difficult to train the pooling GNN (Equation 4) using only gradient signal from the graph classification task. Intuitively, we have a non-convex optimization problem and it can be difficult to push the pooling GNN away from spurious local minima early in training. To alleviate this issue, we train the pooling GNN with an auxiliary link prediction objective, which encodes the intuition that nearby nodes should be pooled together. In particular, at each layer ll, we minimize LLP=A(l),S(l)S(l)TFL_{\text{LP}}=||A^{(l)},S^{(l)}S^{{(l)}^{T}}||_{F}, where F||\cdot||_{F} denotes the Frobenius norm. Note that the adjacency matrix A(l)A^{(l)} at deeper layers is a function of lower level assignment matrices, and changes during training.

Another important characteristic of the pooling GNN (Equation 4) is that the output cluster assignment for each node should generally be close to a one-hot vector, so that the membership for each cluster or subgraph is clearly defined. We therefore regularize the entropy of the cluster assignment by minimizing LE=1ni=1nH(Si)L_{\text{E}}=\frac{1}{n}\sum_{i=1}^{n}H(S_{i}), where HH denotes the entropy function, and SiS_{i} is the ii-th row of SS.

During training, LLPL_{\text{LP}} and LEL_{\text{E}} from each layer are added to the classification loss. In practice we observe that training with the side objective takes longer to converge, but nevertheless achieves better performance and more interpretable cluster assignments.

Experiments

We evaluate the benefits of DiffPool against a number of state-of-the-art graph classification approaches, with the goal of answering the following questions:

How does DiffPool compare to other pooling methods proposed for GNNs (e.g., using sort pooling or the Set2Set method )?

How does DiffPool combined with GNNs compare to the state-of-the-art for graph classification task, including both GNNs and kernel-based methods?

Does DiffPool compute meaningful and interpretable clusters on the input graphs?

Data sets. To probe the ability of DiffPool to learn complex hierarchical structures from graphs in different domains, we evaluate on a variety of relatively large graph data sets chosen from benchmarks commonly used in graph classification . We use protein data sets including Enzymes, Proteins , D&D , the social network data set Reddit-Multi-12k , and the scientific collaboration data set Collab . See Appendix A for statistics and properties. For all these data sets, we perform 10-fold cross-validation to evaluate model performance, and report the accuracy averaged over 10 folds.

DiffPool-Det, is a DiffPool model where assignment matrices are generated using a deterministic graph clustering algorithm .

DiffPool-NoLP is a variant of DiffPool where the link prediction side objective is turned off.

In the performance comparison on graph classification, we consider baselines based upon GNNs (combined with different pooling methods) as well as state-of-the-art kernel-based approaches.

GraphSage with global mean-pooling . Other GNN variants such as those proposed in are omitted as empirically GraphSAGE obtained higher performance in the task.

Structure2Vec (S2V) is a state-of-the-art graph representation learning algorithm that combines a latent variable model with GNNs. It uses global mean pooling.

Edge-conditioned filters in CNN for graphs (ECC) incorporates edge information into the GCN model and performs pooling using a graph coarsening algorithm.

PatchySan defines a receptive field (neighborhood) for each node, and using a canonical node ordering, applies convolutions on linear sequences of node embeddings.

Set2Set replaces the global mean-pooling in the traditional GNN architectures by the aggregation used in Set2Set . Set2Set aggregation has been shown to perform better than mean pooling in previous work . We use GraphSage as the base GNN model.

SortPool applies a GNN architecture and then performs a single layer of soft pooling followed by 1D convolution on sorted node embeddings.

For all the GNN baselines, we use 10-fold cross validation numbers reported by the original authors when possible. For the GraphSage and Set2Set baselines, we use the base implementation and hyperparameter sweeps as in our DiffPool approach. When baseline approaches did not have the necessary published numbers, we contacted the original authors and used their code (if available) to run the model, performing a hyperparameter search based on the original author’s guidelines.

Kernel-based algorithms. We use the Graphlet , the Shortest-Path , Weisfeiler-Lehman kernel (WL) , and Weisfeiler-Lehman Optimal Assignment kernel (WL-OA) as kernel baselines. For each kernel, we computed the normalized gram matrix. We computed the classification accuracies using the CC-SVM implementation of LibSvm , using 10-fold cross validation. The CC parameter was selected from {103,102,,102,\{10^{-3},10^{-2},\dotsc,10^{2}, 103}10^{3}\} by 10-fold cross validation on the training folds. Moreover, for WL and WL-OA we additionally selected the number of iteration from {0,,5}\{0,\dots,5\}.

2 Results for Graph Classification

Table 1 compares the performance of DiffPool to these state-of-the-art graph classification baselines. These results provide positive answers to our motivating questions Q1 and Q2: We observe that our DiffPool approach obtains the highest average performance among all pooling approaches for GNNs, improves upon the base GraphSage architecture by an average of 6.27%6.27\%, and achieves state-of-the-art results on 4 out of 5 benchmarks. Interestingly, our simplified model variant, DiffPool-Det, achieves state-of-the-art performance on the Collab benchmark. This is because many collaboration graphs in Collab show only single-layer community structures, which can be captured well with pre-computed graph clustering algorithm . One observation is that despite significant performance improvement, DiffPool could be unstable to train, and there is significant variation in accuracy across different runs, even with the same hyperparameter setting. It is observed that adding the link predictioin objective makes training more stable, and reduces the standard deviation of accuracy across different runs.

Differentiable Pooling on Structure2Vec. DiffPool can be applied to other GNN architectures besides GraphSage to capture hierarchical structure in the graph data. To further support answering Q1, we also applied DiffPool on Structure2Vec (S2V). We ran experiments using S2V with three layer architecture, as reported in . In the first variant, one DiffPool layer is applied after the first layer of S2V, and two more S2V layers are stacked on top of the output of DiffPool. The second variant applies one DiffPool layer after the first and second layer of S2V respectively. In both variants, S2V model is used to compute the embedding matrix, while GraphSage model is used to compute the assignment matrix.

The results in terms of classification accuracy are summarized in Table 2. We observe that DiffPool significantly improves the performance of S2V on both Enzymes and D&D data sets. Similar performance trends are also observed on other data sets. The results demonstrate that DiffPool is a general strategy to pool over hierarchical structure that can benefit different GNN architectures.

Running time. Although applying DiffPool requires additional computation of an assignment matrix, we observed that DiffPool did not incur substantial additional running time in practice. This is because each DiffPool layer reduces the size of graphs by extracting a coarser representation of the graph, which speeds up the graph convolution operation in the next layer. Concretely, we found that GraphSage with DiffPool was 12×\times faster than the GraphSage model with Set2Set pooling, while still achieving significantly higher accuracy on all benchmarks.

3 Analysis of Cluster Assignment in DiffPool

Hierarchical cluster structure. To address Q3, we investigated the extent to which DiffPool learns meaningful node clusters by visualizing the cluster assignments in different layers. Figure 2 shows such a visualization of node assignments in the first and second layers on a graph from Collab data set, where node color indicates its cluster membership. Node cluster membership is determined by taking the argmax\operatorname*{argmax} of its cluster assignment probabilities. We observe that even when learning cluster assignment based solely on the graph classification objective, DiffPool can still capture the hierarchical community structure. We also observe significant improvement in membership assignment quality with link prediction auxiliary objectives.

Dense vs. sparse subgraph structure. In addition, we observe that DiffPool learns to collapse nodes into soft clusters in a non-uniform way, with a tendency to collapse densely-connected subgraphs into clusters. Since GNNs can efficiently perform message-passing on dense, clique-like subgraphs (due to their small diameters) , pooling together nodes in such a dense subgraph is not likely to lead to any loss of structural information. This intuitively explains why collapsing dense subgraphs is a useful pooling strategy for DiffPool. In contrast, sparse subgraphs may contain many interesting structures, including path-, cycle- and tree-like structures, and given the high-diameter induced by sparsity, GNN message-passing may fail to capture these structures. Thus, by separately pooling distinct parts of a sparse subgraph, DiffPool can learn to capture the meaningful structures present in sparse graph regions (e.g., as in Figure 2).

Assignment for nodes with similar representations. Since the assignment network computes the soft cluster assignment based on features of input nodes and their neighbors, nodes with both similar input features and neighborhood structure will have similar cluster assignment. In fact, one can construct synthetic cases where 2 nodes, although far away, have exactly the same neighborhood structure and features for self and all neighbors. In this case the pooling network is forced to assign them into the same cluster, which is different from the concept of pooling in other architectures such as image ConvNets. In some cases we do observe that disconnected nodes are pooled together.

In practice we rely on the identifiability assumption similar to Theorem 1 in GraphSAGE , where nodes are identifiable via their features. This holds in many real datasets However, some chemistry molecular graph datasets contain many nodes that are structurally similar, and assignment network is observed to pool together nodes that are far away.. The auxiliary link prediction objective is observed to also help discouraging nodes that are far away to be pooled together. Furthermore, it is possible to use more sophisticated GNN aggregation function such as high-order moments to distinguish nodes that are similar in structure and feature space. The overall framework remains unchanged.

Sensitivity of the Pre-defined Maximum Number of Clusters. We found that the assignment varies according to the depth of the network and CC, the maximum number of clusters. With larger CC, the pooling GNN can model more complex hierarchical structure. The trade-off is that very large CC results in more noise and less efficiency. Although the value of CC is a pre-defined parameter, the pooling net learns to use the appropriate number of clusters by end-to-end training. In particular, some clusters might not be used by the assignment matrix. Column corresponding to unused cluster has low values for all nodes. This is observed in Figure 2(c), where nodes are assigned predominantly into 3 clusters.

Conclusion

We introduced a differentiable pooling method for GNNs that is able to extract the complex hierarchical structure of real-world graphs. By using the proposed pooling layer in conjunction with existing GNN models, we achieved new state-of-the-art results on several graph classification benchmarks. Interesting future directions include learning hard cluster assignments to further reduce computational cost in higher layers while also ensuring differentiability, and applying the hierarchical pooling method to other downstream tasks that require modeling of the entire graph structure.

Acknowledgement

This research has been supported in part by DARPA SIMPLEX, Stanford Data Science Initiative, Huawei, JD and Chan Zuckerberg Biohub. Christopher Morris is funded by the German Science Foundation (DFG) within the Collaborative Research Center SFB 876 “Providing Information by Resource-Constrained Data Analysis”, project A6 “Resource-efficient Graph Mining”. The authors also thank Marinka Zitnik for help in visualizing the high-level illustration of the proposed methods.

References