Bilinear Attention Networks

Jin-Hwa Kim, Jaehyun Jun, Byoung-Tak Zhang

Introduction

Machine learning for computer vision and natural language processing accelerates the advancement of artificial intelligence. Since vision and natural language are the major modalities of human interaction, understanding and reasoning of vision and natural language information become a key challenge. For instance, visual question answering involves a vision-language cross-grounding problem. A machine is expected to answer given questions like "who is wearing glasses?", "is the umbrella upside down?", or "how many children are in the bed?" exploiting visually-grounded information.

For this reason, visual attention based models have succeeded in multimodal learning tasks, identifying selective regions in a spatial map of an image defined by the model. Also, textual attention can be considered along with visual attention. The attention mechanism of co-attention networks concurrently infers visual and textual attention distributions for each modality. The co-attention networks selectively attend to question words in addition to a part of image regions. However, the co-attention neglects the interaction between words and visual regions to avoid increasing computational complexity.

In this paper, we extend the idea of co-attention into bilinear attention which considers every pair of multimodal channels, e.g., the pairs of question words and image regions. If the given question involves multiple visual concepts represented by multiple words, the inference using visual attention distributions for each word can exploit relevant information better than that using single compressed attention distribution.

From this background, we propose bilinear attention networks (BAN) to use a bilinear attention distribution, on top of low-rank bilinear pooling . Notice that the BAN exploits bilinear interactions between two groups of input channels, while low-rank bilinear pooling extracts the joint representations for each pair of channels. Furthermore, we propose a variant of multimodal residual networks (MRN) to efficiently utilize the multiple bilinear attention maps of the BAN, unlike the previous works where multiple attention maps are used by concatenating the attended features. Since the proposed residual learning method for BAN exploits residual summations instead of concatenation, which leads to parameter-efficiently and performance-effectively learn up to eight-glimpse BAN. For the overview of two-glimpse BAN, please refer to Figure 1.

We propose the bilinear attention networks (BAN) to learn and use bilinear attention distributions, on top of low-rank bilinear pooling technique.

We propose a variant of multimodal residual networks (MRN) to efficiently utilize the multiple bilinear attention maps generated by our model. Unlike previous works, our method successfully utilizes up to 8 attention maps.

Finally, we validate our proposed method on a large and highly-competitive dataset, VQA 2.0 . Our model achieves a new state-of-the-art maintaining simplicity of model structure. Moreover, we evaluate the visual grounding of bilinear attention map on Flickr30k Entities outperforming previous methods, along with 25.37% improvement of inference speed taking advantage of the processing of multi-channel inputs.

Low-rank bilinear pooling

We first review the low-rank bilinear pooling and its application to attention networks , which uses single-channel input (question vector) to combine the other multi-channel input (image features) as single-channel intermediate representation (attended feature).

Low-rank bilinear model. The previous works proposed a low-rank bilinear model to reduce the rank of bilinear weight matrix Wi\mathbf{W}_{i} to give regularity. For this, Wi\mathbf{W}_{i} is replaced with the multiplication of two smaller matrices UiViT\mathbf{U}_{i}\mathbf{V}_{i}^{T}, where Ui\mathdsRN×d\mathbf{U}_{i}\in\mathds{R}^{N\times d} and Vi\mathdsRM×d\mathbf{V}_{i}\in\mathds{R}^{M\times d}. As a result, this replacement makes the rank of Wi\mathbf{W}_{i} to be at most dmin(N,M)d\leq\min(N,M). For the scalar output fif_{i} (bias terms are omitted without loss of generality):

where \mathds1\mathdsRd\mathds{1}\in\mathds{R}^{d} is a vector of ones and \circ denotes Hadamard product (element-wise multiplication).

Low-rank bilinear pooling. For a vector output f\mathbf{f}, a pooling matrix P\mathbf{P} is introduced:

where P\mathdsRd×c\mathbf{P}\in\mathds{R}^{d\times c}, U\mathdsRN×d\mathbf{U}\in\mathds{R}^{N\times d}, and V\mathdsRM×d\mathbf{V}\in\mathds{R}^{M\times d}. It allows U\mathbf{U} and V\mathbf{V} to be two-dimensional tensors by introducing P\mathbf{P} for a vector output f\mathdsRc\mathbf{f}\in\mathds{R}^{c}, significantly reducing the number of parameters.

