Asymptotics of Wide Networks from Feynman Diagrams
Ethan Dyer, Guy Gur-Ari
Introduction
Neural networks achieve remarkable performance on a wide array of machine learning tasks, yet a complete analytic understanding of deep networks remains elusive. One promising approach is to consider the large width limit, in which the number of neurons in one or several layers is taken to be large. In this limit one can use a mean-field approach to better understand the network’s properties at initialization , as well as its training dynamics . Additional related works are cited below.
Suppose that is the network function evaluated at an input . Let us denote the vector of model parameters by , whose elements are initially chosen to be i.i.d. Gaussian. In this work we consider a class of functions we call correlation functions, obtained by taking the ensemble averages of , its products, and its derivatives with respect to the parameters , evaluated on arbitrary inputs. Here are a few examples of correlation functions.
Correlation functions often show up in the study of wide networks. For example, the first correlation function in (1) plays a central role in the Gaussian Process picture of wide networks , and has been used to diagnose signal propagation in wide networks . The second example in (1) is the ensemble average of the Neural Tangent Kernel (NTK), which controls the evolution of wide networks under gradient flow , and the third example shows up when computing the time derivative of the NTK with MSE loss.
While correlation functions can be computed analytically in some special cases , they are not analytically tractable in general. In this work, we present a method for bounding the asymptotic behavior of such functions at large width. Derivation of the method relies on Feynman diagrams , a technique for calculating multivariate Gaussian integrals, and specifically on the ’t Hooft expansion . However, applying the method is straightforward and does not require any knowledge of Feynman diagrams.
We present a general method for bounding the asymptotic behavior of correlation functions. The method is an adaptation of Feynman diagrams to the case of wide neural networks. The adaptation involves a novel treatment of derivatives of the network function, an element that is not present in the original theoretical physics formulation.
We apply the method to the study of wide network evolution under gradient descent. We improve on existing results for gradient flow by deriving tighter bounds, and extending the analysis to the case of stochastic gradient descent (SGD). Going beyond the infinite-width limit, we present a formalisn for deriving finite-width corrections to network evolution, and present explicit formulas for the first order correction. To our knowledge, this is the first time this correction has been calculated.
As additional applications of our method, in Appendix E.2 we show that in the large width limit the SGD updates are linear in the learning rate, and in Appendix E.3 we discuss finite width corrections to the spectrum of the Hessian.
The main result of this paper is a conjecture. We test our predictions extensively using numerical experiments, and prove the conjecture in some cases, but we do not have a proof that applies to all the cases we tested, including for deep networks with general non-linearities. Furthermore, our method can only be used to derive asymptotic bounds at large width; it does not produce the width-independent coefficient, which is often of interest.
For additional works on wide networks, including relating them to Gaussian processes, see . For additional works discussing the training dynamics of wide networks see . For a previous use of diagrams in this context, see .
The Neural Tangent Hierarchy presented in , published during the completion of this version, has significant overlap with the recursive differential equations (11) presented below.
The rest of the paper is organized as follows. In Section 2 we present our main conjecture and supporting evidence. In Section 3 we apply the method to gradient descent evolution of wide networks, and in Section 4 we present details on Feynman diagrams, which is the basic technique used in our proofs. We conclude with a Discussion. Proofs, additional applications, and details can be found in the Appendices.
Note: An earlier version of this work appeared in the ICML 2019 workshop, Theoretical Physics for Deep Learning .
Correlation function asymptotics
In this section we present our main result: a method for computing asymptotic bounds on correlation functions of wide networks. We present the result as a conjecture, supported by analytic and empirical evidence.
Let us now define correlation functions, the class of functions that is the focus of this work. These functions involve derivative tensors of the network function. We denote the rank- derivative tensor by For we define , and still refer to this as a derivative tensor for consistency.
A correlation function is the expectation value of a product of derivative tensors, evaluated at arbitrary inputs, where the tensor indices are summed in pairs over all the model parameters. A general correlation function takes the form
Here, are integers,When , the tensor has no derivatives. and are even, is a permutation, and We use to denote the Kronecker delta.
2 Asymptotic bounds on wide networks
We now present our main conjecture, which allows us to place asymptotic bounds on general correlation functions of wide networks.
We will refer to the connected components of a cluster graph as the clusters of . Table 1 lists examples of bounds derived using the Conjecture for several correlation functions. The intuition behind Conjecture 1 comes from the following result for deep linear networks.
Conjecture 1 holds for correlation functions of networks with linear activations.
Let us discuss the intuition behind this theorem. Computing correlation functions of deep linear networks amounts to evaluating Gaussian integrals with polynomial integrands in . One can evaluate such integrals using Isserlis’ theorem, which tells us how to express moments of multivariate Gaussian variables in terms of their second moments. For example, given centered Gaussian variables ,
Every correlation function of a deep linear network can be similarly reduced to sums over products of Kronecker delta functions and width-independent functions of the inputs. The asymptotic large width behavior is determined by these sums over delta functions, which are tedious to compute by hand. Feynman diagrams are a graphical tool for computing these sums, allowing us to obtain the asymptotic behavior with minimal effort. This tool, which is described in detail in Section 4, is used to prove Theorem 1.
For networks with non-linear activations we further show the following
Conjecture 1 holds for (1) networks with ReLU activations, where all inputs are set to be equal, and for (2) networks with one hidden layer and smooth activation.
For case (1), the idea behind the proof is to put an asymptotic bound on the ReLU network in terms of a corresponding deep linear network. For case (2), the basic idea is that each network function contains a single sum over the width, and by keeping track of these sums using Feynman diagrams we are able to bound the asymptotic behavior. We refer the reader to Appendix C for details.
3 Numerical experiments
Table 1 lists asymptotic bounds on several correlation functions, derived using Conjecture 1. These are compared against the asymptotic behavior computed using numerical experiments. In addition to the results presented here, we performed experiments using the same correlation functions and experimental setup, but with weights sampled uniformly from instead of from a Gaussian distribution. The results are shown in Appendix A.1. In all cases tested, we found that Conjecture 1 holds. In most cases, we find that the bound is tight. For cases where the bound is not tight, a tight bound can be obtained using the complete Feynman diagram analysis presented below.
Applications to training dynamics
Here, is the Neural Tangent Kernel (NTK), defined by . The authors of showed that the kernel is constant during training up to corrections. This leads to a dramatic simplification in training dynamics . In particular, for MSE loss the network map evaluated on the training data evolves as . Here we are using condensed notation: and are values at initialization, and , , are treated as vectors in training set space. The kernel is a square matrix in the same space. We will use our technology to derive a tighter bound on finite-width corrections to the kernel during training and present explicit formulas for the leading correction.
The following result is useful in analyzing the behavior of correlation functions under gradient flow.
Here we prove the statement for . Appendix D.2 contains a proof for the general case.
Let have even clusters and odd clusters. Consider the correlation function
Denote by () the number of even (odd) clusters in this correlation function, which has derivative tensors. One can check that either or , depending on whether the derivative is acting on an odd or even cluster in .Here we are extending the use of the term cluster to refer to derivative tensors in the integrand itself. Therefore, . ∎
With this result, it is easy to understand the constancy of the NTK at large width. The first derivative of the NTK is given by
Next, we will compute the explicit time dependence of and at order under gradient flow. This is the leading correction to the infinite width result. We define the functions and
Notice that is the kernel. It is easy to check that
Let us denote the time-evolved kernel by , where is the kernel at initialization, and is the correction we are seeking. Integrating equations (11) starting with , we find
Here we have introduced the notation . A detailed derivation can be found in Appendix E.4. There we also evaluate the integrals in (3.1) in terms of the NTK spectrum.
To obtain the correction to the network map (evaluated for simplicity on the training data), we further integrate (11) for and find
Here we have denote the infinite width evolution by . Figures 1(b) and 1(c) compare these predictions against empirical results.
Feynman diagrams for deep linear networks
In this section we present the Feynman diagram technique, and show how it allows us to compute the asymptotic behavior of correlation functions. We end this section with a proof of Theorem 1 for the case of networks with a single hidden layer and linear activations.
Given a correlation function , we map it to a family of graphs called Feynman diagrams. The graphs are independent of the inputs, and are defined as follows.
Let be a correlation function for a network with hidden layers. The family is the set of all graphs that have the following properties.
There are vertices , each of degree .
Each edge has a type . Every vertex has one edge of each type.
If two derivative tensors are contracted times in , the graph must have at least edges (of any type) connecting the vertices .
The graphs in are called the Feynman diagrams of .
As the factors of and are independent of , we see that . Notice that there are two relevant contributions to this answer: each factor of the network function in the integrand contributes , and the summed-over product of Kronecker deltas contributes . Other details, such as the input dependence, are irrelevant. Feynman diagrams allow us to encode only those details that affect the scaling, ignoring the rest.
The set for the correlation function (14) consists of a single Feynman diagram, shown in Figure 2(a).
The asymptotic bound on a correlation function is obtained by the following result, which is due to .
Let be a correlation function with one hidden layer and linear activation. Then where , and is the number of loops in .
We are now ready to prove Theorem 1 for the case of single hidden layer with linear activations. A proof for the general case can be found in Appendix B.
Discussion
Ensemble averages of the network function and its derivatives are an important class of functions that often show up in the study of wide neural networks. Examples include the ensemble average of the train and test losses, the covariance of the network function, and the Neural Tangent Kernel . In this work we presented Conjecture 1, which allows one to derive the asymptotic behavior of such functions at large width.
For the case of deep linear networks, we presented a complete analytic understanding of the Conjecture based on Feynman diagrams. In addition, we presented empirical and anlytic evidence showing that the Conjecture also holds for deep networks with non-linear activations, as well as for networks with non-Gaussian initialization. We found that the Conjecture holds in all cases we tested.
The basic tools presented in this work can be applied to many aspects of wide network research, greatly simplifying theoretical calculations. We presented several applications of our method to the asymptotic behavior of wide networks during stochastic gradient descent, and additional applications are presented in Appendix E. We were able to improve upon known results by tightening existing bounds, and by applying the technique to SGD as well as to gradient flow. In addition, we took a step beyond the infinite width limit, deriving closed-form expressions for the first finite-width correction to the network evolution. These novel results open the door to studying finite-width networks by systematically expanding around the infinite width limit.
A central question in the study of wide networks is whether the infinite width limit is a good model for describing the behavior of realistic deep networks . In this work we take a step toward answering this question, by working out the next order in a perturbative expansion around the infinite width limit, potentially bringing us closer to an analytic description of finite-width networks. We hope that the techniques presented here provide a basis to systematically answering these and other questions about the behavior of wide networks.
Acknowledgements
The authors would like to thank Alex Alemi, Yasaman Bahri, Boris Hanin, Jared Kaplan, Jaehoon Lee, Behanm Neyshabur, Sam Schoenholz, Sylvia Smullin, and Jascha Sohl-Dickstein for useful discussion. The authors would especially like to thank Ying Xiao for extensive comments on early versions of this manuscript.
Appendix A Experimental details and additional results
In Table 1 we listed asymptotic bounds on several correlation functions, where the model parameters were initialized from a Gaussian distribution. Table 2 shows additional results using the same experimental setup, but with weights sampled uniformly from . We again find good agreement with the predictions of Conjecture 1.
A.2 Experimental details
The experiments in Figure 1 were performed on two-class MNIST, computing a single randomly-chosen component of the kernel . Sub-figure (a) uses networks trained for 1024 steps with learning rate 1.0 and 1000 samples per class, averaged over 100 initializations. Each curve in figure (b) represents a single instance of the network map evaluated on a random image over the corse of training. The models were trained with 10 samples per class and learning rate 0.1. The input to the network is normalized by the square root of the input dimension as in
Appendix B Feynman diagrams for deep linear networks
Feynman diagrams can be used to derive asymptotic upper bounds on deep linear networks in the large width limit. In this section we describe the method in detail, and use it to prove Theorem 1.
In this section we build on the results of Section 4 and consider correlation functions of deep linear networks with hidden layers. The network function was defined in (2), and here we set the activation to be the identity. Definition 2 describes how to map a correlation function to , a family of graphs called Feynman diagrams. The Feynman diagram method relies on Isserlis’ theorem, which allows us to express arbitrary moments of multivariate Gaussian variables in terms of their covariance.
Let be a centered multivariate Gaussian variable. For any positive integer ,
In particular, if the covariance matrix of is the identity then
Using this theorem, a correlation function can be expressed as a sum over permutations as in (17). Each term in this sum maps to a Feynman diagram in .
A similar question appeared in the context of theoretical physics, and the correct generalization is due to . The idea is to treat each Feynman diagram as the triangulation of a Riemann surface, and to define the number of loops in a graph to be the number of faces of the triangulation. In practice, this involves mapping each Feynman diagram to a new double-line diagram: A graph in which each edge corresponds to a single Kronecker delta factor, and loops correspond to triangulation faces of the original diagram.
Each edge in of type is mapped to two edges , .
Each edge in of type is mapped to a single edge .
Each edge in of type is mapped to a single edge .
The discussion above is summarized by the following result, due to , that describes how the asymptotic behavior of a general correlation function can be computed using the Feynman rules for deep linear networks.
The intuition for the formula is similar to the single hidden layer case, Theorem 3. The term counts the number of factors of the form that appear in the Correlation function after applying Isserlis’ theorem. The term is due to the explicit normalization of the network function.
B.2 Asymptotics of deep linear networks
We now prove Theorem 1. The theorem follows from the following lemma, again due to , that relates the asymptotic behavior to the number of connected components in a Feynman diagram.
Let be a correlation function for a deep linear network. Let be the number of connected components of a graph . Then , where
Let us prove the remaining statement about . The Euler character of the graph with vertices, edges and faces is . The degree of each vertex in the graph is , and therefore . Using Theorem 5 the graph is where . We therefore find that . The graph is a triangulation of some connected surface with at least one boundary. The Euler character for such a surface is bounded by , and therefore . ∎
Let be a correlation function for a deep linear network. Suppose that the cluster graph has even size components and odd size components. Let be a Feynman diagram with connected components. We will show that . It then follows immediately from Lemma 2 that where , concluding the proof.
Let us derive the bound on . First, all vertices that belong to a given cluster (a component of ) will also belong to the same connected component in . This is because every edge in is also an edge in (note that and have the same set of vertices). Therefore, . Second, note that every connected component of the graph has an even number of vertices. Indeed, each edge has a type , and each vertex has exactly one edge of each type. Therefore, a connected component with vertices has edges of each type, and so must be even. It follows that the vertices of even clusters can form their own connected components in a Feynman diagrams, while odd clusters must be connected in sets of 2 or more to form connected components. The bound on then follows. ∎
Appendix C Non-Linearities
In previous sections we presented the Feynman diagram method for computing the large width asymptotics of correlation functions. In this section we show that the method applies as-is for deep networks with ReLU non-linearities and all-equal inputs, as well as to networks with a single hidden layer, a broader class of non-linearities, and arbitrary inputs. Theorem 2 follows immediately from the results presented in this section.
The following result guarantees that the presence of ReLU non-linearities does not change the asymptotic upper bound compared with linear activations, when all inputs are the same.
Intuitively, we will rely on the fact that for ReLU networks we can, in some cases, treat the binary neuron activations as being statistically independent of the weights. This result is due to . Given this result, we can bound the contribution of the binary activations, and the remaining Gaussian integral is equivalent to that found in a deep linear network. The proof does not work for correlation functions with non-equal inputs, because in that case the independence result of no longer holds.
Here, is the Heaviside step function acting elementwise on its vector argument. We now introduce the construction from . Let , be diagonal matrices of dimension , whose diagonal elements are -Bernoulli() variables with . We define the new variables
are independent of .
are independent of each other for fixed . The diagonal entries of each diagonal matrix are independent, and take the values with probability .
Now, the correlation function can be written as
Here, and , and
In writing the equation (34) we used the facts that are independent of the parameters, and that parameters in different layers are independent. We now use Theorem 4 (Isserlis), which says that each of the expectation values over products of , , and elements is equal to a sum over permutations, where each term is a product over Kronecker delta functions — the covariance matrices of the parameters.
We can now bound the correlation function as follows.
C.2 Single hidden layer networks
For networks with a single hidden layer, defined by
we can extend our asymptotic analysis to smooth non-linearities. We will show in Theorem 7 that for any correlation function , we have
where is a deep linear network of equal width and sufficient depth. Therefore, computing the asymptotics using Feynman diagrams for deep linear networks yields a bound on networks with a single hidden layer and smooth non-linearities.
In the last line, we again used the fact that the summands are equal and independent of .
Our approach to establishing the asymptotic scaling will be to first bound the maximum number of index sums appearing in any term in our correlation function, written in the form (48), and then to argue that the summands are bounded by an -independent constant.
Let us introduce a family of diagrams, , which are different in general than the Feynman diagrams. A given diagram is constructed as follows.
Each derivative tensor in is mapped to a vertex in the graph.
Each edge has a type that corresponds to one of the weight vectors or .
Each vertex has exactly one edge of type.
If two derivative tensors are contracted in , the graph must have at least one edge (of any type) connecting the corresponding vertices for each contraction.
Let be the number of even(odd) clusters in the cluster graph of . The cluster graph, is a subgraph of any graph . We can thus think about the embedding of even and odd clusters into . In any graph , an even cluster may belong to its own connected component, while for odd clusters there must be an even number of them in any connected component. This is because an even (odd) cluster contains an even (odd) number of factors of , which must be paired up in any connected component. We find that
The last inequality was used below Lemma 2 in the proof of Theorem 1. ∎
The correlation function in (48) can be bound as
For fixed inputs, can take at most different values as a function of its indices, and the values are independent of . This is because the variables are identical. We define as the maximum value of as a function of and the indices. Combining this with the above lemmas we can then write.
The result of the theorem follows from Lemma 2. ∎
Appendix D Correlation function asymptotics
In this section we prove several general results about correlation function asymptotics in the large width limit. Throughout this section, we assume that Conjecture 1 holds.
Conjecture 1 can be used to bound the variance of the integrands that appear inside correlation functions.
To bound the variance, it is enough to bound the correlation function
As a corrolary, notice that if according to Conjecture 1, then typical realizations of the integrand will also be . In other words, the asymptotic bound of Conjecture 1 holds for typical initializations, not just in expectation.
D.2 Gradient Flow
The following results are used in the gradient flow calculations of Section 3.
Let be a graph with vertices, even components, and odd components. Let be a subgraph of with vertices, even components, and odd components. Then where .
It is enough to show that does not increase if we (1) add a vertex to , or (2) add an edge to , because can be obtained from by performing such operations finitely many times. Adding a vertex to changes , , and , leaving unchanged. Next, if we add an edge to then does not change, and there are 4 possibilities for how and change.
The edge connects two even components. Then , , and decreases by 1.
The edge connects two odd components. Then , , and does not change.
The edge connects an even component and an odd component. Then , , and decreases by 1.
The edge connects two vertices that belong to the same component. In this case , , and do not change.
We now prove Lemma 1 giving the scaling of time derivatives of correlation functions at initialization. We prove the result for polynomial loss functions. Extension to more general loss functions requires interchanging the large width limit and the Taylor expansion of the loss, which we do not discuss.
D.3 Stochastic Gradient descent
Let be a correlation function for a network with linear activations, and assume that Conjecture 1 holds, namely where is defined in the Conjecture. If is the evolved correlation function after SGD steps, then .
The integrand can be written as a product of derivative tensors of the form , with contracted derivatives. Suppose that the network has hidden layers. Then under an SGD step, we have
Here is the mini-batch, and are mini-batch samples. The sum is truncated because higher-order derivatives of vanish.
We can now see how taking a gradient descent step affects the cluster graph. After taking a step, each derivative tensor in is replaced by a sum (over ). Each term in the combination of these sums is a correlation function, whose cluster graph is a subgraph of . Therefore, by Lemma 6, . ∎
We note that for general activation functions the sum in (60) may be infinite. In this case, to complete the proof we would need to show that the infinite sum obeys the same bound as each individual term in the sum. We leave this for future work.
Appendix E Applications
Here we present several applications of our Feynman diagram method for computing large width asymptotics. We assume throughout this section that Conjecture 1 holds.
In this section we prove two results regarding the NTK at large width. We show that the kernel converges in probability, and that during gradient descent it is constant up to corrections.
The Neural Tangent Kernel of a deep linear network converges in probability in the large width limit, and its variance is .
Conjecture 1 is not sufficient for proving this theorem, as we need to use a more detailed Feynman diagram argument. Therefore, we only prove the theorem for the case of deep linear networks.
includes all diagrams in which the vertices corresponding to share an edge, and also the vertices share an edge (due to the explicit derivatives);
includes all diagrams in which all edges are either between or between .
Therefore, in the full expression (61), the only remaining diagrams (i.e. the diagrams that do not cancel between the two terms) are those that include
an edge connecting , and
an edge connecting one of to one of .
Next, we show that the large width NTK is constant during training, and compute the asymptotics of the higher-order terms. The following argument is phrased for deep linear networks. More generally, the same argument holds for deep networks with smooth non-linear activations under the additional assumption that the large width limit and Taylor series can be exchanged (note that for such networks, the network function is analytic in the weights).
Let be the network output of a deep linear network with MSE loss . Let be the Neural Tangent Kernel at SGD step , for some inputs . Then in the large width limit, the kernel is constant in in expectation, and
The fact that the sums are truncated at follows from using linear activations, as higher-order derivatives of the network function vanish in this case. All terms in the sum over include a tensor product of the form with either or , where stands for additional derivative tensor factors. Therefore, all terms in the sum are correlation functions that have a cluster of size at least 3, including , , and at least one other tensor contracted through the or index. It follows from the Conjecture that each term in the sum is . Lemma 5 then implies that the variance of these updates is . ∎
E.2 Wide network evolution is linear in the learning rate
In this section we prove that, at large width, the NTK determines the evolution of the network function not just for continuous-time gradient descent but also for discrete-time gradient descent. A similar result holds for stochastic gradient descent, using a stochastic kernel.The perspective presented here helps understand the results of where it was observed empirically that linearized evolution is a good description of wide networks even for relatively large learning rates. Again we prove the deep linear case explicitly, but the result holds for deep networks with smooth non-linear activations under the additional assumption that the large width limit and Taylor series can be exchanged.
Let be the network output of a deep linear network, and let be the evolved function after gradient descent steps, defined by . In the large width limit, each gradient descent step update of is linear in the learning rate . Furthermore,
For a deep linear network, under a single gradient descent step we have
First, consider the sum over in (72). Each term in the sum is a correlation function for which the cluster graph contains a connected subgraph of size at least 3, and is therefore by Lemma 6 and Theorem 1. In the remaining term, by the same argument as Theorem 10 we can replace in the correlation function. The result is equation (72). ∎
As mentioned above, we note that the proof goes through when using stochastic gradient descent updates, with the difference that in (72) we should sum over mini-batch samples instead of over the entire training set.
E.3 Spectral properties of the NTK and the Hessian
With an eye towards understanding the structure of the loss landscape at large width and as another example use case of our approach, we investigate the relation between the spectra of the Hessian, and the NTK.
The Hessian of a general loss takes the form,
Thus most moments of the Hessian are equal to moments of in expectation.For the first moment in linear or ReLU networks, .
What’s more, we can relate moments of to those of the kernel, as both are built out of two logit derivatives. Explicitly,
and so the moments of the Hessian are also related to those of the kernel.Here we have argued for relations relating the mean of moments. It is not too difficult to see that these relations will also hold for typical realizations. This follows from Lemma 5.
For the case of MSE loss, we can go even further. In that case, and decays to zero during training. We thus have,
These results indicate that the only difference between the spectra of the Hessian and the NTK come from and that must have eigenvalues which scale as . As for one hidden layer networks and for deep networks we are left with the relation (78) between the eigenvalues of and . As the network trains, the difference between these eigenvalues gets even smaller. These results are confirmed experimentally in Figure 8 and Figure 9.
Expression (86) relating the moments of the Hessian to the NTK still holds in this context provided we take , with the more general matrix,
At large width, the NTK approaches the identity matrix in class space, . This implies that the NTK spectrum consists of repeated copies. This has consequences for the spectrum of the Hessian at large width. For instance, for the case of MSE error it also implies repeated copies of the Hessian spectrum (See Figure 8(b)). It is intriguing to think this could serve as a path towards understanding the emergence of the (or ) large eigenvalues and corresponding subspace observed in the Hessian spectrum .
E.4 Higher-order network evolution
In this section we will explain how to compute higher-order corrections to training dynamics. In principle, this prescription allows one to compute model dynamics as an expansion in to arbitrary order. We apply this explicitly to compute the training dynamics. In Figure 10, we present experimental confirmation of our predictions for the evolution of the NTK. In the main text, we presented these results for the special case of gradient flow with MSE loss.
So far we have mostly been using diagrammatic techniques to understanding the leading order scaling of correlation functions. It is interesting to try and understand finite width networks by asking how training dynamics are modified beyond the leading order asymptotics. There are three clear sources of corrections to the leading behavior.
The initial kernel, , receives finite width corrections.
The linear (in learning rate) update for the network function, equation (72), is modified away from infinite width.
The kernel is not constant at finite width.
The first source of corrections is automatically taken into account in typical empirical settings, as the finite-width can be computed explicitly. The other sources of corrections are non-trivial, and will be the focus of this section. We begin by explaining how to take into account the non-constancy of the kernel order by order in , while maintaining the continuous time approximation. Next, we back away from the continuous time limit and write down the full discrete evolution. We introduce a method to compute arbitrary order corrections, and explicitly work out the corrections at order for MSE loss.
Continuous time evolution of the model function in neural networks is governed by the differential equation
As we have discussed at length, in the large width limit, this equation simplifies and is asymptotically constant . Our goal is to move beyond this leading behavior at large width and solve (96) order-by-order in a expansion. In Section 3 we described how to compute the corrections for the case of MSE loss. Here we explain how to handle corrections more generally as well as giving a more detailed discussion of the MSE case.
The network map and kernel, are members of the following family of operators.
with . Here, as above, is the kernel. With a general loss function, these operators satisfy.
Equations (97) and (98). Give an infinite tower of first order ODEs, the solution of which gives the time evolution of the network map and the kernel.
Solving this infinite tower is not feasible. If we wish to work to a given order in an expansion in , however, there is a dramatic simplification, which makes a solution possible. Firstly, we can truncate these equations. To see this note that the operators contain contracted derivative tensors. As a result, by Lemma 6 and Conjecture 1, correlation functions involving the operators, satisfy
where is arbitrary additional contribution to the integrand.
Thus, if we wish to work to solve for the time evolution up to corrections which scale as we can truncate the tower at and set to be equal to its initial value. Note that the leading order solution, (96), is the result of this procedure with and the results presented in the main text for the evolution correspond to .
The truncation provides a dramatic simplification, however it is not immediately clear how to solve even the truncated differential equations in (98). We now describe how to organize the perturbative expansion of the operators (including and ) in such a way that the differential equations become tractable.
The central idea is to write each operator, , as an expansion.
where each order captures the evolution of . For example,
The notation means both that any correlation function containing is and, by Lemma 5, that typical realizations of the operators scale as .Note that in Section 3 we used the alternate notation for .
Once the operators are organized in this way, solving the differential equations (96) is tractable. As the differential equations describing the evolution of for only depend on the time dependent solutions of , we can iteratively solve for the order by order. For example, in Section 3, we used the leading order solution for and to solve for .
In principle this procedure can be extended to arbitrary order in . Before going onto explain the finite step corrections to this procedure, we reproduce the results of section 3 in more detail.
As such the update equation simplifies to,
The leading order solution to equation (106) is exponential, kernel evolution,
Here we are using a condensed notation where , . Equation (107) is a vector equation with , , and are vectors over the training set and the leading order kernel is a square matrix over the same space.
As discussed above to study the evolution, we can truncate the set of equations in (98) at , and set up to corrections which scale as .
To solve for , i.e. neglecting corrections, we can plug in the leading order approximations, to and .
Using the explicit form of we can write,
Here we are again using a condensed notation. . and are vectors over the training set while is a square matrix, and is also a vector over the training set, with the value on the point .
Plugging in , and using the eigen-decomposition of to perform the integrals gives,
Here, we have introduced the eigenvectors, , which are vectors over the training dataset and eigenvalues of the initial kernel, . The vector as depends on two inputs. is a vector over the dataset while is a square matrix and is a vector over the training data.
Finally, we can plug in the expression (E.4.1) into (106) and give the sub-leading behavior for
This completes the full time dependance.This prediction for the time dependence is confirmed experimentally in Figure 10.
Note, one consequence of (114) is that the corrections to go to a constant at late times.
And the asymptotic evolution of is again constant kernel evolution, but with a corrected kernel.
It is worth noting that these predictions are quite detailed, giving the full time dependence, rather than just the overall scaling with width. In deriving these results and in the discrete time analysis below, we implicitly make an assumption that we can control the discrete time corrections or in the continuous time case that , and all the are differentiable. This relies on being able to differentiate through the activation function. For non-smooth activations such as ReLU, this assumption is suspect, and indeed the evolution is more subtle in the ReLU case and a full analysis is left to future work. For this reason we present experimental results for networks with smooth activation functions.
E.4.2 Discrete time.
With the corrections to continuous time evolution under our belts, it is possible to also keep track of corrections coming from the discrete update step. The essential point, that it is possible to solve iteratively, order by order in is not changed. To see this we can look at how the update equations are modified. Beginning with the network output update equation, we have,
In Section E.2 we argued that the terms vanish at large width. In more detail, each factor of the gradient, , contains a contracted derivative tensor. Thus, the cluster graph for correlation functions involving the term will always contain three contracted vertices and thus the correlation functions will scale as . The term will give four vertices, and thus also scale as , the contribution will be and so on.
Just as in the continuous time case this equation is still solvable order by order in . What we mean by solvable is the following. All terms on the RHS appearing with are already higher order as a result of their derivative structure. This means, just as in the continuous case, we can use the lower order solution of in the gradient terms on the RHS to solve for on the LHS.
Similarly, the update equation for receives discrete time modifications.
Here we have adopted a notation where, . Just as for , increased powers in are increasingly suppressed in and we can iteratively solve this equation using solutions at leading orders in to solve for sub-leading behavior.
More generally the tower of differential equations defined recursively in (98) becomes
As in the continuous time case, at a given order in , we can truncate this tower and use lower order solutions to solve for the time evolution up to the desired order. To ground this discussion in a concrete use case, we again walk through the corrections for MSE loss, now in discrete time.
For discrete time evolution, the leading order behavior of the model map is
At next order we can proceed by solving the truncated set of equations in (E.4.2). The first equation, for still gives a constant solution,
Here, both the discrete and continuous time terms that would appear on the RHS are .
Next, we must solve for . The discrete time update is,
To order we can drop the order and higher terms, as they contain expressions with 5 or more contracted derivative tensors. Thus at this order, we are left with the discrete analogue of (109).
which we can sum to get the discrete version of (111),
Here we have adopted notation similar to above, where is a vector over the training dataset.
So far, this procedure has been a discrete analogue of what we have done in the continuous time case, however as we move on to compute and we will have to keep track of the novel corrections, which vanish in continuous time. Explicitly, at order the discrete update for is given by,
This can be summed to give .
To move up to the neural network map itself there is one additional complication. The term, , in (119) has non trivial time dependence. We can deal with this just as we have been doing, by taking an extra time derivative and integrating. Defining,
This can in turn be plugged into the update equation for .
Here are elements of the training set and summed over. This equation can be solved to give .
Where contains the discrete derivative updates at .
These expressions may look fairly intimidating. The key point is that all terms in the summand in (129) and thus (134) are known functions of time and initial data, just as in the continuous time setting.