Depth Dependence of $μ$P Learning Rates in ReLU MLPs

Samy Jelassi, Boris Hanin, Ziwei Ji, Sashank J. Reddi, Srinadh Bhojanapalli, Sanjiv Kumar

Introduction

Using a neural network requires many choices. Even after fixing an architecture, one must still specify initialization scheme, learning rate (schedule), batch size, data augmentation, regularization strength, and so on. Moreover, model performance is often highly sensitive to the setting of these hyperparameters, and yet exhaustive grid search type approaches are computationally expensive. It is therefore important to develop theoretically grounded principles for reducing the cost of hyperparameter tuning. In this short note we focus specifically on the question of how to select learning rates in a principled way. More precisely, our purpose is to generalize the maximal update (μ\muP) approach of to setting learning rates to take into account network depth.

Selecting learning rates cannot be done independently of an initialization scheme. As in , we draw random weights for the network (1.1) from the so-called mean-field initialization

The factor of two in variance of hidden layer weights corresponds to the well-known He initialization , which ensures that the expected squared activations neither grow nor decay with depth:

The much smaller variance of weights in the final layer distinguishes the initialization scheme (1.2) from the so-called NTK initialization . The difference is twofold. First, when nn is large the network output z(L+1)(x)z^{(L+1)}(x) is close to zero. However, crucially, the parameter gradients θz(L+1)(x)\nabla_{\theta}z^{(L+1)}(x) are remain non-zero. Second, even in the infinite width limit nn\rightarrow\infty networks trained by gradient descent are capable of feature learning . This is in contrast to the setting where the final layer weight variance scales like 1/n1/n, which corresponds to the kernel regime in which neural networks trained by SGD with a small learning rate on a mean squared error loss converge to linear models and hence cannot learn data-adaptive features .

A key contribution of is that the initialization (1.2) not only leads to feature learning at large nn but also allows for zero-shot learning rate transfer with respect to variable width. This means that, empirically, for a fixed depth LL the learning rate at small nn that leads to the smallest training loss after one epoch is close to constant as one varies nnStrictly speaking, the μ\muP prescription gives nn-dependent learning rates for weights in the first and last layer and nn-independent learning rates for weights in other layers (see Table 3 of ).. Hence, in practice, one may do logarithmic grid search for good learning rates in relatively small models (with small nn) and then simply re-use the best learning rate for wider networks.

2 Main Result: Extending the μ𝜇\muP Heuristic to Deeper Networks

with the average being over initialization. To study the change in neuron pre-activations under GD we consider a batch B={(x,y)}\mathcal{B}=\left\{(x,y)\right\} size of 11 and the associated mean-squared error

where we’ve emphasized the dependence of the network output z(L+1)(x;θ)z^{(L+1)}(x;\theta) on the network weights θ\theta. Let us denote by

The maximal update heuristic then asks that we set the learning rate η\eta so that

where the average is over initialization. A priori, η\eta^{*} depends on both network nn width and depth LL. The article shows that η\eta^{*} does not depend on nn and hence can be estimated accurately at small nn. In this article, we take up the question of how η\eta^{*} depends on depth. The following theorem shows that η\eta^{*} is not depth-independent:

For each c1>0c_{1}>0 there exists c2,c3>0c_{2},c_{3}>0 with the following property. Fix a network width nn and depth LL so that L/n<c1L/n<c_{1}. Then,

where B={(x,y)}\mathcal{B}=\left\{(x,y)\right\} is any batch of size one consisting of a normalized datapoint (x,y)(x,y) sampled independent of network weights and biases with:

Theorem 1.1 shows that the μ\muP heuristic (1.4) dictates that

Proof of Theorem 1.1

We prove a slightly more general result than Theorem 1.1 in two senses. First, we allow for variable widths:

Second, we will also allow for parameter-dependent learning rates:

Thus, the batch loss LB\mathcal{L}_{\mathcal{B}} we consider is

With this notation, the forward pass now takes the form

2 Proof Details

where Δμ\Delta\mu is the change in μ\mu after one step of GD. The SGD update satisfies:

We now combine (2.3) and (2.4) to obtain:

Given the distribution of z1;α(L+1)z_{1;\alpha}^{(L+1)} and yy, we have

We integrate out the weights in layer L+1L+1 in (2.9) and (2.10) which yields the stated result.

Consider a random ReLU network with input dimension n0n_{0}, LL hidden layers of widths n1,,nLn_{1},\ldots,n_{L}, and output dimension nL+1n_{L+1} as in (1.1). Suppose that

We adding the contributions (2.17), (2.18) and (2.2) in (2.11) gives the stated result. ∎

We apply the same proof strategy as in 2.5 to get the result. ∎

Combining (2.23) and (LABEL:eq:eq2CL) yields the result. ∎

We apply the same proof strategy as in 2.7 to get the result. ∎

Conclusion

In this short note we’ve computed how variable network depth influences the learning rate predicted by the μ\muP heurisdtic. We found that, unlike with respect to width, this learning rate has a non-trivial power law scaling with respect to depth (see Theorem 1.1). We leave for future work empirical validation of whether this depth dependence indeed leads to learning rate transfer in practice.

References