MemNet: A Persistent Memory Network for Image Restoration

Ying Tai, Jian Yang, Xiaoming Liu, Chunyan Xu

Introduction

As three classical image restoration tasks, image denoising aims to recover a clean image from a noisy observation, which commonly assumes additive white Gaussian noise with a standard deviation σ\sigma; single-image super-resolution recovers a high-resolution (HR) image from a low-resolution (LR) image; and JPEG deblocking removes the blocking artifact caused by JPEG compression .

Recently, due to the powerful learning ability, very deep convolutional neural network (CNN) is widely used to tackle the image restoration tasks. Kim et al. construct a 2020-layer CNN structure named VDSR for SISR , and adopts residual learning to ease training difficulty. To control the parameter number of very deep models, the authors further introduce a recursive layer and propose a Deeply-Recursive Convolutional Network (DRCN) . To mitegate training difficulty, Mao et al. introduce symmetric skip connections into a 3030-layer convolutional auto-encoder network named RED for image denoising and SISR. Moreover, Zhang et al. propose a denoising convolutional neural network (DnCNN) to tackle image denoising, SISR and JPEG deblocking simultaneously.

The conventional plain CNNs, e.g., VDSR , DRCN and DnCNN (Fig. 1(a)), adopt the single-path feed-forward architecture, where one state is mainly influenced by its direct former state, namely short-term memory. Some variants of CNNs, RED and ResNet (Fig. 1(b)), have skip connections to pass information across several layers. In these networks, apart from the short-term memory, one state is also influenced by a specific prior state, namely restricted long-term memory. In essence, recent evidence suggests that mammalian brain may protect previously-acquired knowledge in neocortical circuits . However, none of above CNN models has such mechanism to achieve persistent memory. As the depth grows, they face the issue of lacking long-term memory.

To address this issue, we propose a very deep persistent memory network (MemNet), which introduces a memory block to explicitly mine persistent memory through an adaptive learning process. In MemNet, a Feature Extraction Net (FENet) first extracts features from the low-quality image. Then, several memory blocks are stacked with a densely connected structure to solve the image restoration task. Finally, a Reconstruction Net (ReconNet) is adopted to learn the residual, rather than the direct mapping, to ease the training difficulty.

As the key component of MemNet, a memory block contains a recursive unit and a gate unit. Inspired by neuroscience that recursive connections ubiquitously exist in the neocortex, the recursive unit learns multi-level representations of the current state under different receptive fields (blue circles in Fig. 1(c)), which can be seen as the short-term memory. The short-term memory generated from the recursive unit, and the long-term memory generated from the previous memory blocks For the first memory block, its long-term memory comes from the output of FENet. (green arrow in Fig. 1(c)) are concatenated and sent to the gate unit, which is a non-linear function to maintain persistent memory. Further, we present an extended multi-supervised MemNet, which fuses all intermediate predictions of memory blocks to boost the performance.

In summary, the main contributions of this work include:

\diamond A memory block to accomplish the gating mechanism to help bridge the long-term dependencies. In each memory block, the gate unit adaptively learns different weights for different memories, which controls how much of the long-term memory should be reserved, and decides how much of the short-term memory should be stored.

\diamond A very deep end-to-end persistent memory network (80 convolutional layers) for image restoration. The densely connected structure helps compensate mid/high-frequency signals, and ensures maximum information flow between memory blocks as well. To the best of our knowledge, it is by far the deepest network for image restoration.

\diamond The same MemNet structure achieves the state-of-the-art performance in image denoising, super-resolution and JPEG deblocking. Due to the strong learning ability, our MemNet can be trained to handle different levels of corruption even using a single model.

Related Work

The success of AlexNet in ImageNet starts the era of deep learning for vision, and the popular networks, GoogleNet , Highway network , ResNet , reveal that the network depth is of crucial importance.

As the early attempt, Jain et al. proposed a simple CNN to recover a clean natural image from a noisy observation and achieved comparable performance with the wavelet methods. As the pioneer CNN model for SISR, super-resolution convolutional neural network (SRCNN) predicts the nonlinear LR-HR mapping via a fully deep convolutional network, which significantly outperforms classical shallow methods. The authors further proposed an extended CNN model, named Artifacts Reduction Convolutional Neural Networks (ARCNN) , to effectively handle JPEG compression artifacts.

To incorporate task-specific priors, Wang et al. adopted a cascaded sparse coding network to fully exploit the natural sparsity of images . In , a deep dual-domain approach is proposed to combine both the prior knowledge in the JPEG compression scheme and the practice of dual-domain sparse coding. Guo et al. also proposed a dual-domain convolutional network that jointly learns a very deep network in both DCT and pixel domains.

