Asymptotics of Wide Networks from Feynman Diagrams

Ethan Dyer, Guy Gur-Ari

Introduction

Neural networks achieve remarkable performance on a wide array of machine learning tasks, yet a complete analytic understanding of deep networks remains elusive. One promising approach is to consider the large width limit, in which the number of neurons in one or several layers is taken to be large. In this limit one can use a mean-field approach to better understand the network’s properties at initialization , as well as its training dynamics . Additional related works are cited below.

Suppose that f(x)f(x) is the network function evaluated at an input xx. Let us denote the vector of model parameters by θ\theta, whose elements are initially chosen to be i.i.d. Gaussian. In this work we consider a class of functions we call correlation functions, obtained by taking the ensemble averages of ff, its products, and its derivatives with respect to the parameters θ\theta, evaluated on arbitrary inputs. Here are a few examples of correlation functions.

Correlation functions often show up in the study of wide networks. For example, the first correlation function in (1) plays a central role in the Gaussian Process picture of wide networks , and has been used to diagnose signal propagation in wide networks . The second example in (1) is the ensemble average of the Neural Tangent Kernel (NTK), which controls the evolution of wide networks under gradient flow , and the third example shows up when computing the time derivative of the NTK with MSE loss.

While correlation functions can be computed analytically in some special cases , they are not analytically tractable in general. In this work, we present a method for bounding the asymptotic behavior of such functions at large width. Derivation of the method relies on Feynman diagrams , a technique for calculating multivariate Gaussian integrals, and specifically on the ’t Hooft expansion . However, applying the method is straightforward and does not require any knowledge of Feynman diagrams.

We present a general method for bounding the asymptotic behavior of correlation functions. The method is an adaptation of Feynman diagrams to the case of wide neural networks. The adaptation involves a novel treatment of derivatives of the network function, an element that is not present in the original theoretical physics formulation.

We apply the method to the study of wide network evolution under gradient descent. We improve on existing results for gradient flow by deriving tighter bounds, and extending the analysis to the case of stochastic gradient descent (SGD). Going beyond the infinite-width limit, we present a formalisn for deriving finite-width corrections to network evolution, and present explicit formulas for the first order correction. To our knowledge, this is the first time this correction has been calculated.

As additional applications of our method, in Appendix E.2 we show that in the large width limit the SGD updates are linear in the learning rate, and in Appendix E.3 we discuss finite width corrections to the spectrum of the Hessian.

The main result of this paper is a conjecture. We test our predictions extensively using numerical experiments, and prove the conjecture in some cases, but we do not have a proof that applies to all the cases we tested, including for deep networks with general non-linearities. Furthermore, our method can only be used to derive asymptotic bounds at large width; it does not produce the width-independent coefficient, which is often of interest.

For additional works on wide networks, including relating them to Gaussian processes, see . For additional works discussing the training dynamics of wide networks see . For a previous use of diagrams in this context, see .

The Neural Tangent Hierarchy presented in , published during the completion of this version, has significant overlap with the recursive differential equations (11) presented below.

The rest of the paper is organized as follows. In Section 2 we present our main conjecture and supporting evidence. In Section 3 we apply the method to gradient descent evolution of wide networks, and in Section 4 we present details on Feynman diagrams, which is the basic technique used in our proofs. We conclude with a Discussion. Proofs, additional applications, and details can be found in the Appendices.

Note: An earlier version of this work appeared in the ICML 2019 workshop, Theoretical Physics for Deep Learning .

Correlation function asymptotics

In this section we present our main result: a method for computing asymptotic bounds on correlation functions of wide networks. We present the result as a conjecture, supported by analytic and empirical evidence.

Let us now define correlation functions, the class of functions that is the focus of this work. These functions involve derivative tensors of the network function. We denote the rank-kk derivative tensor by Tμ1μk(x;f):=kf(x)/θμ1θμk.T_{\mu_{1}\dots\mu_{k}}(x;f):=\partial^{k}f(x)/\partial\theta^{\mu_{1}}\cdots\partial\theta^{\mu_{k}}\,. For k=0k=0 we define T(x;f):=f(x)T(x;f):=f(x), and still refer to this as a derivative tensor for consistency.

A correlation function is the expectation value of a product of derivative tensors, evaluated at arbitrary inputs, where the tensor indices are summed in pairs over all the model parameters. A general correlation function CC takes the form

Here, 0k1km1km0\leq k_{1}\leq\cdots\leq k_{m-1}\leq k_{m} are integers,When ka=ka1k_{a}=k_{a-1}, the tensor T(xa)T(x_{a}) has no derivatives. mm and kmk_{m} are even, πSkm\pi\in S_{k_{m}} is a permutation, and Δμ1μkm(π)=δμπ(1)μπ(2)δμπ(km1)μπ(km).\Delta_{\mu_{1}\dots\mu_{k_{m}}}^{(\pi)}=\delta_{\mu_{\pi(1)}\mu_{\pi(2)}}\cdots\delta_{\mu_{\pi(k_{m}-1)}\mu_{\pi(k_{m})}}\,. We use δ\delta to denote the Kronecker delta.

2 Asymptotic bounds on wide networks

We now present our main conjecture, which allows us to place asymptotic bounds on general correlation functions of wide networks.

We will refer to the connected components of a cluster graph GCG_{C} as the clusters of CC. Table 1 lists examples of bounds derived using the Conjecture for several correlation functions. The intuition behind Conjecture 1 comes from the following result for deep linear networks.

Conjecture 1 holds for correlation functions of networks with linear activations.

Let us discuss the intuition behind this theorem. Computing correlation functions of deep linear networks amounts to evaluating Gaussian integrals with polynomial integrands in θ\theta. One can evaluate such integrals using Isserlis’ theorem, which tells us how to express moments of multivariate Gaussian variables in terms of their second moments. For example, given centered Gaussian variables z1,...,z4z_{1},...,z_{4},

Every correlation function of a deep linear network can be similarly reduced to sums over products of Kronecker delta functions and width-independent functions of the inputs. The asymptotic large width behavior is determined by these sums over delta functions, which are tedious to compute by hand. Feynman diagrams are a graphical tool for computing these sums, allowing us to obtain the asymptotic behavior with minimal effort. This tool, which is described in detail in Section 4, is used to prove Theorem 1.

For networks with non-linear activations we further show the following

