Mask-Attention-Free Transformer for 3D Instance Segmentation

Xin Lai, Yuhui Yuan, Ruihang Chu, Yukang Chen, Han Hu, Jiaya Jia

Introduction

Nowadays 3D point clouds can be conveniently collected. They have benefited various applications, such as autonomous driving, robotics, and augmented reality. As a fundamental task, 3D instance segmentation also poses great challenges simultaneously, such as geometric occlusion and semantic ambiguity.

Many works have been proposed to solve the 3D instance segmentation task. Grouping-based methods rely on heuristic clustering algorithms such as DBSCAN or Breadth-First Search (BFS) to generate instance proposals. They thus require sophisticated hyper-parameters tuning and are prone to wrongly segment instances that are close to each other. Recently, transformer-based methods develop a fully end-to-end pipeline. With transformer decoder layers, a fixed number of object queries attend to global features iteratively and directly output instance predictions. It requires no post-processing for duplicate removal such as NMS, since it adopts one-to-one bipartite matching during training. Moreover, it employs mask attention, which uses the instance masks predicted in the last layer to guide the cross-attention.

However, we point out that current transformer-based methods suffer from the issue of slow convergence. As shown in Fig. 1, the baseline model manifests slow convergence and lags behind our method by a large margin, particularly in the early stages of training. We dive further and find that the issue is potentially caused by the low recall of the initial instance masks. Specifically, as shown in Fig. 2 (a), the initial instance masks are produced by the similarity map between the initial object queries and the point-wise mask features. Since the initial object queries are unstable in early training, we notice that the recall of initial instance masks is substantially lower than ours in Fig. 3, especially at the beginning of training (i.e., the 32-th epoch). The low-quality initial instance masks increase the training difficulty, thereby slowing down convergence.

Given the low recall of the initial instance masks, we abandon the mask attention design and instead construct an auxiliary center regression task to guide cross-attention, as depicted in Fig. 2 (b). To enable center regression, we develop a series of position-aware designs. Firstly, we maintain a set of learnable position queries, each of which denotes the position of its corresponding content query. They are densely distributed over the 3D space, and we require each query to attend to its local region. As a result, the queries can easily capture the objects in a scene with a higher recall, which is crucial in reducing training difficulty and accelerating convergence.

In addition, we design the contextual relative position encoding for cross-attention. Compared to the mask attention used in previous works, our solution is more flexible since the attention weights are adjusted by relative positions instead of hard masking. Furthermore, we iteratively update the position queries to achieve more accurate representation. Finally, we introduce the center distances between predictions and ground truths in both matching and loss.

In total, our contribution is three-fold.

We observe that existing transformer-based methods suffer from the low recall of initial instance masks, which causes training difficulty and slow convergence.

Instead of relying on mask attention, we construct an auxiliary center regression task to overcome the low-recall issue and design a series of position-aware components accordingly. Our approach manifests faster convergence and demonstrates higher performance.

Experiments show our approach achieves a new state-of-the-art result and demonstrates superior performance on various datasets including ScanNetv2, ScanNet200, and S3DIS.

Related Work

3D instance segmentation is a fundamental task for 3D recognition . The solutions can be categorized into detection-based, grouping-based and the new emerging transformer-based paradigms. Detection-based approaches first detect the bounding boxes and then segment the fine-grained instance mask. On the other hand, grouping-based approaches employ clustering algorithms to group the points into a set of instance clusters. Before clustering, they either move 3D points towards the associated object center to form a more compact distribution , or transform points to a high-dimension feature space . Further, a series of works leverage semantic priors to avoid the noisy points from other categories or use advanced grouping strategies . With these designs, the grouping-based paradigm has achieved leading performance across various evaluation benchmarks for a long time. Recently, transformer-based paradigm becomes another option and swiftly sets a new state-of-the-art. Compared with previous methods, it presents an elegant pipeline, and can directly output instance predictions. It relies on the transformer decoder and mask attention to aggregate information from global features. In this work, we construct an auxiliary center regression task to assist in cross-attention. Although the existing grouping-based methods also predict center offsets, we explain that ours is not for instance proposals, but to overcome the low-recall issue and provide positional priors for cross-attention.

2 Vision Transformer

Transformer has become a fundamental model in the vision area, thanks to its flexibility and power to model various scenarios using attention mechanisms . Recently, many works rely on the self-attention in transformers to develop vision fundamental models. Besides, DETR proposes a fully end-to-end pipeline for object detection. It utilizes transformer decoders to dynamically aggregate features from images, and uses one-to-one bipartite matching for ground-truth assignment, yielding an elegant pipeline. To solve the notorious slow convergence of DETR, approaches propose deformable attention, impose strong prior or decrease searching space in cross-attention to accelerate convergence. Further, methods of present several ways to stabilize matching and training. Moreover, masked attention are proposed to impose semantic priors to accelerate training for segmentation tasks. Recently, there are works that develop transformer models tailored for 3D point clouds. Following this line of research, we observe the low recall of initial instance masks, and present solutions to circumvent the use of mask attention.

Method

