Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent

Jaehoon Lee, Lechao Xiao, Samuel S. Schoenholz, Yasaman Bahri, Roman Novak, Jascha Sohl-Dickstein, Jeffrey Pennington

Introduction

Machine learning models based on deep neural networks have achieved unprecedented performance across a wide range of tasks . Typically, these models are regarded as complex systems for which many types of theoretical analyses are intractable. Moreover, characterizing the gradient-based training dynamics of these models is challenging owing to the typically high-dimensional non-convex loss surfaces governing the optimization. As is common in the physical sciences, investigating the extreme limits of such systems can often shed light on these hard problems. For neural networks, one such limit is that of infinite width, which refers either to the number of hidden units in a fully-connected layer or to the number of channels in a convolutional layer. Under this limit, the output of the network at initialization is a draw from a Gaussian process (GP); moreover, the network output remains governed by a GP after exact Bayesian training using squared loss . Aside from its theoretical simplicity, the infinite-width limit is also of practical interest as wider networks have been found to generalize better .

In this work, we explore the learning dynamics of wide neural networks under gradient descent and find that the weight-space description of the dynamics becomes surprisingly simple: as the width becomes large, the neural network can be effectively replaced by its first-order Taylor expansion with respect to its parameters at initialization. For this linear model, the dynamics of gradient descent become analytically tractable. While the linearization is only exact in the infinite width limit, we nevertheless find excellent agreement between the predictions of the original network and those of the linearized version even for finite width configurations. The agreement persists across different architectures, optimization methods, and loss functions.

For squared loss, the exact learning dynamics admit a closed-form solution that allows us to characterize the evolution of the predictive distribution in terms of a GP. This result can be thought of as an extension of “sample-then-optimize" posterior sampling to the training of deep neural networks. Our empirical simulations confirm that the result accurately models the variation in predictions across an ensemble of finite-width models with different random initializations.

Parameter space dynamics: We show that wide network training dynamics in parameter space are equivalent to the training dynamics of a model which is affine in the collection of all network parameters, the weights and biases. This result holds regardless of the choice of loss function. For squared loss, the dynamics admit a closed-form solution as a function of time.

Sufficient conditions for linearization: We formally prove that there exists a threshold learning rate ηcritical\eta_{{\rm critical}} (see Theorem 2.1), such that gradient descent training trajectories with learning rate smaller than ηcritical\eta_{{\rm critical}} stay in an O(n1/2)\mathcal{O}\left(n^{-1/2}\right)-neighborhood of the trajectory of the linearized network when nn, the width of the hidden layers, is sufficiently large.

Output distribution dynamics: We formally show that the predictions of a neural network throughout gradient descent training are described by a GP as the width goes to infinity (see Theorem 2.2), extending results from Jacot et al. . We further derive explicit time-dependent expressions for the evolution of this GP during training. Finally, we provide a novel interpretation of the result. In particular, it offers a quantitative understanding of the mechanism by which gradient descent differs from Bayesian posterior sampling of the parameters: while both methods generate draws from a GP, gradient descent does not generate samples from the posterior of any probabilistic model.

Large scale experimental support: We empirically investigate the applicability of the theory in the finite-width setting and find that it gives an accurate characterization of both learning dynamics and posterior function distributions across a variety of conditions, including some practical network architectures such as the wide residual network .

Parameterization independence: We note that linearization result holds both in standard and NTK parameterization (defined in §2.1), while previous work assumed the latter, emphasizing that the effect is due to increase in width rather than the particular parameterization.

Analytic ReLU\operatorname{ReLU} and erf\operatorname{erf} neural tangent kernels: We compute the analytic neural tangent kernel corresponding to fully-connected networks with ReLU\operatorname{ReLU} or erf\operatorname{erf} nonlinearities.

Source code: Example code investigating both function space and parameter space linearized learning dynamics described in this work is released as open source code within .Note that the open source library has been expanded since initial submission of this work. We also provide accompanying interactive Colab notebooks for both parameter spacecolab.sandbox.google.com/github/google/neural-tangents/blob/master/notebooks/weight_space_linearization.ipynb and function spacecolab.sandbox.google.com/github/google/neural-tangents/blob/master/notebooks/function_space_linearization.ipynb linearization.

We build on recent work by Jacot et al. that characterize the exact dynamics of network outputs throughout gradient descent training in the infinite width limit. Their results establish that full batch gradient descent in parameter space corresponds to kernel gradient descent in function space with respect to a new kernel, the Neural Tangent Kernel (NTK). We examine what this implies about dynamics in parameter space, where training updates are actually made.

Daniely et al. study the relationship between neural networks and kernels at initialization. They bound the difference between the infinite width kernel and the empirical kernel at finite width nn, which diminishes as O(1/n)\mathcal{O}(1/\sqrt{n}). Daniely uses the same kernel perspective to study stochastic gradient descent (SGD) training of neural networks.

