Self-Attention Graph Pooling

Junhyun Lee, Inyeop Lee, Jaewoo Kang

Introduction

The advent of deep learning has led to extensive improvements in technology used to recognize and utilize patterns in data (LeCun et al., 2015). In particular, convolutional neural networks (CNNs) successfully leverage the properties of data such as images, speech, and video on Euclidean domains (grid structure) (Hinton et al., 2012; Krizhevsky et al., 2012; He et al., 2016; Karpathy et al., 2014). CNNs consist of convolutional layers and downsampling (pooling) layers. The convolutional and pooling layers exploit the shift-invariance (also known as stationary) property and compositionality of grid-structured data (Simoncelli & Olshausen, 2001; Bronstein et al., 2017). As a result, CNNs perform well with a small number of parameters.

In various fields, however, a large amount of data, such as graphs, exists on the non-Euclidean domain. For example, social networks, biological networks, and molecular structures can be represented by nodes and edges of graphs (Lazer et al., 2009; Davidson et al., 2002; Duvenaud et al., 2015). Therefore, attempts have been made to successfully use CNNs in the non-Euclidean domain. Most previous studies have redefined the convolution and pooling layers to process graph data.

To define graph convolution, studies have used the spectral (Bruna et al., 2014; Henaff et al., 2015; Defferrard et al., 2016; Kipf & Welling, 2016) and non-spectral (Monti et al., 2017; Hamilton et al., 2017; Xu et al., 2018a; Veličković et al., 2018; Morris et al., 2018) methods. The application of graph convolution has resulted in outstanding performance in a variety of fields which include recommender systems (van den Berg et al., 2017; Yao & Li, 2018; Monti et al., 2017), chemical researches (You et al., 2018; Zitnik et al., 2018), natural language processing (Bastings et al., 2017; Peng et al., 2018; Yao et al., 2018), and in many tasks as reported in Zhou et al..

There are fewer methods for graph pooling than for graph convolution. Previous researches have adopted the pooling method that considers only graph topology (Defferrard et al., 2016; Rhee et al., 2018). With growing interest in graph pooling, several improved methods have been proposed (Dai et al., 2016; Duvenaud et al., 2015; Gilmer et al., 2017b; Zhang et al., 2018b). They utilize node features to obtain a smaller graph representation. Recently, Ying et al.; Gao & Ji; Cangea et al. have proposed innovative pooling methods that can learn hierarchical representations of graphs. These methods allow Graph Neural Networks (GNNs) to attain scaled-down graphs after pooling in an end-to-end fashion.

However, the above pooling methods have room for improvement. For example, the differentiable hierarchical pooling method of Ying et al. has a quadratic storage complexity and the number of its parameters is dependent on the number of nodes. Gao & Ji; Cangea et al. have addressed the complexity issue, but their method does not take graph topology into account.

Here, we propose SAGPool which is a Self-Attention Graph Pooling method for GNNs in the context of hierarchical graph pooling. Our method can learn hierarchical representations in an end-to-end fashion using relatively few parameters. The self-attention mechanism is exploited to distinguish between the nodes that should be dropped and the nodes that should be retained. Due to the self-attention mechanism which uses graph convolution to calculate attention scores, node features and graph topology are considered. In short, SAGPool, which has the advantages of the previous methods, is the first method to use self-attention for graph pooling and achieve high performance. The code is available on Github https://github.com/inyeoplee77/SAGPool

Related Work

GNNs have drawn considerable attention due to their state-of-the-art performance on tasks in the graph domain. Studies on GNNs focus on extending the convolution and pooling operation, which are the main components of CNN, to graphs.

Convolution operation on graphs can be defined in either the spectral or non-spectral domain. Spectral approaches focus on redefining the convolution operation in the Fourier domain, utilizing spectral filters that use the graph Laplacian. Kipf & Welling proposed a layer-wise propagation rule that simplifies the approximation of the graph Laplacian using the Chebyshev expansion method (Defferrard et al., 2016). The goal of non-spectral approaches is to define the convolution operation so that it works directly on graphs. In general non-spectral approaches, the central node aggregates features from adjacent nodes when its features are passed to the next layer rather than defining the convolution operation in the Fourier domain. Hamilton et al. proposed GraphSAGE which learns node embeddings through sampling and aggregation. While GraphSAGE operates in a fixed-size neighborhood, Graph Attention Network (GAT) (Veličković et al., 2018), based on attention mechanisms (Bahdanau et al., 2014), computes node representations in entire neighborhoods. Both approaches have improved performance on graph-related tasks.