Conjecture 1 holds for (1) networks with ReLU activations, where all inputs are set to be equal, and for (2) networks with one hidden layer and smooth activation.

For case (1), the idea behind the proof is to put an asymptotic bound on the ReLU network in terms of a corresponding deep linear network. For case (2), the basic idea is that each network function contains a single sum over the width, and by keeping track of these sums using Feynman diagrams we are able to bound the asymptotic behavior. We refer the reader to Appendix C for details.

3 Numerical experiments

Table 1 lists asymptotic bounds on several correlation functions, derived using Conjecture 1. These are compared against the asymptotic behavior computed using numerical experiments. In addition to the results presented here, we performed experiments using the same correlation functions and experimental setup, but with weights sampled uniformly from {±1}\{\pm 1\} instead of from a Gaussian distribution. The results are shown in Appendix A.1. In all cases tested, we found that Conjecture 1 holds. In most cases, we find that the bound is tight. For cases where the bound is not tight, a tight bound can be obtained using the complete Feynman diagram analysis presented below.

Applications to training dynamics

Here, Θ\Theta is the Neural Tangent Kernel (NTK), defined by Θ(x1,x2):=θfT(x1)θf(x2)\Theta(x_{1},x_{2}):=\nabla_{\theta}f^{T}(x_{1})\nabla_{\theta}f(x_{2}). The authors of showed that the kernel is constant during training up to O(n1/2)\mathcal{O}(n^{-1/2}) corrections. This leads to a dramatic simplification in training dynamics . In particular, for MSE loss the network map evaluated on the training data evolves as f(t)=y+etΘ(0)(f(0)y)f(t)=y+e^{-t\Theta^{(0)}}(f^{(0)}-y). Here we are using condensed notation: Θ(0)\Theta^{(0)} and f(0)f^{(0)} are values at initialization, and ff, f(0)f^{(0)}, yy are treated as vectors in training set space. The kernel Θ(0)\Theta^{(0)} is a square matrix in the same space. We will use our technology to derive a tighter bound on finite-width corrections to the kernel during training and present explicit formulas for the leading correction.

The following result is useful in analyzing the behavior of correlation functions under gradient flow.

Here we prove the statement for k=1k=1. Appendix D.2 contains a proof for the general case.

Let CC have nen_{e} even clusters and non_{o} odd clusters. Consider the correlation function

Denote by nen_{e}^{\prime} (non_{o}^{\prime}) the number of even (odd) clusters in this correlation function, which has m=m+2m^{\prime}=m+2 derivative tensors. One can check that either (ne,no)=(ne+1,no)(n_{e}^{\prime},n_{o}^{\prime})=(n_{e}+1,n_{o}) or (ne,no)=(ne1,no+2)(n_{e}^{\prime},n_{o}^{\prime})=(n_{e}-1,n_{o}+2), depending on whether the μF\partial_{\mu}F derivative is acting on an odd or even cluster in FF.Here we are extending the use of the term cluster to refer to derivative tensors in the integrand itself. Therefore, ne+no2m2sCn^{\prime}_{e}+\frac{n^{\prime}_{o}}{2}-\frac{m^{\prime}}{2}\leq s_{C}. ∎

With this result, it is easy to understand the constancy of the NTK at large width. The first derivative of the NTK is given by

Next, we will compute the explicit time dependence of Θ\Theta and ff at order O(n1)\mathcal{O}(n^{-1}) under gradient flow. This is the leading correction to the infinite width result. We define the functions O1(x):=f(x)O_{1}(x):=f(x) and

Notice that O2=ΘO_{2}=\Theta is the kernel. It is easy to check that

Let us denote the time-evolved kernel by Θ(t)=Θ(0)+Θ1(t)+O(n2)\Theta(t)=\Theta^{(0)}+\Theta_{1}(t)+\mathcal{O}(n^{-2}), where Θ(0)\Theta^{(0)} is the kernel at initialization, and Θ1(t)\Theta_{1}(t) is the O(n1)\mathcal{O}(n^{-1}) correction we are seeking. Integrating equations (11) starting with s=4s=4, we find

Here we have introduced the notation Δf(x;t)=etΘ0(f(0)y)\Delta f(x;t)=e^{-t\Theta_{0}}(f^{(0)}-y). A detailed derivation can be found in Appendix E.4. There we also evaluate the integrals in (3.1) in terms of the NTK spectrum.

To obtain the O(n1)\mathcal{O}(n^{-1}) correction to the network map (evaluated for simplicity on the training data), we further integrate (11) for s=1s=1 and find

Here we have denote the infinite width evolution by f0(t)=y+etΘ(0)(f(0)y)f_{0}(t)=y+e^{-t\Theta^{(0)}}(f^{(0)}-y). Figures 1(b) and 1(c) compare these predictions against empirical results.

Feynman diagrams for deep linear networks

In this section we present the Feynman diagram technique, and show how it allows us to compute the asymptotic behavior of correlation functions. We end this section with a proof of Theorem 1 for the case of networks with a single hidden layer and linear activations.

Given a correlation function CC, we map it to a family of graphs called Feynman diagrams. The graphs are independent of the inputs, and are defined as follows.

Let C(x1,,xm)C(x_{1},\dots,x_{m}) be a correlation function for a network with dd hidden layers. The family Γ(C)\Gamma(C) is the set of all graphs that have the following properties.

There are mm vertices v1,,vmv_{1},\dots,v_{m}, each of degree d+1d+1.

Each edge has a type t{U,W1,,Wd1,V}t\in\{U,W^{1},\dots,W^{d-1},V\}. Every vertex has one edge of each type.

If two derivative tensors Tμ1,,μk(xi),Tν1,,νk(xj)T_{\mu_{1},\dots,\mu_{k}}(x_{i}),T_{\nu_{1},\dots,\nu_{k^{\prime}}}(x_{j}) are contracted kk times in CC, the graph must have at least kk edges (of any type) connecting the vertices vi,vjv_{i},v_{j}.

The graphs in Γ(C)\Gamma(C) are called the Feynman diagrams of CC.

As the factors of x1x_{1} and x2x_{2} are independent of nn, we see that C(x1,x2)=O(n0)C(x_{1},x_{2})=\mathcal{O}(n^{0}). Notice that there are two relevant contributions to this answer: each factor of the network function in the integrand contributes n1/2n^{-1/2}, and the summed-over product of Kronecker deltas contributes nn. Other details, such as the input dependence, are irrelevant. Feynman diagrams allow us to encode only those details that affect the nn scaling, ignoring the rest.

