Shake-Shake regularization

Xavier Gastaldi

Introduction

Deep residual nets (He et al., 2016a) were first introduced in the ILSVRC & COCO 2015 competitions (Russakovsky et al., 2015; Lin et al., 2014), where they won the 1st places on the tasks of ImageNet detection, ImageNet localization, COCO detection, and COCO segmentation. Since then, significant effort has been put into trying to improve their performance. Scientists have investigated the impact of pushing depth (He et al., 2016b; Huang et al., 2016a), width (Zagoruyko & Komodakis, 2016) and cardinality (Xie et al., 2016; Szegedy et al., 2016; Abdi & Nahavandi, 2016).

While residual networks are powerful models, they still overfit on small datasets. A large number of techniques have been proposed to tackle this problem, including weight decay (Nowlan & Hinton, 1992), early stopping, and dropout (Srivastava et al., 2014). While not directly presented as a regularization method, Batch Normalization (Ioffe & Szegedy, 2015) regularizes the network by computing statistics that fluctuate with each mini-batch. Similarly, Stochastic Gradient Descent (SGD) (Bottou, 1998; Sutskever et al., 2013) can also be interpreted as Gradient Descent using noisy gradients and the generalization performance of neural networks often depends on the size of the mini-batch (see Keskar et al. (2017)).

Pre-2015, most computer vision classification architectures used dropout to combat overfit but the introduction of Batch Normalization reduced its effectiveness (see Ioffe & Szegedy (2015); Zagoruyko & Komodakis (2016); Huang et al. (2016b)). Searching for other regularization methods, researchers started to look at the possibilities specifically offered by multi-branch networks. Some of them noticed that, given the right conditions, it was possible to randomly drop some of the information paths during training (Huang et al., 2016b; Larsson et al., 2016).

Like these last 2 works, the method proposed in this document aims at improving the generalization ability of multi-branch networks by replacing the standard summation of parallel branches with a stochastic affine combination.

Data augmentation techniques have traditionally been applied to input images only. However, for a computer, there is no real difference between an input image and an intermediate representation. As a consequence, it might be possible to apply data augmentation techniques to internal representations. Shake-Shake regularization was created as an attempt to produce this sort of effect by stochastically "blending" 2 viable tensors.

2 Model description on 3-branch ResNets

Let xix_{i} denote the tensor of inputs into residual block ii. Wi(1)\mathcal{W}_{i}^{(1)} and Wi(2)\mathcal{W}_{i}^{(2)} are sets of weights associated with the 2 residual units. F\mathcal{F} denotes the residual function, e.g. a stack of two 3x3 convolutional layers. xi+1x_{i+1} denotes the tensor of outputs from residual block ii.

A typical pre-activation ResNet with 2 residual branches would follow this equation:

Proposed modification: If αi\alpha_{i} is a random variable following a uniform distribution between 0 and 1, then during training:

Following the same logic as for dropout, all αi\alpha_{i} are set to the expected value of 0.5 at test time.

This method can be seen as a form of drop-path (Larsson et al., 2016) where residual branches are scaled-down instead of being completely dropped (i.e. multiplied by 0).

Replacing binary variables with enhancement or reduction coefficients is also explored in dropout variants like shakeout (Kang et al., 2016) and whiteout (Yinan et al., 2016). However, where these methods perform an element-wise multiplication between an input tensor and a noise tensor, shake-shake regularization multiplies the whole image tensor with just one scalar αi\alpha_{i} (or 1αi1-\alpha_{i}).

3 Training procedure

As shown in Figure 1, all scaling coefficients are overwritten with new random numbers before each forward pass. The key to making this work is to repeat this coefficient update operation before each backward pass. This results in a stochastic blend of forward and backward flows during training.

Related to this idea are the works of An (1996) and Neelakantan et al. (2015). These authors showed that adding noise to the gradient during training helps training and generalization of complicated neural networks. Shake-Shake regularization can be seen as an extension of this concept where gradient noise is replaced by a form of gradient augmentation.