Saxe et al. study the training dynamics of deep linear networks, in which the nonlinearities are treated as identity functions. Deep linear networks are linear in their inputs, but not in their parameters. In contrast, we show that the outputs of sufficiently wide neural networks are linear in the updates to their parameters during gradient descent, but not usually their inputs.

Du et al. , Allen-Zhu et al. , Zou et al. study the convergence of gradient descent to global minima. They proved that for i.i.d. Gaussian initialization, the parameters of sufficiently wide networks move little from their initial values during SGD. This small motion of the parameters is crucial to the effect we present, where wide neural networks behave linearly in terms of their parameters throughout training.

Mei et al. , Chizat and Bach , Rotskoff and Vanden-Eijnden , Sirignano and Spiliopoulos analyze the mean field SGD dynamics of training neural networks in the large-width limit. Their mean field analysis describes distributional dynamics of network parameters via a PDE. However, their analysis is restricted to one hidden layer networks with a scaling limit (1/n)\left(1/n\right) different from ours (1/n)\left(1/\sqrt{n}\right), which is commonly used in modern networks .

Chizat et al. We note that this is a concurrent work and an expanded version of this note is presented in parallel at NeurIPS 2019. argued that infinite width networks are in ‘lazy training’ regime and maybe too simple to be applicable to realistic neural networks. Nonetheless, we empirically investigate the applicability of the theory in the finite-width setting and find that it gives an accurate characterization of both the learning dynamics and posterior function distributions across a variety of conditions, including some practical network architectures such as the wide residual network .

Theoretical results

Let η\eta be the learning rateNote that compared to the conventional parameterization, η\eta is larger by factor of width . The NTK parameterization allows usage of a universal learning rate scale irrespective of network width.. Via continuous time gradient descent, the evolution of the parameters θ\theta and the logits ff can be written as

where ft(X)=vec([ft(x)]xX)f_{t}(\mathcal{X})=\operatorname{vec}\left(\left[f_{t}\left(x\right)\right]_{x\in\mathcal{X}}\right), the kD×1k|\mathcal{D}|\times 1 vector of concatenated logits for all examples, and ft(X)L\nabla_{f_{t}(\mathcal{X})}\mathcal{L} is the gradient of the loss with respect to the model’s output, ft(X)f_{t}(\mathcal{X}). Θ^tΘ^t(X,X)\hat{\Theta}_{t}\equiv\hat{\Theta}_{t}(\mathcal{X},\mathcal{X}) is the tangent kernel at time tt, which is a kD×kDk|\mathcal{D}|\times k|\mathcal{D}| matrix

One can define the tangent kernel for general arguments, e.g. Θ^t(x,X)\hat{\Theta}_{t}(x,\mathcal{X}) where xx is test input. At finite-width, Θ^\hat{\Theta} will depend on the specific random draw of the parameters and in this context we refer to it as the empirical tangent kernel.

The dynamics of discrete gradient descent can be obtained by replacing θ˙t\dot{\theta}_{t} and f˙t(X)\dot{f}_{t}(\mathcal{X}) with (θi+1θi)(\theta_{i+1}-\theta_{i}) and (fi+1(X)fi(X))(f_{i+1}(\mathcal{X})-f_{i}(\mathcal{X})) above, and replacing eηΘ^0te^{-\eta\hat{\Theta}_{0}t} with (1(1ηΘ^0)i)(1-(1-\eta\hat{\Theta}_{0})^{i}) below.

2 Linearized networks have closed form training dynamics for parameters and outputs

In this section, we consider the training dynamics of the linearized network. Specifically, we replace the outputs of the neural network by their first order Taylor expansion,

where ωtθtθ0\omega_{t}\equiv\theta_{t}-\theta_{0} is the change in the parameters from their initial values. Note that ftlinf^{\textrm{lin}}_{t} is the sum of two terms: the first term is the initial output of the network, which remains unchanged during training, and the second term captures the change to the initial value during training. The dynamics of gradient flow using this linearized function are governed by,

For an arbitrary point xx, ftlin(x)=μt(x)+γt(x)f^{\textrm{lin}}_{t}(x)=\mu_{t}(x)+\gamma_{t}(x), where

Therefore, we can obtain the time evolution of the linearized neural network without running gradient descent. We only need to compute the tangent kernel Θ^0\hat{\Theta}_{0} and the outputs f0f_{0} at initialization and use Equations 8, 10, and 11 to compute the dynamics of the weights and the outputs.

3 Infinite width limit yields a Gaussian process

As the width of the hidden layers approaches infinity, the Central Limit Theorem (CLT) implies that the outputs at initialization {f0(x)}xX\left\{f_{0}(x)\right\}_{x\in\mathcal{X}} converge to a multivariate Gaussian in distribution. Informally, this occurs because the pre-activations at each layer are a sum of Gaussian random variables (the weights and bias), and thus become a Gaussian random variable themselves. See for more details, and for a formal treatment.