2 Graph Pooling

Pooling layers enable CNN models to reduce the number of parameters by scaling down the size of representations, and thus avoid overfitting. To generalize CNNs, the pooling method for GNNs is necessary. Graph pooling methods can be grouped into the following three categories: topology based, global, and hierarchical pooling.

Topology based pooling Earlier works used graph coarsening algorithms rather than neural networks. Spectral clustering algorithms use eigendecomposition to obtain coarsened graphs. However, alternatives were needed due to the time complexity of eigendecomposition. Graclus(Dhillon et al., 2007) computes clustered versions of given graphs without eigenvectors because of the mathematical equivalence between a general spectral clustering objective and a weighted kernel k-means objective. Even in recent GNN models (Defferrard et al., 2016; Rhee et al., 2018), Graclus is employed as a pooling module.

Global pooling Unlike the previous methods, global pooling methods consider graph features. Global pooling methods use summation or neural networks to pool all the representations of nodes in each layer. Graphs with different structures can be processed because global pooling methods collect all the representations. Gilmer et al. viewed GNNs as message passing schemes, and proposed a general framework for graph classification where entire graph representations could be obtained by utilizing the Set2Set(Vinyals et al., 2015) method. SortPool(Zhang et al., 2018b) sorts embeddings for nodes according to the structural roles of a graph and feeds the sorted embeddings to the next layers.

where XX denotes the node feature matrix and AA is the adjacency matrix.

Cangea et al. utilized gPool(Gao & Ji, 2019) and achieved performance comparable to that of DiffPool. gPool requires a storage complexity of O(V+E)\mathcal{O}(|V|+|E|) whereas DiffPool requires O(kV2)\mathcal{O}(k|V|^{2}) where VV, EE, and kk denote vertices, edges, and pooling ratio, respectively. gPool uses a learnable vector pp to calculate projection scores, and then uses the scores to select the top ranked nodes. Projection scores are obtained by the dot product between pp and the features of all the nodes. The scores indicate the amount of information of nodes that can be retained. The following equation roughly describes the pooling procedure in gPool.

As in Equation (2), the graph topology does not affect the projection scores.

To further improve graph pooling, we propose SAGPool which can use features and topology to yield hierarchical representations with a reasonable complexity of time and space.

Proposed Method

The key point of SAGPool is that it uses a GNN to provide self-attention scores. In Section 3.1, we describe the mechanism of SAGPool and its variants. Model architectures for the evaluations are described in Section 3.2. The SAGPool layer and the model architectures are illustrated in Figure 1 and Figure 2, respectively.

where top-rank is the function that returns the indices of the top kN\lceil kN\rceil values, idx\cdot_{\text{idx}} is an indexing operation and ZmaskZ_{mask} is the feature attention mask.

Graph pooling An input graph is processed by the operation notated as masking in Figure 1.

where Xidx,:X_{\text{idx},:} is the row-wise (i.e. node-wise) indexed feature matrix, \odot is the broadcasted elementwise product, and Aidx,idxA_{\text{idx},\text{idx}} is the row-wise and col-wise indexed adjacency matrix. XoutX_{out} and AoutA_{out} are the new feature matrix and the corresponding adjacency matrix, respectively.

where XX denotes the node feature matrix and AA is the adjacency matrix.

There are several ways to calculate attention scores using not only adjacent nodes but also multi-hop connected nodes. In Equation (7) and (8), we illustrate examples of using the two-hop connections which involve the augmentation of edges and the stack of GNN layers. Adding the square of an adjacency matrix creates edges between two-hop neighbors.

The stack of GNN layers allows for the indirect aggregation of two-hop nodes. In this case, the nonlinearity and the number of parameters of the SAGPool layer increase.