Improving on the best single shot published results on CIFAR

The Shake-Shake code is based on fb.resnet.torchhttps://github.com/facebook/fb.resnet.torch and is available at https://github.com/xgastaldi/shake-shake. The first layer is a 3x3 Conv with 16 filters, followed by 3 stages each having 4 residual blocks. The feature map size is 32, 16 and 8 for each stage. Width is doubled when downsampling. The network ends with a 8x8 average pooling and a fully connected layer (total 26 layers deep). Residual paths have the following structure: ReLU-Conv3x3-BN-ReLU-Conv3x3-BN-Mul. The skip connections represent the identity function except during downsampling where a slightly customized structure consisting of 2 concatenated flows is used. Each of the 2 flows has the following components: 1x1 average pooling with step 2 followed by a 1x1 convolution. The input of one of the two flows is shifted by 1 pixel right and 1 pixel down to make the average pooling sample from a different position. The concatenation of the two flows doubles the width. Models were trained on the CIFAR-10 (Krizhevsky, 2009) 50k training set and evaluated on the 10k test set. Standard translation and flipping data augmentation is applied on the 32x32 input image. Due to the introduced stochasticity, all models were trained for 1800 epochs. Training starts with a learning rate of 0.2 and is annealed using a Cosine function without restart (see Loshchilov & Hutter (2016)). All models were trained on 2 GPUs with a mini-batch size of 128. Other implementation details are as in fb.resnet.torch.

1.2 Influence of Forward and Backward training procedures

The base network is a 26 2x32d ResNet (i.e. the network has a depth of 26, 2 residual branches and the first residual block has a width of 32). "Shake" means that all scaling coefficients are overwritten with new random numbers before the pass. "Even" means that all scaling coefficients are set to 0.5 before the pass. "Keep" means that we keep, for the backward pass, the scaling coefficients used during the forward pass. "Batch" means that, for each residual block ii, we apply the same scaling coefficient for all the images in the mini-batch. "Image" means that, for each residual block ii, we apply a different scaling coefficient for each image in the mini-batch (see Image level update procedure below).

Let x0x_{0} denote the original input mini-batch tensor of dimensions 128x3x32x32. The first dimension « stacks » 128 images of dimensions 3x32x32. Inside the second stage of a 26 2x32d model, this tensor is transformed into a mini-batch tensor xix_{i} of dimensions 128x64x16x16. Applying Shake-Shake regularization at the Image level means slicing this tensor along the first dimension and, for each of the 128 slices, multiplying the jthj^{th} slice (of dimensions 64x16x16) with a scalar αi.j\alpha_{i.j} (or 1αi.j1-\alpha_{i.j}).

The network architecture chosen for CIFAR-100 is a ResNeXt without pre-activation (this model gives slightly better results on CIFAR-100 than the model used for CIFAR-10). Hyperparameters are the same as in Xie et al. (2016) except for the learning rate which is annealed using a Cosine function and the number of epochs which is increased to 1800. The network in Table 2.2 is a ResNeXt-29 2x4x64d (2 residual branches with 4 grouped convolutions, each with 64 channels). Due to the combination of the larger model (34.4M parameters) and the long training time, fewer tests were performed than on CIFAR-10.

One problem to be mindful of is the issue of alignment (see Li et al. (2016)). The method above assumes that the summation at the end of the residual blocks forces an alignment of the layers on the left and right residual branches. This can be verified by calculating the layer wise correlation for each configuration of the first 3 layers of each block.

The results are presented in Figure 4. L1R3 for residual block ii means the correlation between the activations of the first layer in yi(1)y_{i}^{(1)} (left branch) and the third layer in yi(2)y_{i}^{(2)} (right branch). Figure 4 shows that the correlation between the same layers on the left and right branches (i.e. L1R1, L2R2, etc..) is higher than in the other configurations, which is consistent with the assumption that the summation forces alignment.

Regularization strength