Therefore, randomly initialized neural networks are in correspondence with a certain class of GPs (hereinafter referred to as NNGPs), which facilitates a fully Bayesian treatment of neural networks . More precisely, let ftif_{t}^{i} denote the ii-th output dimension and K\mathcal{K} denote the sample-to-sample kernel function (of the pre-activation) of the outputs in the infinite width setting,

then f0(X)N(0,K(X,X))f_{0}(\mathcal{X})\sim\mathcal{N}(0,\mathcal{K}(\mathcal{X},\mathcal{X})), where Ki,j(x,x)\mathcal{K}^{i,j}(x,x^{\prime}) denotes the covariance between the ii-th output of xx and jj-th output of xx^{\prime}, which can be computed recursively (see Lee et al. [5, §2.3] and SM §E). For a test input xXTx\in\mathcal{X}_{T}, the joint output distribution f([x,X])f\left([x,\mathcal{X}]\right) is also multivariate Gaussian. Conditioning on the training samples This imposes that hL+1h^{L+1} directly corresponds to the network predictions. In the case of softmax readout, variational or sampling methods are required to marginalize over hL+1h^{L+1}. , f(X)=Yf(\mathcal{X})=\mathcal{Y}, the distribution of f(x)X,Y\left.f(x)\right|\mathcal{X},\mathcal{Y} is also a Gaussian N(μ(x),Σ(x))\mathcal{N}\left(\mu(x),\Sigma(x)\right),

and where K=K(X,X)\mathcal{K}=\mathcal{K}(\mathcal{X},\mathcal{X}). This is the posterior predictive distribution resulting from exact Bayesian inference in an infinitely wide neural network.

If we freeze the variables θL\theta^{\leq L} after initialization and only optimize θL+1\theta^{L+1}, the original network and its linearization are identical. Letting the width approach infinity, this particular tangent kernel Θ^0\hat{\Theta}_{0} will converge to K\mathcal{K} in probability and Equation 10 will converge to the posterior Equation 13 as tt\to\infty (for further details see SM §D). This is a realization of the “sample-then-optimize" approach for evaluating the posterior of a Gaussian process proposed in Matthews et al. .

If none of the variables are frozen, in the infinite width setting, Θ^0\hat{\Theta}_{0} also converges in probability to a deterministic kernel Θ\Theta , which we sometimes refer to as the analytic kernel, and which can also be computed recursively (see SM §E). For ReLU\operatorname{ReLU} and erf\operatorname{erf} nonlinearity, Θ\Theta can be exactly computed (SM §C) which we use in §3. Letting the width go to infinity, for any tt, the output ftlin(x)f^{\textrm{lin}}_{t}(x) of the linearized network is also Gaussian distributed because Equations 10 and 11 describe an affine transform of the Gaussian [f0(x),f0(X)][f_{0}(x),f_{0}(\mathcal{X})]. Therefore

For every test points in xXTx\in\mathcal{X}_{T}, and t0t\geq 0, ftlin(x)f^{\textrm{lin}}_{t}(x) converges in distribution as width goes to infinity to a Gaussian with mean and covariance given byHere “+h.c.” is an abbreviation for “plus the Hermitian conjugate”.

Therefore, over random initialization, limtlimnftlin(x)\lim_{t\to\infty}\lim_{n\to\infty}f^{\textrm{lin}}_{t}(x) has distribution

Unlike the case when only θL+1\theta^{L+1} is optimized, Equations 14 and 15 do not admit an interpretation corresponding to the posterior sampling of a probabilistic model.One possible exception is when the NNGP kernel and NTK are the same up to a scalar multiplication. This is the case when the activation function is the identity function and there is no bias term. We contrast the predictive distributions from the NNGP, NTK-GP (i.e. Equations 14 and 15) and ensembles of NNs in Figure 2.

Infinitely-wide neural networks open up ways to study deep neural networks both under fully Bayesian training through the Gaussian process correspondence, and under GD training through the linearization perspective. The resulting distributions over functions are inconsistent (the distribution resulting from GD training does not generally correspond to a Bayesian posterior). We believe understanding the biases over learned functions induced by different training schemes and architectures is a fascinating avenue for future work.

4 Infinite width networks are linearized networks

Equation 2 and 3 of the original network are intractable in general, since Θ^t\hat{\Theta}_{t} evolves with time. However, for the mean squared loss, we are able to prove formally that, as long as the learning rate η<ηcritical:=2(λmin(Θ)+λmax(Θ))1\eta<\eta_{{\rm critical}}:=2({\lambda_{\rm{min}}(\Theta)+\lambda_{\rm{max}}(\Theta)})^{-1}, where λmin/max(Θ){\lambda_{\textrm{min/max}}}(\Theta) is the min/max eigenvalue of Θ\Theta, the gradient descent dynamics of the original neural network falls into its linearized dynamics regime.

Therefore, as nn\to\infty, the distributions of ft(x)f_{t}(x) and ftlin(x)f^{\textrm{lin}}_{t}(x) become the same. Coupling with Corollary 1, we have

