Learning to Compose Dynamic Tree Structures for Visual Contexts
Kaihua Tang, Hanwang Zhang, Baoyuan Wu, Wenhan Luo, Wei Liu
Introduction
Objects are not alone. They are placed in the visual context: a coherent object configuration attributed to the fact that they co-vary with each other. Extensive studies in cognitive science show that our brains inherently exploit visual contexts to understand cluttered visual scenes comprehensively . For example, even the girl’s leg and the horse are not fully observed in Figure 1, we can still infer “girl riding horse”. Inspired by this, modeling visual contexts is also indispensable in many modern computer vision systems. For example, state-of-the-art CNN architectures capture the context by convolutions of various receptive fields and encode it into multi-scale feature map pyramid . Such pixel-level visual context (or local context ) arguably plays one of the key roles in closing the performance gap of the “mid-level” vision between humans and machines, such as R-CNN based object detection , instance segmentation , and FCN based semantic segmentation .
Modeling visual contexts explicitly on the object-level has also been shown effective in “high-level” vision tasks such as image captioning and visual Q&A . In fact, the visual context serves as a powerful inductive bias that connects objects in a particular layout for high-level reasoning . For example, the spatial layout of “person” on “horse” is useful for determining the relationship “ride”, which is in turn informative to localize the “person” if we want to answer “who is riding on the horse?”. However, those works assume that the context is a scene graph, whose detection per se is a high-level task and not yet reliable. Without high-quality scene graphs, we have to use a prior layout structure. As shown in Figure 1, two popular structures are chains and fully-connected graphs , where the context is encoded by sequential models such as bidirectional LSTM for chains and CRF-RNN for graphs.
However, these two prior structures are sub-optimal. First, chains are oversimplified and may only capture simple spatial information or co-occurrence bias; though fully-connected graphs are complete, they lack the discrimination between hierarchical relations, e.g., “helmet affiliated to head”, and parallel relations, e.g., “girl on horse”; in addition, dense connections could also lead to message passing saturation in the subsequent context encoding . Second, visual contexts are inherently content-/task-driven, e.g., the object layouts should vary from content to content, question to question. Therefore, fixed chains and graphs are incompatible with the dynamic nature of visual contexts .
In this paper, we propose a model dubbed VCTree, pioneering to compose dynamic tree structures for encoding object-level visual context for high-level visual reasoning tasks, such as scene graph generation (SGG) and visual Q&A (VQA). Given a set of object proposals in an image (e.g., obtained from Faster-RCNN ), we maintain a trainable task-specific score matrix of the objects, where each entry indicates the contextual validity of the pairwise objects. Then, a maximum spanning tree can be trimmed from the score matrix, e.g., the multi-branch trees shown in Figure 1. This dynamic structure represents a “hard” hierarchical layout bias of what objects should gain more contextual information from others, e.g., objects on the person’s head are most informative given the question “what on the little girl’s head?”; while the whole person’s body is more important given the question “Is the girl sitting on the horse correctly?”. To avoid the saturation issue caused by the densely connected arbitrary number of children, we further morph the multi-branch trees to the equivalent left-child right-sibling binary trees , where the left branches (red) indicate the hierarchical relations and right branches (blue) indicate the parallel relations, then use TreeLSTM to encode the context.
As the above VCTree construction is in a discrete and non-differentiable nature, we develop a hybrid learning strategy using REINFORCE for tree structure exploration and supervised learning for context encoding and its subsequent tasks. In particular, the evaluation result (Recall for SGG and Accuracy for VQA) from supervised task can be exploited as a “critic” function that guide the “action” of tree construction. We evaluate VCTree on two benchmarks: Visual Genome for SGG and VQA2.0 for VQA. For SGG, we achieve a new state-of-the-art on all three standard tasks, i.e., Scene Graph Generation, Scene Graph Classification, and Predicate Classification; for VQA, we achieve competitive results on single model performances. In particular, VCTree helps high-level vision models fight against the dataset bias. For example, we achieve 4.1% absolute gain in proposed Mean Recall@100 metric of Predicate Classification than MOTIFS , and observe higher improvement in VQA2.0 balanced pair subset than normal validation set. Qualitative results also show that VCTree composes interpretable structures.
Related Work
Visual Context Structures. Despite the consensus on the value of visual contexts, existing context models are diversified into a variety of implicit or explicit approaches. Implicit models directly encode surrounding pixels into multi-scale feature maps, e.g., dilated convolution presents a efficient way to increase receptive field, applicable in various dense prediction tasks ; feature pyramid structure combines low-resolution contextual features with high-resolution detailed features, facilitating object detection with rich semantics. Explicit models incorporate contextual cues through object connections. However, such methods group objects into fixed layouts, i.e., chains or graphs.
Learning to Compose Structures. Learning to compose structures is becoming popular in NLP for sentence representation, e.g., Cho et al. applied a gated recursive convolutional neural network (grConv) to control the bottom-up feature flow for a dynamic structure; Choi et al. combines TreeLSTM with Gumbel-Softmax, allowing task-specific tree structures automatically learned from plain text. Yet, only few works compose visual structures for images. Conventional approaches construct a statistical dependency graph/tree for the entire dataset based on object categories or exemplars . Those statistical methods cannot put per-image objects in a context as a whole to reason over content-/task-specific fashion. Socher et al. constructed a bottom-up tree structure to parse images; however, their tree structure learning is supervised while ours is reinforced, which does not require tree ground-truth.
Visual Reasoning Tasks. Scene Graph Generation (SGG) task is derived from Visual Relationship Detection (VRD). Early work on VRD treats objects as isolated individuals, while SGG considers each image as a whole. Along with the widely used message passing mechanism , a variety of context models have been exploited in SGG to fine-tune local predictions through rich global contexts, making it the best competition field for different contextual models. Visual Question Answering (VQA) as a high-level task bridges the gap between computer vision and natural language processing. State-of-the-art VQA models rely on bag-of-object visual attentions which can be considered as a trivial context structure. However, we propose to learn a tree context structure that is dynamic to visual content and questions.
Approach
VCTree construction aims to learn a score matrix , which approximates the task-dependent validity between each object pair. Two principles guide the formulation of this matrix: 1) inherent object correlations should be maintained, e.g., “man wears helmet” in Figure 2; (2) task related object pair has higher score than irrelevant ones, e.g., given question “what is on the man’s head?”, “man-helmet” pair should be more important than “man-motorcycle” and “helmet-motorcycle” pairs. Therefore, we define each element of as the product of the object correlation and the pairwise task-dependency :
where is the sigmoid function; is the task feature, e.g., the question feature encoded by GRU in VQA; MLP is a multi-layer perceptron; is the object-task correlation in VQA, which will be introduced later in Section 3.4. In SGG, the entire is set to , as we assume that each object pair contributes equally without the question prior. We pretrain on Visual Genome for a reasonable binary prior if two objects are related. Yet, such a pretrained model is not perfect due to the lack of coherent graph-level constraint or question prior, so it will be further fine-tuned in Section 3.5.
Considering as a symmetric adjacency matrix, we can obtain a maximum spanning tree using the Prim’s algorithm , with a root (source node) satisfying . In a nutshell, as illustrated in Figure 3, we construct the tree recursively by connecting the node from the pool to the tree node if it has the most validity. Note that during the tree structure exploration in Section 3.5, each of the -th step in the above tree construction is sampled from all possible choices in a multinomial distribution with the probability in proportion to the validity. The resultant tree is multi-branch and is merely a sparse graph with only one kind of connection, which is still unable to discriminate the hierarchical and parallel relations in the subsequent context encoding. To this end, we convert the multi-branch tree into an equivalent binary tree, i.e., VCTree by changing non-leftmost edges into right branches as in Figure 1. In this fashion, the right branches (blue) indicate parallel contexts, and left ones (red) indicate hierarchical contexts. Such a binary tree structure achieves significant improvements in our SGG and VQA experiments compared to its multi-branch alternative.
2 TreeLSTM Context Encoding
Given the above constructed VCTree, we adopt BiTreeLSTM as our context encoder:
where is the input node feature, which will be specified in each task, and is the encoded object-level visual context. Each is the concatenated hidden states from both TreeLSTM directions:
3 Scene Graph Generation Model
Now we detail the implementation of Eq. (2) and how to decode them for the SGG task as illustrated in Figure 4.
4 Visual Question Answering Model
Now we detail the implementation of Eq. (2) for VQA, and illustrate our VQA model in Figure 5.
Finally, we fuse as the final VQA feature and feed it into the softmax classifier.
5 Hybrid Learning
Due to the discrete nature of VCTree construction, the score matrix is not fully differentiable from the loss back-propagated from the end-task loss. Inspired by , we use a hybrid learning strategy that combines reinforcement learning, i.e., policy gradient for the parameters of in the tree construction and supervised learning for the rest parameters. Suppose a layout , i.e., a constructed VCTree, is sampled from , i.e., the construction procedure in Section 3.1, where is the given image, is the task, e.g., questions in VQA. To avoid clutter, we drop and . Then, we define the reinforcement learning loss as:
where aims to minimize the negative expected reward , which can be the end-task evaluation metrics such as Recall@100 for SGG and Accuracy for VQA. Then, the above gradient will be . Since it is impractical to estimate all possible layouts, we use the Monte-Carlo sampling to estimate the gradient:
where we set M to 1 in our implementation.
To reduce the gradient variance, we apply a self-critic baseline , where is the greedy constructed tree without sampling. So the original reward can be replaced by in Eq. (8). We observe faster convergence than using a traditional moving baseline .
The overall hybrid learning will be alternatively conducted between supervised learning and reinforcement learning, where we first train the supervised end-task on pretrained , then fix the end-task as reward function to learn our reinforcement policy network, after that, we update the supervised end-task by new . The latter two stages are running alternatively 2 times in our model.
Experiments on Scene Graph Generation
Dataset. Visual Genome (VG) is a popular benchmark for SGG. It contains 108,077 images with tens of thousands of unique object and predicate relation categories, yet most of categories have very limited instances. Therefore, previous works proposed various VG splits that remove rare categories. We adopted the most popular one from , which selects top-150 object categories and top-50 predicate categories by frequency. The entire dataset is divided into the training set and test set by 70%, 30%, respectively. We further picked 5,000 images from training set as the validation set for hyper-parameter tuning.
Protocols. We followed three conventional protocols to evaluate our SGG model: (1) Scene Graph Generation (SGGen): given an image, detect object bounding boxes and their categories, and predict their relationships; (2) Scene Graph Classification (SGCls): given ground-truth object bounding boxes in an image, predict the object categories and their relationships; (3) Predicate Classification (PredCls): given the object categories and their bounding boxes in the image, predict their relationships.
Metrics. Since the annotation in VG is incomplete and biased, we followed the conventional Recall@K (R@K = 20,50,100) as the evaluation metrics . However, it is well-known that SGG models trained on biased datasets such as VG have low performances for less frequent categories. To this end, we introduced a balanced metric called: Mean Recall (mR@K). It calculates the recall on each predicate category independently, and then averages the results. So, each category contributes equally. Such a metric reduces the influence of some common yet meaningless predicates, e.g., “on”, “of”, and gives equal attention to those infrequent predicates, e.g., “riding”, “carrying”, which are more valuable to high-level reasoning.
2 Implementation Details
We adopted Faster-RCNN with VGG backbone to detect object bounding boxes and extract RoI features. Since the performance of SGG highly depends on the underlying detector, we used the same set of parameters as for fair comparison. Object correlations in Eq. (1) will be pretrained on ground-truth bounding boxes with class-agnostic relationships (i.e., foreground/background relationships), using all possible symmetric pairs without sampling. In SGGen, top-64 object proposals were selected after non-maximal suppression (NMS) with 0.3 IoU. We set background/foreground ratio for predicate classification to 3, and capped the number of training samples at 64 (retained all foreground pairs if possible). Our model is optimized by SGD with momentum, using learning rate and batch size for supervised learning, and for reinforcement learning.
3 Ablation Studies
We investigated the influence of different structure construction policies. They are reported on the bottom half of Table 1. The ablative methods are (1) Chain: sorting all the objects by , then constructing a chain, which is different from the left-to-right ordered chain in MOTIFS ; (2) Overlap: iteratively constructing a binary tree by selecting the node with largest number of overlapped objects as parent, and dividing the rest nodes into left/right sub-trees by relatively positions of their bounding boxes; (3) Multi-Branch: the maximum spanning tree generated from score matrix , using Child-Sum TreeLSTM to incorporate context; (4) VCTree-SL: the proposed VCTree trained by supervised learning; (5) VCTree-HL: the complete version of VCTree, trained by hybrid learning for structure exploration in Section 3.5. As we will show that Multi-Branch is significantly worse than VCTree, so there is no need to conduct hybrid learning experiment on Multi-Branch. We observe that VCTree performs better than other structures, and it is further improved by hybrid learning for structure exploration.
4 Comparisons with State-of-the-Arts
Comparing Methods. We compared VCTree with state-of-the-art methods in Table 1: (1) VRD , FREQ are methods without using visual contexts. (2) AssocEmbed assembles implicit contextual features by stacked hourglass backbone . (3) IMP , TFR , MOTIFS , Graph-RCNN are explicit context models with a variety of structures.
Quantitative Analysis. From Table 1, compared with the previous state-of-the-art MOTIFS , the proposed VCTree has the best performances. Interestingly, Overlap tree and Multi-Branch tree are better than other non-tree context models. From Table 2, the proposed VCTree-HL shows larger absolute gains of PredCls under mR@100, which indicates that our model learns non-trivial visual context, i.e., not merely class distribution bias as in FREQ and partially in MOTIFS. Note that MOTIFS is even worse than its FREQ baseline under mR@100.
Qualitative Analysis. To better understand what context is learned by VCTree, we visualized a statistics of left-/right-branch nodes for nodes classified as “street” in Figure 6. From the left pie, the hierarchical relations, we can see the node categories are long-tailed, i.e., top-10 categories cover the 73% of the instances; while the right pie, the parallel relations, are more uniformly distributed. This demonstrates that VCTree captures the two types of context successfully. More qualitative examples of VCTrees and their generated scene graph can be viewed in Figure 7. The common errors are generally synonymous labels, e.g., “jeans” vs. “pants”, “man” vs. “person”, and over-interpretation, e.g., the “tail” of bottom left “dog” is considered as “leg”, as it appears at the place where “leg” should be.
Experiments on Visual Q&A
Datasets. We evaluated the proposed VQA model on VQA2.0 . Compared with VQA1.0 , VQA2.0 has more question-image pairs for training (443,757) and validation (214,354), and all the question-answer pairs are balanced by making sure the same question can have different answers. In VQA2.0, the ground-truth accuracy of a candidate answer is considered as the average of over all 10 select 9 sets. Question-answer pairs are organized in three answer types: i.e. “Yes/No”, “Number”, “Other”. There are also 65 question types determined by prefixed words, which we used to generate question-guided gates. We also tested our models on a balanced subset of validation set, called Balanced Pairs , which requires the same question on different images with two different yet perfect (with 1.0 ground-truth score) answers. Since Balanced Pairs strictly removes question-related bias, it reflects the ability of a context model to distinguish subtle differences between images.
2 Implementation Details
We employed a simple text preprocessing for questions and answers, which changes all characters into lower-case and removes special characters. Questions were encoded into a vocabulary of the size 13,758 without trimming. Answers used a 3,000 vocabulary selected by frequency. For fair comparison, we used the same bottom-up feature as previous methods , which contains 10 to 100 object proposals per image extracted by Faster-RCNN . We used the same Faster-RCNN detector to pretrain the . Since candidate answers were represented by probabilities rather than one-hot vectors in VQA2.0, we allowed the cross-entropy loss calculating soft categories, i.e., probabilities of ground-truth candidate answers. We used Adam optimizer with learning rate and batch size , decayed at ratio of 0.5 every 20 epochs.
3 Ablation Studies
In addition to the 5 structure construction policies introduced in Section 4.3, we also implemented a fully-connected graph structure using the message passing mechanism . From Table 3, the proposed VCTree-HL outperforms all the context models on three answer types.
We further evaluated the above context models on VQA2.0 balanced pair subset : the last column of Table 3, and found that the absolute gains between VCTree-HL and other structures are even larger than those on the original validation set. Meanwhile, as reported in , different architectures or hyper-parameters in non-contextual VQA model normally gain less improvements on the balanced pair subset than overall validation set. Thus, it suggests that VCTree indeed use better context structures to alleviate the question-answer bias in VQA.
4 Comparisons with State-of-the-Arts
Comparing Methods. Table 4 & 5 reports the single-model performances of various state-of-the-art methods on both test-dev and test-standard sets. For fair comparison, the reported methods are all using the same Faster-RCNN features as ours.
Quantitative Analysis. The proposed VCTree-HL shows the best overall performance in both test-dev and test-standard. Note that though Count has close overall performance to our VCTree, it mainly improves the “Number” task by the elaborately designed model, while the proposed VCTree is a more general solution.
Qualitative Analysis. We visualized several examples of VCTree-HL on the validation set. They illustrate that the proposed VCTree is able to learn dynamic structures with interpretability, e.g., in Figure 7, given the right middle image with the question “Is there any snow on the trees?”, the generated VCTree locates the “tree” then searching for the “snow”, while with question “What sport is the man doing?”, the “man” appears to be the root.
Conclusions
In this paper, we proposed a dynamic tree structure called VCTree to capture task-specific visual contexts, which can be encoded to support two high-level vision tasks: SGG and VQA. By exploiting VCTree, we observed consistent performance gains in SGG on Visual Genome and in VQA on VQA2.0, compared to models with or without visual contexts. Besides, to justify that VCTree learns non-trivial contexts, we conducted additional experiments against the category bias in SGG and the question-answer bias in VQA, respectively. In the future, we intend to study the potential of a dynamic forest as the underlying context structure.
References
Appendix A Bidirectional TreeLSTM
In this section, we will introduce the details of the bidirectional TreeLSTM applied to encode the object-level visual contexts. For the bottom-up direction, we employ -ary TreeLSTM for binary trees, i.e., VCTrees and Overlap Trees, and the normalized Child-Sum TreeLSTM for Multi-Branch Trees. For the top-down direction, since each node only has one parent, TreeLSTM is similar to the traditional LSTM .
According to the definition of -ary TreeLSTM , it can be applied to the tree structures with at most ordered branches for each node. In our work, we adopt binary TreeLSTM as our bottom-up TreeLSTM for the proposed binary tree structures, i.e., VCTrees and Overlap Trees. It can be formulated as follows:
A.2 Child-Sum TreeLSTM for Multi-Branch Trees
The Child-Sum TreeLSTM is able to deal with the tree structure where each node has arbitrary number of children. Therefore, we adopt it as the bottom-up TreeLSTM of the context encoder for the Multi-Branch Trees in the ablation studies. For each node of a Multi-Branch Tree, we define as the set of its children. Compared with the original paper , we replace the Child-Sum with the Child-Mean in our implementation for better normalization, then it is formulated as:
A.3 Top-Down TreeLSTM
We use the traditional LSTM as the top-down TreeLSTM for all the VCTrees, Overlap Trees, and Multi-Branch Trees, because each node only has at most one parent. The only difference with the traditional LSTM is that our structures are trees rather than chains, the previous hidden state is from the parent of node .
For the proposed VCTree, we assigned different learnable matrices for the hidden states from the left-branch parents and right-branch parents. However, the result didn’t show significant improvements in the end-tasks, so we employ traditional LSTM as our top-down LSTM for efficiency.
Appendix B Quantitative Analysis
We also report more detailed results of the proposed Mean Recall (mR@K) in Table 6. The proposed VCTree-HL shows best performance among all the ablative structures. Note that MOTIFS has lower mR@100 than FREQ baseline in SGCls and PredCls, which means that MOTIFS is even worse at predicting infrequent predicate categories. However, its mR@20 and mR@50 are higher than FREQ in SGCls and PredCls, which indicates that MOTIFS better separates the foreground relationships from the background ones than FREQ.
B.2 Predicate Recall Analysis
To better visualize the improvement of the proposed VCTree-HL on infrequent predicate categories, we rank all the predicate categories by frequency, and show the PredCls Recall@100 of MOTIFS and VCTree-HL for each top-35 category independently in Figure 8. We can observe significant improvements on those less frequent but more semantically meaningful predicates.
Appendix C Qualitative Analysis
We further investigated more misclassified results of the proposed VCTree-HL. The corresponding tree structures and the generated scene graphs are reported in Figure 9. We observed 3 types of interesting misclassifications: 1) In the image (a) of Figure 9, the proposed VCTree-HL predicts more appropriate predicates “in front of” and “behind” than original “near”. 2) In the image (b) and (d), the ground truth “man in snow” and “window near building” are improper, while our method shows more appropriate predicates. 3) In the image (c) and (d), the objects isolated from the Scene Graph (only considering R@20 predicates) are easier to be misclassified.
C.2 Visual Question Answering
More constructed VCTrees for VQA2.0 are visualized in Figure 10. The dynamic tree structures are subject to different questions, which allow the objects in an image to incorporate the different contextual cues according to each question. The proposed VCTree also helps us understand how the model predicts the answer of the question given the image, e.g., in image (a) of Figure 10, given the question “does this dog have a collar?”, we find that our model first focuses on the collar-like object rather than the dog; in image (b) of Figure 10, given the question “what sport is being played?”, we find that our model focuses on the sportsman rather than playground to answer this question.