Equations (7) and (8) can be applied to the multi-hop connections.

Another variant is to average multiple attention scores. The average attention score is obtained by MM GNNs as follows:

In this paper, the models using Equation (7), (8), and (9) are referred to as SAGPoolaugmentation{}_{\text{augmentation}}, SAGPoolserial{}_{\text{serial}} , and SAGPoolparallel{}_{\text{parallel}}, respectively.

2 Model Architecture

According to Lipton & Steinhardt, if numerous modifications are made to a model, it may be difficult to identify which modification contributes to improving performance. For a fair comparison, we adopted the model architectures from Zhang et al. and Cangea et al., and compared the baselines and our method using the same architectures.

Convolution layer As mentioned in Section 2.1, there are many definitions for graph convolution. Other types of graph convolution may improve performance, but we utilize the widely used graph convolution proposed by Kipf & Welling for all the models. Equation (10) is the same as Equation (3), except for the dimension of Θ\Theta.

Readout layer Inspired by the JK-net architecture (Xu et al., 2018b), Cangea et al. proposed a readout layer that aggregates node features to make a fixed size representation. The summarized output feature of the readout layer is as follows:

where NN is the number of nodes, xix_{i} is the feature vector of ii-th node, and || denotes concatenation.

Global pooling architecture We implemented the global pooling architecture proposed by Zhang et al.. As shown in Figure 2, the global pooling architecture consists of three graph convolutional layers and the outputs of each layer are concatenated. Node features are aggregated in the readout layer which follows the pooling layer. Then graph feature representations are passed to the linear layer for classification.

Hierarchical pooling architecture In this setting, we implemented the hierarchical pooling architecture from the recent hierarchical pooling study of Cangea et al.. As shown in Figure 2, the architecture is comprised of three blocks each of which consists of a graph convolutional layer and a graph pooling layer. The outputs of each block are summarized in the readout layer. The summation of the outputs of each readout layer is fed to the linear layer for classification.

Experiments

We evaluate the global pooling and hierarchical pooling methods on the graph classification task. In Section 4.1, we discuss the datasets used for evaluation. Section 4.3 describes how we train the models. The methods compared in the experiments are introduced in Sections 4.4 and 4.5.

Five datasets with a large number of graphs (>1>1k) were selected among the benchmark datasets (Kersting et al., 2016). The statistics of the datasets are summarized in Table 1.

D&D (Dobson & Doig, 2003; Shervashidze et al., 2011) contains graphs of protein structures. A node represents an amino acid and edges are constructed if the distance of two nodes is less than 6 Å. A label denotes whether a protein is an enzyme or non-enzyme. PROTEINS (Dobson & Doig, 2003; Borgwardt et al., 2005) is also a set of proteins, where nodes are secondary structure elements. If nodes have edges, the nodes are in an amino acid sequence or in a close 3D space. NCI (Wale et al., 2008) is a biological dataset used for anticancer activity classification. In the dataset, each graph represents a chemical compound, with nodes and edges representing atoms and chemical bonds, respectively. NCI1 and NCI109 are commonly used as benchmark datasets for graph classification. FRANKENSTEIN (Orsini et al., 2015) is a set of molecular graphs (Costa & Grave, 2010) with node features containing continuous values. A label denotes whether a molecule is a mutagen or non-mutagen.

2 Evaluation of GNNs

In addition, the same early stopping criterion and hyperparameter selection strategy are used for all the models to ensure a fair comparison.

3 Training Procedures

Shchur et al. demonstrate that different splits of data can affect the performance of GNN models. In our experiments, we evaluated the pooling methods over 20 random seeds using 10-fold cross validation. A total of 200 testing results were used to obtain the final accuracy of each method on each dataset. 10 percent of the training data was used for validation in the training session. We used the Adam optimizer (Kingma & Ba, 2014), early stopping criterion, patience, and hyperparameter selection strategy for the global pooling architecture and hierarchical pooling architecture. We stopped the training if the validation loss did not improve for 50 epochs in an epoch termination condition with a maximum of 100k epochs, as done in (Shchur et al., 2018). The optimal hyperparameters are obtained by grid search. The ranges of grid search are summarized in Table 2.