We refer the readers to Figure 2 for empirical verification of this theorem. The proof of Theorem 2.1 consists of two steps. The first step is to prove the global convergence of overparameterized neural networks and stability of the NTK under gradient descent (and gradient flow); see SM §G. This stability was first observed and proved in in the gradient flow and sequential limit (i.e. letting n1n_{1}\to\infty, …, nLn_{L}\to\infty sequentially) setting under certain assumptions about global convergence of gradient flow. In §G, we show how to use the NTK to provide a self-contained (and cleaner) proof of such global convergence and the stability of NTK simultaneously. The second step is to couple the stability of NTK with Grönwall’s type arguments to upper bound the discrepancy between ftf_{t} and ftlinf^{\textrm{lin}}_{t}, i.e. the first norm in Equation 17. Intuitively, the ODE of the original network (Equation 3) can be considered as a Θ^tΘ^0F\|\hat{\Theta}_{t}-\hat{\Theta}_{0}\|_{F}-fluctuation from the linearized ODE (Equation 7). One expects the difference between the solutions of these two ODEs to be upper bounded by some functional of Θ^tΘ^0F\|\hat{\Theta}_{t}-\hat{\Theta}_{0}\|_{F}; see SM §H. Therefore, for a large width network, the training dynamics can be well approximated by linearized dynamics.

Note that the updates for individual weights in Equation 6 vanish in the infinite width limit, which for instance can be seen from the explicit width dependence of the gradients in the NTK parameterization. Individual weights move by a vanishingly small amount for wide networks in this regime of dynamics, as do hidden layer activations, but they collectively conspire to provide a finite change in the final output of the network, as is necessary for training. An additional insight gained from linearization of the network is that the individual instance dynamics derived in can be viewed as a random features method,We thank Alex Alemi for pointing out a subtlety on correspondence to a random features method. where the features are the gradients of the model with respect to its weights.

5 Extensions to other optimizers, architectures, and losses

Our theoretical analysis thus far has focused on fully-connected single-output architectures trained by full batch gradient descent. In SM §B we derive corresponding results for: networks with multi-dimensional outputs, training against a cross entropy loss, and gradient descent with momentum.

In addition to these generalizations, there is good reason to suspect the results to extend to much broader class of models and optimization procedures. In particular, a wealth of recent literature suggests that the mean field theory governing the wide network limit of fully-connected models extends naturally to residual networks , CNNs , RNNs , batch normalization , and to broad architectures . We postpone the development of these additional theoretical extensions in favor of an empirical investigation of linearization for a variety of architectures.

Experiments

In this section, we provide empirical support showing that the training dynamics of wide neural networks are well captured by linearized models. We consider fully-connected, convolutional, and wide ResNet architectures trained with full- and mini- batch gradient descent using learning rates sufficiently small so that the continuous time approximation holds well. We consider two-class classification on CIFAR-10 (horses and planes) as well as ten-class classification on MNIST and CIFAR-10. When using MSE loss, we treat the binary classification task as regression with one class regressing to +1+1 and the other to 1-1.

Experiments in Figures 1, 4, S2, S3, S4, S5 and S6, were done in JAX . The remaining experiments used TensorFlow . An open source implementation of this work providing tools to investigate linearized learning dynamics is available at www.github.com/google/neural-tangents .

Predictive output distribution: In the case of an MSE loss, the output distribution remains Gaussian throughout training. In Figure 2, the predictive output distribution for input points interpolated between two training points is shown for an ensemble of neural networks and their corresponding GPs. The interpolation is given by x(α)=αx(1)+(1α)x(2)x(\alpha)=\alpha x^{(1)}+(1-\alpha)x^{(2)} where x(1,2)x^{(1,2)} are two training inputs with different classes. We observe that the mean and variance dynamics of neural network outputs during gradient descent training follow the analytic dynamics from linearization well (Equations 14, 15). Moreover the NNGP predictive distribution which corresponds to exact Bayesian inference, while similar, is noticeably different from the predictive distribution at the end of gradient descent training. For dynamics for individual function draws see SM Figure S1.

Comparison of training dynamics of linearized network to original network: For a particular realization of a finite width network, one can analytically predict the dynamics of the weights and outputs over the course of training using the empirical tangent kernel at initialization. In Figures 3, 4 (see also S2, S3), we compare these linearized dynamics (Equations 8, 9) with the result of training the actual network. In all cases we see remarkably good agreement. We also observe that for finite networks, dynamics predicted using the empirical kernel Θ^\hat{\Theta} better match the data than those obtained using the infinite-width, analytic, kernel Θ\Theta. To understand this we note that Θ^T(n)Θ^0(n)F=O(1/n)O(1/n)=Θ^0(n)ΘF\|\hat{\Theta}^{(n)}_{T}-\hat{\Theta}^{(n)}_{0}\|_{F}=\mathcal{O}(1/n)\leq\mathcal{O}(1/{\sqrt{n}})=\|\hat{\Theta}^{(n)}_{0}-\Theta\|_{F}, where Θ^0(n)\hat{\Theta}^{(n)}_{0} denotes the empirical tangent kernel of width nn network, as plotted in Figure 1.

