Multimodal Token Fusion for Vision Transformers
Yikai Wang, Xinghao Chen, Lele Cao, Wenbing Huang, Fuchun Sun, Yunhe Wang
Introduction
Transformer is initially widely studied in the natural language community as a non-recurrent sequence model and it is soon extended to benefit vision-language tasks. Recently, numerous studies have further adopted transformers for computer vision tasks with well-adapted architectures and optimization schedules. As a result, vision transformer variants have shown great potential in many single-modal vision tasks, such as classification , segmentation , detection , image generation .
Yet up until the date of this work, the attempt of extending vision transformers to handle multimodal data remains scarce. When multimodal data with complicated alignment relations are introduced, it poses great challenges in designing the fusion scheme for model architectures. The key question to answer is how and where the interaction of features from different modalities should take place. There have been a few methods for transformer-based vision-language fusion, e.g., VL-BERT and ViLT . In these methods, vision and language tokens are directly concatenated before each transformer layer, making the overall architecture very similar to the original transformer. Such fusion is usually alignment-agnostic, which indicates the inter-modal alignments are not explicitly utilized. We also try to apply similar fusion methods on multimodal vision tasks (Sec. 4). Unfortunately, this intuitive transformer fusion cannot bring promising gains or may even result in worse performance than the single-modal counterpart, which is mainly due to the fact that the inter-modal interaction is not fully exploited. There are also several attempts for fusing multiple vision modalities. For example, TransFuser leverages transformer modules to connect CNN backbones of images and LiDAR points. Different from exising trials, our work aims to seek an effective and general method to combine multiple single-modal transformers while inserting inter-modal alignments into the models.
This work benefits the learning process by multimodal data while leveraging inter-modal alignments. Such alignments are naturally available in many vision tasks, e.g., with camera intrinsics/extrinsics, world-space points could be projected and correspond to pixels on the camera plane. Unlike the alignment-agnostic fusion (Sec. 3.1), the alignment-aware fusion explicitly involves the alignment relations of different modalities. Yet, since inter-modal projections are introduced to the transformer, alignment-aware fusion may greatly alter the original model structure and data flow, which potentially undermines the success of single-modal architecture designs or learned attention during pretraining. Thus, one may have to determine the “correct” layers/tokens/channels for multimodal projection and fusion, and also re-design the architecture or re-tune optimization settings for the new model. To avoid dealing with these challenging matters and inherit the majority of the original single-modal design, we propose multimodal token fusion, termed TokenFusion, which adaptively and effectively fuses multiple single-modal transformers.
The basic idea of our TokenFusion is to prune multiple single-modal transformers and then re-utilize pruned units for multimodal fusion. We apply individual pruning to each single-modal transformer and each pruned unit is substituted by projected alignment features from other modalities. This fusion scheme is assumed to have a limited impact on the original single-modal transformers, as it maintains the relative attention relations of the important units. TokenFusion also turns out to be superior in allowing multimodal transformers to inherit the parameters from single-modal pretraining, e.g., on ImageNet.
To demonstrate the advantage of the proposed method, we consider extensive tasks including multimodal image translation, RGB-depth semantic segmentation, and 3D object detection based on images and point clouds, covering up to four public datasets and seven different modalities. TokenFusion obtains state-of-the-art performance on these extensive tasks, demonstrating its great effectiveness and generality. Specifically, TokenFusion achieves 64.9% and 70.8% mAP@0.25 for 3D object detection on the challenging SUN RGB-D and ScanNetV2 benchmarks, respectively.
Related Work
Transformers in computer vision. Transformer is originally designed for NLP research fields , which stacking multi-head self-attention and feed-forward MLP layers to capture the long-term correlation between words. Recently, vision transformer (ViT) reveals the great potential of transformer-based models in large-scale image classification. As a result, transformer has soon achieved profound impacts in many other computer vision tasks such as segmentation , detection , image generation , video processing , etc.
Fusion for vision transformers. Deep fusion with multimodal data has been an essential topic which potentially boosts the performance by leveraging multiple sources of inputs, and it may also unleash the power of transformers further. Yet it is challenging to combine multiple off-the-rack single transformers while guaranteeing that such combination will not impact their elaborate singe-modal designs. and process consecutive video frames with transformers for spatial-temporal alignments and capturing fine-grained patterns by correlating multiple frames. Regarding multimodal data, utilize the dynamic property of transformer modules to combine CNN backbones for fusing infrared/visible images or LiDAR points. extends the coarse-to-fine experience from CNN fusion methods to transformers for image processing tasks. adopts transformers to combine hyperspectral images by the simple feature concatenation. inserts intermediate tokens between image patches and audio spectrogram patches as bottlenecks to implicitly learn inter-modal alignments. These works, however, differ from ours since we would like to build a general fusion pipeline for combing off-the-rack vision transformers without the need of re-designing their structures or re-tuning their optimization settings, while explicitly leveraging inter-modal alignment relations.
Methodology
This part intends to provide a full landscape of the proposed methodology. We first introduce two naïve multimodal fusion methods for vision transformers in Sec. 3.1. Given the limitations of both intuitive methods, we then propose multimodal token fusion in Sec. 3.2. We elaborate the fusion designs for both homogeneous and heterogeneous modalities to evaluate the effectiveness and generality of our method in Sec. 3.4 and Sec. 3.5, respectively.
We use different transformers for input modalities and denote as the final prediction of the -th transformer. Given the token feature of the -th modality, the -th layer computes
where MSA, MLP, and LN denote the multi-head self-attention, multi-layer perception, and layer normalization, receptively. represents the output of MSA.
During multimodal fusion for vision tasks, the alignment relations of different modalities may be explicitly available. For example, pixel positions are often used to determine the image-depth correlation; and camera intrinsics/extrinsics are important in projecting 3D points to images. Based on the involvement of alignment information, we consider two kinds of transformer fusion methods as below.
Alignment-agnostic fusion does not explicitly use the alignment relations among modalities. It expects the alignment may be implicitly learned from large amount of data. A common method of the alignment-agnostic fusion is to directly concatenate multimodal input tokens, which is widely applied in vision-language models. Similarly, the input feature for the -th layer is also the token-wise concatenation of different modalities. Although the alignment-agnostic fusion is simple and may have minimal modification to the original transformer model, it is hard to directly benefit from the known multimodal alignment relations.
Alignment-aware fusion explicitly utilizes inter-modal alignments. For instance, this can be achieved by selecting tokens that correspond to the same pixel or 3D coordinate. Suppose is the -th token of the -th modality input , where . We define the “token projection” from the -th modality to the -th modality as
where could simply be an identity function (for homogeneous modalities) or a shallow multi-layer perception (for heterogeneous modalities). And when considering the entire tokens, we can conveniently define the “modality projection” as the concatenation of token projections:
Eq. 3 only depicts the fusion strategy on the input side. We can also perform middle-layer or multi-layer fusion across different modality-specific models, by projecting and aggregating feature embeddings which possibly enables more diversified and accurate feature interactions. However, with the growing complexity of transformer-based models, searching for optimal fusion strategies (e.g. layers and tokens to apply projection and aggregation) for merely two modalities (e.g. 2D and 3D detection transformers) can grow into an extremely hard problem to solve. To tackle this issue, we propose multimodal token fusion in Sec. 3.2.
2 Multimodal Token Fusion
As described in Sec. 1, multimodal token fusion (TokenFusion) first prunes single-modal transformers and further re-utilizes the pruned units for fusion. In this way, the informative units of original single-modal transformers are assumed to be preserved to a large extent, while multimodal interactions could be involved for boosting performance.
As previously shown in , tokens of vision transformers could be pruned in a hierarchical manner while maintaining the performance. Similarly, we can select less informative tokens by adopting a scoring function , which dynamically predicts the importance of tokens for the -th layer and the -th modality. To enable the back propagation on , we re-formulate the MSA output in Eq. 1 as
We use to denote the task-specific loss for the -th modality. To prune uninformative tokens, we further add a token-wise pruning loss (an -norm) on . Thus the overall loss function for optimization is derived as
where is a hyper-parameter for balancing different losses.
In Eq. 6, if there are only two modalities as input, will simply be the other modality other than . With more than two modalities, we pre-allocate the tokens into parts, each of which is bound with one of the other modalities than itself. More details of this pre-allocation will be described in Sec. 3.4.
3 Residual Positional Alignment
Directly substituting tokens will risk completely undermining their original positional information. Hence, the model can still be ignorant of the alignment of the projected features from another modality. To mitigate this problem, we adopt Residual Positional Alignment (RPA) that leverages Positional Embeddings (PEs) for the multimodal alignment. As depicted in Fig. 1 and Fig. 2 which will be detailed later, the key idea of RPA lies in injecting equivalent PEs to subsequent layers. Moreover, the back propagation of PEs stops after the first layer, which means only the gradients of PEs at the first layer are retained while for the rest of the layers are frozen throughout the training. In this way, PEs serve a purpose of aligning multimodal tokens despite the substitution status of the original token. In summary, even if a token is substituted, we still reserve its original PEs that are added to the projected feature from another modality.
4 Homogeneous Modalities
In the common setup of either a generation task (multimodal image-to-image translation) or a regression task (RGB-depth semantic segmentation), the homogeneous vision modalities are typically aligned with pixels, such that the pixels located at the same position in RGB or depth input should share the same label. We also expect that such property allows the transformer-based models to benefit from joint learning. Hence, we adopt shared parameters in both MSA and MLP layers for different modalities; yet rely on modality-specific layer normalizations to uncouple the normalization process, since different modalities may vary drastically in their statistical means and variances by nature. In this scenario, we simply set function in Eq. 6 as an identity function, and we also let , which means we always substitute each pruned token with the token sharing the same position.
An overall illustration of TokenFusion for fusing homogeneous modalities is depicted in Fig. 1. Regarding two input modalities, we adopt bi-directional projection and apply token-wise pruning on both modalities respectively. Then the token substitution process is performed according to Eq. 6. When there are modalities, we also apply the token-wise pruning on all modalities with an additional pre-allocation strategy that selects in based on according to Eq. 6. To be specific, for the -th modality, we randomly pre-allocate tokens into groups with equal group sizes. This pre-allocation is carried out prior to the commence of training procedure, and the obtained groups will be fixed throughout the training. We denote the group allocation as , where indicates that if the -th token of the -th modaltity is pruned, it will be substituted by the corresponding token of the -th modality, otherwise . Having obtained the pre-allocation strategy for modalties, Eq. 6 can be further developed into a more specific form:
5 Heterogeneous Modalities
In this section, we further explore how TokenFusion handles heterogeneous modalities, in which input modalities exhibit quite different data formats and large structural discrepancies, e.g., different number of layers or embedding dimensions for the transformer architectures. A concrete example would be to learn 3D object detection (based on point cloud) and 2D object detection (based on images) simultaneously with different transformers. Although there are already specific transformer-based models designed for 3D or 2D object detection respectively, there still lacks a fast and effective method to combine these models and tasks.
An overall structure of TokenFusion for fusing heterogeneous modalities is depicted in Fig. 2. Different from the homogeneous case, we approximate the token projection function in Eq. 2 with a shallow multi-layer perception (MLP), since transformers for these heterogeneous modalities may have different hidden embedding dimensions. For the case of 3D object detection with 3D point cloud and 2D image, we project each point to the corresponding image based on camera intrinsics and extrinsics. Likewise, we also project 3D object labels to the images for obtaining the corresponding 2D object labels. We train two standalone transformers with unshared parameters in an end-to-end manner. Regarding the 3D object detection with point cloud as input, we follow the architecture used in Group-Free , where sampled seed points and learned proposal points are considered as input tokens, which are sent to the transformer for predicting 3D bounding boxes and object categories. For the 2D object detection with images as input, we follow the framework in YOLOS which sends image patches and object queries to the transformer to predict 2D bounding boxes together with their associated object categories.
Experiments
To evaluate the effectiveness of the proposed TokenFusion, we conduct comprehensive experiments towards both homogeneous and heterogeneous modalities with state-of-the-art (SOTA) methods. Experiments are conducted on totally seven different modalities and four application scenarios, implemented with PyTorch and MindSpore .
The task of multimodal image-to-image translation aims at generating a target image modality based on different image modalities as input (e.g. Normal+DepthRGB). We evaluate TokenFusion in this task using the Taskonomy dataset, which is a large-scale indoor scene dataset containing about 4 million indoor images captured from 600 buildings. Taskonomy provides over 10 multimodal representations in addition to each RGB image, such as depth (euclidean or z-buffering), normal, shade, texture, edge, principal curvature, etc. The resolution of each representation is . To facilitate comparison with the existing fusion methods, we adopt the same sampling strategy as , resulting in 1,000 high-quality multimodal images for training, and 500 for validation.
Our implementation contains two transformers as the generator and discriminator respectively. We provide configuration details in our supplementary materials. The resolution of the generator/discriminator input or the generator prediction is . We adopt two kinds of architecture settings, the tiny (Ti) version with layers and the small (S) version with layers, and both settings are only different in layer numbers. The learning rates of both transformers are set to . We adopt overlapped patches in both transformers inspired by .
In our experiments for this task, we adopt shared transformers for all input modalities with individual layer normalizations (LNs) that individually compute the means and variances of different modalities. Specifically, parameters in the linear projection on patches, all linear projections (e.g. for key, queries, etc) in MSA, and MLP are shared for different modalities. Such a mechanism largely reduces the total model size which as discussed in the supplementary materials, even achieves better performance than using individual transformers. In addition, we also adopt shared positional embeddings for different modalities. We let the sparsity weight in Eq. 10 and the threshold in Sec. 3.4 for all these experiments.
Our evaluation metrics include FID/KID for RGB predictions and MAE/MSE for other predictions. These metrics are introduced in the supplementary materials.
Results. In Table 1, we provide comparisons with extensive baseline methods and a SOTA method with the same data settings. All methods adopt the learned ensemble over the two predictions which are corresponded to the two modality branches. In addition, all predictions have the same resolution for a fair comparison. Since most existing methods are based on CNNs, we further provide two baselines for transformer-based models including the baseline without feature fusion (only uses ensemble for the late fusion) and the feature fusion method. By comparison, our TokenFusion surpasses all the other methods with large margins. For example, in the Shade+TextureRGB task, our TokenFusion (S) achieves FID/KID scores, remarkably better than the current SOTA method CEN with 29.8% relative FID metric decrease.
In supplementary materials, we consider more modality inputs up to 4 which evaluates our group allocation strategy.
Visualization and analysis. We provide qualitative results in Fig. 3, where we choose tough samples for comparison. The predictions with our TokenFusion obtain better natural patterns and are also richer in colors and details. In Fig. 4, we further visualize the process of TokenFusion of which tokens are learned to be fused under our sparsity constraints. We observe that the tokens for fusion follow specific regularities. For example, the texture modality tends to preserve its advantage of detailed boundaries, and meanwhile seek facial tokens from the shade modality. In this sense, TokenFusion combines complementary properties of different modalities.
2 RGB-Depth Semantic Segmentation
We then evaluate TokenFusion on another homogeneous scenario, semantic segmentation with RGB and depth as input, which is a very common multimodal task and numerous methods have been proposed towards better performance. We choose the typical indoor datasets, NYUDv2 and SUN RGB-D . For NYUDv2, we follow the standard 795/654 images for train/test splits to predict the standard 40 classes . SUN RGB-D is one of the most challenging large-scale indoor datasets, and we adopt the standard 5,285/5,050 images for train/test of 37 semantic classes.
Our models include TokenFusion (tiny) and TokenFusion (small), of which the single-modal backbones follow B2 and B3 settings of SegFormer . Both tiny and small versions adopt the pretrained parameters on ImageNet-1 for initialization following . Similar to our implementation in Sec. 4.1, we also adopt shared transformers and positional embeddings for RGB and depth inputs with individual LNs. We let the sparsity weight in Eq. 10 and the threshold in Sec. 3.4 for all these experiments.
Results. Results provided in Table 2 conclude that current transformer-based models equipped with our TokenFusion surpass SOTA models using CNNs. Note that we choose relatively light backbone settings (B1 and B2 as mentioned in Sec. 4.2). We expect that using larger backbones (e.g., B5) would yield better performance.
3 Vision and Point Cloud 3D Object Detection
We further apply TokenFusion for fusing heterogeneous modalities, specifically, the 3D object detection task which has received great attention. We leverage 3D point clouds and 2D images to learn 3D and 2D detections, respectively, and both processes are learned simultaneously. We expect the involvement of 2D learning boosts the 3D counterpart.
We adopt SUN RGB-D and ScanNetV2 datasets. For SUN RGB-D, we follow the same train/test splits as in Sec. 4.2 and detect the 10 most common classes. For ScanNetV2, we adopt the 1,201/312 scans as train/test splits to detect the 18 object classes. All these settings (splits and detected target classes) follow current works for a fair comparison. Note that different from SUN RGB-D, ScanNetV2 provides multi-view images for each scene alongside the point cloud. We randomly sample 10 frames per scene from the scannet-frames-25k samples provided in .
Our architectures for 3D detection and 2D detection follow GF and YOLOS , respectively. We adopt the “L6, O256” or “L12, O512” versions of GF for the 3D detection branch. We combine GF with the tiny (Ti) and small (S) versions of YOLOS, respectively, and adopt mAP@0.25 and mAP@0.5 as evaluation metrics following .
Visualizations. Fig. 5 illustrates the comparison of detection results when using TokenFusion for multimodal interactions against individual learning. We observe that TokenFusion benefits the 3D detection part. For example, with the help of images, models with TokenFusion can locate 3D objects even with sparse or missing point data (second row). In addition, using images also benefits when the points of two objects are largely overlapped (first row). These observations demonstrate the advantages of our TokenFusion.
Ablation Study
-norm and token fusion. In Table 5, we demonstrate the advantages of -norm and token fusion. We additionally conduct experiments with random token fusion. We observe that applying -norm itself has little effect on the performance yet it is essential to reveal tokens for fusion. Our token fusion together with -norm achieves much better performance than the random fusion baselines.
Evaluation of RPA. Table 6 evaluates RPA proposed in Sec. 3.3. Results indicate that only using RPA without token fusion does not noticeably affect the performance, but is important when combined with the token fusion process for alignments, especially for the 3D detection task.
Conclusion
We propose TokenFusion, an adaptive method generally applicable for fusing vision transformers with homogeneous or heterogeneous modalities. TokenFusion exploits uninformative tokens and re-utilizes these tokens to strengthen the interaction of other informative multimodal tokens. Alignment relations of different modalities can be explicitly utilized due to our residual positional alignment and inter-modal projection. TokenFusion surpasses state-of-the-art methods on a variety of tasks, demonstrating its superiority and generality for multimodal fusion.
Acknowledgement
This work is funded by Major Project of the New Generation of Artificial Intelligence (No. 2018AAA0102900) and the Sino-German Collaborative Research Project Crossmodal Learning (NSFC 62061136001/DFG TRR169). We gratefully acknowledge the support of MindSpore, CANN and Ascend AI Processor used for this research.
Appendix
Appendix A Additional Results
Multiple input modalities. In Table 7, we further evaluate our TokenFusion with more modality inputs from 1 to 4. When the number of input modalities is larger than 2, we adopt the group allocation strategy as proposed in Sec. 3.4 of our main paper. By comparison, the performance is consistently improved when using more modalities, and TokenFusion is again noticeably better than CEN , suggesting the ability to absorb information from more modalities.
Network sharing. As mentioned in Sec. 3.4 of our main paper, we adopt shared parameters in both Multi-head Self-Attention (MSA) and Multi-Layer Perception (MLP) for the fusion with homogeneous modalities, and rely on modality-specific Layer Normalization (LN) layers to uncouple the normalization process. Such network sharing technique is evaluated by our experiments including multimodal image-to-image translation (in Sec. 4.1) and RGB-depth semantic segmentation (in Sec. 4.2), which largely reduces the model size, and also enables the reuse of attention weights for different modalities. In Table 8, we further conduct ablation studies to demonstrate the effectiveness of our network sharing scheme. Fortunately, the comparison indicates that our default setting (i.e., Shared MSA and MLP, individual LN) achieves a win-win scenario: apart from the advantage on storage efficiency, also achieves better results than using individual MSA and MLP on both tasks. Note that further sharing LN layers leads to the performance drop, especially on the image-to-image translation task. In addition, we adopt shared Positional Embeddings (PEs) by default for the fusion with homogeneous modalities, and we observe that sharing/unsharing PEs can achieve comparable performance in practice.
Combining TokenFusion with channel-wise fusion. Our TokenFusion detects uninformative tokens and re-utilizes these tokens for multimodal fusion. We may further combine TokenFusion with an orthogonal method by channel-wise pruning which automatically detects uninformative channels. Different from the token-wise fusion method in TokenFusion, the channel-wise fusion is not conditional on input features. Inspired by CEN , we leverage the scaling factors of layer normalization (LN) to perform channel-wise pruning, and apply sparsity constraints on . LN in transformers performs normalization on its input .
To prune uninformative channels, we add a channel-wise pruning loss to the main loss in Eq. (5) (main paper). The overall loss function is
where are hyper-parameters for balancing different losses; is a vector with the length , representing the scaling factor of LN at the -th layer of the -th modality.
We let for RGB-depth segmentation experiments. Results provided in Table 9 demonstrate that our TokenFusion can be combined with the channel-wise fusion to obtain a further improved performance. For example, the segmentation on NYUDv2 with both token-wise and channel-wise fusion achieves an additional 0.5 mIoU gain than TokenFusion. More detailed studies of such combined framework, the relation between the overall pruning rate and fusion performance gain, and the extension to fuse heterogeneous modalities are left to be the future works.
Additional visualizations. In Fig. 6, we provide another group of visualizations that depict the fused tokens under the sparsity constraints during training. We observe that fused tokens follow the regularities mentioned in our main paper, e.g., the texture modality preserves its advantage at boundaries while seeking facial tokens from the shade modality.
Inference speed. In Table 10, we test the real inference speed (single V100, 256G RAM) with different numbers of input frames for 3D detection. We observe that additional time costs are mild, which is partly because the added YOLOS-Ti is a light model (with only three multi-heads).
Appendix B More Details of Image Translation
In this part, we discuss the implementation details for our image-to-image translation task. Our implementation contains two transformers as the generator and discriminator respectively. The resolution of the generator/discriminator input or the generator prediction is . Specifically, the discriminator of our model is similar to , which adopts five stages with two layers for each, where the embedding dimensions and head numbers gradually double from to and from to respectively. The generator is composed of nine stages where the first five have the same configurations with the discriminator, and the last four stages have reverse configurations of its first four stages.
We adopt four kinds of evaluation metrics including Mean Square Error (MSE), Mean Absolute Error (MAE), Fréchet-Inception-Distance (FID), and Kernel-Inception-Distance (KID). Here we briefly introduce FID and KID scores. FID, proposed by , contrasts the statistics of generated samples against real samples. The FID fits a Gaussian distribution to the hidden activations of InceptionNet for each compared image set and then computes the Fréchet distance (also known as the Wasserstein-2 distance) between those Gaussians. Lower FID is better, corresponding to generated images more similar to the real. KID, developed by , is a metric similar to the FID but uses the squared Maximum-Mean-Discrepancy (MMD) between Inception representations with a polynomial kernel. Unlike FID, KID has a simple unbiased estimator, making it more reliable especially when there are much more inception features channels than image numbers. Lower KID indicates more visual similarity between real and generated images. Regarding our implementation of KID, the hidden representations are derived from the Inception-v3 pool3 layer.