The set Γ(C)\Gamma(C) for the correlation function (14) consists of a single Feynman diagram, shown in Figure 2(a).

The asymptotic bound on a correlation function is obtained by the following result, which is due to .

Let C(x1,,xm)C(x_{1},\dots,x_{m}) be a correlation function with one hidden layer and linear activation. Then C=O(ns)C=\mathcal{O}(n^{s}) where s=maxγΓ(C)lγm2s=\max_{\gamma\in\Gamma(C)}l_{\gamma}-\frac{m}{2}, and lγl_{\gamma} is the number of loops in γ\gamma.

We are now ready to prove Theorem 1 for the case of single hidden layer with linear activations. A proof for the general case can be found in Appendix B.

Discussion

Ensemble averages of the network function and its derivatives are an important class of functions that often show up in the study of wide neural networks. Examples include the ensemble average of the train and test losses, the covariance of the network function, and the Neural Tangent Kernel . In this work we presented Conjecture 1, which allows one to derive the asymptotic behavior of such functions at large width.

For the case of deep linear networks, we presented a complete analytic understanding of the Conjecture based on Feynman diagrams. In addition, we presented empirical and anlytic evidence showing that the Conjecture also holds for deep networks with non-linear activations, as well as for networks with non-Gaussian initialization. We found that the Conjecture holds in all cases we tested.

The basic tools presented in this work can be applied to many aspects of wide network research, greatly simplifying theoretical calculations. We presented several applications of our method to the asymptotic behavior of wide networks during stochastic gradient descent, and additional applications are presented in Appendix E. We were able to improve upon known results by tightening existing bounds, and by applying the technique to SGD as well as to gradient flow. In addition, we took a step beyond the infinite width limit, deriving closed-form expressions for the first finite-width correction to the network evolution. These novel results open the door to studying finite-width networks by systematically expanding around the infinite width limit.

A central question in the study of wide networks is whether the infinite width limit is a good model for describing the behavior of realistic deep networks . In this work we take a step toward answering this question, by working out the next order in a perturbative expansion around the infinite width limit, potentially bringing us closer to an analytic description of finite-width networks. We hope that the techniques presented here provide a basis to systematically answering these and other questions about the behavior of wide networks.

Acknowledgements

The authors would like to thank Alex Alemi, Yasaman Bahri, Boris Hanin, Jared Kaplan, Jaehoon Lee, Behanm Neyshabur, Sam Schoenholz, Sylvia Smullin, and Jascha Sohl-Dickstein for useful discussion. The authors would especially like to thank Ying Xiao for extensive comments on early versions of this manuscript.

Appendix A Experimental details and additional results

In Table 1 we listed asymptotic bounds on several correlation functions, where the model parameters were initialized from a Gaussian distribution. Table 2 shows additional results using the same experimental setup, but with weights sampled uniformly from {±1}\{\pm 1\}. We again find good agreement with the predictions of Conjecture 1.

A.2 Experimental details

The experiments in Figure 1 were performed on two-class MNIST, computing a single randomly-chosen component of the kernel Θ\Theta. Sub-figure (a) uses networks trained for 1024 steps with learning rate 1.0 and 1000 samples per class, averaged over 100 initializations. Each curve in figure (b) represents a single instance of the network map evaluated on a random image over the corse of training. The models were trained with 10 samples per class and learning rate 0.1. The input to the network is normalized by the square root of the input dimension as in

Appendix B Feynman diagrams for deep linear networks

Feynman diagrams can be used to derive asymptotic upper bounds on deep linear networks in the large width limit. In this section we describe the method in detail, and use it to prove Theorem 1.

In this section we build on the results of Section 4 and consider correlation functions of deep linear networks with dd hidden layers. The network function was defined in (2), and here we set the activation σ\sigma to be the identity. Definition 2 describes how to map a correlation function CC to Γ(C)\Gamma(C), a family of graphs called Feynman diagrams. The Feynman diagram method relies on Isserlis’ theorem, which allows us to express arbitrary moments of multivariate Gaussian variables in terms of their covariance.

Let z=(z1,,zl)z=(z_{1},\dots,z_{l}) be a centered multivariate Gaussian variable. For any positive integer kk,

In particular, if the covariance matrix of zz is the identity then

Using this theorem, a correlation function CC can be expressed as a sum over permutations as in (17). Each term in this sum maps to a Feynman diagram in Γ(C)\Gamma(C).

A similar question appeared in the context of theoretical physics, and the correct generalization is due to . The idea is to treat each Feynman diagram as the triangulation of a Riemann surface, and to define the number of loops in a graph to be the number of faces of the triangulation. In practice, this involves mapping each Feynman diagram to a new double-line diagram: A graph in which each edge corresponds to a single Kronecker delta factor, and loops correspond to triangulation faces of the original diagram.

Each edge (v(1),v(2))(v^{(1)},v^{(2)}) in γ\gamma of type WlW^{l} is mapped to two edges (vl(1),vl(2))(v^{(1)}_{l},v^{(2)}_{l}), (vl+1(1),vl+1(2))(v^{(1)}_{l+1},v^{(2)}_{l+1}).

Each edge (v(1),v(2))(v^{(1)},v^{(2)}) in γ\gamma of type UU is mapped to a single edge (v1(1),v1(2))(v^{(1)}_{1},v^{(2)}_{1}).

Each edge (v(1),v(2))(v^{(1)},v^{(2)}) in γ\gamma of type VV is mapped to a single edge (vd(1),vd(2))(v^{(1)}_{d},v^{(2)}_{d}).

The discussion above is summarized by the following result, due to , that describes how the asymptotic behavior of a general correlation function can be computed using the Feynman rules for deep linear networks.

The intuition for the formula sγ=lγdm2s_{\gamma}=l_{\gamma}-\frac{dm}{2} is similar to the single hidden layer case, Theorem 3. The term lγl_{\gamma} counts the number of factors of the form i1,,ikδi1i2δi2i3δiki1=n\sum_{i_{1},\dots,i_{k}}\delta_{i_{1}i_{2}}\delta_{i_{2}i_{3}}\cdots\delta_{i_{k}i_{1}}=n that appear in the Correlation function after applying Isserlis’ theorem. The term (dm2)\left(-\frac{dm}{2}\right) is due to the explicit nd/2n^{-d/2} normalization of the network function.