One can directly optimize parameters of flinf^{\textrm{lin}} instead of solving the ODE induced by the tangent kernel Θ^\hat{\Theta}. Standard neural network optimization techniques such as mini-batching, weight decay, and data augmentation can be directly applied. In Figure 4 (S2, S3), we compared the training dynamics of the linearized and original network while directly training both networks.

With direct optimization of linearized model, we tested full (D=50,000|\mathcal{D}|=50,000) MNIST digit classification with cross-entropy loss, and trained with a momentum optimizer (Figure S3). For cross-entropy loss with softmax output, some logits at late times grow indefinitely, in contrast to MSE loss where logits converge to target value. The error between original and linearized model for cross entropy loss becomes much worse at late times if the two models deviate significantly before the logits enter their late-time steady-growth regime (See Figure S4).

Linearized dynamics successfully describes the training of networks beyond vanilla fully-connected models. To demonstrate the generality of this procedure we show we can predict the learning dynamics of subclass of Wide Residual Networks (WRNs) . WRNs are a class of model that are popular in computer vision and leverage convolutions, batch normalization, skip connections, and average pooling. In Figure 4, we show a comparison between the linearized dynamics and the true dynamics for a wide residual network trained with MSE loss and SGD with momentum, trained on the full CIFAR-10 dataset. We slightly modified the block structure described in Table S1 so that each layer has a constant number of channels (1024 in this case), and otherwise followed the original implementation. As elsewhere, we see strong agreement between the predicted dynamics and the result of training.

Effects of dataset size: The training dynamics of a neural network match those of its linearization when the width is infinite and the dataset is finite. In previous experiments, we chose sufficiently wide networks to achieve small error between neural networks and their linearization for smaller datasets. Overall, we observe that as the width grows the error decreases (Figure S5). Additionally, we see that the error grows in the size of the dataset. Thus, although error grows with dataset this can be counterbalanced by a corresponding increase in the model size.

Discussion

We showed theoretically that the learning dynamics in parameter space of deep nonlinear neural networks are exactly described by a linearized model in the infinite width limit. Empirical investigation revealed that this agrees well with actual training dynamics and predictive distributions across fully-connected, convolutional, and even wide residual network architectures, as well as with different optimizers (gradient descent, momentum, mini-batching) and loss functions (MSE, cross-entropy). Our results suggest that a surprising number of realistic neural networks may be operating in the regime we studied. This is further consistent with recent experimental work showing that neural networks are often robust to re-initialization but not re-randomization of layers (Zhang et al. ).

In the regime we study, since the learning dynamics are fully captured by the kernel Θ^\hat{\Theta} and the target signal, studying the properties of Θ^\hat{\Theta} to determine trainability and generalization are interesting future directions. Furthermore, the infinite width limit gives us a simple characterization of both gradient descent and Bayesian inference. By studying properties of the NNGP kernel K\mathcal{K} and the tangent kernel Θ\Theta, we may shed light on the inductive bias of gradient descent.

Some layers of modern neural networks may be operating far from the linearized regime. Preliminary observations in Lee et al. showed that wide neural networks trained with SGD perform similarly to the corresponding GPs as width increase, while GPs still outperform trained neural networks for both small and large dataset size. Furthermore, in Novak et al. , it is shown that the comparison of performance between finite- and infinite-width networks is highly architecture-dependent. In particular, it was found that infinite-width networks perform as well as or better than their finite-width counterparts for many fully-connected or locally-connected architectures. However, the opposite was found in the case of convolutional networks without pooling. It is still an open research question to determine the main factors that determine these performance gaps. We believe that examining the behavior of infinitely wide networks provides a strong basis from which to build up a systematic understanding of finite-width networks (and/or networks trained with large learning rates).

Acknowledgements

We thank Greg Yang and Alex Alemi for useful discussions and feedback. We are grateful to Daniel Freeman, Alex Irpan and anonymous reviewers for providing valuable feedbacks on the draft. We thank the JAX team for developing a language which makes model linearization and NTK computation straightforward. We would like to especially thank Matthew Johnson for support and debugging help.

References

Appendix A Additional figures

Appendix B Extensions

One direction is to go beyond vanilla gradient descent dynamics. We consider momentum updatesCombining the usual two stage update into a single equation.

The discrete update to the function output becomes

where ftlin(x)f^{\textrm{lin}}_{t}(x) is the output of the linearized network after tt steps. One can take the continuous time limit as in Qian , Su et al. and obtain

B.2 Multi-dimensional output and cross-entropy loss

