Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models

Sergey Ioffe

Introduction

Batch Normalization (“batchnorm” ) has recently become a part of the standard toolkit for training deep networks. By normalizing activations, batch normalization helps stabilize the distributions of internal activations as the model trains. Batch normalization also makes it possible to use significantly higher learning rates, and reduces the sensitivity to initialization. These effects help accelerate the training, sometimes dramatically so. Batchnorm has been successfully used to enable state-of-the-art architectures such as residual networks .

Batchnorm works on minibatches in stochastic gradient training, and uses the mean and variance of the minibatch to normalize the activations. Specifically, consider a particular node in the deep network, producing a scalar value for each input example. Given a minibatch B\mathcal{B} of mm examples, consider the values of this node, x1xmx_{1}\ldots x_{m}. Then batchnorm takes the form:

where μB\mu_{\mathcal{B}} is the sample mean of x1xmx_{1}\ldots x_{m}, and σB2\sigma_{\mathcal{B}}^{2} is the sample variance (in practice, a small ϵ\epsilon is added to it for numerical stability). It is clear that the normalized activations corresponding to an input example will depend on the other examples in the minibatch. This is undesirable during inference, and therefore the mean and variance computed over all training data can be used instead. In practice, the model usually maintains moving averages of minibatch means and variances, and during inference uses those in place of the minibatch statistics.

While it appears to make sense to replace the minibatch statistics with whole-data ones during inference, this changes the activations in the network. In particular, this means that the upper layers (whose inputs are normalized using the minibatch) are trained on representations different from those computed in inference (when the inputs are normalized using the population statistics). When the minibatch size is large and its elements are i.i.d. samples from the training distribution, this difference is small, and can in fact aid generalization. However, minibatch-wise normalization may have significant drawbacks:

For small minibatches, the estimates of the mean and variance become less accurate. These inaccuracies are compounded with depth, and reduce the quality of resulting models. Moreover, as each example is used to compute the variance used in its own normalization, the normalization operation is less well approximated by an affine transform, which is what is used in inference.

Non-i.i.d. minibatches can have a detrimental effect on models with batchnorm. For example, in a metric learning scenario (e.g. ), it is common to bias the minibatch sampling to include sets of examples that are known to be related. For instance, for a minibatch of size 32, we may randomly select 16 labels, then choose 2 examples for each of those labels. Without batchnorm, the loss computed for the minibatch decouples over the examples, and the intra-batch dependence introduced by our sampling mechanism may, at worst, increase the variance of the minibatch gradient. With batchnorm, however, the examples interact at every layer, which may cause the model to overfit to the specific distribution of minibatches and suffer when used on individual examples.

The dependence of the batch-normalized activations on the entire minibatch makes batchnorm powerful, but it is also the source of its drawbacks. Several approaches have been proposed to alleviate this. However, unlike batchnorm which can be easily applied to an existing model, these methods may require careful analysis of nonlinearities and may change the class of functions representable by the model . Weight normalization presents an alternative, but does not offer guarantees about the activations and gradients when the model contains arbitrary nonlinearities, or contains layers without such normalization. Furthermore, weight normalization has been shown to benefit from mean-only batch normalization, which, like batchnorm, results in different outputs during training and inference. Another alternative is to use a separate and fixed minibatch to compute the normalization parameters, but this makes the training more expensive, and does not guarantee that the activations outside the fixed minibatch are normalized.

In this paper we propose Batch Renormalization, a new extension to batchnorm. Our method ensures that the activations computed in the forward pass of the training step depend only on a single example and are identical to the activations computed in inference. This significantly improves the training on non-i.i.d. or small minibatches, compared to batchnorm, without incurring extra cost.

Prior Work: Batch Normalization

We are interested in stochastic gradient optimization of deep networks. The task is to minimize the loss, which decomposes over training examples:

which the optimizer uses to adjust Θ\Theta.

Consider a particular node xx in a deep network. We observe that xx depends on all the model parameters that are used for its computation, and when those change, the distribution of xx also changes. Since xx itself affects the loss through all the layers above it, this change in distribution complicates the training of the layers above. This has been referred to as internal covariate shift. Batch Normalization addresses it by considering the values of xx in a minibatch B={x1m}\mathcal{B}=\{x_{1\ldots m}\}. It then normalizes them as follows:

Here γ\gamma and β\beta are trainable parameters (learned using the same procedure, such as stochastic gradient descent, as all the other model weights), and ϵ\epsilon is a small constant. Crucially, the computation of the sample mean μB\mu_{\mathcal{B}} and sample standard deviation σB\sigma_{\mathcal{B}} are part of the model architecture, are themselves functions of the model parameters, and as such participate in backpropagation. The backpropagation formulas for batchnorm are easy to derive by chain rule and are given in .

