Depth Dependence of $μ$P Learning Rates in ReLU MLPs
Samy Jelassi, Boris Hanin, Ziwei Ji, Sashank J. Reddi, Srinadh Bhojanapalli, Sanjiv Kumar
Introduction
Using a neural network requires many choices. Even after fixing an architecture, one must still specify initialization scheme, learning rate (schedule), batch size, data augmentation, regularization strength, and so on. Moreover, model performance is often highly sensitive to the setting of these hyperparameters, and yet exhaustive grid search type approaches are computationally expensive. It is therefore important to develop theoretically grounded principles for reducing the cost of hyperparameter tuning. In this short note we focus specifically on the question of how to select learning rates in a principled way. More precisely, our purpose is to generalize the maximal update (P) approach of to setting learning rates to take into account network depth.
Selecting learning rates cannot be done independently of an initialization scheme. As in , we draw random weights for the network (1.1) from the so-called mean-field initialization
The factor of two in variance of hidden layer weights corresponds to the well-known He initialization , which ensures that the expected squared activations neither grow nor decay with depth:
The much smaller variance of weights in the final layer distinguishes the initialization scheme (1.2) from the so-called NTK initialization . The difference is twofold. First, when is large the network output is close to zero. However, crucially, the parameter gradients are remain non-zero. Second, even in the infinite width limit networks trained by gradient descent are capable of feature learning . This is in contrast to the setting where the final layer weight variance scales like , which corresponds to the kernel regime in which neural networks trained by SGD with a small learning rate on a mean squared error loss converge to linear models and hence cannot learn data-adaptive features .
A key contribution of is that the initialization (1.2) not only leads to feature learning at large but also allows for zero-shot learning rate transfer with respect to variable width. This means that, empirically, for a fixed depth the learning rate at small that leads to the smallest training loss after one epoch is close to constant as one varies Strictly speaking, the P prescription gives -dependent learning rates for weights in the first and last layer and -independent learning rates for weights in other layers (see Table 3 of ).. Hence, in practice, one may do logarithmic grid search for good learning rates in relatively small models (with small ) and then simply re-use the best learning rate for wider networks.
2 Main Result: Extending the μ𝜇\muP Heuristic to Deeper Networks
with the average being over initialization. To study the change in neuron pre-activations under GD we consider a batch size of and the associated mean-squared error
where we’ve emphasized the dependence of the network output on the network weights . Let us denote by
The maximal update heuristic then asks that we set the learning rate so that
where the average is over initialization. A priori, depends on both network width and depth . The article shows that does not depend on and hence can be estimated accurately at small . In this article, we take up the question of how depends on depth. The following theorem shows that is not depth-independent:
For each there exists with the following property. Fix a network width and depth so that . Then,
where is any batch of size one consisting of a normalized datapoint sampled independent of network weights and biases with:
Theorem 1.1 shows that the P heuristic (1.4) dictates that
Proof of Theorem 1.1
We prove a slightly more general result than Theorem 1.1 in two senses. First, we allow for variable widths:
Second, we will also allow for parameter-dependent learning rates:
Thus, the batch loss we consider is
With this notation, the forward pass now takes the form
2 Proof Details
where is the change in after one step of GD. The SGD update satisfies:
We now combine (2.3) and (2.4) to obtain:
Given the distribution of and , we have
We integrate out the weights in layer in (2.9) and (2.10) which yields the stated result.
Consider a random ReLU network with input dimension , hidden layers of widths , and output dimension as in (1.1). Suppose that
We adding the contributions (2.17), (2.18) and (2.2) in (2.11) gives the stated result. ∎
We apply the same proof strategy as in 2.5 to get the result. ∎
Combining (2.23) and (LABEL:eq:eq2CL) yields the result. ∎
We apply the same proof strategy as in 2.7 to get the result. ∎
Conclusion
In this short note we’ve computed how variable network depth influences the learning rate predicted by the P heurisdtic. We found that, unlike with respect to width, this learning rate has a non-trivial power law scaling with respect to depth (see Theorem 1.1). We leave for future work empirical validation of whether this depth dependence indeed leads to learning rate transfer in practice.