One can extend the loss function to general functions with multiple output dimensions. Unlike for squared error, we do not have a closed form solution to the dynamics equation. However, the equations for the dynamics can be solved using an ODE solver as an initial value problem.

Let Θ^ij(x,X)=θfi(x)θfj(X)T\hat{\Theta}^{ij}(x,\mathcal{X})=\nabla_{\theta}f^{i}(x)\nabla_{\theta}f^{j}(\mathcal{X})^{T}. The above is

For general loss, e.g. cross-entropy with softmax output, we need to rely on solving the ODE Equations S10 and S11. We use the dopri5 method for ODE integration, which is the default integrator in TensorFlow (tf.contrib.integrate.odeint).

Appendix C Neural Tangent kernel for ReLUReLU\operatorname{ReLU} and erferf\operatorname{erf}

For ReLU\operatorname{ReLU} and erf\operatorname{erf} activation functions, the tangent kernel can be computed analytically. We begin with the case ϕ=\phi= ReLU\operatorname{ReLU}; using the formula from Cho and Saul , we can compute T\mathcal{T} and T˙\dot{\mathcal{T}} in closed form. Let Σ\Sigma be a 2×22\times 2 PSD matrix. We will use

Let d=2d=2 and u=(xw,yw)Tu=(x\cdot w,y\cdot w)^{T}. Then uu is a mean zero Gaussian with Σ=[[xx,xy];[xy,yy]]\Sigma=[[x\cdot x,x\cdot y];[x\cdot y,y\cdot y]]. Then

For ϕ=erf\phi=\operatorname{erf}, let Σ\Sigma be the same as above. Following Williams , we get

Appendix D Gradient flow dynamics for training only the readout-layer

The connection between Gaussian processes and Bayesian wide neural networks can be extended to the setting when only the readout layer parameters are being optimized. More precisely, we show that when training only the readout layer, the outputs of the network form a Gaussian process (over an ensemble of draws from the parameter prior) throughout training, where that output is an interpolation between the GP prior and GP posterior.

and applying gradient flow to optimize the readout layer (and freezing all other parameters),

where η\eta is the learning rate. The solution to this ODE gives the evolution of the output of an arbitrary xx^{*}. So long as the empirical kernel xˉ(X)xˉ(X)T\bar{x}(\mathcal{X})\bar{x}(\mathcal{X})^{T} is invertible, it is

Moreover xˉ(X)θ0L+1\bar{x}(\mathcal{X})\theta_{0}^{L+1} and the term containing f0(X)f_{0}(\mathcal{X}) are the only stochastic term over the ensemble of network initializations, therefore for any tt the output f(x)f(x^{*}) throughout training converges to a Gaussian distribution in the infinite width limit, with

Thus the output of the neural network is also a GP and the asymptotic solution (i.e. tt\to\infty) is identical to the posterior of the NNGP (Equation 13). Therefore, in the infinite width case, the optimized neural network is performing posterior sampling if only the readout layer is being trained. This result is a realization of sample-then-optimize equivalence identified in Matthews et al. .

Appendix E Computing NTK and NNGP Kernel

Using this one can also derive the tangent kernel for gradient descent training. We will use induction to show that

Letting n1,,nl1n_{1},\dots,n_{l-1}\to\infty sequentially, the first term converges to the NNGP kernel Kl(x,x)\mathcal{K}^{l}(x,x^{\prime}). By applying the chain rule and the induction step (letting n1,,nl2n_{1},\dots,n_{l-2}\to\infty sequentially), the second term is

Appendix F Results in function space for NTK parameterization transfer to standard parameterization

In this Section we present a sketch for why the function space linearization results, derived in for NTK parameterized networks, also apply to networks with a standard parameterization. We follow this up with a formal proof in §G of the convergence of standard parameterization networks to their linearization in the limit of infinite width. A network with standard parameterization is described as:

The NTK parameterization in Equation 1 is not commonly used for training neural networks. While the function that the network represents is the same for both NTK and standard parameterization, training dynamics under gradient descent are generally different for the two parameterizations. However, for a particular choice of layer-dependent learning rate training dynamics also become identical. Let ηNTK,wl\eta^{l}_{\text{NTK},w} and ηNTK,bl\eta^{l}_{\text{NTK},b} be layer-dependent learning rate for WlW^{l} and blb^{l} in the NTK parameterization, and ηstd=1nmaxη0\eta_{\text{std}}=\frac{1}{n_{\text{max}}}\eta_{0} be the learning rate for all parameters in the standard parameterization, where nmax=maxlnln_{\text{max}}=\max_{l}n_{l}. Recall that gradient descent training in standard neural networks requires a learning rate that scales with width like 1nmax\frac{1}{n_{\text{max}}}, so η0\eta_{0} defines a width-invariant learning rate . If we choose