B.2 Asymptotics of deep linear networks

We now prove Theorem 1. The theorem follows from the following lemma, again due to , that relates the asymptotic behavior to the number of connected components in a Feynman diagram.

Let C(x1,,xm)C(x_{1},\dots,x_{m}) be a correlation function for a deep linear network. Let cγc_{\gamma} be the number of connected components of a graph γΓ(C)\gamma\in\Gamma(C). Then C=O(ns)C=\mathcal{O}(n^{s}), where

Let us prove the remaining statement about sγs_{\gamma^{\prime}}. The Euler character of the graph γ\gamma^{\prime} with v=vγv=v_{\gamma^{\prime}} vertices, ee edges and ff faces is χ=ve+f\chi=v-e+f. The degree of each vertex in the graph is d+1d+1, and therefore e=(d+1)v2e=\frac{(d+1)v}{2}. Using Theorem 5 the graph is O(nsγ)\mathcal{O}(n^{s_{\gamma^{\prime}}}) where sγ=fdv2s_{\gamma^{\prime}}=f-\frac{dv}{2}. We therefore find that χ=v2+sγ\chi=\frac{v}{2}+s_{\gamma^{\prime}}. The graph γ\gamma^{\prime} is a triangulation of some connected surface with at least one boundary. The Euler character for such a surface is bounded by χ1\chi\leq 1, and therefore sγ1v2s_{\gamma^{\prime}}\leq 1-\frac{v}{2}. ∎

Let C(x1,,xm)C(x_{1},\dots,x_{m}) be a correlation function for a deep linear network. Suppose that the cluster graph GCG_{C} has nen_{e} even size components and non_{o} odd size components. Let γΓ(C)\gamma\in\Gamma(C) be a Feynman diagram with cγc_{\gamma} connected components. We will show that cγne+no2c_{\gamma}\leq n_{e}+\frac{n_{o}}{2}. It then follows immediately from Lemma 2 that C=O(ns)C=\mathcal{O}(n^{s}) where s=ne+no2m2s=n_{e}+\frac{n_{o}}{2}-\frac{m}{2}, concluding the proof.

Let us derive the bound on cγc_{\gamma}. First, all vertices that belong to a given cluster (a component of GCG_{C}) will also belong to the same connected component in γ\gamma. This is because every edge in GCG_{C} is also an edge in γ\gamma (note that GCG_{C} and γ\gamma have the same set of vertices). Therefore, cγne+noc_{\gamma}\leq n_{e}+n_{o}. Second, note that every connected component of the graph γ\gamma has an even number of vertices. Indeed, each edge has a type tt, and each vertex has exactly one edge of each type. Therefore, a connected component with vv vertices has v2\frac{v}{2} edges of each type, and so vv must be even. It follows that the vertices of even clusters can form their own connected components in a Feynman diagrams, while odd clusters must be connected in sets of 2 or more to form connected components. The bound on cγc_{\gamma} then follows. ∎

Appendix C Non-Linearities

In previous sections we presented the Feynman diagram method for computing the large width asymptotics of correlation functions. In this section we show that the method applies as-is for deep networks with ReLU non-linearities and all-equal inputs, as well as to networks with a single hidden layer, a broader class of non-linearities, and arbitrary inputs. Theorem 2 follows immediately from the results presented in this section.

The following result guarantees that the presence of ReLU non-linearities does not change the asymptotic upper bound compared with linear activations, when all inputs are the same.

Intuitively, we will rely on the fact that for ReLU networks we can, in some cases, treat the binary neuron activations as being statistically independent of the weights. This result is due to . Given this result, we can bound the contribution of the binary activations, and the remaining Gaussian integral is equivalent to that found in a deep linear network. The proof does not work for correlation functions with non-equal inputs, because in that case the independence result of no longer holds.

Here, HH is the Heaviside step function acting elementwise on its vector argument. We now introduce the construction from . Let ξj,ηj\xi^{j},\eta^{j}, j=1,,dj=1,\dots,d be diagonal matrices of dimension nn, whose diagonal elements are ±1\pm 1-Bernoulli(pp) variables with p=12p=\frac{1}{2}. We define the new variables

{D^j(x),ρj,j=1,,d}\{\hat{D}^{j}(x),\,\rho^{j},\,j=1,\dots,d\} are independent of {U,V,W1,,Wd1}\{U,V,W^{1},\dots,W^{d-1}\}.

{D^j(x),j=1,,d}\{\hat{D}^{j}(x),\,j=1,\dots,d\} are independent of each other for fixed xx. The diagonal entries of each diagonal matrix D^j(x)\hat{D}^{j}(x) are independent, and take the values {0,1}\{0,1\} with probability 12\frac{1}{2}.

Now, the correlation function can be written as

Here, i={i1,1,,im,d}\vec{i}=\{i_{1,1},\dots,i_{m,d}\} and α={α1,,αm}\vec{\alpha}=\{\alpha_{1},\dots,\alpha_{m}\}, and

In writing the equation (34) we used the facts that D^,ρ\hat{D},\rho are independent of the parameters, and that parameters in different layers are independent. We now use Theorem 4 (Isserlis), which says that each of the expectation values over products of VV, WjW^{j}, and UU elements is equal to a sum over permutations, where each term is a product over Kronecker delta functions — the covariance matrices of the parameters.

We can now bound the correlation function as follows.

C.2 Single hidden layer networks

For networks with a single hidden layer, defined by

we can extend our asymptotic analysis to smooth non-linearities. We will show in Theorem 7 that for any correlation function CC, we have

where ff is a deep linear network of equal width and sufficient depth. Therefore, computing the asymptotics using Feynman diagrams for deep linear networks yields a bound on networks with a single hidden layer and smooth non-linearities.

In the last line, we again used the fact that the summands are equal and independent of nn.

Our approach to establishing the asymptotic scaling will be to first bound the maximum number of index sums appearing in any term in our correlation function, written in the form (48), and then to argue that the summands are bounded by an nn-independent constant.

Let us introduce a family of diagrams, Γ(C)\Gamma^{\prime}(C), which are different in general than the Feynman diagrams. A given diagram gΓ(C)g\in\Gamma^{\prime}(C) is constructed as follows.

Each derivative tensor in CC is mapped to a vertex in the graph.

