SegMamba: Long-range Sequential Modeling Mamba For 3D Medical Image Segmentation
Zhaohu Xing, Tian Ye, Yijun Yang, Guang Liu, Lei Zhu
Introduction
Extending the model’s receptive field is a critical aspect of 3D medical image segmentation. Conventional convolutional neural networks (CNNs) are not very effective at extracting global information from high-resolution 3D medical images. Hence, the use of depth-wise convolution with a large kernel size is proposed to model a broader range of features. 3D UX-Net introduces a new architecture utilizing an convolution () block with a large kernel to facilitate larger receptive fields. However, CNN-based methods struggle to model relationships when the distances between pixels are too large.
Recently, the transformer architecture, utilizing a self-attention module to extract global information, has been extensively explored for 3D medical image segmentation. TransBTS combines 3D-CNN to extract local spatial features and then applies the transformer to model global dependencies in high-level features. UNETR employs the Vision Transformer (ViT) as its encoder to learn contextual information, which is then merged with a CNN-based decoder via skip connections at multiple resolutions. SwinUNETR leverages the SwinTransformer as the encoder to extract multi-scale features. It also designs a multi-scale decoder to fuse features from each encoder stage, achieving promising results in 3D medical image segmentation. However, the typically high resolution of 3D medical images can result in significant computational burdens and reduced speed performance for transformer-based methods.
To overcome the challenges of long sequence modeling, Mamba , which originates from state space models (SSMs) , is designed to model long-range dependencies and enhance the efficiency of training and inference through a selection mechanism and a hardware-aware algorithm. Numerous studies have explored the applications of Mamba in computer vision (CV). U-Mamba integrates the Mamba layer into the encoder of nnunet to enhance general medical image segmentation. Meanwhile, Vision Mamba proposes the Vim block, which incorporates bidirectional SSM for data-dependent global visual context modeling and position embeddings for location-aware visual understanding. Additionally, VMamba designs a CSM module to bridge the gap between 1-D array scanning and 2-D plain traversing. In 3D medical image segmentation, the traditional transformer block faces challenges in handling large-size features. It’s necessary to model the correlations within the high-dimensional features for a stronger visual understanding. Motivated by this, we introduce SegMamba, a novel architecture that combines the U-shape structure with the Mamba for modeling the whole volume global features at various scales. To our knowledge, this is the first method utilizing Mamba specifically for 3D medical image segmentation. SegMamba exhibits a remarkable capability to model long-range dependencies within volumetric data, while maintaining outstanding inference efficiency, compared to traditional CNN-based and transformer-based methods. Extensive experiments demonstrate the effectiveness of our method.
Method
SegMamba mainly consists of three components: 1) the Mamba encoder with multiple Mamba blocks to extract features at different scales, 2) the 3D decoder based on the convolution layer for predicting segmentation results, and 3) the skip-connections to connect the multi-scale features to the decoder for feature reuse. Fig. 1 illustrates the overview of the proposed SegMamba. We further describe the details of the encoder and decoder in this section.
Modeling global features and multi-scale features is critically important for 3D medical image segmentation. While the transformer architecture can extract global information, it incurs a significant computational burden when dealing with overly long feature sequences. To reduce the sequence length, methods based on the transformer architecture, such as UNETR, directly down-sample the input 3D medical image with a resolution of to . However, this approach limits the ability to model multi-scale features, which are essential for predicting segmentation results via the decoder. To overcome this limitation, we design the Mamba block, which substitutes the self-attention module in the transformer architecture with the more efficient Mamba layer. This enables both multi-scale and global feature modeling while stays a high efficiency during training and inference.
where the denotes the layer-norm to normalize the input features, Mamba represents the Mamba layer, and represents the multiple layers perception layer to enrich the feature representation.
2 Decoder
The Mamba encoder extract the multi-scale features, following many previous studies, we utilize a CNN-based decoder and a skip connections to form a U-shape network for predicting the segmentation results.
Experiments
The BraTS2023 dataset contains a total of 1,251 3D brain MRI volumes. Each volume includes four modalities (namely T1, T1Gd, T2, T2-FLAIR) and three segmentation targets (WT: Whole Tumor, ET: Enhancing Tumor, TC: Tumor Core). All data have been resampled to the same spacing (1.0, 1.0, 1.0).
2 Evaluation Metrics
We adopt various metrics to compare the performance of our method against other methods, based on the characteristics of different datasets.
The Dice score is an overlap metric that measures the percentage overlap between the prediction and ground truth. It is calculated using the following equation:
where represents the semantic prediction, represents the ground truth annotation, and represents the cardinality computation operation.
2.2 95% Hausdorff distance (HD95)
The HD95 metric measures the 95th percentile of the Hausdorff Distance, providing a robust evaluation of the maximum distance between the prediction and ground truth. It is defined by the following equation:
where represents the Hausdorff Distance between and , and denotes the maximum value at the 95th percentile.
3 Comparison Methods
For a thorough evaluation, SegMamba is compared with seven other state-of-the-art methods, which both cover both CNN and transformer architectures.
We compare SegMamba with SegresNet , UX-Net , and MedNeXt . Among these, UX-Net and MedNeXt are the latest 3D medical image segmentation methods.
3.2 Transformer-based methods
We also compare SegMamba with state-of-the-art transformer-based segmentation methods in medical imaging. UNETR , SwinUNETR , SwinUNETR-V2 , and nnFormer are the most famous transformer-based methods for 3D medical image segmentation. They use Vision Transformer or SwinTransformer structures as encoders to model the global features.
To ensure a fair comparison, we use the publicly available codes of all the methods in our experiments and keep all the settings the same.
4 Implementation Details
Our model is implemented in Pytorch 2.0.1-cuda11.7 and Monai 1.2.0. During training, we use a random crop size of and a batch size of 2 per GPU for each dataset. Since each volume contains 4 modalities in the BraTS2023 dataset, we concatenate each modality in the channel dimension at the input of the network. We use cross-entropy loss for all experiments and an SGD optimizer along with a polynomial learning rate scheduler (initial learning rate of 1e-2, a decay of 1e-5). We run 1000 epochs for all datasets and adopt the following data augmentations: additive brightness, gamma, rotation, scaling, mirror, and elastic deformation. All experiments are conducted on a cloud computing platform with four NVIDIA A100 GPUs. We apply test-time augmentation (TTA) techniques (i.e., mirror prediction and overlapped sliding window inference ) dwuring inference for all three datasets with the overlap ratio set to 0.5. For each dataset, we randomly allocate 70% of the 3D volumes for training, 10% for validation, and the remaining 20% for testing, ensuring that each volume appears only once in the training, validation and testing sets.
5 Quantitative Comparison to Previous Methods
The segmentation results of gliomas for the BraTS2023 dataset are listed in Table 1, in which we use the Dice score and HD95 to evaluate the performance on the three segmentation targets: WT, TC, and ET. The UX-Net, a CNN-based method, achieves the best performance among the comparison methods, with an average Dice score of 89.69% and an average HD95 of 4.81. Among the transformer-based methods, SwinUNETR-V2 also demonstrates good performance with an average Dice score of 89.39% and an average HD95 of 4.51. In comparison, our SegMamba achieves Dice scores of 93.61%, 92.65%, and 87.71% on WT, TC, and ET, and HD95s of 3.37, 3.85, 3.48, respectively, and outperforms all other methods. The average Dice score of our SegMamba is 91.32%, representing a significant improvement of 1.63% and 1.93% over the second-place UX-Net and the third-place SwinUNETRv2.