then learning dynamics are identical for networks with NTK and standard parameterizations. With only extremely minor modifications, consisting of incorporating the multiplicative factors in Equation S37 into the per-layer contributions to the Jacobian, the arguments in §2.4 go through for an NTK network with learning rates defined in Equation S37. Since an NTK network with these learning rates exhibits identical training dynamics to a standard network with learning rate ηstd\eta_{\text{std}}, the result in §2.4 that sufficiently wide NTK networks are linear in their parameters throughout training also applies to standard networks.

We can verify this property of networks with the standard parameterization experimentally. In Figure S7, we see that for different choices of dataset, activation function and loss function, final performance of two different parameterization leads to similar quality model for similar value of normalized learning rate ηstd=ηNTK/n\eta_{\textrm{std}}=\eta_{\textrm{NTK}}/n. Also, in Figure S8, we observe that our results is not due to the parameterization choice and holds for wide networks using the standard parameterization.

Appendix G Convergence of neural network to its linearization, and stability of NTK under gradient descent

In this section, we show that how to use the NTK to provide a simple proof of the global convergence of a neural network under (full-batch) gradient descent and the stability of NTK under gradient descent. We present the proof for standard parameterization. With some minor changes, the proof can also apply to the NTK parameterization. To lighten the notation, we only consider the asymptotic bound here. The neural networks are parameterized as in Equation S36. We make the following assumptions: Assumptions :

The widths of the hidden layers are identical, i.e. n1==nL=nn_{1}=\dots=n_{L}=n (our proof extends naturally to the setting nlnlαl,l(0,)\frac{n_{l}}{n_{l^{\prime}}}\to\alpha_{l,l^{\prime}}\in(0,\infty) as min{n1,,nL}\min\{n_{1},\dots,n_{L}\}\to\infty.)

The analytic NTK Θ\Theta (defined in Equation S42) is full-rank, i.e. 0<λmin:=λmin(Θ)λmax:=λmax(Θ)<.0<\lambda_{\rm{min}}:=\lambda_{\rm{min}}(\Theta)\leq\lambda_{\rm{max}}:=\lambda_{\rm{max}}(\Theta)<\infty. We set ηcritical=2(λmin+λmax)1\eta_{{\rm critical}}=2(\lambda_{\rm{min}}+\lambda_{\rm{max}})^{-1} .

Let θt\theta_{t} denote the parameters at time step tt. We use the following short-hand

where X|\mathcal{X}| is the cardinality of the training set and kk is the output dimension of the network. The empirical and analytic NTK of the standard parameterization is defined as

Note that the convergence of the empirical NTK in probability is proved rigorously in . We consider the MSE loss

Since f(θt)f(\theta_{t}) converges in distribution to a mean zero Guassian with covariance K\mathcal{K}, one can show that for arbitrarily small δ0>0\delta_{0}>0, there are constants R0>0R_{0}>0 and n0n_{0} (both may depend on δ0\delta_{0}, X|\mathcal{X}| and K\mathcal{K}) such that for every nn0n\geq n_{0}, with probability at least (1δ0)(1-\delta_{0}) over random initialization,

The gradient descent update with learning rate η\eta is

We prove convergence of neural network training and the stability of NTK for both discrete gradient descent and gradient flow. Both proofs rely on the local lipschitzness of the Jacobian J(θ)J(\theta).

There is a K>0K>0 such that for every C>0C>0, with high probability over random initialization (w.h.p.o.r.i.) the following holds

The following are the main results of this section.

See the following two subsections for the proof.

One can extend the results in Theorem G.1 and Theorem G.2 to other architectures or functions as long as

The empirical NTK converges in probability and the limit is positive definite.

Lemma 1 holds, i.e. the Jacobian is locally Lipschitz.

As discussed above, there exist R0R_{0} and n0n_{0} such that for every nn0n\geq n_{0}, with probability at least (1δ0/10)(1-\delta_{0}/10) over random initialization,

Let C=3KR0λminC=\frac{3KR_{0}}{\lambda_{\rm{min}}} in Lemma 1. We first prove Equation S49 by induction. Choose n1>n0n_{1}>n_{0} such that for every nn1n\geq n_{1} Equation S47 and Equation S53 hold with probability at least (1δ0/5)(1-\delta_{0}/5) over random initialization. The t=0t=0 case is obvious and we assume Equation S49 holds for t=tt=t. Then by induction and the second estimate of Equation S47

which gives the first estimate of Equation S49 for t+1t+1 and which also implies θjθ023KR0λminn12\|\theta_{j}-\theta_{0}\|_{2}\leq\frac{3KR_{0}}{\lambda_{\rm{min}}}n^{-\frac{1}{2}} for j=0,,t+1j=0,\dots,t+1. To prove the second one, we apply the mean value theorem and the formula for gradient decent update at step t+1t+1

This can be verified by Lemma 1. Because Θ^0Θ\hat{\Theta}_{0}\to\Theta in probability, one can find n2n_{2} such that the event

has probability at least (1δ0/5)(1-\delta_{0}/5) for every nn2n\geq n_{2}. The assumption η0<2λmin+λmax\eta_{0}<\frac{2}{\lambda_{\rm{min}}+\lambda_{\rm{max}}} implies

