Degradation-Aware Unfolding Half-Shuffle Transformer for Spectral Compressive Imaging
Yuanhao Cai, Jing Lin, Haoqian Wang, Xin Yuan, Henghui Ding, Yulun Zhang, Radu Timofte, Luc Van Gool
Introduction
Hyperspectral images (HSIs) have more spectral bands than normal RGB images to store more detailed information. Thus, HSIs are widely applied in image recognition fauvel2012advances ; maggiori2017recurrent ; zhang2015scene , object detection ot_1 ; ot_2 ; ot_3 , tracking fu2016exploiting ; uzkent2016real ; uzkent2017aerial , medical image processing mi_1 ; mi_2 ; mi_3 , remote sensing rs_1 ; rs_2 ; rs_3 ; rs_4 , etc. To obtain HSIs, traditional imaging systems use spectrometers to scan the scenes along the spectral or spatial dimensions, usually requiring a long time. These imaging systems fail to capture dynamic objects. Recently, snapshot compressive imaging (SCI) systems sci_1 ; sci_2 ; sci_3 ; sci_5 ; sci_6 ; yuan2015compressive ; ma2021led have been developed to capture HSIs at video rate. Among these SCI systems, coded aperture snapshot spectral imaging (CASSI) sci_2 ; tsa_net ; gehm2007single stands out for its impressive performance. CASSI uses a coded aperture and a disperser to modulate the HSI signal at different wavelengths, and then mixes all modulated signal to generate a 2D compressed measurement. Subsequently, HSI restoration methods are employed to solve the CASSI inverse problem, i.e., restore the HSIs from the measurement. These methods are divided into four categories.
(i) Model-based methods sci_2 ; sparse_1 ; desci ; non_local_1 ; non_local_2 ; gap_tv ; tra_rela_1 ; gradient rely on hand-crafted image priors, e.g., total variation gap_tv , sparsity sci_2 ; sparse_1 , low-rank desci , etc. These methods have theoretically proven properties and can be interpreted. Yet, these methods need manual parameter tweaking, which slows down reconstruction. Also, they suffer from limited representation capacity and generalization ability.
(ii) Plug-and-play (PnP) algorithms pnp_3 ; pnp_1 ; pnp_2 ; self ; zheng2021deep ; yuan2021plug plug pre-trained denoising networks into traditional model-based methods pnp_2 ; qiao2020deep to solve the HSI reconstruction problem. Nonetheless, the pre-trained networks in PnP methods are fixed without re-training, therefore limiting the performance.
(iii) End-to-end (E2E) algorithms employ a powerful model, usually a convolutional neural network (CNN) mi_3 ; tsa_net ; hdnet ; lambda , to learn the E2E mapping function from a measurement to the desired HSIs. E2E methods enjoy the power of deep learning. However, they learn a brute-force mapping from the compressed measurement to the underlying spectral images, thereby ignoring the working principles of CASSI systems. They come without theoretically proven properties, interpretability Yuan_review , and flexibility because the imaging models widely differ from each other for various hardware systems.
(iv) Deep unfolding methods dnu ; hssp ; gapnet ; admm-net ; gsm ; fu2021bidirectional ; herosnet adopt a multi-stage network to map the measurement into the HSI cube. Each stage usually includes two phases, i.e., linear projection followed by passing the signal through a single-stage network that learns the underlying denoiser prior. In deep unfolding methods, the network architecture is intuitively interpretable by explicitly characterizing the image priors and the system imaging model. Besides, these methods also enjoy the power of deep learning and thus have great potential. Yet, this potential has not been fully explored.
Existing deep unfolding algorithms suffer from two issues. (a) The iterative learning is highly related to the CASSI system. However, current unfolding methods do not estimate CASSI degradation patterns and ill-posedness degree to adjust the linear projection and denoising network in each iteration. (b) Existing deep unfolding methods are mainly CNN-based, therefore showing limitations in capturing non-local self-similarity and long-range dependencies, both critical for HSI reconstruction.
Recently, the emerging Transformer vaswani2017attention has provided a solution to tackle the drawbacks of CNN. Due to its strong capability in modeling the interactions of non-local spatial regions, Transformer has been widely applied in image classification liu2021swin ; arnab2021vivit ; global_msa ; tc_2 ; tc_3 ; xcit ; crossvit , object detection de_detr ; to_1 ; liu2021swin ; DETR ; dy_detr ; to_2 ; to_5 , semantic segmentation liu2021swin ; tc_3 ; ts_1 ; cao2021swin ; ts_2 ; ts_3 ; ts_4 , human pose estimation tokenpose ; transpose ; rsn ; prtr ; th_1 ; th_2 ; th_3 , image restoration ipt ; swinir ; uformer ; vsrt ; fgst ; mst ; mst_pp ; cst ; pngan ; rformer , etc. Yet, the use of Transformer is confronted with two main issues. (a) The computational complexity of global Transformer global_msa is quadratic to the spatial dimensions. This nontrivial cost is sometimes unaffordable. (b) The receptive fields of local Transformer liu2021swin are limited within position-specific windows. As a result, some tokens with highly-related contents can not match each other when computing the self-attention.
To address the above problems, in this paper, we firstly formulate a principled Degradation-Aware Unfolding Framework (DAUF) based on maximum a posteriori (MAP) theory for HSI reconstruction. Different from previous deep unfolding methods, our DAUF implicitly estimates informative parameters from the degraded compressed measurement and the physical mask used in the modulation. Then DAUF feeds the parameters, which capture key cues of CASSI degradation patterns and ill-posedness degree, into each iteration to adaptively scale the linear projection and provide the noise level information for the denoising network. Secondly, we design a novel Half-Shuffle Transformer (HST) as the denoiser prior in each iteration. Our HST can jointly extract local contextual information and model non-local dependencies, while requiring much cheaper computational costs than global Transformer. We achieve this by customizing a Half-Shuffle Multi-head Self-Attention (HS-MSA) mechanism that composes the basic unit of HST. More specifically, our HS-MSA has two branches, i.e., and -. The calculates the self-attention within the local window while the - shuffles the tokens and captures cross-window interactions. We plug HST into DAUF to establish an iterative architecture, Degradation-Aware Unfolding Half-Shuffle Transformer (DAUHST). With the proposed techniques, DAUHST models dramatically outperform state-of-the-art (SOTA) deep unfolding methods with the same number of stages by over 4 dB, as shown in Fig. 1.
In a nutshell, our contributions can be summarized as follows:
(i) We formulate a principled MAP-based unfolding framework DAUF for HSI reconstruction.
(ii) We propose a novel Transformer HST and plug it into DAUF to establish DAUHST. To the best of our knowledge, DAUHST is the first Transformer-based deep unfolding method for HSI restoration.
(iii) DAUHST outperforms SOTA methods by a large margin while requiring cheaper computational and memory costs. Besides, DAUHST yields more visually pleasant results in real HSI reconstruction.
Proposed Method
2 Degradation-Aware Unfolding Framework
Previous unfolding frameworks dnu ; hssp ; gapnet ; admm-net do not estimate the CASSI degradation patterns to adjust the iterative learning. To alleviate this limitation, we formulate a principled Degradation-Aware Unfolding Framework (DAUF) as depicted in Fig. 2. DAUF starts from the MAP theory. In particular, the original HSI signal could be estimated by minimizing the following energy function as
where is the data fidelity term, is the image prior term, and is a hyperparameter balancing the importance. By introducing an auxiliary variable , Eq. (2) can be reformulated as
This is a constrained optimization problem. To obtain an unfolding inference, we adopt half-quadratic splitting (HQS) algorithm for its simplicity and fast convergence. Then Eq. (3) is solved by minimizing
where is a penalty parameter that forces and to approach the same fixed point. Subsequently, Eq. (4) can be solved by decoupling and into the following two iterative sub-problems as
where indexes the iteration. Note that the data fidelity term is associated with a quadratic regularized least-squares problem, i.e., in Eq. (5). It has a closed-form solution as
where is an identity matrix. Since is a fat matrix, will be large and thus we simplify the computation of the inverse problem by the matrix inversion formula as
By plugging Eq. (7) into Eq. (6), we can reformulate Eq. (6) as
In CASSI systems, is a diagonal matrix which can be defined as . By plugging into and , we obtain:
Let and denotes the -th element of . We plug Eq. (9) into Eq. (8) as
Note that can be directly updated by , and is pre-calculated and stored in . Thus, by element-wise computation in Eq. (10), can be updated very efficiently. According to Eq. (5), the penalty parameter should be large enough so that and can approach approximately the same fixed point. This indicates that controls the convergence and output of each iteration. Thus, instead of manually tweaking , we set as a series of iteration-specific parameters to be automatically estimated from the CASSI system. We denote in the -th iteration as .
Returning to Eq. (5), we also set as iteration-specific parameters and can be reformulated as
From the perspective of Bayesian probability, Eq. (11) is equivalent to denoising image with a Gaussian noise at level pnp_3 . To conveniently solve Eq. (11), we set as parameters to be estimated from CASSI. Let , , , and . Then we can formulate our DAUF as an iterative scheme:
where denotes the parameter estimator that takes the compressed measurement and the sensing matrix of the CASSI system as inputs, equivalent to Eq. (10) denotes the linear projection, and represents the Gaussian denoiser solving Eq. (11). is initialized by passing the shifted concatenated with through a (convolution with 11 kernel). Fig. 2 shows the architecture of . It consists of a , a strided , a global average pooling, and three fully connected layers. Through , DAUF captures critical cues from CASSI by learning the degradation patterns and ill-posedness degree caused by the mask-modulation and dispersion-integration. Parameters and estimated by direct the iterative learning by adaptively scaling the linear projection in Eq. (10) and providing noise level information for the denoiser prior in Eq. (11).
3 Half-Shuffle Transformer
When designing the denoiser prior, previous deep unfolding methods dnu ; hssp ; gapnet ; admm-net mainly adopt CNNs, showing limitations in capturing long-range dependencies. Directly applying local and global Transformers will encounter two problems, i.e., limited receptive fields and nontrivial computational costs. To address these challenges, we propose Half-Shuffle Transformer (HST) to play the role of .
Experiment
Similar to tsa_net ; hdnet ; gapnet ; gsm ; mst , 28 wavelengths are selected from 450nm to 650nm and derived by spectral interpolation manipulation for the HSI data. Simulation and real experiments are conducted.
Simulation Dataset. We adopt two datasets, i.e., CAVE cave and KAIST kaist for simulation experiments. The CAVE dataset consists of 32 HSIs with spatial size 512512. The KAIST dataset contains 30 HSIs of spatial size 27043376. Following the settings of tsa_net ; hdnet ; gapnet ; gsm ; mst , the CAVE dataset is adopted as the training set while 10 scenes from the KAIST dataset are selected for testing.
Real Dataset. Five real HSIs collected by the CASSI system developed in tsa_net are used for testing.
Implementation Details. We implement DAUHST by Pytorch. All DAUHST models are trained with Adam adam optimizer ( = 0.9 and = 0.999) using Cosine Annealing scheme cosine for 300 epochs on an RTX 3090 GPU. The initial learning rate is 410-4. Patches with spatial sizes 256256 and 660660 are randomly cropped from the 3D HSI cubes with 28 channels as training samples for the simulation and real experiments. The shifting step in the dispersion is set to 2. The batch size is 5. We set the basic channel = = 28 to store HSI information. The weights of in different stages are unshared. Data augmentation includes random rotation and flipping. The training objective is to minimize the Root Mean Square Error (RMSE) between reconstructed and ground-truth HSIs.
2 Quantitative Comparisons with State-of-the-Art Methods
Tab. 1 compares the results of DAUHST and 16 SOTA methods including three model-based methods (TwIST twist , GAP-TV gap_tv , and DeSCI desci ), one PnP method (DIP-HSI self ), seven E2E methods (-Net lambda , TSA-Net tsa_net , HDNet hdnet , MST mst , MST++ mst_pp , CST cst , and BIRNAT birnat ), and five deep unfolding methods (HSSP hssp , DNU dnu , DGSMP gsm , GAP-Net gapnet , and ADMM-Net admm-net ) on 10 simulation scenes. All algorithms are tested with the same settings as gsm ; mst .
(i) Our best model DAUHST-9stg (9-stage DAUHST) yields very impressive results, i.e., 38.36 dB in PSNR and 0.967 in SSIM. DAUHST-9stg significantly outperforms two recent SOTA methods BIRNAT birnat and MST-L mst by 0.78 and 3.18 dB, suggesting the effectiveness of our method.
(ii) Our DAUHST models dramatically surpass SOTA methods while requiring cheaper computational and memory costs. For instance, when compared with the only one Transformer-based E2E method MST, our DAUHST-2stg outperforms MST-L by 1.16 dB but only costs 68.9% (1.40 / 2.03) Params and 65.5% (18.44 / 28.15) FLOPS. When compared with CNN-based E2E methods, DAUHST-3stg surpasses HDNet, TSA-Net, and -Net by 2.24, 5.75, and 8.68 dB while only requiring 87.8%, 4.7%, 3.3% Params and 17.6%, 24.7%, 23.0% FLOPS. When compared with RNN-based E2E method BIRNAT, our DAUHST-5stg is 0.17 dB higher but only costs 2.1% FLOPS and 78.2% Params. Fig. 1 plots the PSNR-FLOPS comparisons of DAUHST and SOTA unfolding methods. DAUHST outperforms other competitors with the same number of stages by very large margins, over 4 dB.
3 Qualitative Comparisons with State-of-the-Art Methods
Simulation HSI Reconstruction. Fig. 4 depicts the simulation HSI reconstruction comparisons between our DAUHST and other SOTA methods on Scene 2 with 4 (out of 28) spectral channels. The top-right part shows the zoomed-in patches of the yellow boxes in the entire HSIs (bottom). As can be observed that our DAUHST-9stg is more favorable to reconstruct visually pleasant HSIs with more detailed contents, cleaner textures, and fewer artifacts while preserving the spatial smoothness of homogeneous regions. In contrast, previous methods either yield over-smooth results compromising fine-grained structures, or introduce undesired chromatic artifacts and blotchy textures that are absent in the ground truth (GT). The top-middle part illustrates the density-wavelength spectral curves corresponding to the green boxes identified as a and b in the RGB image (top-left). The spectral curves of DAUHST-9stg achieve the highest correlation and coincidence with the reference curves, showing the advantage of our proposed DAUHST in spectral-dimension consistency reconstruction.
Real HSI Reconstruction. We further evaluate the effectiveness of DAUHST in real HSI reconstruction. Following the same settings as tsa_net ; gsm ; mst for a fair comparison, we re-train DAUHST-3stg with the real mask on the CAVE and KAIST datasets jointly. To simulate the real imaging situations, the training samples are also injected with 11-bit shot noise. Fig. 5 shows the visual comparisons between our DAUHST-3stg and nine SOTA methods. In the top three rows, only our DAUHST-3stg can reconstruct the flower patch corresponding to the yellow box at all wavelengths while other methods all fail to recover the entire patch. In the bottom row, DAUHST-3stg restores more HSI structural details and clearer contents with fewer artifacts. In contrast, other methods recover blurry images, generate incomplete responses, and are susceptible to the noise corruption. This evidence suggests that DAUHST is more robust to the noise distortion and more effective in real HSI reconstruction.
4 Ablation Study
Break-down Ablation. We adopt baseline-1 that is derived by removing HS-MSA and DAUF from DAUHST-3stg to conduct the break-down ablation. Our goal is to study the effect of each component towards higher performance. Baseline-1 is cascaded end to end by three single-stage networks. As shown in Tab. 2a, baseline-1 achieves 33.05 dB. When we respectively apply DAUF and HS-MSA, the model achieves 2.32 and 2.44 dB improvements. When we exploit DAUF and HS-MSA jointly, the model gains by 4.16 dB. These results demonstrate the effectiveness of our DAUF and HS-MSA.
Self-Attention Mechanism. To compare HS-MSA with other MSAs, we adopt baseline-2 that is obtained by removing HS-MSA from DAUHST-1stg to conduct the ablation in Tab. 2b. We remove different position embedding schemes to avoid their impacts and only compare MSAs. For fairness, we keep the Params of MSAs the same by fixing the number of channels and heads. Baseline-2 yields 32.79 dB. We apply global MSA (G-MSA) global_msa , Swin MSA (SW-MSA) liu2021swin , Spectral-wise MSA (S-MSA) mst , and HS-MSA. Note that we downsample the input feature maps of G-MSA to avoid memory bottlenecks. As shown in Tab. 2b, HS-MSA yields the most significant improvement of 1.26 dB, which is 0.42, 0.30, and 0.23 dB higher than G-MSA, SW-MSA, and S-MSA. This superiority is mainly derived from HS-MSA’s ability to jointly capture local contents and non-local dependencies.
Unfolding Framework. We compare our DAUF with previous unfolding frameworks including DNU dnu , ADMM-Net admm-net , and GAP-Net gapnet . For a fair comparison, we replace each single-stage network of DNU, ADMM-Net, and GAP-Net by our HST. 3-stage architecture is adopted to conduct ablations. The results are shown in Tab. 2c. Our DAUF significantly outperforms DNU, ADMM, and GAP by 2.59, 1.69, and 1.63 dB while adding only 0.05M Params and 0.94G FLOPS. This is mainly because DAUF uses the parameters estimated from the compressed measurement and physical mask in the CASSI system to direct the iterative learning. These parameters capture critical information of CASSI degradation patterns and ill-posedness degree, providing key cues for HSI reconstruction.
To study the effect of the estimated parameters and , we perform a break-down ablation of DAUF. We adopt DAUHST-3stg as baseline-3 but is set as learnable parameters instead of being estimated by in Eq. (12) and is not fed into . The results are shown in Tab. 2d. Baseline-3 yields 36.49 dB. When is set to be estimated by , baseline-3 is improved by 0.45 dB. When is fed into , baseline-3 gains by 0.34 dB. When and are exploited jointly in the iterative learning, baseline-3 achieves a significant improvement of 0.72 dB. These results verify that the estimated parameters and are beneficial for the linear projection and denoising network of deep unfolding methods.
To further analyze the roles of the estimated parameters, we visualize and of Eq. (12), and plot the curves of and as they change with the iteration in Fig. 6. We observe: (i) and yield either blurry or noisy images. There is a significant gap between them. Since in Eq. (5) penalizes the differences between and , is estimated to be a large value. From the linear projection of the second iteration () on, the gap between and decreases substantially. Therefore, are estimated to be small values when . This indicates that can adaptively scale the linear projection . (ii) The noise corruption is severe in the first iteration. Thus, = = , which is inversely proportional to the noise level, is estimated to be a small value. With further iterations, the noise level decreases, and thus the estimated increases. These results demonstrate that can provide the information about noise level for the denoising network .
Conclusion
In this paper, we remedy two issues of previous deep unfolding methods, i.e., they do not estimate informative parameters from the CASSI system to direct the iterative learning and they are mainly CNN-based showing limitations in capturing long-range dependencies. To cope with these challenges, we firstly formulate a principled MAP-based unfolding framework DAUF that estimates parameters from the compressed measurement and physical mask. Then the parameters, which capture critical cues of CASSI degradation patterns and ill-posedness degree, are fed into each iteration to contextually scale the linear projection and provide noise level information for the denoising network. Secondly, we propose a novel Transformer HST that can jointly extract local contents and model non-local dependencies. By plugging HST into DAUF, we derive the first Transformer-based unfolding method DAUHST for HSI reconstruction. Comprehensive experiments show that our DAUHST outperforms SOTA methods by a large margin while requiring much cheaper memory and computational costs.