Recently, very deep CNNs become popular for image restoration. Kim et al. stacked 2020 convolutional layers to exploit large contextual information. Residual learning and adjustable gradient clipping are used to speed up the training. Zhang et al. introduced batch normalization into a DnCNN model to jointly handle several image restoration tasks. To reduce the model complexity, the DRCN model introduced recursive-supervision and skip-connection to mitigate the training difficulty . Using symmetric skip connections, Mao et al. proposed a very deep convolutional auto-encoder network for image denoising and SISR. Very Recently, Lai et al. proposed LapSRN to address the problems of speed and accuracy for SISR, which operates on LR images directly and progressively reconstruct the sub-band residuals of HR images. Tai et al. proposed deep recursive residual network (DRRN) to address the problems of model parameters and accuracy, which recursively learns the residual unit in a multi-path model.

MemNet for Image Restoration

Our MemNet consists of three parts: a feature extraction net FENet, multiple stacked memory blocks and finally a reconstruction net ReconNet (Fig. 2). Let’s denote x\mathbf{x} and y\mathbf{y} as the input and output of MemNet. Specifically, a convolutional layer is used in FENet to extract the features from the noisy or blurry input image,

where fextf_{ext} denotes the feature extraction function and B0B_{0} is the extracted feature to be sent to the first memory block. Supposing MM memory blocks are stacked to act as the feature mapping, we have

where Mm\mathcal{M}_{m} denotes the mm-th memory block function and Bm1B_{m-1} and BmB_{m} are the input and output of the mm-th memory block respectively. Finally, instead of learning the direct mapping from the low-quality image to the high-quality image, our model uses a convolutional layer in ReconNet to reconstruct the residual image . Therefore, our basic MemNet can be formulated as,

where frecf_{rec} denotes the reconstruction function and D\mathcal{D} denotes the function of our basic MemNet.

2 Memory Block

We now present the details of our memory block. The memory block contains a recursive unit and a gate unit.

Recursive Unit is used to model a non-linear function that acts like a recursive synapse in the brain . Here, we use a residual building block, which is introduced in ResNet and shows powerful learning ability for object recognition, as a recursion in the recursive unit. A residual building block in the mm-th memory block is formulated as,

where Hmr1H_{m}^{r-1}, HmrH_{m}^{r} are the input and output of the rr-th residual building block respectively. When r=1r=1, Hm0=Bm1H_{m}^{0}=B_{m-1}. F\mathcal{F} denotes the residual function, WmW_{m} is the weight set to be learned and R\mathcal{R} denotes the function of residual building block. Specifically, each residual function contains two convolutional layers with the pre-activation structure ,

where τ\tau denotes the activation function, including batch normalization followed by ReLU , and Wmi,i=1,2W_{m}^{i},i=1,2 are the weights of the ii-th convolutional layer. The bias terms are omitted for simplicity.

Then, several recursions are recursively learned to generate multi-level representations under different receptive fields. We call these representations as the short-term memory. Supposing there are RR recursions in the recursive unit, the rr-th recursion in recursive unit can be formulated as,

where rr-fold operations of Rm\mathcal{R}_{m} are performed and {Hmr}r=1R\{H_{m}^{r}\}_{r=1}^{R} are the multi-level representations of the recursive unit. These representations are concatenated as the short-term memory: Bmshort=[Hm1,Hm2,...,HmR]B_{m}^{short}=[H_{m}^{1},H_{m}^{2},...,H_{m}^{R}]. In addition, the long-term memory coming from the previous memory blocks can be constructed as: Bmlong=[B0,B1,...,Bm1]B_{m}^{long}=[B_{0},B_{1},...,B_{m-1}]. The two types of memories are then concatenated as the input to the gate unit,

Gate Unit is used to achieve persistent memory through an adaptive learning process. In this paper, we adopt a 1×11\times 1 convolutional layer to accomplish the gating mechanism that can learn adaptive weights for different memories,

where fmgatef_{m}^{gate} and BmB_{m} denote the function of the 1×11\times 1 convolutional layer (parameterized by WmgateW_{m}^{gate}) and the output of the mm-th memory block, respectively. As a result, the weights for the long-term memory controls how much of the previous states should be reserved, and the weights for the short-term memory decides how much of the current state should be stored. Therefore, the formulation of the mm-th memory block can be written as,

3 Multi-Supervised MemNet

To further explore the features at different states, inspired by , we send the output of each memory block to the same reconstruction net f^rec\hat{f}_{rec} to generate