This section looks at what would happen if we give, during the backward pass, a large weight to a branch that received a small weight in the forward pass (and vice-versa).

Let αi.j\alpha_{i.j} be the coefficient used during the forward pass for image jj in residual block ii. Let βi.j\beta_{i.j} be the coefficient used during the backward pass for the same image at the same position in the network.

The first test (method 1) is to set βi.j\beta_{i.j} = 1 - αi.j\alpha_{i.j}. All the tests in this section were performed on CIFAR-10 using 26 2x32d models at the Image level. These models are compared to a 26 2x32d Shake-Keep-Image model. The results of M1 can be seen on the left part of Figure 5 (blue curve). The effect is quite drastic and the training error stays really high.

Tests M2 to M5 in Table 4 were designed to understand why Method 1 (M1) has such a strong effect. The right part of Figure 5 illustrates Table 4 graphically.

The regularization effect seems to be linked to the relative position of βi.j\beta_{i.j} compared to αi.j\alpha_{i.j}

The further away βi.j\beta_{i.j} is from αi.j\alpha_{i.j}, the stronger the regularization effect

There seems to be a jump in strength when 0.5 is crossed

These insights could be useful when trying to control with more accuracy the strength of the regularization.

Removing skip connections / Removing Batch Normalization

One interesting question is whether the skip connection plays a role. A lot of deep learning systems don’t use ResNets and making this type of regularization work without skip connections could extend the number of potential applications.

Table 5 and Figure 6 present the results of removing the skip connection. The first variant (A) is exactly like the 26 2x32d used on CIFAR-10 but without the skip connection (i.e. 2 branches with the following components ReLU-Conv3x3-BN-ReLU-Conv3x3-BN-Mul). The second variant (B) is the same as A but with only 1 convolutional layer per branch (ReLU-Conv3x3-BN-Mul) and twice the number of blocks. Models using architecture A were tested once and models using architecture B were tested twice.

The results of architecture A clearly show that shake-shake regularization can work even without a skip connection. On that particular architecture and on a 26 2x32d model, S-S-I is too strong and the model underfits. The softer effect of S-E-I works better but this could change if the capacity is increased (e.g. 64d or 96d).

The results of architecture B are actually the most surprising. The first point to notice is that the regularization no longer works. This, in itself, would indicate that the regularization happens thanks to the interaction between the 2 convolutions in each branch. The second point is that the train and test curves of the S-E-I and E-E-B models are absolutely identical. This would indicate that, for architecture B, the shake operation of the forward pass has no effect on the cost function. The third point is that even with a really different training curve, the test curve of the S-S-I model is nearly identical to the test curves of the E-E-B and S-E-I models (albeit with a smaller variance).

Finally, it would be interesting to see whether this method works without Batch Normalization. While batchnorm is commonly used on computer vision datasets, it is not necessarily the case for other types of problems (e.g. NLP, etc ..). Architecture C is the same as architecture A but without Batch Normalization (i.e. no skip, 2 branches with the following structure ReLU-Conv3x3-ReLU-Conv3x3-Mul). To allow the E-E-B model to converge the depth was reduced from 26 to 14 and the initial learning rate was set to 0.05 after a warm start at 0.025 for 1 epoch. The absence of Batch Normalization makes the model a lot more sensitive and applying the same methods as before makes the model diverge. To soften the effect a S-E-I model was chosen and the interval covered by αi.j\alpha_{i.j} was reduced from to [0.4,0.6]. Models using architecture C and different intervals were tested once on CIFAR-10. As shown in Table 5 and Figure 6, this method works quite well but it is also really easy to make the model diverge (see model 14 2x32d S-E-I v3).

A series of experiments seem to indicate an ability to combat overfit by decorrelating the branches of multi-branch networks. This method leads to state of the art results on CIFAR datasets and could potentially improve the accuracy of architectures that do not use ResNets or Batch Normalization. While these results are encouraging, questions remain on the exact dynamics at play. Understanding these dynamics could help expand the application field to a wider variety of complex architectures.

References