Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model
Lianghui Zhu, Bencheng Liao, Qian Zhang, Xinlong Wang, Wenyu Liu, Xinggang Wang
Introduction
Recent research advancements have led to a surge of interest in the state space model (SSM). Originating from the classic Kalman filter model , modern SSMs excel at capturing long-range dependencies and benefit from parallel training. Some SSM-based methods, such as the linear state-space layers (LSSL) , structured state space sequence model (S4) , diagonal state space (DSS) , and S4D , are proposed to process sequence data across a wide range of tasks and modalities, particularly on modeling long-range dependencies. They are efficient in processing long sequences because of convolutional computation and near-linear computation. 2-D SSM , SGConvNeXt , and ConvSSM combine SSM with CNN or Transformer architecture to process 2-D data. The recent work, Mamba , incorporates time-varying parameters into the SSM and proposes a hardware-aware algorithm to enable very efficient training and inference. The superior scaling performance of Mamba indicates that it is a promising alternative to Transformer in language modeling. Nevertheless, a generic pure-SSM-based backbone network has not been explored for processing visual data, such as images and videos.
Vision Transformers (ViTs) have achieved great success in visual representation learning, excelling in large-scale self-supervised pre-training and high performance on downstream tasks. Compared with convolutional neural networks, the core advantage lies in that ViT can provide each image patch with data/patch-dependent global context through self-attention. This differs from convolutional networks that use the same parameters, i.e., the convolutional filters, for all positions. Another advantage is the modality-agnostic modeling by treating an image as a sequence of patches without 2D inductive bias, which makes it the preferred architecture for multimodal applications . At the same time, the self-attention mechanism in Transformers poses challenges in terms of speed and memory usage when dealing with long-range visual dependencies, e.g., processing high-resolution images.
Motivated by the success of Mamba in language modeling, it is appealing that we can also transfer this success from language to vision, i.e., to design a generic and efficient visual backbone with the advanced SSM method. However, there are two challenges for Mamba, i.e., unidirectional modeling and lack of positional awareness. To address these challenges, we propose the Vision Mamba (Vim) model, which incorporates the bidirectional SSMs for data-dependent global visual context modeling and position embeddings for location-aware visual recognition. We first split the input image into patches and linearly project them as vectors to Vim. Image patches are treated as the sequence data in Vim blocks, which efficiently compresses the visual representation with the proposed bidirectional selective state space. Furthermore, the position embedding in Vim block provides the awareness for spatial information, which enables Vim to be more robust in dense prediction tasks. In the current stage, we train the Vim model on the supervised image classification task using the ImageNet dataset and then use the pretrained Vim as the backbone to perform sequential visual representation learning for downstream dense prediction tasks, i.e., semantic segmentation, object detection, and instance segmentation. Like Transformers, Vim can be pretrained on large-scale unsupervised visual data for better visual representation. Thanks to the better efficiency of Mamba, the large-scale pretraining of Vim can be achieved with lower computational cost.
Compared with other SSM-based models for vision tasks, Vim is a pure-SSM-based method and models images in a sequence manner, which is more promising for a generic and efficient backbone. Thanks to the bidirectional compressing modeling with positional awareness, Vim is the first pure-SSM-based model to handle dense prediction tasks. Compared with the most convincing Transformer-based model, i.e., DeiT , Vim achieves superior performance on ImageNet classification. Furthermore, Vim is more efficient in terms of GPU memory and inference time for high-resolution images. The efficiency in terms of memory and speed empowers Vim to directly perform sequential visual representation learning without relying on 2D priors (such as the 2D local window in ViTDet ) for high-resolution visual understanding tasks while achieving higher accuracy than DeiT.
Our main contributions can be summarized as follows:
We propose Vision Mamba (Vim), which incorporates bidirectional SSM for data-dependent global visual context modeling and position embeddings for location-aware visual understanding.
Without the need of attention, the proposed Vim has the same modeling power as ViT while it only has subquadratic-time computation and linear memory complexity. Specifically, Vim is 2.8 faster than DeiT and saves 86.8% GPU memory when performing batch inference to extract features on images at the resolution of 12481248.
We conduct extensive experiments on ImageNet classification and dense prediction downstream tasks. The results demonstrate that Vim achieves superior performance compared to the well-established and highly-optimized plain vision Transformer, i.e., DeiT.
Related Work
Architectures for generic vision backbone. In the early eras, ConvNet serves as the de-facto standard network design for computer vision. Many convolutional neural architectures have been proposed as the vision backbone for various visual applications. The pioneering work, Vision Transformer (ViT) changes the landscape. It treats an image as a sequence of flattened 2D patches and directly applies a pure Transformer architecture. The surprising results of ViT on image classification and its scaling ability encourage a lot of follow-up works . One line of works focuses on hybrid architecture designs by introducing 2D convolutional priors into ViT . PVT proposes a pyramid structure Transformer. Swin Transformer applies self-attention within shift windows. Another line of works focuses on improving traditional 2D ConvNets with more advanced settings . ConvNeXt reviews the design space and proposes pure ConvNets, which can be scalable as ViT and its variants. RepLKNet proposes to scale up the kernel size of existing ConvNets to bring improvements.
Though these dominant follow-up works demonstrate superior performance and better efficiency on ImageNet and various downstream tasks by introducing 2D priors, with the surge of large-scale visual pretraining and multi-modality applications , vanilla Transformer-style model strikes back to the center stage of computer vision. The advantages of larger modeling capacity, unified multi-modality representation, being friendly to self-supervised learning etc., make it the preferred architecture. However, the number of visual tokens is limited due to the quadratic complexity of Transformer. There are plenty of works to address this long-standing and prominent challenge, but few of them focus on visual applications. Recently, LongViT built an efficient Transformer architecture for computational pathology applications via dilated attention. The linear computation complexity of LongViT allows it to encode the extremely long visual sequence. In this work, we draw inspiration from Mamba and explore building a pure-SSM-based model as a generic vision backbone without using attention, while preserving the sequential, modality-agnostic modeling merit of ViT.
State space models for long sequence modeling. proposes a Structured State-Space Sequence (S4) model, a novel alternative to CNNs or Transformers, to model the long-range dependency. The promising property of linearly scaling in sequence length attracts further explorations. proposes a new S5 layer by introducing MIMO SSM and efficient parallel scan into S4 layer. designs a new SSM layer, H3, that nearly fills the performance gap between SSMs and Transformer attention in language modeling. builds the Gated State Space layer on S4 by introducing more gating units to improve the expressivity. Recently, proposes a data-dependent SSM layer and builds a generic language model backbone, Mamba, which outperforms Transformers at various sizes on large-scale real data and enjoys linear scaling in sequence length. In this work, we explore transferring the success of Mamba to vision, i.e., building a generic vision backbone purely upon SSM without attention.
State space models for visual applications. uses 1D S4 to handle the long-range temporal dependencies for video classification. further extends 1D S4 to handle multi-dimensional data including 2D images and 3D videos. combines the strengths of S4 and self-attention to build TranS4mer model, achieving state-of-the-art performance for movie scene detection. introduces a novel selectivity mechanism to S4, largely improving the performance of S4 on long-form video understanding with a much lower memory footprint. supplants attention mechanisms with a more scalable SSM-based backbone to generate high-resolution images and process fine-grained representation under affordable computation. proposes U-Mamba, a hybrid CNN-SSM architecture, to handle the long-range dependencies in biomedical image segmentation. The above works either apply SSM to specific visual applications or build a hybrid architecture by combining SSM with convolution or attention. Different from them, we build a pure-SSM-based model, which can be adopted as a generic vision backbone.
Method
The goal of Vision Mamba (Vim) is to introduce the advanced state space model (SSM), i.e., Mamba , to computer vision. This section begins with a description of the preliminaries of SSM. It is followed by an overview of Vim. We then detail how the Vim block processes input token sequences and proceed to illustrate the architecture details of Vim. The section concludes with an analysis of the efficiency of the proposed Vim.
The S4 and Mamba are the discrete versions of the continuous system, which include a timescale parameter to transform the continuous parameters , to discrete parameters , . The commonly used method for transformation is zero-order hold (ZOH), which is defined as follows:
After the discretization of , , the discretized version of Eq. (1) using a step size can be rewritten as:
At last, the models compute output through a global convolution.
2 Vision Mamba
where is the proposed vision mamba block, is the number of layers, and is the normalization layer.
3 Vim Block
The original Mamba block is designed for the 1-D sequence, which is not suitable for vision tasks requiring spatial-aware understanding. In this section, we introduce the Vim block, which incorporates the bidirectional sequence modeling for the vision tasks. The Vim block is shown in Fig. 2.
Specifically, we present the operations of Vim block in Algo. 21. The input token sequence is first normalized by the normalization layer. Next, we linearly project the normalized sequence to the and with dimension size . Then, we process the from the forward and backward directions. For each direction, we first apply the 1-D convolution to the and get the . We then linearly project the to the , , , respectively. The is then used to transform the , , respectively. Finally, we compute the and through the SSM. The and are then gated by the and added together to get the output token sequence .
4 Architecture Details
In summary, the hyper-parameters of our architecture are listed as follows:
Following ViT and DeiT , we first employ 1616 kernel size projection layer to get a 1-D sequence of non-overlapping patch embeddings. Subsequently, we directly stack Vim blocks. By default, we set the number of blocks to 24, SSM dimension to 16. To align with the model sizes of DeiT series, we set the hidden state dimension to 192 and expanded state dimension to 384 for the tiny-size variant. For the small-size variant, we set to 384 and to 768.
5 Efficiency Analysis
Traditional SSM-based methods leverage the fast Fourier transform to boost the convolution operation as shown in Eq. (4). For data-dependent methods, such as Mamba, the SSM operation in Line 11 of Algo. 21 is no longer equivalent to convolution. To address this problem, Mamba and the proposed Vim choose a modern-hardware-friendly way to ensure efficiency. The key idea of this optimization is to avoid the IO-bound and memory-bound of modern hardware accelerators (GPUs).
IO-Efficiency. The high bandwidth memory (HBM) and SRAM are two important components for GPUs. Among them, SRAM has a larger bandwidth and HBM has a bigger memory size. The standard implementation of Vim’s SSM operation with HBM requires the number of memory IO on the order of . Inspired by Mamba, Vim first reads in bytes of memory from slow HBM to fast SRAM. Then, Vim gets the discrete , of a size of in SRAM. Last, Vim performs SSM operations in SRAM and writes the output of a size of back to HBM. This method can help to reduce IOs from to .
Memory-Efficiency. To avoid out-of-memory problems and achieve lower memory usage when dealing with long sequences, Vim chooses the same recomputation method as Mamba. For the intermediate states of size to calculate the gradient, Vim recomputes them at the network backward pass. For intermediate activations such as the output of activation functions and convolution, Vim also recomputes them to optimize the GPU memory requirement, as the activation values take a lot of memory but are fast for recomputation.
Computation-Efficiency. SSM in Vim block (Line 11 in Algo.21) and self-attention in Transformer both play a key role in providing global context adaptively. Given a visual sequence and the default setting , the computation complexity of a global self-attention and SSM are:
where self-attention is quadratic to sequence length , and SSM is linear to sequence length ( is a fixed parameter, set to 16 by default). The computational efficiency makes Vim scalable for gigapixel applications with large sequence lengths.
Experiment
Settings. We benchmark Vim on the ImageNet-1K dataset , which contains 1.28M training images and 50K validation images from 1,000 categories. All models are trained on the training set, and top-1 accuracy on the validation set is reported. For fair comparisons, our training settings mainly follow DeiT . Specifically, we apply random cropping, random horizontal flipping, label-smoothing regularization, mixup, and random erasing as data augmentations. When training on input images, we employ AdamW with a momentum of , a total batch size of , and a weight decay of to optimize models. We train the Vim models for epochs using a cosine schedule, initial learning rate, and EMA. During testing, we apply a center crop on the validation set to crop out images. Experiments are performed on 8 A800 GPUs.
Long Sequence Fine-tuning To make full use of the efficient long sequence modeling power of Vim, we continue to fine-tune Vim with a long sequence setting for 30 epochs after ImageNet pretraining. Specifically, we set a patch extraction stride of while keeping the patch size unchanged, a constant learning rate of , and a weight decay of .
Results. Tab. 1 compares Vim with ConvNet-based, Transformer-based and SSM-based backbone networks. Compared to ConvNet-based ResNet , Vim demonstrates superior performance. For example, when the parameters are roughly similar, the top-1 accuracy of Vim-Small reaches 80.5, which is 4.3 points higher than that of ResNet50. Compared with the conventional self-attention-based ViT , Vim outperforms it by considerable margins in terms of both parameter numbers and classification accuracy. When compared to the highly-optimized ViT-variant, i.e., DeiT , Vim surpasses it at different scales with comparable parameter numbers: 3.9 points higher for Vim-Tiny over DeiT-Tiny, and 0.7 points higher for Vim-Small over DeiT-Small. Compared with SSM-based S4ND-ViT-B , Vim achieves higher top-1 accuracy with 3 fewer parameters. After long sequence fine-tuning, Vim-Tiny† and Vim-S† all achieve higher results. Among them, Vim-S† even achieves similar results with DeiT-B. The results demonstrate that Vim can be adapted to longer sequence modeling easily and extract stronger visual representation.
Fig. 1 (b) and (c) compare the FPS and GPU memory of tiny-size Vim and DeiT. Vim demonstrates better efficiency in speed and memory as image resolution grows. Specifically, when the image size is 512512, Vim achieves similar FPS and memory as DeiT. As the image size grows to 12481248, Vim is 2.8 faster than DeiT and saves 86.8% GPU memory. The pronounced superiority of Vim’s linear scaling in sequence length makes it ready for high-resolution downstream vision applications and long-sequence multi-modality applications.
2 Semantic Segmentation
Settings. We conduct experiments for semantic segmentation on the ADE20K and use UperNet as the segmentation framework. We provide detailed settings in Sec. B.
Results. As shown in Tab. 2, Vim consistently outperforms DeiT across different scales: 1.8 mIoU higher for Vim-Ti over DeiT-Ti, and 0.9 mIoU higher for Vim-S over DeiT-S. Compared to the ResNet-101 backbone, our Vim-S achieves the same segmentation performance with nearly 2 fewer parameters.
To further evaluate the efficiency for downstream tasks, i.e., segmentation, detection, and instance segmentation, we combine the backbones with a commonly used feature pyramid network (FPN) module and benchmark their FPS and GPU memory. As shown in Fig. 4 and Fig. 3, the efficiency curves demonstrate similar comparison results of the pure backbone (Fig. 1), though we append a heavy FPN on the backbones. The exceptional linear scaling performance is attributed to our proposed efficient backbone Vim, which builds the foundation for learning gigapixel-level visual representation in an end-to-end manner without the need for multi-stage encoding (e.g., aerial image, medical image, and computational pathology).
3 Object Detection and Instance Segmentation
Settings. We conduct experiments for object detection and instance segmentation on the COCO 2017 dataset and use ViTDet as the basic framework. We provide detailed settings in Sec. B.
Results. Tab. 3 compares Vim-Ti with DeiT-Ti using Cascade Mask R-CNN framework . Vim-Ti surpasses DeiT-Ti by 1.3 box AP and 1.1 mask AP. For the middle-size and large-size objects, Vim-Ti outperforms DeiT-Ti by 1.6 AP/1.3 AP and 1.4 AP/1.8 AP, demonstrating better long-range context learning than DeiT (Fig. 5).
We highlight that the accuracy superiority is non-trivial since DeiT is equipped with window attention while Vim works in a pure sequence modeling manner. Specifically, to perform representation learning on high-resolution images (i.e., 10241024), we follow ViTDet and modify the DeiT backbone with the use of 2D window attention, which injects 2D prior and breaks the sequential modeling nature of Transformer. Thanks to the efficiency illustrated in Sec. 3.5, Fig. 1 and Fig. 3, we can directly apply Vim on 10241024 input images and learn sequential visual representation for object detection and instance segmentation without need for 2D priors in the backbone.
4 Ablation Study
Bidirectional SSM. We ablate the key bidirectional design of Vim, using ImageNet-1K classification and the Segmenter semantic segmentation framework on ADE20K. To fully evaluate the power of learned representation on ImageNet, we use a simple Segmenter head with only 2 layers to perform transfer learning on semantic segmentation. We study these bidirectional strategies:
None. We directly adopt the Mamba block to process visual sequence with only the forward direction.
Bidirectional Sequence. During training, we randomly flip the visual sequence. This works like data augmentation.
Bidirectional Block. We pair the stacked blocks. The first block of each pair processes visual sequence in the forward direction and the second block of each pair processes in the backward direction.
Bidirectional SSM. We add an extra SSM for each block to process the visual sequence in the backward direction.
Bidirectional SSM + Conv1d. Based on Bidirectional SSM, we further add a backward Conv1d before the backward SSM (Fig. 2).
As shown in Tab. 4, directly adopting the Mamba block achieves good performance in classification. However, the unnatural unidirectional manner poses challenges in downstream dense prediction. Specifically, the preliminary bidirectional strategy of using Bidirectional Block achieves 7 points lower top-1 accuracy on classification. Yet, it outperforms the vanilla unidirectional Mamba block by 1.3 mIoU on semantic segmentation. By adding extra backward SSM and Conv1d, we achieve superior classification accuracy (73.9 top-1 acc vs. 73.2 top-1 acc) and exceptional segmentation superiority (35.9 mIoU vs. 32.3 mIoU). We use the strategy of Bidirectional SSM + Conv1d as the default setting in our Vim block.
Classification Design. We ablate the classification design of Vim, benchmarking on ImageNet-1K classification. We study the following classification strategies:
Mean pool. We adopt mean pooling on the output feature from the last Vim block and perform classification on this pooled feature.
Max pool. We first adapt the classification head on each token of the visual sequence and then perform max pooling on the sequence to get the classification prediction result.
Head class token. Following DeiT , we concatenate the class token at the head of the visual sequence and perform classification.
Double class token. Based on the head class token strategy, we additionally add a class token at the tail of the visual sequence.
Middle class token. We add a class token at the middle of the visual sequence and then perform classification on the final middle class token.
As shown in Tab. 5, experiments show that the middle class token strategy can fully exploit the recurrent nature of SSM and the central object prior in ImageNet, demonstrating the best top-1 accuracy of 76.1.
Conclusion and Future Work
We have proposed Vision Mamba (Vim) to explore the very recent efficient state space model, i.e., Mamba, as generic vision backbones. Unlike prior state space models for vision tasks which use hybrid architecture or equivalent global 2D convolutional kernel, Vim learns visual representation in the sequence modeling manner and does not introduce image-specific inductive biases. Thanks to the proposed bidirectional state space modeling, Vim achieves data-dependent global visual context and enjoys the same modeling power as Transformer, while having lower computation complexity. Benefiting from the hardware-aware designs of Mamba, the inference speed and memory usage of Vim are significantly better than ViTs when processing high-resolution images. Experiment results on standard computer vision benchmarks have verified the modeling power and high efficiency of Vim, showing that Vim has great potential to be the next-generation vision backbone.
In future works, Vim with the bidirectional SSM modeling with position embeddings is suitable for unsupervised tasks such as mask image modeling pretraining and the similar architecture with Mamba enables multimodal tasks such as CLIP-style pretraining. Based on the pretrained Vim weights, exploring the usefulness of Vim for analyzing high-resolution medical images, remote sensing images, and long videos, which can be regarded as downstream tasks, is very straightforward.
Acknowledgement
We would like to acknowledge Tianheng Cheng, Yuxin Fang, Shusheng Yang, Bo Jiang, and Jingfeng Yao for their helpful feedback on the draft.
References
Appendix A Visualization
Appendix B Additional Setting
Settings for Semantic Segmentation. We conduct experiments for semantic segmentation on the ADE20K dataset. ADE20K contains 150 fine-grained semantic categories, with 20K, 2K, and 3K images for training, validation, and testing, respectively. We choose UperNet as our base framework. In training, we employ AdamW with a weight decay of , and a total batch size of to optimize models. The employed training schedule uses an initial learning rate of , linear learning rate decay, a linear warmup of iterations, and a total training of K iterations. The data augmentations follow common settings, including random horizontal flipping, random re-scaling within the ratio range , and random photometric distortion. During evaluation, we rescale the image to have a shorter side of .
Settings for Object Detection and Instance Segmentation. We conduct experiments for object detection and instance segmentation on the COCO 2017 dataset . The COCO 2017 dataset contains 118K images for training, 5K images for validating, and 20K images for testing. We use the canonical Cascade Mask R-CNN as the base framework. For ViT-based backbones, we apply extra configurations (e.g., interleaved window & global attention) to handle the high-resolution images following ViTDet . For SSM-based Vim, we directly use it without any modifications. Other training and evaluation settings are just the same. During training, we employ AdamW with a weight decay of , and a total batch size of to optimize models. The employed training schedule uses an initial learning rate of , linear learning rate decay, and a total training of K iterations. The data augmentations use large-scale jitter data augmentation to 10241024 input images. During evaluation, we rescale the image to have a shorter side of 1024.