During inference, the standard practice is to normalize the activations using the moving averages μ\mu, σ2\sigma^{2} instead of minibatch mean μB\mu_{\mathcal{B}} and variance σB2\sigma_{\mathcal{B}}^{2}:

which depends only on a single input example rather than requiring a whole minibatch.

It is natural to ask whether we could simply use the moving averages μ\mu, σ\sigma to perform the normalization during training, since this would remove the dependence of the normalized activations on the other example in the minibatch. This, however, has been observed to lead to the model blowing up. As argued in , such use of moving averages would cause the gradient optimization and the normalization to counteract each other. For example, the gradient step may increase a bias or scale the convolutional weights, in spite of the fact that the normalization would cancel the effect of these changes on the loss. This would result in unbounded growth of model parameters without actually improving the loss. It is thus crucial to use the minibatch moments, and to backpropagate through them.

Batch Renormalization

With batchnorm, the activities in the network differ between training and inference, since the normalization is done differently between the two models. Here, we aim to rectify this, while retaining the benefits of batchnorm.

Let us observe that if we have a minibatch and normalize a particular node xx using either the minibatch statistics or their moving averages, then the results of these two normalizations are related by an affine transform. Specifically, let μ\mu be an estimate of the mean of xx, and σ\sigma be an estimate of its standard deviation, computed perhaps as a moving average over the last several minibatches. Then, we have:

In practice, it is beneficial to train the model for a certain number of iterations with batchnorm alone, without the correction, then ramp up the amount of allowed correction. We do this by imposing bounds on rr and dd, which initially constrain them to 11 and , respectively, and then are gradually relaxed.

Algorithm 1 presents Batch Renormalization. Unlike batchnorm, where the moving averages are computed during training but used only for inference, Batch Renorm does use μ\mu and σ\sigma during training to perform the correction. We use a fairly high rate of update α\alpha for these averages, to ensure that they benefit from averaging multiple batches but do not become stale relative to the model parameters. We explicitly update the exponentially-decayed moving averages μ\mu and σ\sigma, and optimize the rest of the model using gradient optimization, with the gradients calculated via backpropagation:

Batch Renormalization shares many of the beneficial properties of batchnorm, such as insensitivity to initialization and ability to train efficiently with large learning rates. Unlike batchnorm, our method ensures that that all layers are trained on internal representations that will be actually used during inference.

Results

To evaluate Batch Renormalization, we applied it to the problem of image classification. Our baseline model is Inception v3 , trained on 1000 classes from ImageNet training set , and evaluated on the ImageNet validation data. In the baseline model, batchnorm was used after convolution and before the ReLU . To apply Batch Renorm, we simply swapped it into the model in place of batchnorm. Both methods normalize each feature map over examples as well as over spatial locations. We fix the scale γ=1\gamma=1, since it could be propagated through the ReLU and absorbed into the next layer.

The training used 50 synchronized workers . Each worker processed a minibatch of 32 examples per training step. The gradients computed for all 50 minibatches were aggregated and then used by the RMSProp optimizer . As is common practice, the inference model used exponentially-decayed moving averages of all model parameters, including the μ\mu and σ\sigma computed by both batchnorm and Batch Renorm.

For Batch Renorm, we used rmax=1r_{\text{max}}=1, dmax=0d_{\text{max}}=0 (i.e. simply batchnorm) for the first 5000 training steps, after which these were gradually relaxed to reach rmax=3r_{\text{max}}=3 at 40k steps, and dmax=5d_{\text{max}}=5 at 25k steps. These final values resulted in clipping a small fraction of rrs, and none of dds. However, at the beginning of training, when the learning rate was larger, it proved important to increase rmaxr_{\text{max}} slowly: otherwise, occasional large gradients were observed to suddenly and severely increase the loss. To account for the fact that the means and variances change as the model trains, we used relatively fast updates to the moving statistics μ\mu and σ\sigma, with α=0.01\alpha=0.01. Because of this and keeping rmax=1r_{\text{max}}=1 for a relatively large number of steps, we did not need to apply initialization bias correction [adam].

All the hyperparameters other than those related to normalization were fixed between the models and across experiments.

As a baseline, we trained the batchnorm model using the minibatch size of 32. More specifically, batchnorm was applied to each of the 50 minibatches; each example was normalized using 32 examples, but the resulting gradients were aggregated over 50 minibatches. This model achieved the top-1 validation accuracy of 78.3%78.3\% after 130k training steps.

To verify that Batch Renorm does not diminish performance on such minibatches, we also trained the model with Batch Renorm, see Figure 1. The test accuracy of this model closely tracked the baseline, achieving a slightly higher test accuracy (78.5%78.5\%) after the same number of steps.

2 Small minibatches