Each edge has a type that corresponds to one of the weight vectors UU or VV.

Each vertex has exactly one edge of VV type.

If two derivative tensors are contracted in CC, the graph must have at least one edge (of any type) connecting the corresponding vertices for each contraction.

Let ne(no)n_{e}(n_{o}) be the number of even(odd) clusters in the cluster graph GCG_{C} of C(x1,,xm;fd)C(x_{1},\dots,x_{m};f_{d}). The cluster graph, GCG_{C} is a subgraph of any graph gΓ(C)g\in\Gamma^{\prime}(C). We can thus think about the embedding of even and odd clusters into gg. In any graph gΓ(C)g\in\Gamma^{\prime}(C), an even cluster may belong to its own connected component, while for odd clusters there must be an even number of them in any connected component. This is because an even (odd) cluster contains an even (odd) number of factors of VV, which must be paired up in any connected component. We find that

The last inequality was used below Lemma 2 in the proof of Theorem 1. ∎

The correlation function in (48) can be bound as

For fixed inputs, Si1,,irα(α)\mathcal{S}_{i_{1},\ldots,i_{r_{\alpha}}}^{(\alpha)} can take at most O(1)\mathcal{O}(1) different values as a function of its indices, and the values are independent of nn. This is because the variables UiU_{i} are identical. We define smaxs_{\textrm{max}} as the maximum value of Si1,,irα(α)|\mathcal{S}_{i_{1},\ldots,i_{r_{\alpha}}}^{(\alpha)}| as a function of α\alpha and the indices. Combining this with the above lemmas we can then write.

The result of the theorem follows from Lemma 2. ∎

Appendix D Correlation function asymptotics

In this section we prove several general results about correlation function asymptotics in the large width limit. Throughout this section, we assume that Conjecture 1 holds.

Conjecture 1 can be used to bound the variance of the integrands that appear inside correlation functions.

To bound the variance, it is enough to bound the correlation function

As a corrolary, notice that if C=O(ns)C=\mathcal{O}(n^{s}) according to Conjecture 1, then typical realizations of the integrand will also be O(ns)\mathcal{O}(n^{s}). In other words, the asymptotic bound of Conjecture 1 holds for typical initializations, not just in expectation.

D.2 Gradient Flow

The following results are used in the gradient flow calculations of Section 3.

Let GG^{\prime} be a graph with mm^{\prime} vertices, nen^{\prime}_{e} even components, and non^{\prime}_{o} odd components. Let GG be a subgraph of GG^{\prime} with mm vertices, nen_{e} even components, and non_{o} odd components. Then s(ne,no,m)s(ne,no,m)s(n_{e},n_{o},m)\geq s(n^{\prime}_{e},n^{\prime}_{o},m^{\prime}) where s(a,b,c):=a+bc2s(a,b,c):=a+\frac{b-c}{2}.

It is enough to show that s(ne,no,m)s(n_{e},n_{o},m) does not increase if we (1) add a vertex to GG, or (2) add an edge to GG, because GG^{\prime} can be obtained from GG by performing such operations finitely many times. Adding a vertex to GG changes nenen_{e}\mapsto n_{e}, nono+1n_{o}\mapsto n_{o}+1, and mm+1m\mapsto m+1, leaving s(ne,no,m)s(n_{e},n_{o},m) unchanged. Next, if we add an edge to GG then mm does not change, and there are 4 possibilities for how nen_{e} and non_{o} change.

The edge connects two even components. Then nene1n_{e}\mapsto n_{e}-1, nonon_{o}\mapsto n_{o}, and s(ne,no,m)s(n_{e},n_{o},m) decreases by 1.

The edge connects two odd components. Then nene+1n_{e}\mapsto n_{e}+1, nono2n_{o}\mapsto n_{o}-2, and s(ne,no,m)s(n_{e},n_{o},m) does not change.

The edge connects an even component and an odd component. Then nene1n_{e}\mapsto n_{e}-1, nonon_{o}\mapsto n_{o}, and s(ne,no,m)s(n_{e},n_{o},m) decreases by 1.

The edge connects two vertices that belong to the same component. In this case nen_{e}, non_{o}, and s(ne,no,m)s(n_{e},n_{o},m) do not change.

We now prove Lemma 1 giving the scaling of time derivatives of correlation functions at initialization. We prove the result for polynomial loss functions. Extension to more general loss functions requires interchanging the large width limit and the Taylor expansion of the loss, which we do not discuss.

D.3 Stochastic Gradient descent

Let CC be a correlation function for a network with linear activations, and assume that Conjecture 1 holds, namely C=O(nsC)C=\mathcal{O}(n^{s_{C}}) where sCs_{C} is defined in the Conjecture. If CtC_{t} is the evolved correlation function after tt SGD steps, then Ct=O(nsC)C_{t}=\mathcal{O}(n^{s_{C}}).

The integrand FθF_{\theta} can be written as a product of derivative tensors of the form Tμ1μk(x;θ)T_{\mu_{1}\dots\mu_{k}}(x;\theta), with contracted derivatives. Suppose that the network has dd hidden layers. Then under an SGD step, we have

Here DBD_{B} is the mini-batch, and xax^{\prime}_{a} are mini-batch samples. The kk sum is truncated because higher-order derivatives of ff vanish.

We can now see how taking a gradient descent step affects the cluster graph. After taking a step, each derivative tensor in CtC_{t} is replaced by a sum (over k,x1,,xkk,x^{\prime}_{1},\dots,x^{\prime}_{k}). Each term in the combination of these sums is a correlation function, whose cluster graph is a subgraph of CC. Therefore, by Lemma 6, Ct=O(nsC)C_{t}=\mathcal{O}(n^{s_{C}}). ∎

We note that for general activation functions the sum in (60) may be infinite. In this case, to complete the proof we would need to show that the infinite sum obeys the same bound as each individual term in the sum. We leave this for future work.

Appendix E Applications

Here we present several applications of our Feynman diagram method for computing large width asymptotics. We assume throughout this section that Conjecture 1 holds.

In this section we prove two results regarding the NTK at large width. We show that the kernel converges in probability, and that during gradient descent it is constant up to O(n1)\mathcal{O}(n^{-1}) corrections.

The Neural Tangent Kernel Θ\Theta of a deep linear network converges in probability in the large width limit, and its variance is O(n1)\mathcal{O}(n^{-1}).