Unitary attention networks. Attention provides an efficient mechanism to reduce input channel by selectively utilizing given information. Assuming that a multi-channel input Y\mathbf{Y} consisting of ϕ={yi}\phi=|\{\mathbf{y}_{i}\}| column vectors, we want to get single channel y^\hat{\mathbf{y}} from Y\mathbf{Y} using the weights {αi}\{\alpha_{i}\}:

where α\mathdsRG×ϕ\alpha\in\mathds{R}^{G\times\phi}, P\mathdsRd×G\mathbf{P}\in\mathds{R}^{d\times G}, U\mathdsRN×d\mathbf{U}\in\mathds{R}^{N\times d}, x\mathdsRN\mathbf{x}\in\mathds{R}^{N}, \mathds1\mathdsRϕ\mathds{1}\in\mathds{R}^{\phi}, V\mathdsRM×d\mathbf{V}\in\mathds{R}^{M\times d}, and Y\mathdsRM×ϕ\mathbf{Y}\in\mathds{R}^{M\times\phi}. If G>1G>1, multiple glimpses (a.k.a. attention heads) are used , then y^=\bigparallelg=1Giαg,iyi\hat{\mathbf{y}}=\bigparallel_{g=1}^{G}\sum_{i}\alpha_{g,i}\mathbf{y}_{i}, the concatenation of attended outputs. Finally, two single channel inputs x\mathbf{x} and y^\hat{\mathbf{y}} can be used to get the joint representation using the other low-rank bilinear pooling for a classifier.

Bilinear attention networks

We generalize a bilinear model for two multi-channel inputs, X\mathdsRN×ρ\mathbf{X}\in\mathds{R}^{N\times\rho} and Y\mathdsRM×ϕ\mathbf{Y}\in\mathds{R}^{M\times\phi}, where ρ={xi}\rho=|\{\mathbf{x}_{i}\}| and ϕ={yj}\phi=|\{\mathbf{y}_{j}\}|, the numbers of two input channels, respectively. To reduce both input channel simultaneously, we introduce bilinear attention map A\mathdsRρ×ϕ\mathcal{A}\in\mathds{R}^{\rho\times\phi} as follows:

where Xi\mathbf{X}_{i} and Yj\mathbf{Y}_{j} denotes the ii-th channel (column) of input X\mathbf{X} and the jj-th channel (channel) of input Y\mathbf{Y}, respectively, Uk\mathbf{U}^{\prime}_{k} and Vk\mathbf{V}^{\prime}_{k} denotes the kk-th column of U\mathbf{U}^{\prime} and V\mathbf{V}^{\prime} matrices, respectively, and Ai,j\mathcal{A}_{i,j} denotes an element in the ii-th row and the jj-th column of A\mathcal{A}. Notice that, for each pair of channels, the 1-rank bilinear representation of two feature vectors is modeled in XiT(UkVkT)Yj\mathbf{X}_{i}^{T}(\mathbf{U}^{\prime}_{k}\mathbf{V}_{k}^{\prime T})\mathbf{Y}_{j} of Equation 6 (eventually at most KK-rank bilinear pooling for f\mathdsRK\mathbf{f}^{\prime}\in\mathds{R}^{K}). Then, the bilinear joint representation is f=\mathdsPTf\mathbf{f}=\mathds{\mathbf{P}}^{T}\mathbf{f}^{\prime} where f\mathdsRC\mathbf{f}\in\mathds{R}^{C} and \mathdsP\mathdsRK×C\mathds{\mathbf{P}}\in\mathds{R}^{K\times C}. For the convenience, we define the bilinear attention networks as a function of two multi-channel inputs parameterized by a bilinear attention map as follows:

Bilinear attention map. Now, we want to get the attention map similarly to Equation 4. Using Hadamard product and matrix-matrix multiplication, the attention map A\mathcal{A} is defined as:

The multiple bilinear attention maps can be extended as follows:

where the parameters of U\mathbf{U} and V\mathbf{V} are shared, but not for pg\mathbf{p}_{g} where gg denotes the index of glimpses.

Residual learning of attention. Inspired by multimodal residual networks (MRN) from Kim et al. , we propose a variant of MRN to integrate the joint representations from the multiple bilinear attention maps. The i+1i+1-th output is defined as:

where f0=X\mathbf{f}_{0}=\mathbf{X} (if N=KN=K) and \mathds1\mathdsRρ\mathds{1}\in\mathds{R}^{\rho}. Here, the size of fi\mathbf{f}_{i} is the same with the size of X\mathbf{X} as successive attention maps are processed. To get the logits for a classifier, e.g., two-layer MLP, we sum over the channel dimension of the last output fG\mathbf{f}_{G}, where GG is the number of glimpses.

Time complexity. When we assume that the number of input channels is smaller than feature sizes, MNKϕρM\geq N\geq K\gg\phi\geq\rho, the time complexity of the BAN is the same with the case of one multi-channel input as O(KMϕ)\mathcal{O}(KM\phi) for single glimpse model. Since the BAN consists of matrix chain multiplication and exploits the property of low-rank factorization in the low-rank bilinear pooling.

Related works

Multimodal factorized bilinear pooling. Yu et al. extends low-rank bilinear pooling using the rank > 1. They remove a projection matrix P\mathbf{P}, instead, dd in Equation 2 is replaced with much smaller kk while U\mathbf{U} and V\mathbf{V} are three-dimensional tensors. However, this generalization was not effective for BAN, at least in our experimental setting. Please see BAN-1+MFB in Figure 2b where the performance is not significantly improved from that of BAN-1. Furthermore, the peak GPU memory consumption is larger due to its model structure which hinders to use multiple-glimpse BAN.

Experiments

Flickr30k Entities. For the evaluation of visual grounding by the bilinear attention maps, we use Flickr30k Entities consisting of 31,783 images and 244,035 annotations that multiple entities (phrases) in a sentence for an image are mapped to the boxes on the image to indicate the correspondences between them. The task is to localize a corresponding box for each entity. In this way, visual grounding of textual information is quantitatively measured. Following the evaluation metric , if a predicted box has the intersection over union (IoU) of overlapping area with one of the ground-truth boxes which are greater than or equal to 0.5, the prediction for a given entity is correct. This metric is called Recall@1. If K predictions are permitted to find at least one correction, it is called Recall@K. We report Recall@1, 5, and 10 to compare state-of-the-arts (R@K in Table 4). The upper bound of performance depends on the performance of object detection if the detector proposes candidate boxes for the prediction.

2 Preprocessing

Question embedding. For VQA, we get a question embedding XT\mathdsR14×N\mathbf{X}^{T}\in\mathds{R}^{14\times N} using GloVe word embeddings and the outputs of Gated Recurrent Unit (GRU) for every time-steps up to the first 14 tokens following the previous work . The questions shorter than 14 words are end-padded with zero vectors. For Flickr30k Entities, we use a full length of sentences (82 is maximum) to get all entities. We mark the token positions which are at the end of each annotated phrase. Then, we select a subset of the output channels of GRU using these positions, which makes the number of channels is the number of entities in a sentence. The word embeddings and GRU are fine-tuned in training.

3 Nonlinearity and classifier

Nonlinearity. We use ReLU to give nonlinearity to BAN:

4 Hyperparameters and regularization

Regularization. For the test split of VQA, both train and validation splits are used for training. We augment a subset of Visual Genome dataset following the procedure of the previous works . Accordingly, we adjust the model capacity by increasing all of NN, CC, and KK to 1,280. And, G=8G=8 glimpses are used. For Flickr30k Entities, we use the same test split of the previous methods , without additional hyperparameter tuning from VQA experiments.

VQA results and discussions