with probability as least (1δ0/2)(1-\delta_{0}/2) if

where we have applied the second estimate of Equation S49 and Equation S47.

G.2 Proof of Theorem G.2

The first step is the same. There exist R0R_{0} and n0n_{0} such that for every nn0n\geq n_{0}, with probability at least (1δ0/10)(1-\delta_{0}/10) over random initialization,

Let C=3KR0λminC=\frac{3KR_{0}}{\lambda_{\rm{min}}} in Lemma 1. Using the same arguments as in Section G.1, one can show that there exists n1n_{1} such that for all nn1n\geq n_{1}, with probability at least (1δ0/10)(1-\delta_{0}/10)

We claim t1=t_{1}=\infty. If not, then for all tt1t\leq t_{1}, θtB(θ0,Cn12)\theta_{t}\in B(\theta_{0},Cn^{-\frac{1}{2}}) and

This contradicts to the definition of t1t_{1} and thus t1=t_{1}=\infty. Note that Equation S78 is the same as the first equation of Equation S51.

G.3 Proof of Lemma 1

The proof relies on upper bounds of operator norms of random Gaussian matrices.

Let A=AN,nA=A_{N,n} be an N×nN\times n random matrix whose entries are independent standard normal random variables. Then for every t0t\geq 0, with probability at least 12exp(t2/2)1-2\exp(-t^{2}/2) one has

Using this and the assumption on ϕ\phi Equation S38, it is not difficult to show that there is a constant K1K_{1}, depending on σω2,σb2,X\sigma_{\omega}^{2},\sigma_{b}^{2},|\mathcal{X}| and LL such that with high probability over random initializationThese two estimates can be obtained via induction. To prove bounds relating to xlx^{l} and δl\delta^{l}, one starts with l=1l=1 and l=Ll=L, respectively.

Lemma 1 follows from these two estimates. Indeed, with high probability over random initialization

G.4 Remarks on NTK parameterization

For completeness, we also include analogues of Theorem G.1 and Lemma 1 with NTK parameterization.

There is a K>0K>0 such that for every C>0C>0, with high probability over random initialization the following holds

Appendix H Bounding the discrepancy between the original and the linearized network: MSE loss

We provide the proof for the gradient flow case. The proof for gradient descent can be obtained similarly. To simplify the notation, let glin(t)ftlin(X)Yg^{\textrm{lin}}(t)\equiv f^{\textrm{lin}}_{t}(\mathcal{X})-\mathcal{Y} and g(t)ft(X)Yg(t)\equiv f_{t}(\mathcal{X})-\mathcal{Y}. The theorem and proof apply to both standard and NTK parameterization. We use the notation \lesssim to hide the dependence on uninteresting constants.

Integrating both sides and using the fact glin(0)=g(0)g^{\textrm{lin}}(0)=g(0),

Let λ0>0\lambda_{0}>0 be the smallest eigenvalue of Θ^0\hat{\Theta}_{0} (with high probability λ0>13λmin\lambda_{0}>\frac{1}{3}\lambda_{\rm{min}}). Taking the norm gives

Note that α(t)\alpha(t) is non-decreasing. Applying an integral form of the Grönwall’s inequality (see Theorem 1 in ) gives

Let σt=sup0stΘ^sΘ^0op\sigma_{t}=\sup_{0\leq s\leq t}\|\hat{\Theta}_{s}-\hat{\Theta}_{0}\|_{op}. Then

As it is proved in Theorem G.1, for every δ0>0\delta_{0}>0, with probability at least (1δ0)(1-\delta_{0}) over random initialization,

when n1==nL=nn_{1}=\dots=n_{L}=n\to\infty. Thus for large nn and any polynomial P(t)P(t) (we use P(t)=tP(t)=t here)

Now we control the discrepancy on a test point xx. Let yy be its true label. Similarly,

Integrating over [0,t][0,t] and taking the norm imply

Appendix I Convergence of empirical kernel

As in Novak et al. , we can use Monte Carlo estimates of the tangent kernel (Equation 4) to probe convergence to the infinite width kernel (analytically computed using Equations S26, S29). For simplicity, we consider random inputs drawn from N(0,1){\mathcal{N}}(0,1) with n0=1024n_{0}=1024. In Figure S9, we observe convergence as both width nn increases and the number of Monte Carlo samples MM increases. For both NNGP and tangent kernels we observe Θ^(n)ΘF=O(1/n)\|\hat{\Theta}^{(n)}-\Theta\|_{F}=\mathcal{O}\left(1/\sqrt{n}\right) and K^(n)KF=O(1/n)\|\hat{\mathcal{K}}^{(n)}-\mathcal{K}\|_{F}=\mathcal{O}\left({1}/\sqrt{n}\right), as predicted by a CLT in Daniely et al. .

Appendix J Details on Wide Residual Network