Conjecture 1 is not sufficient for proving this theorem, as we need to use a more detailed Feynman diagram argument. Therefore, we only prove the theorem for the case of deep linear networks.

A(x,x)A(x,x^{\prime}) includes all diagrams in which the vertices corresponding to f(1),f(2)f^{(1)},f^{(2)} share an edge, and also the vertices f(3),f(4)f^{(3)},f^{(4)} share an edge (due to the explicit derivatives);

B(x,x)B(x,x^{\prime}) includes all diagrams in which all edges are either between f(1),f(2)f^{(1)},f^{(2)} or between f(3),f(4)f^{(3)},f^{(4)}.

Therefore, in the full expression (61), the only remaining diagrams (i.e. the diagrams that do not cancel between the two terms) are those that include

an edge connecting f(3),f(4)f^{(3)},f^{(4)}, and

an edge connecting one of f(1),f(2)f^{(1)},f^{(2)} to one of f(3),f(4)f^{(3)},f^{(4)}.

Next, we show that the large width NTK is constant during training, and compute the asymptotics of the higher-order terms. The following argument is phrased for deep linear networks. More generally, the same argument holds for deep networks with smooth non-linear activations under the additional assumption that the large width limit and Taylor series can be exchanged (note that for such networks, the network function is analytic in the weights).

Let f(x)f(x) be the network output of a deep linear network with MSE loss LL. Let Θt(x,x)\Theta_{t}(x,x^{\prime}) be the Neural Tangent Kernel at SGD step tt, for some inputs x,xx,x^{\prime}. Then in the large width limit, the kernel is constant in tt in expectation, and

The fact that the k,lk,l sums are truncated at d+1d+1 follows from using linear activations, as higher-order derivatives of the network function vanish in this case. All terms in the sum over k,lk,l include a tensor product of the form μ,μ,νTμμ1μk(x)Tμν1νl(x)()\sum_{\mu,\vec{\mu},\vec{\nu}}T_{\mu\mu_{1}\dots\mu_{k}}(x)T_{\mu\nu_{1}\dots\nu_{l}}(x^{\prime})(\cdots) with either k1k\geq 1 or l1l\geq 1, where ()(\cdots) stands for additional derivative tensor factors. Therefore, all terms in the k,lk,l sum are correlation functions that have a cluster of size at least 3, including T(x)T(x), T(x)T(x^{\prime}), and at least one other tensor contracted through the μ1\mu_{1} or ν1\nu_{1} index. It follows from the Conjecture that each term in the sum is O(n1)\mathcal{O}(n^{-1}). Lemma 5 then implies that the variance of these updates is O(n2)\mathcal{O}(n^{-2}). ∎

E.2 Wide network evolution is linear in the learning rate

In this section we prove that, at large width, the NTK determines the evolution of the network function not just for continuous-time gradient descent but also for discrete-time gradient descent. A similar result holds for stochastic gradient descent, using a stochastic kernel.The perspective presented here helps understand the results of where it was observed empirically that linearized evolution is a good description of wide networks even for relatively large learning rates. Again we prove the deep linear case explicitly, but the result holds for deep networks with smooth non-linear activations under the additional assumption that the large width limit and Taylor series can be exchanged.

Let f(x)f(x) be the network output of a deep linear network, and let ft(x)f_{t}(x) be the evolved function after tt gradient descent steps, defined by θt+1=θtηL(θt)\theta_{t+1}=\theta_{t}-\eta\nabla L(\theta_{t}). In the large width limit, each gradient descent step update of ftf_{t} is linear in the learning rate η\eta. Furthermore,

For a deep linear network, under a single gradient descent step we have

First, consider the sum over kk in (72). Each term in the sum is a correlation function for which the cluster graph contains a connected subgraph of size at least 3, and is therefore O(n1)\mathcal{O}(n^{-1}) by Lemma 6 and Theorem 1. In the remaining O(η)\mathcal{O}(\eta) term, by the same argument as Theorem 10 we can replace Θt=Θ0+O(n1)\Theta_{t}=\Theta_{0}+\mathcal{O}(n^{-1}) in the correlation function. The result is equation (72). ∎

As mentioned above, we note that the proof goes through when using stochastic gradient descent updates, with the difference that in (72) we should sum over mini-batch samples instead of over the entire training set.

E.3 Spectral properties of the NTK and the Hessian

With an eye towards understanding the structure of the loss landscape at large width and as another example use case of our approach, we investigate the relation between the spectra of the Hessian, and the NTK.

The Hessian of a general loss takes the form,

Thus most moments of the Hessian are equal to moments of A\mathcal{A} in expectation.For the first moment in linear or ReLU networks, TrB=0\operatorname{Tr}{\mathcal{B}}=0.

What’s more, we can relate moments of A\mathcal{A} to those of the kernel, as both are built out of two logit derivatives. Explicitly,

and so the moments of the Hessian are also related to those of the kernel.Here we have argued for relations relating the mean of moments. It is not too difficult to see that these relations will also hold for typical realizations. This follows from Lemma 5.

For the case of MSE loss, we can go even further. In that case, M(xa,xb)=δabM(x_{a},x_{b})=\delta_{ab} and B\mathcal{B} decays to zero during training. We thus have,

These results indicate that the only difference between the spectra of the Hessian and the NTK come from B\mathcal{B} and that B\mathcal{B} must have eigenvalues which scale as 1/dim(B)1/\sqrt{\textrm{dim}(\mathcal{B})}. As dim(B)=O(n)\textrm{dim}(B)=\mathcal{O}(n) for one hidden layer networks and O(n2)\mathcal{O}(n^{2}) for deep networks we are left with the relation (78) between the eigenvalues of Θ\Theta and HH. As the network trains, the difference between these eigenvalues gets even smaller. These results are confirmed experimentally in Figure 8 and Figure 9.

Expression (86) relating the moments of the Hessian to the NTK still holds in this context provided we take Tr((MΘ)m)Tr((MABΘAB)m)\operatorname{Tr}\left((M\Theta)^{m}\right)\rightarrow\operatorname{Tr}\left((M_{AB}\Theta^{AB})^{m}\right), with the more general matrix,