where {ym}m=1M\{\mathbf{y}_{m}\}_{m=1}^{M} are the intermediate predictions. All of the predictions are supervised during training, and used to compute the final output via weighted averaging: y=m=1Mwmym\mathbf{y}=\sum_{m=1}^{M}w_{m}\cdot\mathbf{y}_{m} (Fig. 3). The optimal weights {wm}m=1M\{w_{m}\}_{m=1}^{M} are automatically learned during training and the final output from the ensemble is also supervised. The loss function of our multi-supervised MemNet can be formulated as,

4 Dense Connections for Image Restoration

Now we analyze why the long-term dense connections in MemNet may benefit the image restoration. In very deep networks, some of the mid/high-frequency information can get lost at latter layers during a typical feedforward CNN process, and dense connections from previous layers can compensate such loss and further enhance high-frequency signals. To verify our intuition, we train a 8080-layer MemNet without long-term connections, which is denoted as MemNet_\_NL, and compare with the original MemNet. Both networks have 66 memory blocks leading to 66 intermediate outputs, and each memory block contains 66 recursions. Fig. 4(a) shows the 44th and 66th outputs of both networks. We compute their power spectrums, center them, estimate spectral densities for a continuous set of frequency ranges from low to high by placing concentric circles, and plot the densities of four outputs in Fig. 4(b).

We further plot differences of these densities in Fig. 4(c). From left to right, the first case indicates the earlier layer does contain some mid-frequency information that the latter layers lose. The 22nd case verifies that with dense connections, the latter layer absorbs the information carried from the previous layers, and even generate more mid-frequency information. The 33rd case suggests in earlier layers, the frequencies are similar between two models. The last case demonstrates the MemNet recovers more high frequency than the version without long-term connections.

Discussions

Difference to Highway Network First, we discuss how the memory block accomplishes the gating mechanism and present the difference between MemNet and Highway Network – a very deep CNN model using a gate unit to regulate information flow .

To avoid information attenuation in very deep plain networks, inspired by LSTM, Highway Network introduced the bypassing layers along with gate units, i.e.,

where a\mathbf{a} and b\mathbf{b} are the input and output, A\mathcal{A} and T\mathcal{T} are two non-linear transform functions. T\mathcal{T} is the transform gate to control how much information produced by A\mathcal{A} should be stored to the output; and 1T1-\mathcal{T} is the carry gate to decide how much of the input should be reserved to the output.

In MemNet, the short-term and long-term memories are concatenated. The 1×11\times 1 convolutional layer adaptively learns the weights for different memories. Compared to Highway Network that learns specific weight for each pixel, our gate unit learns specific weight for each feature map, which has two advantages: (11) to reduce model parameters and complexity; (22) to be less prone to overfitting.

Difference to DRCN There are three main differences between MemNet and DRCN . The first is the design of the basic module in network. In DRCN, the basic module is a convolutional layer; while in MemNet, the basic module is a memory block to achieve persistent memory. The second is in DRCN, the weights of the basic modules (i.e., the convolutional layers) are shared; while in MemNet, the weights of the memory blocks are different. The third is there are no dense connections among the basic modules in DRCN, which results in a chain structure; while in MemNet, there are long-term dense connections among the memory blocks leading to the multi-path structure, which not only helps information flow across the network, but also encourages gradient backpropagation during training. Benefited from the good information flow ability, MemNet could be easily trained without the multi-supervision strategy, which is imperative for training DRCN .

Difference to DenseNet Another related work to MemNet is DenseNet , which also builds upon a densely connected principle. In general, DenseNet deals with object recognition, while MemNet is proposed for image restoration. In addition, DenseNet adopts the densely connected structure in a local way (i.e., inside a dense block), while MemNet adopts the densely connected structure in a global way (i.e., across the memory blocks). In Secs. 3.4 and 5.2, we analyze and demonstrate the long-term dense connections in MemNet indeed play an important role in image restoration tasks.

Experiments

Datasets For image denoising, we follow to use 300300 images from the Berkeley Segmentation Dataset (BSD) , known as the train and val sets, to generate image patches as the training set. Two popular benchmarks, a dataset with 1414 common images and the BSD test set with 200200 images, are used for evaluation. We generate the input noisy patch by adding Gaussian noise with one of the three noise levels (σ=30\sigma=30, 5050 and 7070) to the clean patch.

For SISR, by following the experimental setting in , we use a training set of 291291 images where 9191 images are from Yang et al. and other 200200 are from BSD train set. For testing, four benchmark datasets, Set55 , Set1414 , BSD100100 and Urban100100 are used. Three scale factors are evaluated, including ×2\times 2, ×3\times 3 and ×4\times 4. The input LR image is generated by first bicubic downsampling and then bicubic upsampling the HR image with a certain scale.