4 Baselines

We consider the following four pooling methods as baselines: Set2Set, SortPool, DiffPool, and gPool. DiffPool, gPool, and SAGPoolh were compared using the hierarchical pooling architecture while Set2Set, SortPool, and SAGPoolg were compared using the global pooling architecture. We used the same hyperparameter search strategy for all the baselines and SAGPool. The hyperparameters are summarized in Table 2.

Set2Set (Vinyals et al., 2015) requires an additional hyperparameter which is the number of processing steps for the LSTM(Hochreiter & Schmidhuber, 1997) module. We use 10 processing steps for all the experiments. We assume that the readout layer is unnecessary because the LSTM module produces embeddings for graphs invariant to the order of nodes.

SortPool (Zhang et al., 2018b) is a recent global pooling method which uses sorting for pooling. The KK number of nodes is set such that 60% of graphs have more than KK nodes. In the global pooling setting, SAGPoolg has the same KK number of output nodes as SortPool.

DiffPool (Ying et al., 2018) is the first end-to-end trainable graph pooling method that can produce hierarchical representations of graphs. We did not use batch normalization for DiffPool, which is not related to the pooling method. For the hyperparameter search, the pooling ratio ranges from 0.25 to 0.5 for the following reasons. In the reference implementation, the cluster size is set to 25% of the maximum number of nodes. DiffPoolh causes the out of memory error when the pooling ratio is larger than 0.5.

gPool (Gao & Ji, 2019) selects top-ranked nodes for pooling, which makes it similar to our method. The comparison between our method and gPool demonstrates that considering topology can help improve performance on the graph classification task.

5 Variations of SAGPool

As mentioned in section 3.1, three variations of SAGPool are used to obtain attention scores ZZ. In our experiments, we compared each variant on the two datasets. First, any kind of GNNs can be applied to Equation (6). We compared the performance of the three most widely used GNNs (SAGPoolCheb{}_{\text{Cheb}}, SAGPoolSAGE{}_{\text{SAGE}}, SAGPoolGAT{}_{\text{GAT}}). Second, we made the following modifications to SAGPool so that it can consider the two-hop connection: an edge augmentation (SAGPoolaugmentation{}_{\text{augmentation}}) in Equation (7) and a stack of GNN layers (SAGPoolserial{}_{\text{serial}}) in Equation (8). Last, multiple GNNs calculate attention scores and the scores are averaged to obtain the final attention score (SAGPoolparallel{}_{\text{parallel}}). We evaluated the performance of M=2M=2 and M=4M=4 using Equation (9). The results are summarized in Table 4.

6 Summary of Results

The results are summarized in Table 3 and 4. The accuracies and standard deviations are given in percentages. From the comparison of the global pooling methods and SAGPool, the results demonstrate that SAGPool generally performs well, but it performs especially well on D&D and PROTEINS. In the experiments, SAGPool outperformed the hierarchical pooling methods on all the datasets. We also compared variants of SAGPool with the hierarchical pooling architecture on the two benchmark datasets. The performance of the variants of SAGPool varied. The experimental results of the SAGPool variants show that SAGPool has the potential to improve performance. A detailed analysis of the experimental results is provided in the next section.

Analysis

In this section, we provide an analysis of the experimental results. In Section 5.1, we compare global pooling and hierarchical pooling. Section 5.2 provides an explanation on how the SAGPool method addresses the shortcomings of the gPool method. In the 5.3 and 5.4 sections, we compare the efficiency of SAGPool with that of DiffPool. We provide an analysis of SAGPool variants in Section 5.5.

It is difficult to determine whether the global pooling architecture or hierarchical pooling architecture is completely beneficial to graph classification. Since the global pooling architecture POOLgPOOL_{g} (SAGPoolg, SortPoolg, Set2Setg) minimizes the loss of information, it performs better than the hierarchical pooling architecture POOLhPOOL_{h} (SAGPoolh, gPoolh, DiffPoolh) on datasets with fewer nodes (NCI1, NCI109, FRANKENSTEIN). However, POOLhPOOL_{h} is more effective on datasets with a large number of nodes (D&D, PROTEINS) because it efficiently extracts useful information from large scale graphs. Therefore, it is important to use the pooling architecture that is the most suitable for the given data. Nonetheless, SAGPool tends to perform well with each architecture.