At large width, the NTK approaches the identity matrix in class space, ΘABNclass1δABTr(Θ)\Theta^{AB}\rightarrow N_{\textrm{class}}^{-1}\delta^{AB}\operatorname{Tr}\left(\Theta\right) . This implies that the NTK spectrum consists of NclassN_{\textrm{class}} repeated copies. This has consequences for the spectrum of the Hessian at large width. For instance, for the case of MSE error it also implies NclassN_{\textrm{class}} repeated copies of the Hessian spectrum (See Figure 8(b)). It is intriguing to think this could serve as a path towards understanding the emergence of the NclassN_{\textrm{class}} (or Nclass1N_{\textrm{class}}-1) large eigenvalues and corresponding subspace observed in the Hessian spectrum .

E.4 Higher-order network evolution

In this section we will explain how to compute higher-order corrections to training dynamics. In principle, this prescription allows one to compute model dynamics as an expansion in 1/n1/n to arbitrary order. We apply this explicitly to compute the O(1/n)\mathcal{O}(1/n) training dynamics. In Figure 10, we present experimental confirmation of our predictions for the evolution of the NTK. In the main text, we presented these results for the special case of gradient flow with MSE loss.

So far we have mostly been using diagrammatic techniques to understanding the leading order scaling of correlation functions. It is interesting to try and understand finite width networks by asking how training dynamics are modified beyond the leading order asymptotics. There are three clear sources of corrections to the leading behavior.

The initial kernel, Θ0\Theta_{0}, receives finite width corrections.

The linear (in learning rate) update for the network function, equation (72), is modified away from infinite width.

The kernel is not constant at finite width.

The first source of corrections is automatically taken into account in typical empirical settings, as the finite-width Θ0\Theta_{0} can be computed explicitly. The other sources of corrections are non-trivial, and will be the focus of this section. We begin by explaining how to take into account the non-constancy of the kernel order by order in 1/n1/n, while maintaining the continuous time approximation. Next, we back away from the continuous time limit and write down the full discrete evolution. We introduce a method to compute arbitrary order corrections, and explicitly work out the corrections at order 1/n1/n for MSE loss.

Continuous time evolution of the model function in neural networks is governed by the differential equation

As we have discussed at length, in the large width limit, this equation simplifies and Θ(t)\Theta(t) is asymptotically constant . Our goal is to move beyond this leading behavior at large width and solve (96) order-by-order in a 1/n1/n expansion. In Section 3 we described how to compute the O(n1)\mathcal{O}(n^{-1}) corrections for the case of MSE loss. Here we explain how to handle corrections more generally as well as giving a more detailed discussion of the MSE case.

The network map ff and kernel, Θ\Theta are members of the following family of operators.

with O1:=fO_{1}:=f. Here, as above, O2=ΘO_{2}=\Theta is the kernel. With a general loss function, these operators satisfy.

Equations (97) and (98). Give an infinite tower of first order ODEs, the solution of which gives the time evolution of the network map and the kernel.

Solving this infinite tower is not feasible. If we wish to work to a given order in an expansion in 1/n1/n, however, there is a dramatic simplification, which makes a solution possible. Firstly, we can truncate these equations. To see this note that the operators OsO_{s} contain ss contracted derivative tensors. As a result, by Lemma 6 and Conjecture 1, correlation functions involving the operators, OsO_{s} satisfy

where F(x;t)F(\vec{x};t) is arbitrary additional contribution to the integrand.

Thus, if we wish to work to solve for the time evolution up to corrections which scale as O(nr)\mathcal{O}(n^{-r}) we can truncate the tower at s=2rs=2r and set O2r(t)O_{2r}(t) to be equal to its initial value. Note that the leading order solution, (96), is the result of this procedure with r=1r=1 and the results presented in the main text for the O(n1)\mathcal{O}(n^{-1}) evolution correspond to r=2r=2.

The truncation provides a dramatic simplification, however it is not immediately clear how to solve even the truncated differential equations in (98). We now describe how to organize the perturbative expansion of the operators Os(t)O_{s}(t) (including ff and Θ\Theta) in such a way that the differential equations become tractable.

The central idea is to write each operator, Os(t)O_{s}(t), as an expansion.

where each order Os(r)O_{s}^{(r)} captures the O(nr)\mathcal{O}(n^{-r}) evolution of OsO_{s}. For example,

The notation Os(r)O_{s}^{(r)} means both that any correlation function containing Os(r)O_{s}^{(r)} is O(nr)\mathcal{O}(n^{-r}) and, by Lemma 5, that typical realizations of the operators scale as O(nr)\mathcal{O}(n^{-r}).Note that in Section 3 we used the alternate notation Θ1\Theta_{1} for Θ(1)\Theta^{(1)}.

Once the operators OsO_{s} are organized in this way, solving the differential equations (96) is tractable. As the differential equations describing the evolution of Os(r)O_{s}^{(r)} for r>0r>0 only depend on the time dependent solutions of Os(r1)O_{s}^{(r-1)}, we can iteratively solve for the Os(r)O_{s}^{(r)} order by order. For example, in Section 3, we used the leading order solution for ff and Θ\Theta to solve for Θ(1)\Theta^{(1)}.

In principle this procedure can be extended to arbitrary order in 1/n1/n. Before going onto explain the finite step corrections to this procedure, we reproduce the results of section 3 in more detail.

As such the update equation simplifies to,

The leading order solution to equation (106) is exponential, kernel evolution,

Here we are using a condensed notation where f0:=f(0)(0)f_{0}:=f^{(0)}(0), Θ0:=Θ(0)(0)\Theta_{0}:=\Theta^{(0)}(0). Equation (107) is a vector equation with f0f_{0}, f(0)(t)f^{(0)}(t), and yy are vectors over the training set and the leading order kernel Θ0\Theta_{0} is a square matrix over the same space.

As discussed above to study the O(1/n)\mathcal{O}(1/n) evolution, we can truncate the set of equations in (98) at s=4s=4, and set O4(t)=O4(1)(t)=O4(0)O_{4}(t)=O_{4}^{(1)}(t)=O_{4}(0) up to corrections which scale as O(n2)\mathcal{O}(n^{-2}).

To solve for O3(1)(t)O_{3}^{(1)}(t), i.e. neglecting O(n2)\mathcal{O}(n^{-2}) corrections, we can plug in the leading order approximations, to O4(t)O_{4}(t) and f(t)f(t).

Using the explicit form of f(0)(t)f^{(0)}(t) we can write,