For JPEG deblocking, the same training set for image denoising is used. As in , Classic55 and LIVE11 are adopted as the test datasets. Two JPEG quality factors are used, i.e., 1010 and 2020, and the JPEG deblocking input is generated by compressing the image with a certain quality factor using the MATLAB JPEG encoder.

Training Setting Following the method , for image denoising, the grayscale image is used; while for SISR and JPEG deblocking, the luminance component is fed into the model. The input image size can be arbitrary due to the fully convolution architecture. Considering both the training time and storage complexities, training images are split into 31×3131\times 31 patches with a stride of 2121. The output of MemNet is the estimated high-quality patch with the same resolution as the input low-quality patch. We follow to do data augmentation. For each task, we train a single model for all different levels of corruption. E.g., for image denoising, noise augmentation is used. Images with different noise levels are all included in the training set. Similarly, for super-resolution and JPEG deblocking, scale and quality augmentation are used, respectively.

We use Caffe to implement two 8080-layer MemNet networks, the basic and the multi-supervised versions. In both architectures, 66 memory blocks, each contains 66 recursions, are constructed (i.e., M66R66). Specifically, in multi-supervised MemNet, 66 predictions are generated and used to compute the final output. α\alpha balances different regularizations, and is empirically set as α=1/(M+1)\alpha=1/(M+1).

The objective functions in Eqn. 4 and Eqn. 12 are optimized via the mini-batch stochastic gradient descent (SGD) with backpropagation . We set the mini-batch size of SGD to 6464, momentum parameter to 0.90.9, and weight decay to 10410^{-4}. All convolutional layer has 6464 filters. Except the 1×11\times 1 convolutional layers in the gate units, the kernel size of other convolutional layers is 3×33\times 3. We use the method in for weight initialization. The initial learning rate is set to 0.10.1 and then divided 1010 every 2020 epochs. Training a 8080-layer basic MemNet by 9191 images for SISR roughly takes 55 days using 11 Tesla P4040 GPU. Due to space constraint and more recent baselines, we focus on SISR in Sec. 5.2, 5.4 and 5.6, while all three tasks in Sec. 5.3 and 5.5.

2 Ablation Study

Tab. 1 presents the ablation study on the effects of long-term and short-term connections. Compared to MemNet, MemNet_\_NL removes the long-term connections (green curves in Fig. 3) and MemNet_\_NS removes the short-term connections (black curves from the first R1R-1 recursions to the gate unit in Fig. 1. Connection from the last recursion to the gate unit is reserved to avoid a broken interaction between recursive unit and gate unit). The three networks have the same depth (8080) and filter number (6464). We see that, long-term dense connections are very important since MemNet significantly outperforms MemNet_\_NL. Further, MemNet achieves better performance than MemNet_\_NS, which reveals the short-term connections are also useful for image restoration but less powerful than the long-term connections. The reason is that the long-term connections skip much more layers than the short-term ones, which can carry some mid/high frequency signals from very early layers to latter layers as described in Sec. 3.4.

3 Gate Unit Analysis

We now illustrate how our gate unit affects different kinds of memories. Inspired by , we adopt a weight norm as an approximate for the dependency of the current layer on its preceding layers, which is calculated by the corresponding weights from all filters w.r.t. each feature map: vml=i=164(Wmgate(1,1,l,i))2, l=1,2,...,Lmv_{m}^{l}=\sqrt{\sum_{i=1}^{64}(W_{m}^{gate}(1,1,l,i))^{2}},\ l=1,2,...,L_{m}, where LmL_{m} is the number of the input feature maps for the mm-th gate unit, ll denotes the feature map index, WmgateW_{m}^{gate} stores the weights with the size of 1×1×Lm×641\times 1\times{L_{m}}\times 64, and vmlv_{m}^{l} is the weight norm of the ll-th feature map for the mm-th gate unit. Basically, the larger the norm is, the stronger dependency it has on this particular feature map. For better visualization, we normalize the norms to the range of to 11. Fig. 5 presents the norm of the filter weights {vml}m=16\{v_{m}^{l}\}_{m=1}^{6} vs. feature map index ll. We have three observations: (11) Different tasks have different norm distributions. (22) The average and variance of the weight norms become smaller as the memory block number increases. (33) In general, the short-term memories from the last recursion in recursive unit (the last 6464 elements in each curve) contribute most than the other two memories, and the long-term memories seem to play a more important role in late memory blocks to recover useful signals than the short-term memories from the first R1R-1 recursions.