2 Effect of Considering Graph Topology

3 Sparse Implementation

Manipulating graph data with a sparse matrix is important for GNNs because the adjacency matrix is usually sparse. When graph convolution is calculated using a dense matrix, the computational complexity of multiplication AXAX is O(V2)\mathcal{O}(|V|^{2}) where AA is the adjacency matrix, XX is the feature matrix of nodes, and VV denotes vertices. Pooling with a dense matrix causes the memory efficiency problem, as mentioned by (Cangea et al., 2018). However, if a sparse matrix is used in the same operation, the computational complexity is reduced to O(E)\mathcal{O}(|E|) where EE represents the edges. Since SAGPool is a sparse pooling method, it can reduce its computational complexity, unlike DiffPool which is a dense pooling method. Sparseness also affects space complexity. Since SAGPool uses GNN for obtaining attention scores, SAGPool requires O(V+E)\mathcal{O}(|V|+|E|) of storage for sparse pooling whereas dense pooling methods need O(V2)\mathcal{O}(|V|^{2}).

4 Relation with the Number of Nodes

In DiffPool, the cluster size has to be defined when constructing a model because a GNN produces an assignment matrix SS as stated in Equation (1). The cluster size has to be proportional to the maximum number of nodes according to the reference implementation. These requirements of DiffPool can lead to two problems. First, the number of parameters is dependent on the maximum number of nodes as shown in Figure 3. Second, it is difficult to determine the right cluster size when the number of nodes varies greatly. For example, only 10 out of 1178 graphs have over 1000 nodes, where the maximum number of nodes is 5748 and the minimum is 30. The cluster size is 574 if the pooling ratio is 10%, which expands the size of graphs after pooling for most of the data. On the other hand, in SAGPool, the number of parameters is independent of the cluster size. In addition, the cluster size can be changed based on the number of input nodes.

5 Comparison of the SAGPool Variants

To investigate the potential of our method, we evaluated SAGPool variants on two datasets. SAGPool can be modified to perform the following: changing the type of GNN, considering the two-hop connections, and averaging the attention scores of multiple GNNs. As shown in Table 4, the performance on graph classification varies depending on which dataset and type of GNN in SAGPool are used. We used two techniques to consider two-hop connections. The attention scores obtained by the two sequential GNN layers (SAGPoolserial{}_{\text{serial}}) reflect the information of two-hop neighbors. Another technique is to add the square of an adjacency matrix to itself, resulting in a new adjacency matrix that has two-hop connectivity. Without any modifications to the SAGPool layer, the new adjacency matrix can be processed in SAGPoolaugmentation{}_{\text{augmentation}}. The information of two-hop neighbors may help improve performance. The last variants of SAGPool is to average the attention scores from multiple GNNs. We found that choosing the right MM for the dataset can help achieve stable performance.

6 Limitations

We retain a certain percentage (pooling ratio kk) of nodes to handle different input graphs of various sizes, which has also been done in previous studies (Gao & Ji, 2019; Cangea et al., 2018). In SAGPool, we cannot parameterize the pooling ratios to find optimal values for each graph. To address this limitation, we used binary classification to decide which nodes to preserve, but this did not completely solve the issue.

Conclusion

In this paper, we proposed SAGPool which is a novel graph pooling method based on self-attention. Our method has the following features: hierarchical pooling, consideration of both node features and graph topology, reasonable complexity, and end-to-end representation learning. SAGPool uses a consistent number of parameters regardless of the input graph size. Extensions of our work may include using learnable pooling ratios to obtain optimal cluster sizes for each graph and studying the effects of multiple attention masks in each pooling layer, where final representations can be derived by aggregating different hierarchical representations. Our experiments were run on a NVIDIA TitanXp GPU. We implemented all the baselines and SAGPool using PyTorch (Paszke et al., 2017) and the geometric deep learning extension library provided by Fey & Lenssen.

References