Here we are again using a condensed notation. x=(x1,x2,x3)\vec{x}=(x_{1},x_{2},x_{3}). f0f_{0} and yy are vectors over the training set while Θ0\Theta_{0} is a square matrix, and O4(x;0)O_{4}(\vec{x};0) is also a vector over the training set, with the value O4(x,x;0)O_{4}(\vec{x},x^{\prime};0) on the point xDtrx^{\prime}\in D_{\rm tr}.

Plugging in O3(1)O^{(1)}_{3}, and using the eigen-decomposition of Θ0\Theta_{0} to perform the integrals gives,

Here, we have introduced the eigenvectors, e^i\hat{e}_{i}, which are vectors over the training dataset and eigenvalues λi\lambda_{i} of the initial kernel, Θ0\Theta_{0}. The vector x=(x1,x2)\vec{x}=(x_{1},x_{2}) as Θ\Theta depends on two inputs. O3(x;0)O_{3}(\vec{x};0) is a vector over the dataset while O4(x;0)O_{4}(\vec{x};0) is a square matrix and Δf0:=f0y\Delta f_{0}:=f_{0}-y is a vector over the training data.

Finally, we can plug in the expression (E.4.1) into (106) and give the sub-leading behavior for f(t)f(t)

This completes the full O(n1)\mathcal{O}(n^{-1}) time dependance.This prediction for the O(1/n)\mathcal{O}(1/n) time dependence is confirmed experimentally in Figure 10.

Note, one consequence of (114) is that the O(1/n)\mathcal{O}(1/n) corrections to Θ\Theta go to a constant at late times.

And the asymptotic evolution of Δf(t)\Delta f(t) is again constant kernel evolution, but with a corrected kernel.

It is worth noting that these predictions are quite detailed, giving the full time dependence, rather than just the overall scaling with width. In deriving these results and in the discrete time analysis below, we implicitly make an assumption that we can control the discrete time corrections or in the continuous time case that f(t)f(t), Θ(t)\Theta(t) and all the Os(t)O_{s}(t) are differentiable. This relies on being able to differentiate through the activation function. For non-smooth activations such as ReLU, this assumption is suspect, and indeed the evolution is more subtle in the ReLU case and a full analysis is left to future work. For this reason we present experimental results for networks with smooth activation functions.

E.4.2 Discrete time.

With the 1/n1/n corrections to continuous time evolution under our belts, it is possible to also keep track of corrections coming from the discrete update step. The essential point, that it is possible to solve iteratively, order by order in 1/n1/n is not changed. To see this we can look at how the update equations are modified. Beginning with the network output update equation, we have,

In Section E.2 we argued that the O(η2)\mathcal{O}(\eta^{2}) terms vanish at large width. In more detail, each factor of the gradient, gμg^{\mu}, contains a contracted derivative tensor. Thus, the cluster graph for correlation functions involving the η2\eta^{2} term will always contain three contracted vertices and thus the correlation functions will scale as O(n1)\mathcal{O}(n^{-1}). The η3\eta^{3} term will give four vertices, and thus also scale as O(1/n)\mathcal{O}(1/n), the η4\eta^{4} contribution will be O(1/n2)\mathcal{O}(1/n^{2}) and so on.

Just as in the continuous time case this equation is still solvable order by order in 1/n1/n. What we mean by solvable is the following. All terms on the RHS appearing with ftf_{t} are already higher order as a result of their derivative structure. This means, just as in the continuous case, we can use the lower order solution of ftf_{t} in the gradient terms on the RHS to solve for ftf_{t} on the LHS.

Similarly, the update equation for Θ\Theta receives discrete time modifications.

Here we have adopted a notation where, x=(x1,,xs)\vec{x}=(x_{1},\ldots,x_{s}). Just as for ff, increased powers in η\eta are increasingly suppressed in nn and we can iteratively solve this equation using solutions at leading orders in nn to solve for sub-leading behavior.

More generally the tower of differential equations defined recursively in (98) becomes

As in the continuous time case, at a given order in nn, we can truncate this tower and use lower order solutions to solve for the time evolution up to the desired order. To ground this discussion in a concrete use case, we again walk through the 1/n1/n corrections for MSE loss, now in discrete time.

For discrete time evolution, the leading order behavior of the model map is

At next order we can proceed by solving the truncated set of equations in (E.4.2). The first equation, for O4;tO_{4;t} still gives a constant solution,

Here, both the discrete and continuous time terms that would appear on the RHS are O(1/n2)\mathcal{O}(1/n^{2}).

Next, we must solve for O3;tO_{3;t}. The discrete time update is,

To order 1/n1/n we can drop the order η2\eta^{2} and higher terms, as they contain expressions with 5 or more contracted derivative tensors. Thus at this order, we are left with the discrete analogue of (109).

which we can sum to get the discrete version of (111),

Here we have adopted notation similar to above, where O4;0(x)O_{4;0}(\vec{x}) is a vector over the training dataset.

So far, this procedure has been a discrete analogue of what we have done in the continuous time case, however as we move on to compute Θ\Theta and ff we will have to keep track of the novel corrections, which vanish in continuous time. Explicitly, at order 1/n1/n the discrete update for Θt\Theta_{t} is given by,

This can be summed to give Θt(1)\Theta_{t}^{(1)}.

To move up to the neural network map itself there is one additional complication. The O(η2/n)\mathcal{O}(\eta^{2}/n) term, μν2ft(x)θμθνft(x)θμft(x)θν\sum_{\mu\nu}\frac{\partial^{2}f_{t}(x)}{\partial\theta^{\mu}\partial\theta^{\nu}}\frac{\partial f_{t}(x^{\prime})}{\partial\theta^{\mu}}\frac{\partial f_{t}(x^{\prime\prime})}{\partial\theta^{\nu}}, in (119) has non trivial time dependence. We can deal with this just as we have been doing, by taking an extra time derivative and integrating. Defining,

This can in turn be plugged into the update equation for ftf_{t}.

Here (xa,ya)(x_{a},y_{a}) are elements of the training set DtrD_{\rm tr} and summed over. This equation can be solved to give ft(1)f_{t}^{(1)}.

Where Disct\textrm{Disc}_{t} contains the discrete derivative updates at O(1/n)\mathcal{O}(1/n).

These expressions may look fairly intimidating. The key point is that all terms in the summand in (129) and thus (134) are known functions of time and initial data, just as in the continuous time setting.

References