4 Comparision with Non-Persistent CNN Models

In this subsection, we compare MemNet with three existing non-persistent CNN models, i.e., VDSR , DRCN and RED , to demonstrate the superiority of our persistent memory structure. VDSR and DRCN are two representative networks with the plain structure and RED is representative for skip connections. Tab. 6 presents the published results of these models along with their training details. Since the training details are different among different work, we choose DRCN as a baseline, which achieves good performance using the least training images. But, unlike DRCN that widens its network to increase the parameters (filter number: 256256 vs. 6464), we deepen our MemNet by stacking more memory blocks (depth: 2020 vs. 8080). It can be seen that, using the fewest training images (9191), filter number (6464) and relatively few model parameters (667667K), our basic MemNet already achieves higher PSNR than the prior networks. Keeping the setting unchanged, our multi-supervised MemNet further improves the performance. With more training images (291291), our MemNet significantly outperforms the state of the arts.

Since we aim to address the long-term dependency problem in networks, we intend to make our MemNet very deep. However, MemNet is also able to balance the model complexity and accuracy. Fig. 6 presents the PSNR of different intermediate predictions in MemNet (e.g., MemNet_\_M33 denotes the prediction of the 33rd memory block) for scale ×3\times 3 on Set55, in which the colorbar indicates the inference time (sec.) when processing a 288×288288\times 288 image on GPU P4040. Results of VDSR and DRCN are cited from their papers. RED is skipped here since its high number of parameters may reduce the contrast among other methods. We see that our MemNet already achieve comparable result at the 33rd prediction using much fewer parameters, and significantly outperforms the state of the arts by slightly increasing model complexity.

5 Comparisons with State-of-the-Art Models

We compare multi-supervised 8080-layer MemNet with the state of the arts in three restoration tasks, respectively.

Image Denoising Tab. 3 presents quantitative results on two benchmarks, with results cited from . For BSD200200 dataset, by following the setting in RED, the original image is resized to its half size. As we can see, our MemNet achieves the best performance on all cases. It should be noted that, for each test image, RED rotates and mirror flips the kernels, and performs inference multiple times. The outputs are then averaged to obtain the final result. They claimed this strategy can lead to better performance. However, in our MemNet, we do not perform any post-processing. For qualitative comparisons, we use public codes of PCLR , PGPD and WNNM . The results are shown in Fig. 7. As we can see, our MemNet handles Gaussian noise better than the previous state of the arts.

Super-Resolution Tab. 4 summarizes quantitative results on four benchmarks, by citing the results of prior methods. MemNet outperforms prior methods in almost all cases. Since LapSRN doesn’t report the results on scale ×3\times 3, we use the symbol ’-’ instead. Fig. 8 shows the visual comparisons for SISR. SRCNN , VDSR and DnCNN are compared using their public codes. MemNet recovers relatively sharper edges, while others have blurry results.

JPEG Deblocking Tab. 5 shows the JPEG deblocking results on Classic55 and LIVE11, by citing the results from . Our network significantly outperforms the other methods, and deeper networks do improve the performance compared to the shallow one, e.g., ARCNN. Fig. 9 shows the JPEG deblocking results of these three methods, which are generated by their corresponding public codes. As it can be seen, MemNet effectively removes the blocking artifact and recovers higher quality images than the previous methods.

6 Comparison on Different Network Depths

Finally, we present the comparison on different network depths, which is caused by stacking different numbers of memory blocks or recursions. Specifically, we test four network structures: M44R66, M66R66, M66R88 and M1010R1010, which have the depth 5454, 8080, 104104 and 212212, respectively. Tab. 6 shows the SISR performance of these networks on Set55 with scale factor ×3\times 3. It verifies deeper is still better and the proposed deepest network M1010R1010 achieves 34.2334.23 dB, with the improvement of 0.140.14 dB compared to M66R66.

Conclusions

In this paper, a very deep end-to-end persistent memory network (MemNet) is proposed for image restoration, where a memory block accomplishes the gating mechanism for tackling the long-term dependency problem in the previous CNN architectures. In each memory block, a recursive unit is adopted to learn multi-level representations as the short-term memory. Both the short-term memory from the recursive unit and the long-term memories from the previous memory blocks are sent to a gate unit, which adaptively learns different weights for different memories. We use the same MemNet structure to handle image denoising, super-resolution and JPEG deblocking simultaneously. Comprehensive benchmark evaluations well demonstrate the superiority of our MemNet over the state of the arts.

References