Comparison with other attention methods. Unitary attention has a similar architecture with Kim et al. where a question embedding vector is used to calculate the attentional weights for multiple image features of an image. Co-attention has the same mechanism of Yu et al. , similar to Lu et al. , Xu and Saenko , where multiple question embeddings are combined as single embedding vector using a self-attention mechanism, then unitary visual attention is applied. Table 2 confirms that bilinear attention is significantly better than any other attention methods. The co-attention is slightly better than simple unitary attention. In Figure 2a, co-attention suffers overfitting more severely (green) than any other methods, while bilinear attention (blue) is more regularized compared with the others. In Figure 2b, BAN is the most parameter-efficient among various attention methods. Notice that four-glimpse BAN more parsimoniously utilizes its parameters than one-glimpse BAN does.

2 Residual learning of attention

Ablation study. An interesting property of residual learning is robustness toward arbitrary ablations . To see the relative contributions, we observe the learning curve of validation scores when incremental ablation is performed. First, we train {1,2,4,8,12}-glimpse models using training split. Then, we evaluate the model on validation split using the first NN attention maps. Hence, the intermediate representation fN\mathbf{f}_{N} is directly fed into the classifier instead of fG\mathbf{f}_{G}. As shown in Figure 2c, the accuracy gain of the first glimpse is the highest, then the gain is smoothly decreased as the number of used glimpses is increased.

Entropy of attention. We analyze the information entropy of attention distributions in a four-glimpse BAN. As shown in Figure 2d, the mean entropy of each attention for validation split is converged to a different level of values. This result is repeatably observed in the other number of glimpse models. Our speculation is the multi-attention maps do not equally contribute similarly to voting by committees, but the residual learning by the multi-step attention. We argue that this is a novel observation where the residual learning is used for stacked attention networks.

3 Qualitative analysis

The visualization for a two-glimpse BAN is shown in Figure 3. The question is “what color are the pants of the guy skateboarding”. The question and content words, what, pants, guy, and skateboarding and skateboarder’s pants in the image are attended. Notice that the box 2 (orange) captured the sitting man’s pants in the bottom.

Flickr30k entities results and discussions

To examine the capability of bilinear attention map to capture vision-language interactions, we conduct experiments on Flickr30k Entities . Our experiments show that BAN outperforms the previous state-of-the-art on the phrase localization task with a large margin of 4.48% at a high speed of inference.

Performance. In Table 4, we compare with other previous approaches. Our bilinear attention map to predict the boxes for the phrase entities in a sentence achieves new state-of-the-art with 69.69% for Recall@1. This result is remarkable considering that BAN does not use any additional features like box size, color, segmentation, or pose-estimation . Note that both Query-Adaptive RCNN and our off-the-shelf object detector are based on Faster RCNN and pre-trained on Visual Genome . Compared to Query-Adaptive RCNN, the parameters of our object detector are fixed and only used to extract 10-100 visual features and the corresponding box proposals.

Type. In Table 6 (included in Appendix), we report the results for each type of Flickr30k Entities. Notice that clothing and body parts are significantly improved to 74.95% and 47.23%, respectively.

Speed. The faster inference is achieved taking advantage of multi-channel inputs in our BAN. Unlike previous methods, BAN ables to infer multiple entities in a sentence which can be prepared as a multi-channel input. Therefore, the number of forwardings to infer is significantly decreased. In our experiment, BAN takes 0.67 ms/entity whereas the setting that single entity as an example takes 0.84 ms/entity, achieving 25.37% improvement. We emphasize that this property is a novel in our model that considers every interaction among vision-language multi-channel inputs.

Visualization. Figure 4 shows three examples from the test split of Flickr30k Entities. The entities which has visual properties, for instance, a yellow tennis suit and white tennis shoes in Figure 4a, and a denim shirt in Figure 4b, are correct. However, relatively small object (e.g., a cigarette in Figure 4b) and the entity that requires semantic inference (e.g., a male conductor in Figure 4c) are incorrect.

Conclusions

BAN gracefully extends unitary attention networks exploiting bilinear attention maps, where the joint representations of multimodal multi-channel inputs are extracted using low-rank bilinear pooling. Although BAN considers every pair of multimodal input channels, the computational cost remains in the same magnitude, since BAN consists of matrix chain multiplication for efficient computation. The proposed residual learning of attention efficiently uses up to eight bilinear attention maps, keeping the size of intermediate features constant. We believe our BAN gives a new opportunity to learn the richer joint representation for multimodal multi-channel inputs, which appear in many real-world problems.