We first review previous methods and present the overview of our method in Sec. 3.1. Then, we elaborate on the details of our position-aware designs in Sec. 3.2.

Recently, Mask3D and SPFormer present a fully end-to-end pipeline, which allows the object queries to directly output instance predictions. With transformer decoders, a fixed number of object queries aggregate information from the global features (either multi-scale voxel features or superpoint features ) extracted with the backbone. Moreover, similar to Mask2Former , they adopt mask attention and rely on the instance masks to guide the cross-attention. Specifically, the cross-attention is masked with the instance masks predicted in the last decoder layer, so that the queries only need to consider the masked features. However, as shown in Fig. 3, the recall of initial instance masks is low in the early stages of training. It hinders the ability to achieve a high-quality result in the subsequent layers and thus increases training difficulty.

Ours.

2 Position-aware Designs

To effectively support the center regression task and improve the recall of initial instance masks, we propose a series of position-aware designs as follows.

It is notable that the initial position queries are densely spread throughout the 3D space. Also, every query aggregates features from its local region. This design choice makes it easier for the initial queries to capture the objects in a scene with a high recall, as shown in Fig. 3. It overcomes the low-recall issue caused by initial instance masks, and consequently reduces the training complexity of the subsequent layers.

Relative Position Encoding.

where ss denotes the quantization size, LL denotes the length of position encoding table. We plus L2\frac{L}{2} to ensure the discrete relative positions are non-negative.

It is worth noting that the RPE offers a greater degree of flexibility and error-insensitivity, compared to mask attention. In essence, RPE can be likened to a soft mask that has the ability to adjust attention weights flexibly, instead of hard masking. Another advantage of RPE is that it integrates semantic information (e.g., object size and class) and thus can harvest local information selectively. This is accomplished by the interaction between the relative positions and the semantic features (i.e., fq\mathbf{f}^{q} and fk\mathbf{f}^{k}).

Iterative Refinement.

Since the content queries in our decoder layers are updated regularly, it is not optimal to maintain frozen position queries throughout the decoding process. Additionally, the initial position queries are static, so it is beneficial to adapt them to the specific input scene in the subsequent layers. To that end, we iteratively refine the position queries based on the content queries. Specifically, as shown in Fig. 4 (b), we leverage an MLP to predict a center offset Δpt\Delta p_{t} from the updated content query Qt+1c\mathcal{Q}^{c}_{t+1}. We then add it to the previous position query Q^tp\hat{\mathcal{Q}}^{p}_{t} as

Center Matching & Loss.

To eliminate the need for duplicate removal methods such as non-maximum suppression (NMS), bipartite matching is adopted during training. Existing works rely on semantic predictions and binary masks to match the ground truths.

In contrast, to support center regression, we also incorporate center distance in bipartite matching. Since we require the queries to only attend to a local region, it is critical to ensure that they only match with nearby ground-truth objects. To achieve this, we adapt the matching costs formulation as follows

Experiment

This section first provides an overview of the experimental setup in Sec. 4.1. We then present the 3D instance segmentation results in Sec. 4.2. Additionally, we conduct an extensive ablation study in Sec. 4.3. Furthermore, we showcase the object detection results and visual comparisons in Sections 4.4 and 4.5, respectively. Code and models will be made publicly available.

For both ScanNetv2 and ScanNet200 , we follow previous works to use 5-layer U-Net as the backbone. The initial channel is set to 32. Unless otherwise specified, we use the coordinates and colors as the input features. We use 6 layers of Transformer decoders, where the head number is set to 8 and the hidden and feed-forward dimensions are set to 256 and 1024, respectively. We adopt Fourier absolute position encoding with the temperature set to 10,000. The quantization size for RPE is set to 0.1m, and the length of the RPE table is 48. Unless otherwise specified, we choose as the baseline model, since it has achieved the best performance on ScanNetv2 val set so far. For the S3DIS dataset, following Mask3D , we use Res16UNet34C as the backbone and employ 4 decoders to attend to the coarsest four scales, and this is repeated 3 times with the shared parameters. The decoder hidden and feed-forward dimensions are set to 128 and 1024, respectively.

Datasets.

We use the ScanNetv2 , ScanNet200 and S3DIS datasets for evaluation. All of them are challenging large-scale indoor scene datasets.

The ScanNetv2 dataset comprises 1201 scenes for training, and an additional 312 and 100 indoor scenes for validation and testing, respectively. The scenes are captured with RGB-D cameras and annotated with 20 semantic labels, 18 of which are instance classes. The ScanNet200 dataset adopts the same point cloud data, but it offers more diverse annotations, covering 200 classes, 198 of which are instance classes.

The S3DIS dataset contains 271 rooms in 6 areas of three buildings, and 13 semantic categories are annotated. Following previous works, the scenes in Area 5 are used for validation and the others are for training.

Implementation Details.