To investigate the effectiveness of Batch Renorm when training on small minibatches, we reduced the number of examples used for normalization to 4. Each minibatch of size 32 was thus broken into “microbatches” each having 4 examples; each microbatch was normalized independently, but the loss for each minibatch was computed as before. In other words, the gradient was still aggregated over 1600 examples per step, but the normalization involved groups of 4 examples rather than 32 as in the baseline. Figure 2 shows the results.

The validation accuracy of the batchnorm model is significantly lower than the baseline that normalized over minibatches of size 32, and training is slow, achieving 74.2%74.2\% at 210k steps. We obtain a substantial improvement much faster (76.5%76.5\% at 130k steps) by replacing batchnorm with Batch Renorm, However, the resulting test accuracy is still below what we get when applying either batchnorm or Batch Renorm to size 32 minibatches. Although Batch Renorm improves the training with small minibatches, it does not eliminate the benefit of having larger ones.

3 Non-i.i.d. minibatches

When examples in a minibatch are not sampled independently, batchnorm can perform rather poorly. However, sampling with dependencies may be necessary for tasks such as for metric learning . We may want to ensure that images with the same label have more similar representations than otherwise, and to learn this we require that a reasonable number of same-label image pairs can be found within the same minibatch.

In this experiment (Figure 3), we selected each minibatch of size 32 by randomly sampling 16 labels (out of the total 1000) with replacement, then randomly selecting 2 images for each of those labels. When training with batchnorm, the test accuracy is much lower than for i.i.d. minibatches, achieving only 67%67\%. Surprisingly, even the training accuracy is much lower (72.8%72.8\%) than the test accuracy in the i.i.d. case, and in fact exhibits a drop that is consistent with overfitting. We suspect that this is in fact what happens: the model learns to predict labels for images that come in a set, where each image has a counterpart with the same label. This does not directly translate to classifying images individually, thus producing a drop in the accuracy computed on the training data. To verify this, we also evaluated the model in the “training mode”, i.e. using minibatch statistics μB\mu_{\mathcal{B}}, σB\sigma_{\mathcal{B}} instead of moving averages μ\mu, σ\sigma, where each test minibatch had size 50 and was obtained using the same procedure as the training minibatches – 25 labels, with 2 images per label. As expected, this does much better, achieving 76.5%76.5\%, though still below the baseline accuracy. Of course, this evaluation scenario is usually infeasible, as we want the image representation to be a deterministic function of that image alone.

We can improve the accuracy for this problem by splitting each minibatch into two halves of size 16 each, so that for every pair of images belonging to the same class, one image is assigned to the first half-minibatch, and the other to the second. Each half is then more i.i.d., and this achieves a much better test accuracy (77.4%77.4\% at 140k steps), but still below the baseline. This method is only applicable when the number of examples per label is small (since this determines the number of microbatches that a minibatch needs to be split into).

With Batch Renorm, we simply trained the model with minibatch size of 32. The model achieved the same test accuracy (78.5%78.5\% at 120k steps) as the equivalent model on i.i.d. minibatches, vs. 67%67\% obtained with batchnorm. By replacing batchnorm with Batch Renorm, we ensured that the inference model can effectively classify individual images. This has completely eliminated the effect of overfitting the model to image sets with a biased label distribution.

Conclusions

We have demonstrated that Batch Normalization, while effective, is not well suited to small or non-i.i.d. training minibatches. We hypothesized that these drawbacks are due to the fact that the activations in the model, which are in turn used by other layers as inputs, are computed differently during training than during inference. We address this with Batch Renormalization, which replaces batchnorm and ensures that the outputs computed by the model are dependent only on the individual examples and not the entire minibatch, during both training and inference.

Batch Renormalization extends batchnorm with a per-dimension correction to ensure that the activations match between the training and inference networks. This correction is identity in expectation; its parameters are computed from the minibatch but are treated as constant by the optimizer. Unlike batchnorm, where the means and variances used during inference do not need to be computed until the training has completed, Batch Renormalization benefits from having these statistics directly participate in the training. Batch Renormalization is as easy to implement as batchnorm itself, runs at the same speed during both training and inference, and significantly improves training on small or non-i.i.d. minibatches. Our method does have extra hyperparameters: the update rate α\alpha for the moving averages, and the schedules for correction limits dmaxd_{\text{max}}, rmaxr_{\text{max}}. A more extensive investigation of the effect of these is a part of future work.

Batch Renormalization offers a promise of improving the performance of any model that would normally use batchnorm. This includes Residual Networks . Another application is Generative Adversarial Networks , where the non-determinism introduced by batchnorm has been found to be an issue, and Batch Renorm may provide a solution.

Finally, Batch Renormalization may benefit applications where applying batch normalization has been difficult – such as recurrent networks. There, batchnorm would require each timestep to be normalized independently, but Batch Renormalization may make it possible to use the same running averages to normalize all timesteps, and then update those averages using all timesteps. This remains one of the areas that warrants further exploration.

References