We would like to thank Kyoung-Woon On, Bohyung Han, Hyeonwoo Noh, Sungeun Hong, Jaesun Park, and Yongseok Choi for helpful comments and discussion. Jin-Hwa Kim was supported by 2017 Google Ph.D. Fellowship in Machine Learning and Ph.D. Completion Scholarship from College of Humanities, Seoul National University. This work was funded by the Korea government (IITP-2017-0-01772-VTT, IITP-R0126-16-1072-SW.StarLab, 2018-0-00622-RMI, KEIT-10060086-RISF). The part of computing resources used in this study was generously shared by Standigm Inc.

References

A Variants of BAN

We augment a computed 300-dimensional word embedding to each 300-dimensional Glove word embedding. The computation is as follows: 1) we choose arbitrary two words wiw_{i} and wjw_{j} from each question that can be found in VQA and Visual Genome datasets or each caption in MS COCO dataset. 2) we increase the value of Ai,j\mathbf{A}_{i,j} by one where A\mathdsRV×V\mathbf{A}\in\mathds{R}^{V^{\prime}\times V^{\prime}} is an association matrix initialized with zeros. Notice that ii and jj can be the index out of vocabulary VV and the size of vocabulary in this computation is denoted by VV^{\prime}. 3) to penalize highly frequent words, each row of A\mathbf{A} is divided by the number of sentences (question or caption) which contain the corresponding word . 4) each row is normalized by the sum of all elements of each row. 5) we calculate W=AW\mathbf{W}^{\prime}=\mathbf{A}\cdot\mathbf{W} where W\mathdsRV×E\mathbf{W}\in\mathds{R}^{V^{\prime}\times E} is a Glove word embedding matrix and EE is the size of word embedding, i.e., 300. Therefore, W\mathdsRV×E\mathbf{W}^{\prime}\in\mathds{R}^{V^{\prime}\times E} stands for the mixed word embeddings of semantically closed words. 6) finally, we select VV rows from W\mathbf{W}^{\prime} corresponding to the vocabulary in our model and augment these rows to the previous word embeddings, which makes 600-dimensional word embeddings in total. The input size of GRU is increased to 600 to match with these word embeddings. These word embeddings are fine-tuned.

As a result, this variant significantly improves the performance to 66.03 (±\pm0.12) compared with the performance of 65.72 (±\pm 0.11) which is done by augmenting the same 300-dimensional Glove word embeddings (so the number of parameters is controlled). In this experiment, we use four-glimpse BAN and evaluate on validation split. The standard deviation is calculated by three random initialized models and the means are reported. The result on test-dev split can be found in Table 3 as BAN+Glove.

A.2 Integrating counting module

The counting module is proposed to improve the performance related to counting tasks. This module is a neural network component to get a dense representation from spatial information of detected objects, i.e., the left-top and right-bottom positions of the ϕ\phi proposed objects (rectangles) denoted by S\mathdsR4×ϕ\mathbf{S}\in\mathds{R}^{4\times\phi}. The interface of the counting module is defined as:

The BAN integrated with the counting module is defined as:

As a result, this variant significantly improves the counting performance from 54.92 (±\pm0.30) to 58.21 (±\pm0.49), while overall performance is improved from 65.81 (±\pm0.09) to 66.01 (±\pm0.14) in a controlled experiment using a vanilla four-glimpse BAN. The definition of a subset of counting questions comes from the previous work . The result on test-dev split can be found in Table 3 as BAN+Glove+Counter, notice that, which is applied by the previous embedding variant, too.

A.3 Integrating multimodal factorized bilinear (MFB) pooling

However, this generalization was not effective for BAN. In Figure 2b, the performance of BAN-1+MFB is not significantly different from that of BAN-1. Furthermore, the larger KK^{\prime} increases the peak consumption of GPU memory which hinders to use multiple-glimpses for the BAN.