We adopt one RTX 3090 GPU for training on ScanNet and ScanNet200, and one A100 GPU on S3DIS. Following previous works, we use AdamW optimizer with the learning rate and weight decay set to 0.0001 and 0.05, respectively. We adopt poly scheduler on ScanNet and ScanNet200, and onecycle scheduler on S3DIS. The batch size is set to 4. For the weights of matching costs and losses, (λcls\lambda_{cls}, λbce\lambda_{bce}, λdice\lambda_{dice}, λcenter\lambda_{center}) are set to (0.5, 1.0, 1.0, 0.5) on ScanNet and ScanNet200, and (2.0, 5.0, 1.0, 0.5) on S3DIS. The voxel size is set to 0.02m. We limit the points number up to 250,000. Otherwise, we crop the scene by cubic windows iteratively until the point number is lower than the limit. During inference, we select the top 100 instances with the highest scores and set the minimum points number per instance to 100.

2 Instance Segmentation Results

We present the results of instance segmentation on both the ScanNetv2 test and val sets in Tables 1 and 2, respectively. Our method achieves a considerable increase in mAP compared to previous works, suggesting a superior ability to capture fine-grained details and produce high-quality instance segmentation. While Mask3D slightly outperforms our model in terms of mAP50, it is worth noting that this is potentially due to their use of a stronger backbone (i.e., Res16UNet34C with twice as many parameters as ours) and DBSCAN post-processing. Despite this, our approach produces significantly better performance on the ScanNetv2 val set than Mask3D, as seen in Table 2.

ScanNet200.

Table 3 presents our comparison with previous state-of-the-art methods on the val set of ScanNet200. Our method achieves a significant improvement in comparison to the other methods. Consistent conclustion is also seen on this challenging dataset. It is important to note that previous works employ mask attention, while our approach does not. This verifies the success of our auxiliary center regression task in replacing mask attention.

S3DIS.

As shown in Table 4, our method is evaluated on S3DIS Area5. Our approach outperforms previous works. This consistently shows the superiority of our method.

3 Ablation Study

We conduct an extensive ablation study to verify each component of our method as follows.

The position query aims to provide an explicit center representation to the content query counterpart. Making it learnable intends to learn an optimal initial spatial distribution. We notice that some previous works adopt non-parametric initial queries, where Furthest Point Sampling (FPS) is used to sample a number of points and transform them into position encodings via Fourier transformation followed by an MLP. We make comparisons in Table 5. The results show that learnable position query and zero-initialized content query perform best. A potential reason why ‘FPS’ lags behind ‘learnable’ is that the latter learns an optimal spatial distribution.

Moreover, to show the pattern of the learnable position query, we visualize the spatial distribution of center coordinates of the matched ground truths for a query in Fig. 6. It shows that each query consistently attends to a local region.

Relative Position Encoding.

We compare various position encodings that are employed in previous works , such as Fourier Absolute Position Encoding (APE) and the content query-conditioned APE. Specifically, the latter uses an MLP to project the content query into a dd-dim diagonal matrix, which then transforms the original absolute position encoding into a new one. It incorporates semantic information into the position encoding but does not consider relative relation. As shown in Table 6, RPE outperforms the others, which implies that both semantic information and relative relation are beneficial. Also, we notice that if we do not apply any position encoding, the training corrupts. This shows that positional prior is crucial in our framework.

Iterative Refinement.

We remove the iterative refinement and freeze the position query in all decoder layers, and we find that it causes a performance drop of 0.9%0.9\% mAP as shown in the first row of Table 7. This verifies the effectiveness of iterative refinement.

Center Matching & Loss.

Moreover, to manifest the importance of center matching and center loss, we also conduct ablation studies in Table 7. We first remove the center matching and keep the center loss in the second row of the table, and we find that the performance drops by 1.7%1.7\% mAP. Then we keep the center matching and remove the center loss. The performance also decreases by 1.6%1.6\% mAP as shown in the 3-rd row. When both are absent, we observe an even larger performance drop (2.0%2.0\% mAP) in the 4-th row. The results reveal that both center matching and loss are important to our framework.

4 Object Detection Results

The instance predictions of instance segmentation can be easily transformed into bounding box predictions, by obtaining the minimum and maximum coordinates of the masked instances. We empirically find that the generated object detection results from the instance predictions work significantly better than previous methods tailored for 3D object detection in terms of mAP50, as shown in Table 8. This finding also shows that our approach outputs high-quality instance segmentation results with fewer artifacts.

5 Visual Comparison

We visually compare our approach with previous state-of-the-art methods in Fig. 7. More examples are given in the supplementary material. The visualizations demonstrate that our method tends to correctly recognize the classes of the instances. It implies that our approach is able to generate more high-quality instance segmentation results.

Conclusion

In this work, we have presented a mask-attention-free transformer for the 3D instance segmentation task. We first observe the issue of low-recall of the initial masks in existing works. It adds training difficulty and slows down convergence. We thus avoid using mask attention and instead propose an auxiliary center regression task to guide the cross-attention. To fit center regression, we develop a series of designs. A dense distribution of position queries is learned to yield a higher recall of the perceived instances. Also, relative position encoding and iterative refinement are designed to further boost the performance. Each component is verified to be effective.

Acknowledgements

This work was supported in part by the Research Grants Council under the Areas of Excellence scheme grant AoE/E-601/22-R and Shenzhen Science and Technology Program KQTD20210811090149095.

References