Tensor Programs IIb: Architectural Universality of Neural Tangent Kernel Training Dynamics
Greg Yang, Etai Littwin
Introduction
(Jacot et al., 2018)’s pioneering work showed that a multi-layer perceptron (MLP) trained by gradient descent (GD) evolves like a linear model. This spurred a flurry of research papers using this insight to tackle the core questions in deep learning theory, from optimization to generalization in both finite and infinite width regimes. (Jacot et al., 2018)’s argument consists of two observations:
For the output of a network with parameters given example , (Jacot et al., 2018) identified the kernel , known as the Neural Tangent Kernel (NTK). They showed that if is parametrized and initialized appropriately, then converges to a deterministic kernel as the width of the network tends to infinity.
As the infinitely wide network is trained by gradient descent, the NTK remains frozen in its initial state, and the network evolves as by kernel gradient descent with kernel
In (Yang, 2020a), the ntkInit property was proven to hold for standard architectures, meaning any composition of MLPs, recurrent neural networks (RNN), LSTMs (Hochreiter & Schmidhuber, 1997), gated recurrent unit (GRU) (Cho et al., 2014), convolutions (Fukushima, 1980, 1975; Lecun et al., 1998, 2000; Rumelhart et al., 1986), residual connections (He et al., 2016; Huang et al., 2017), batch normalization (Ioffe & Szegedy, 2015), graph neural networks (Bruna et al., 2014; Defferrard et al., 2016; Duvenaud et al., 2015; Henaff et al., 2015; Kipf & Welling, 2017) and attention (Bahdanau et al., 2015; Vaswani et al., 2017), along with arbitrary weight sharing between components. More generally, it holds for any architecture expressible in a so-called Tensor Program (Yang, 2019b, a, 2020a, 2020b), of which the standard architectures are a subset. However, their reasoning is limited to initialization only.
A statement is architecturally universal if it holds for any reasonable neural architecture. This is an informal property, but here we will formalize it by taking reasonable to be “expressable in Tensor Programs.” By the expressiveness of such programs (Yang, 2019a, 2020a), architectural universality is a fairly robust notion that covers present (and, we expect, future) architectures comprehensively. In this terminology, (Yang, 2020a) showed that ntkInit is architecturally universal.
We show the architectural universality of the entire NTK theory by proving ntkTrain for the same architectures discussed above, including all standard architectures. In the process, we introduce a new graphical form of Tensor Programs that is both required in our proofs and useful for the pedagogy of Tensor Programs.
This paper follows (Yang, 2019b, a, 2020a, 2020b; Yang & Hu, 2020) in the series. While we number this paper “IIb” right after (Yang, 2020a), we actually need the complete theoretical foundation developed in III (Yang, 2020b). See Footnote 22 for more details.
Background
Let denote the (scalar) output of a neural network parameterized by , given example . To understand how the output changes with a slight change in the network parameters , we may naively expand the network function using the first order Taylor expansion around a base point :
Under the SGD algorithm, the weight update is given by the gradient where is the loss derivative, is a sample from the training set, and is the learning rate. Plugging into Eq. 1, we get:
where is the NTK. The NTK theory of infinitely wide neural networks as first proposed by (Jacot et al., 2018) boils down to the the following observations: When the width of tend to infinity, the NTK converges to a fixed kernel at random initialization, independent of the specific instantiation of the weights, and remains frozen during the optimization process. Eq. 2 then gives an accurate description of the output evolution with if we substitue with . The seemingly complex optimization trajectory of SGD therefore reduce to the convex trajectory of kernel gradient descent with a time-independent kernel . Consider the output of the network on the full training dataset. As shown in (Jacot et al., 2018), when the loss is used the evolution of the output at time under continuous time GD (i.e. gradient flow) takes a simple form:
where is the full NTK matrix evaluated on the training data, is the label function, and is the output at initialization. Hence, provided is full rank, as we have that , and the network can fit the training data perfectly.
A common theme in showing ntkTrain for MLP is to derive high-probability bounds on the deviation of the NTK from its initial value after training (e.g. Allen-Zhu et al. (2018); Du et al. (2018); Zou et al. (2018)).111 In the original NTK paper (Jacot et al., 2018), the limit is taken as each layer width goes to infinity sequentially, which already doesn’t make sense for weight-tied architectures like RNNs. Obtaining these bounds usually requires developing ad hoc methods on a per-architecture basis, hindering the scalability of the method to other settings. In the present work we take a more holistic approach, leveraging the recently developed Tensor Programs framework (Yang, 2019b, a, 2020a, 2020b). It consists of two layers of arguments: 1) The bottom layer analyzes how the distribution of (pre-)activations change throughout the course of training; this crucially leverages the mathematical machinery of the Tensor Programs Master Theorem.222In particular, we need to use the Master Theorem in (Yang, 2020b), so (Yang, 2020a) could not have obtained ntkTrain at the same time as ntkInit. 2) The top layer packages these insights systematically via the notion of paths so as to apply to any architecture expressible by a Tensor Program. We will illustrate 1) through examples in Section 3 and 2) through figures in Section 5.1.
In this paper, we will consider the architecture (including depth), data, and training time to be fixed as width .333They will affect the rate of convergence to the infinite-width limit, but since we are only concerned with whether convergence occurs, they do not appear in our theorem statements here. We describe common notations used in the remainder of the paper. For simplicity, we will consider SGD with batch size 1 and learning rate (often set to 1 WLOG).444This generalizes readily to any batch size and learning rate. We use to denote the input and to denote the loss function (absorbing the label) at step . More generally, subscript on any symbol means time . However, for brevity, we abuse notation and shorthand for , and, for any (pre-)activation , for .555We will not refer to the function (likewise for ), so this abuse of notation should cause no confusion. We will also write for the loss derivative . For any vector we define \delta x_{t+1}(\xi)\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\sqrt{n}\big{(}x_{t+1}(\xi)-x_{t}(\xi)\big{)} and dx(\xi)\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\sqrt{n}\frac{\partial f(\xi)}{\partial x(\xi)}. We will track the evolution of on an arbitrary input .666It might help to think of as some test sample, but it can also fall in the training set. Similar to above, we shorthand for .
Motivating Examples
The purpose of this section is to illustrate our key ideas via simple, intuitive examples without diving into the specifics of Tensor Programs. In the process, we will gain insight into how randomness from initialization propagates over the course of training. As these examples intend to provide the reader with the proper intuition, we use informal arguments alone and relegate all formal statements to the appendix. For brevity, we will gloss over minor details or routine calculations, but interested readers can see Appendix A for these omissions.
It turns out that the random initialization and the overparametrization of weights cause each (pre-)activation vector , its gradient , and its (scaled) change every time step to have roughly iid coordinates, not just initially but throughout training.777This is a consequence of the Tensor Program Master Theorem. Then, as we shall demonstrate through the examples below, to track the evolution of the neural network function, it suffices to track the evolution of the coordinate distributions of , , . We write , , for the random variables corresponding to such coordinate distributions.888As we will explain below, different s may correlate, reflecting correlations between corresponding vectors.
Our goal is to derive, from these insights, {restatable}claimclaimNTK In the large width limit, changes by
𝑡1subscript~𝑓𝑡subscript̊𝜒𝑡̊𝒦~𝜉subscript𝜉𝑡\displaystyle\lim_{n\to\infty}\tilde{f}_{t+1}-\tilde{f}_{t}=-\mathring{\chi}_{t}\mathring{\mathcal{K}}(\tilde{\xi},\xi_{t}) (3) at step , where is the limiting NTK of the architecture and is the loss derivative.
We start with an example derivation for 1-hidden-layer MLP, before moving on to 2-hidden-layers, where the mathematics quickly become much more involved.
1 1 Hidden Layer
Consider a 1-hidden-layer network with nonlinearity :
where , for trainable parameter tensor , initialized iid from . In the interest of clarity we assume the output layer is not trained, and .
For a vector , let mean that “ has coordinates of order when is large”999 More rigorously, we mean that . Note this is different from the common interpretation that . ; likewise for , etc. Recall the notations and likewise for . The key insights are as follows:
It turns out so has coordinates. Likewise for . Consequently, for any and input , by telescoping,
subscriptℎ0𝜉𝑜1\displaystyle=h_{0}(\xi)+o(1). (4) Using and , we have:
𝑡1direct-productsubscript𝜒𝑡superscriptsubscript𝜉𝑡top~𝜉superscriptitalic-ϕ′subscriptℎ𝑡𝑣\displaystyle\delta\tilde{h}_{t+1}=-\chi_{t}\xi_{t}^{\top}\tilde{\xi}\phi^{\prime}(h_{t})\odot v. (5) Also, \delta\tilde{x}_{t+1}=\sqrt{n}\big{(}\phi(\tilde{h}_{t}+\frac{\delta\tilde{h}_{t+1}}{\sqrt{n}})-\phi(\tilde{h}_{t})\big{)}. Since , by intuitive Taylor expansion, we have
𝑡1\displaystyle\delta\tilde{x}_{t+1} (6) The change in the output on example from step to step is given by:
By definition has iid coordinates. It turns out (likewise for ) all have approx. iid coordinates of size as well.101010Technically, they have iid coordinates only after conditioning on the initial function (GP) . Likewise, when we take expectation in this example (e.g. Eqs. 9 and LABEL:{eqn:dot}), it’s a conditional expectation of this kind. See Section D.1.1 for a rigorous treatment. However, to convey the main intuition, we gloss over this technicality here. Let denote the random variables encoding the corresponding coordinate distributions; likewise for the other vectors. Note that will in general be correlated, reflecting the coordinatewise correlation between and .
𝑡1subscript~𝑓𝑡𝔼superscript𝑍𝑣superscript𝑍𝛿subscript~𝑥𝑡1\displaystyle\phantom{{}={}}\lim_{n\to\infty}\tilde{f}_{t+1}-\tilde{f}_{t}=\operatorname*{\mathbb{E}}Z^{v}Z^{\delta\tilde{x}_{t+1}} (8) (9) where as in Section 3.
By Eq. 4, in the limit, and . They are independent from and jointly Gaussian with variances and covariance . So (using the initialization of to simplify ),
𝑡1subscript~𝑓𝑡\displaystyle\lim_{n\to\infty}\tilde{f}_{t+1}-\tilde{f}_{t} (10) This can easily be seen to be Eq. 3 (recall we assumed for simplicity the output layer is not trained).
Our strategy so far has been computing the form of and plugging it into Eq. 8 to compute the dynamics of the output in the limit. Note that our approach differs from previous work which mainly focus on proving a bound on the change of the NTK post training. As the architecture gets more complex, bounding the NTK movement becomes quite complex, but our approach easily scales due to the automation provided by the Tensor Programs framework (see Section 4).
In the previous example, the coordinate distribution took a fairly simple form, which allowed us to intuitively compute the expectation . Before introducing a method for computing coordinate distributions in a general architecture, we move on to a slightly more involved architecture, with the intention of highlighting the intuition behind the general case.
2 2 Hidden Layers
In this example we consider a model of the form:
where , for trainable parameters , initialized iid from a normal distribution. As before we assume the last layer is not trained, and .
Again, we want to establish Section 3 for this model. As in the 1-hidden-layer example, the dynamics of the output in the limit is still given by Eq. 8. This time around, the second hidden layer adds nontrivial complexity when evaluating the expectation . As we shall see, this complexity arises from the dependency of on the matrices and , which will make it wrong to naively apply LLN arguments. Resolving this complexity will pave the way to a general strategy which we will then be able to apply in any arbitrary architecture. We now apply the same insights as in the 1-hidden-layer MLP. Namely:
Eq. 4 continues to hold with replaced by any of .
After some brief calculations, with denoting the scaled gradient ,
𝑡1\displaystyle\delta\tilde{g}_{t+1} \displaystyle\approx-\chi_{t}\xi_{t}^{\top}\tilde{\xi}\phi^{\prime}(g_{t})\odot\big{(}W^{\top}dh_{t}\big{)} (11) (12) • As in the 1-hidden-layer case, for all , have iid coordinates of size , as does by definition.111111Again, this is technically true only after conditioning on ; see Footnote 10. Let denote the (generally correlated) random variables encoding the corresponding coordinate distributions.
As in Eq. 7, by naive Taylor expansion we have:
𝑡1\displaystyle\delta\tilde{z}_{t+1} (13) • Eqs. 7 and 6 in the 1-hidden-layer case continue to hold here. Then by Eq. 12 and Law of Large Numbers,
𝑡1subscript~𝑓𝑡𝔼superscript𝑍𝑣superscript𝑍𝛿subscript~𝑥𝑡1\displaystyle\phantom{{}={}}\lim_{n\to\infty}\tilde{f}_{t+1}-\tilde{f}_{t}=\operatorname*{\mathbb{E}}Z^{v}Z^{\delta\tilde{x}_{t+1}} (14) (15) (16) where as in Section 3.
In this expression, the first term (Eq. 15) can easily be seen to correspond to the contribution from to the NTK. It remains to show that the second (Eq. 16) corresponds to the contribution from .
To do this, we must reason about the coordinate distribution of (encoded by random variable ) and compute the expectation in Eq. 16. To understand why this represents a greater challenge than it might first appear, note that from (Eq. 13), the term hides within itself a dependency on through (Eq. 11). While at , we may assume be independent of and obtain the correct results (Gradient Independent Assumption (Yang & Schoenholz, 2017; Yang, 2020a)), this is no longer the case for : will be nontrivially correlated with and (which would be false if can be assumed independent of ). We will give some intuition why later in Eq. 20. Now, what is this dependence exactly? {restatable}claimclaimZdot Based on the above discussion and some easy calculations, can be written as for some applied coordinatewise (which will depend on other vectors not of the form ). Then it turns out††footnotemark:
𝑡1𝐺superscript𝑍𝑑subscriptℎ𝑡𝔼superscript𝑍𝛿subscript~𝑧𝑡1superscript𝑍superscript𝑊top𝑑subscriptℎ𝑡\displaystyle Z^{W\delta\tilde{z}_{t+1}}=G+Z^{dh_{t}}\operatorname*{\mathbb{E}}\frac{\partial Z^{\delta\tilde{z}_{t+1}}}{\partial Z^{W^{\top}dh_{t}}}, (17) where is some Gaussian variable independent from , and \frac{\partial Z^{\delta\tilde{z}_{t+1}}}{\partial Z^{W^{\top}dh_{t}}}\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\Phi^{\prime}(Z^{W^{\top}dh_{t}}).
Thus, from Eqs. 13 and 11, it follows that:
𝑡1\displaystyle Z^{\delta\tilde{z}_{t+1}} (18) (19) Plugging into Eqs. 14 and LABEL:{eqn:dot}, followed by some straightforward calculation, then yields Section 3.
Eq. 17 may appear cryptic at first, so let’s give some intuition using an example. Suppose in Section 3.2 is actually identity. For brevity, we set . Then, following straightforward calculation, has coordinates
subscript𝛾𝛼subscript𝐱𝛾superscriptsubscript𝛽1𝑛subscript𝑊𝛼𝛽subscript𝑊𝛾𝛽superscriptsubscript𝛽1𝑛superscriptsubscript𝑊𝛼𝛽2subscript𝐱𝛼\displaystyle(W\mathbf{y})_{\alpha}=\sum_{\gamma\neq\alpha}\mathbf{x}_{\gamma}\sum_{\beta=1}^{n}W_{\alpha\beta}W_{\gamma\beta}+\sum_{\beta=1}^{n}(W_{\alpha\beta})^{2}\mathbf{x}_{\alpha} (20) Now, the second sum converges via LLN to as . On the other hand, the first sum will converge via CLT to . Thus, in terms of s, we have
𝐺superscript𝑍𝐱𝐺superscript𝑍𝐱𝔼superscriptΦ′\displaystyle Z^{W\mathbf{y}}=G+Z^{\mathbf{x}}=G+Z^{\mathbf{x}}\operatorname*{\mathbb{E}}\Phi^{\prime} (21) for some Gaussian ; this corresponds directly to Eq. 17.121212 This example was worked out in (Yang, 2020a, b) as well, though in different contexts. Readers needing more explanation may see those works. For general , a similar intuition applies after Taylor expansion of .
This 2-hidden-layer example proceeded much the same as the 1-hidden-layer case, with the main exception of analyzing the interaction of the Gaussian matrix and (Eq. 16) that occurs after taking at least 1 step of SGD. This was absent in the 1-hidden-layer case because each weight matrix has at most one side tending to infinity. Such analysis is crucial to obtaining the right results, as assuming be independent from would imply does not move from initialization.131313One can see this easily by modifying our calculations above.
It turns out these two examples have essentially covered all of the core ideas needed to extend the analysis into arbitrary architectures. To formalize and scale up our calculations, we now turn to the Tensor Programs framework.
Tensor Programs
So far, our results have been obtained by unrolling the SGD updates on toy models with specific architectures, and using informal arguments. Obviously, these computations quickly become unmanageable when the architecture becomes more complex. The sheer amount of architectural innovations that have sprung up in recent years requires us to adopt a much more general formulation of our results. To that end, we adopt the Tensor Programs (TP) framework developed in (Yang, 2019a, 2020a, 2020b). In a nutshell, it provides a language for describing typical computations done in the context of neural networks, such as forward and backward propagation. It is simultaneously simple and expressive, covering all standard architectures (Yang, 2019a, 2020a). Here we review two basic forms of Tensor Programs, and .
A program is just a sequence of vectors inductively generated via one of the following instructions from an initial set of random vectors and a set of random matrices
For in the program and any , we can generate
Given and , we can generate or
We propose to represent a program as a computational graph, where each node in the graph represents vectors (initial or generated), each (dashed) edge represents a MatMul, and each gate represents a Nonlin. For example, Fig. 1 shows the computation graphs expressing (the forward passes of) an MLP and an RNN. We can also express the backpropagation as well (see Fig. 6). Graphically, the initial vectors are the empty nodes with only one edge coming out, toward the direction of computation. The matrices correspond to (the labels of) the dashed edges. We can also define the output vectors to correspond to the nodes that have only one edge coming out, against the direction of computation.
Each program can be thought of as computing a function taking an instantiation of the initial vectors and matrices and computing the values of all output vectors . We can say a program represents a network if it computes the body of it (without the input and output layers), as exemplified by Fig. 1. This is formalized below.
Consider a neural network with input embedding matrices (not necessarily distinct) and readout matrix , so that for some function with parameters . We say a program represents if it computes (under some correspondence of to ).141414 We only consider with scalar output for simplicity but generalization to multi-dimensional output is straightforward.
For example, the programs in Fig. 1 resp. represent a 3-hidden-layer MLP and an RNN running for 3 steps. Note that the initial vectors correspond to a combination of input embeddings (e.g. ) and vector parameters (e.g. biases) and the matrices correspond to matrix parameters (e.g. weights).
Typically, the vectors (resp. matrices) in a program will be sampled iid like (resp. ), corresponding to the “standard” initialization of neural networks.151515 In the original definition of (Yang, 2019a, 2020a, 2020b), the vectors can have correlations between them, but we can always rewrite these vectors as linear image of another set of uncorrelated vectors. In such cases, when , a program behaves as follows, in a gist:
Any vector in the program has roughly iid coordinates. We write for the random variable encoding this coordinate distribution. This may be correlated with for other vector in the program, such that, for example, .
.
Consider a matrix in the program and any set of vectors not dependent on vectors of the form . Then the set of random variables are jointly Gaussian with mean zero and covariance for any . If is another matrix in the program and is a set of such vectors w.r.t. , then the set is independent from .
For general , decomposes into a sum of a Gaussian part, identical to in the above case, and a correction term. This decomposition is a generalization of Eq. 17.
\textsc{Netsor}\top^{+} Programs (Yang, 2019a, 2020a) showed that suffices to express the forward and backward passes of most architectures such as ResNet (with Batchnorm), but Transformer and other standard architectures require adding to a new “averaging” instruction161616For readers familiar with Tensor Programs, this is equivalent in expressivity to Moment, which is a composition of a Nonlin and this averaging instruction. that returns the “empirical average” of a vector . In the limit, this scalar converges to as would be expected from the intuitions above. This extension of (called ) also allows us to express the network output and loss derivative (e.g. in contrast to Fig. 1), which will be a technical requirement for unrolling the entirety of SGD training inside a single program, a key step in the proof of our main result. See discussion of proof formalization in Section 5. We can say a program represents a network if it computes the body of .
Universality of Kernel Dynamics
(Yang, 2019a, 2020a) showed that any neural network of standard architecture is represented by a program. Moreover,
For a neural network as in 5.2 below, its Neural Tangent Kernel at initialization has a well-defined infinite-width limit .
Suppose a neural network is represented by a program (in the sense of 4.2) whose Nonlin all have polynomially bounded derivatives.171717More generally, we can allow any pseudo-Lipschitz function here, but for simplicity we go with the statement in the main text. Adopt the NTK parametrization: for every matrix parameter of , we factor where is the trainable parameter; likewise, for each input layer matrix , we factor , and likewise the output matrix , such that are trainable. Finally, we randomly initialize all trainable parameters iid as .
Our main result is to show that the SGD training of such a neural network described in 5.2 reduces to kernel gradient descent with kernel in the infinite-width limit.
Consider training a network described in 5.2 via SGD with batch-size 1 and (WLOG) learning rate 1. Let be the input and be the loss function (absorbing the label) at time . Suppose is continuous for all . Then, for any and , converges almost surely to a random variable as width , such that
𝑡1𝜉subscript̊𝑓𝑡𝜉̊𝒦𝜉subscript𝜉𝑡superscriptsubscriptℒ𝑡′subscript̊𝑓𝑡subscript𝜉𝑡\displaystyle\mathring{f}_{t+1}(\xi)-\mathring{f}_{t}(\xi)=-\mathring{\mathcal{K}}(\xi,\xi_{t})\mathcal{L}_{t}^{\prime}(\mathring{f}_{t}(\xi_{t})) (22) where is the infinite-width NTK (at initialization) of the neural network.
The full proof of 5.3 is given Appendix D.
We briefly mention several ways our result can be easily extended. 0) Different batch sizes, learning rate schedules, and nonscalar outputs. 1) Variants of NTK parametrization. We can deal with any parametrization that scales the same way as NTK parametrization, e.g. weights are sampled like for any , with the multipliers for any . 2) Variable width. In real networks, the width of different layers can often be different (e.g. in ResNet). Our result can be extended to the case where the widths tend to infinity at a fixed ratio, using the variable-width version of Tensor Programs (Yang, 2020b). 3) Unsupervised and other learning settings can be covered because their training and testing computation can be written into Tensor Programs. 4) Weight decay, momentum, and other optimizer tricks can be covered as well as they can be straightforwardly written into Tensor Programs, but in general the kernel will change from step to step in contrast to 5.3.
1 Proof Sketch of Special Case
To convey the main idea, we give a proof sketch of a simplified problem: we assume 1) the input, output layers and biases are not trained (the network has only matrices as trainable parameters); 2) the forward pass does not contain both a weight matrix and its transpose (but a single matrix can still be used multiple times without being transposed); 3) input space is (with ), and ; 4) the output vector is a G-var; 5) the network is represented by a (instead of ) program; 181818This is not common in deep learning, but appears in some weight-tied autoencoders (Li & Nguyen, 2019). In the appendix, we prove the general case with these simplifications lifted.
It turns out, every can be simplified into a standard form of sorts, which greatly facilitates our proof.
In a program, a G-var191919“G” because G-vars often are roughly Gaussian vectors is an initial vector or a vector created by MatMul, while an X-var is a vector created by Nonlin.202020Var is short for variable, as the vectors are considered variables in the program. In previous works, H-var refers to any vector in the program; we will not use this terminology. We define a reduced program as a program in which only G-vars are allowed as inputs to a Nonlin, while only an X-var is allowed as input to a MatMul.
Observe that any program may be trivially expressed as a reduced program by: 1) collapsing chains of non-linearities which appear consecutively, and 2) insert a Nonlin operation with in between consecutive G-vars. Hence, we may safely assume that is representable by a reduced program.
The examples of Sections 3.1 and 3.2 exposed several insights, such as the iid-coordinates intuition, important for proving 5.3. Now we discuss the one remaining key idea for scaling up to general architectures:
[Paths]defnPaths In a program, a path starts with an X-var and ends with a G-var, alternating between X- and G-vars along the path. We write for the starting X-var, for the following G-var, and so on, as well as for the ending G-var (see Fig. 2 for a graphical illustration). For odd , let denote the defining matrix of G-var . For two equal length paths , we write (path is isomorphic to path ) if for all odd , is the same matrix as .212121Here we are talking about equality of symbols rather than equality of values of those symbols. In other words, we say path is isomorphic to path if their sequences of MatMul matrices are identical, (but the Nonlin don’t have to be, see Fig. 3 for a graphical illustration). Let denote the number of vectors in (this is always an even number).
The collection of paths starting with an X-var and ending with a G-var describes all possible pathways of backpropagating an error signal at to an error signal at . Simultaneously, it also describes all possible pathways of forward propagating a change in to a change in .
Because the gradient of a weight is the sum of outer products , summing over all G-vars and X-vars in the program with (where denotes ), we also have
𝑓superscript𝑝1superscript𝑝1superscript𝑝2superscript𝑝2superscript𝑝3…superscript𝑝2ℎ\displaystyle(J^{p})^{\top}=\frac{\partial f}{\partial p^{-1}}\times\frac{\partial p^{-1}}{\partial p^{-2}}\times\frac{\partial p^{-2}}{\partial p^{-3}}\times...\times\frac{\partial p^{2}}{\partial h} (24) i.e, denotes the error signal at from backpropagation through path , and ranges over all paths starting with and ending with the output node of the underlying program. Recall factors as where is the trainable parameter, not . By the discussion above, updating with causes to change by
When every parameter is randomly initialized iid as , it turns out that will go to 0 as unless (Fig. 3). If one think of as a product of random Gaussian matrices (interleaved with other matrices), then this is akin to the fact that, for a mixed moment M\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\operatorname*{\mathbb{E}}\prod_{i=1}^{r}Z_{\gamma(i)},\gamma:[r]\to[k] of standard iid Gaussians , is nonzero iff every appears an even number of times in the product. This means we can replace the 4 nested sums in Eq. 25 with the single sum and rewrite .
We have suppressed dependence on input in Eq. 25. Being more explicit about it and performing updates on all weights, we have
𝑡1𝜉subscript𝑓𝑡𝜉superscriptsubscriptℒ𝑡′subscript𝑓𝑡subscript𝜉𝑡subscript𝑝¯𝑝superscript𝐽𝑝𝜉superscript𝐽¯𝑝subscript𝜉𝑡superscript𝑝0𝜉superscript¯𝑝0subscript𝜉𝑡\begin{gathered}f_{t+1}(\xi)-f_{t}(\xi)\\ \approx-\mathcal{L}_{t}^{\prime}(f_{t}(\xi_{t}))\sum_{p\cong\bar{p}}\langle J^{p}(\xi),J^{\bar{p}}(\xi_{t})\rangle\langle p^{0}(\xi),\bar{p}^{0}(\xi_{t})\rangle.\end{gathered} (26) after taking a gradient step on input at initialization . (Here denotes the vector as a function of at initialization). However, Eq. 25 holds for general as well: The key insight is similar to Eq. 4 in the 1-hidden-layer example, that vectors , , etc change vanishingly from their initial values as , after any number of SGD steps. Because our arguments above only depend on inner products between vectors, this means that the error in Eq. 26 for vanishes as . Finally, at least heuristically, it is straightforward to show the RHS of Eq. 26 is the NTK via a series of calculations exemplified by those in Sections 3.1 and 3.2, to get Eq. 22.
While the core ideas discussed above are intuitive, making them rigorous at face value would be quite challenging. Instead we use the machinery offered by the Tensor Programs framework. The mechanics of the proof then goes as follows: 1) First we unroll SGD of into a program.222222 We note that this formalization crucially relies on and its Master Theorem from (Yang, 2020b) because the SGD unrolling cannot be done in . The reason is that we need to express the output and loss derivatives of the network, which are scalars (or at least finite dimensional), and that cannot be done in a program. Furthermore, the Master Theorem from (Yang, 2020a) only pertains to a specific type of programs that look like the first backpropagation after initialization. Thus, it cannot deal with the complete unrolling of SGD as we do here, which requires the more advanced Master Theorem from (Yang, 2020b). This is similar to the equations in Sections 3.1 and 3.2; the key here is to express as a vector in the program, for any (pre-)activation . 2) We apply the Master Theorem (Yang, 2020b) to this program. This yields the coordinate distribution of each vector. The core insights here are demonstrated by the calculation with the random variables in Sections 3.1 and 3.2. 3) Finally, we need to show that (the rigorous version of) Eq. 26 indeed recovers the NTK and agrees with Eq. 22. This is done via an inductive (symbolic) computation, and the path concept in this section plays a key role here.
Related Works
The connection between kernel methods and neural networks has had a long history before its recent resurgence. The Gaussian Process (NNGP) view of wide neural networks, which characterizes the behaviour of training only the last layer of a wide neural network, has been studied in (Daniely et al., 2016; Hazan & Jaakkola, 2015; Roux & Bengio, 2007; Lee et al., 2018; Matthews et al., 2018; Hinton & Neal, 1995; Novak et al., 2019). Since the original NTK paper (Jacot et al., 2018), many works have informally derived the infinite-width NTK for various architectures such as CNNs (Arora et al., 2019), RNN (Alemohammad et al., 2020), attention (Hron et al., 2020), ensembles (Littwin et al., 2020b) and graph neural networks (Du et al., 2019), but none of them formally proved ntkInit or ntkTrain for those architectures. Finite width corrections to the NTK were derived for fully connected networks in (Hanin & Nica, 2019; Littwin et al., 2020a). The validity of the NTK theory was empirically studied in (Lee et al., 2019) for a variety of architectures.
The Tensor Program framework (Yang, 2019a, 2020a, 2020b) was introduced in an attempt to unify and generalize the NNGP/NTK theory to a broad range of architectures, eliminating the need to re-develop the theory for each new architecture. For example, (Yang, 2019a) proved the architectural universality of NNGP correspondence, while (Yang, 2020a) proved that of ntkInit. On the other hand, (Yang, 2020b) developed the most general machinery for Tensor Programs and as a corollary constructed a comprehensive theory of nonlinear random matrix theory, that, for example, can calculate the singular value distribution of a wide neural network of any architecture. Our proofs depend on the machinery of (Yang, 2020b) crucially, as discussed in Section 5.
Conclusion
New theories of deep learning almost always start with MLPs, rightly so as they are the simplest case and can often reveal the key insights more clearly. Of course, as deep learning itself is an applied field, one should always ask whether insights on MLPs extend to more general architectures, i.e. whether there is an architecturally universal extension of a proposed theory of MLPs. This is not always easy to answer.
In this paper, we showed that the NTK theory is architecturally universal, but more importantly, we showed that the Tensor Programs technique is a very powerful tool for answering the above question as a matter of routine. Looking forward, we hope to apply it to generate more novel and general insights.
References
The appendix is organized as follows: In Appendix A we expands upon the examples given in Section 3, while adding some additional details. In Appendix B we introduce the formal version of the , programs. In Appendix C we introduce the graphical notation of and demonstrate other examples of architectures or computations expressible in Tensor Programs. In Appendix D we prove our main result.
For the readers convenience we restate the notations described in Section 2, along with some additional ones which will be used throughout the appendix. We will consider SGD with batch size 1 and learning rate of 1 (WLOG).We use to denote the input and to denote the loss function (absorbing the label) at step . More generally, subscript on any symbol means time . However, for brevity, we abuse notation and shorthand for , and, for any (pre-)activation , for . We will also write for the loss derivative . For any vector we define \delta x_{t+1}(\xi)\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\sqrt{n}\big{(}x_{t+1}(\xi)-x_{t}(\xi)\big{)} and dx(\xi)\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\sqrt{n}\frac{\partial f(\xi)}{\partial x(\xi)}. We will track the evolution of on an arbitrary input .232323It might help to think of as some test sample, but it can also fall in the training set. Similar to above, we shorthand for . In general, omitting the time index for any time dependent quantity implies its value at initialization. (i.e ). Finally, we use to imply equality of symbols (i.e iff represent the same variable, as opposed to equality in value).
Appendix A Additional Examples
In this section we flesh out the examples given in Section 3 of the main text with the purpose of adding additional clarity, while maintaining the intuitive arguments as presented in each example to perform these calculations. The rigorous justification for these calculations will be given in the following section with the formal introduction of the Tensor Program framework.
Recall that our objective is to derive Section 3 by tracking the coordinate distribution of each (pre-)activations vector x(\xi),dx(\xi)\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\sqrt{n}\frac{\partial f(\xi)}{\partial x(\xi)},\delta x(\xi)\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\sqrt{n}\big{(}x_{t+1}(\xi)-x_{t}(\xi)\big{)}. \claimNTK*
In our calculations we will rely on the following rules relating to the coordinates of any (pre-)activation vector in the large width regime, which we will later formalize:
has coordinates.
has coordinates.
. Consequently .
If x(\xi)=\phi\big{(}y(\xi)\big{)} for some vector , then by Taylor approximation \delta x_{t+1}(\xi)=\sqrt{n}\big{(}\phi(y_{t}(\xi)+\frac{\delta y_{t+1}(\xi)}{\sqrt{n}})-\phi(y_{t}(\xi))\big{)}\approx\phi^{\prime}\big{(}y_{t}(\xi)\big{)}\odot\delta y_{t+1}(\xi). Consequently .
We write to denote the limit coordinate distribution of conditioned on the output function at initialization. Consequently we write to express a conditional expectation given the output function . See Appendix B for the formal statement.
where , for trainable parameter tensor , initialized iid from , and .
The infinite width NTK of this architecture is given by:
𝑓𝜉ℎ𝜉direct-productsuperscriptitalic-ϕ′ℎ𝜉𝑣\displaystyle\mathrel{\raisebox{-1.29167pt}{}}\sqrt{n}\frac{\partial f(\xi)}{\partial h(\xi)}=\phi^{\prime}(h(\xi))\odot v. (28) Therefore, by LLN it follows:242424While in A.1, we said denotes expectation conditioned on , the NTK here does not actually depend on .
To show that Section 3 holds with the kernel in Eq. 29, we track the coordinate distribution at each step of SGD. At step , the update to the weights is given by the gradient of the loss with respect to :
𝑡1subscript𝑢𝑡subscript𝜒𝑡𝑑subscriptℎ𝑡superscriptsubscript𝜉𝑡top𝑛𝑑subscriptℎ𝑡direct-productsuperscriptitalic-ϕ′subscriptℎ𝑡𝑣\displaystyle u_{t+1}-u_{t}=-\chi_{t}\frac{dh_{t}\xi_{t}^{\top}}{\sqrt{n}},~{}~{}~{}dh_{t}=\phi^{\prime}(h_{t})\odot v (30) Recall that and . It therefore follows:
𝑡1direct-productsubscript𝜒𝑡superscriptsubscript𝜉𝑡top~𝜉superscriptitalic-ϕ′subscriptℎ𝑡𝑣𝛿subscript~𝑥𝑡1𝑛italic-ϕsubscript~ℎ𝑡𝛿subscript~ℎ𝑡1𝑛italic-ϕsubscript~ℎ𝑡\displaystyle\delta\tilde{h}_{t+1}=-\chi_{t}\xi_{t}^{\top}\tilde{\xi}\phi^{\prime}(h_{t})\odot v,~{}~{}~{}\delta\tilde{x}_{t+1}=\sqrt{n}\big{(}\phi(\tilde{h}_{t}+\frac{\delta\tilde{h}_{t+1}}{\sqrt{n}})-\phi(\tilde{h}_{t})\big{)}. (31) Since and , for large we may Taylor expand to first order around :
𝑡1\displaystyle\delta\tilde{x}_{t+1} \displaystyle\approx\sqrt{n}\big{(}\phi(\tilde{h}_{t})+\frac{1}{\sqrt{n}}\phi^{\prime}(\tilde{h}_{t})\odot\delta\tilde{h}_{t+1}-\phi(\tilde{h}_{t})\big{)} (32) (33) (34) Again since , it follows that . Hence, in the infinite width limit the coordinate distribution of is identical to the coordinate distribution of (i.e ). Using Eq. 34, the coordinate distribution of is given by:
𝑡1subscript̊𝜒𝑡superscriptsubscript𝜉𝑡top~𝜉superscriptitalic-ϕ′superscript𝑍subscript~ℎ𝑡superscriptitalic-ϕ′superscript𝑍subscriptℎ𝑡superscript𝑍𝑣\displaystyle Z^{\delta\tilde{x}_{t+1}}=-\mathring{\chi}_{t}\xi_{t}^{\top}\tilde{\xi}\phi^{\prime}(Z^{\tilde{h}_{t}})\phi^{\prime}(Z^{h_{t}})Z^{v}. (35) In the large width limit, the change in the output is simply given by . Using Eq. 35 and the independence of from the other random variables,252525 Again, as in Footnote 10, the expectations in Sections A.1 and A.2 should more rigorously be interpreted as expectation conditional on the function values of .
A.2 2 hidden layers
where , for trainable parameters , initialized iid from a normal distribution. As before we assume the last layer is not trained, and .
subscript∇𝑢𝑓𝜉subscript∇𝑢𝑓~𝜉subscript∇𝑤𝑓𝜉subscript∇𝑤𝑓~𝜉\displaystyle=\langle\nabla_{u}f(\xi),\nabla_{u}f(\tilde{\xi})\rangle+\langle\nabla_{w}f(\xi),\nabla_{w}f(\tilde{\xi})\rangle (39) (40) \displaystyle\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\sqrt{n}\frac{\partial f(\xi)}{\partial h(\xi)}=\phi^{\prime}\big{(}h(\xi)\big{)}\odot v (41) \displaystyle\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\sqrt{n}\frac{\partial f(\xi)}{\partial g(\xi)}=\phi^{\prime}\big{(}g(\xi)\big{)}\odot\big{(}W^{\top}dh(\xi)\big{)}. (42) Naively using LLN on Eq. 39 (and being independent from everything else) should result in:
superscript𝜉top~𝜉𝔼superscript𝑍𝑑𝑔𝜉superscript𝑍𝑑𝑔~𝜉𝔼superscript𝑍𝑧𝜉superscript𝑍𝑧~𝜉𝔼superscriptitalic-ϕ′superscript𝑍ℎ𝜉superscriptitalic-ϕ′superscript𝑍ℎ~𝜉\displaystyle\mathcal{K}(\xi,\tilde{\xi})=\xi^{\top}\tilde{\xi}\operatorname*{\mathbb{E}}\big{[}Z^{dg(\xi)}Z^{dg(\tilde{\xi})}\big{]}+\operatorname*{\mathbb{E}}\big{[}Z^{z(\xi)}Z^{z(\tilde{\xi})}\big{]}\operatorname*{\mathbb{E}}\big{[}\phi^{\prime}(Z^{h(\xi)})\phi^{\prime}(Z^{h(\tilde{\xi})})\big{]}. (43) Evaluating the term \operatorname*{\mathbb{E}}\big{[}Z^{dg(\xi)}Z^{dg(\tilde{\xi})}\big{]} however presents a challenge since depends on both and . As it turns out, at initialization we may naively assume that are independent (formally known in the literature as gradient independence assumption, or GIA) 262626For a rigorous justification of the GIA assumption see (Yang, 2020a), we arrive using simple LLN arguments to:
Plugging Eq. 44 into Eq. 43 we arrive at the correct expression for the infinite width NTK.
To show that Section 3 holds at any step (where we may not assume that GIA holds), we track the distributions of the vectors throughout training.
At any step the weights are updated according to:
𝑡1subscript𝑢𝑡subscript𝜒𝑡𝑑subscript𝑔𝑡superscriptsubscript𝜉𝑡top𝑛subscript𝑤𝑡1subscript𝑤𝑡subscript𝜒𝑡𝑑subscriptℎ𝑡superscriptsubscript𝑧𝑡top𝑛\displaystyle u_{t+1}-u_{t}=-\chi_{t}\frac{dg_{t}\xi_{t}^{\top}}{\sqrt{n}},~{}~{}~{}w_{t+1}-w_{t}=-\chi_{t}\frac{dh_{t}z_{t}^{\top}}{n}. (45) The update \delta\tilde{g}_{t+1}\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\sqrt{n}(\tilde{g}_{t+1}-\tilde{g}_{t}),\delta\tilde{z}_{t+1}\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\sqrt{n}(\tilde{z}_{t+1}-\tilde{z}_{t}) are given by:
As before, with large we have that and coordinates for replaced by . And so after Taylor expanding around :
𝑡1direct-productsuperscriptitalic-ϕ′subscript~𝑔𝑡𝛿subscript~𝑔𝑡1\displaystyle\delta\tilde{z}_{t+1}\approx\phi^{\prime}(\tilde{g}_{t})\odot\delta\tilde{g}_{t+1}. (47) In a similar fashion, using Eqs. 41 and 46 the updates take the form:
𝑡1\displaystyle\delta\tilde{z}_{t+1} \displaystyle\approx\phi^{\prime}(g_{t})\odot\delta\tilde{g}_{t+1}=-\chi_{t}\xi_{t}^{\top}\tilde{\xi}\phi^{\prime}(g_{t})\odot\phi^{\prime}(\tilde{g}_{t})\odot\big{(}W^{\top}dh_{t}\big{)} (48) (49) (50) where we used Eq. 45 and
𝑡1\displaystyle\delta\tilde{h}_{t+1} (51) (52) to get Eq. 49. Based on Eqs. 41, 48, 49 and 50, the corresponding coordinate distributions take the form:
𝑡1\displaystyle Z^{\delta\tilde{z}_{t+1}} (54) (55) (56) As before, the functional update is given by . Plugging Eqs. 56 and 55:
𝑡1subscript̊𝜒𝑡𝔼superscript𝑍subscript𝑧𝑡superscript𝑍~𝑧𝔼superscriptitalic-ϕ′superscript𝑍subscriptℎ𝑡superscriptitalic-ϕ′superscript𝑍subscript~ℎ𝑡𝔼superscriptitalic-ϕ′superscript𝑍subscriptℎ𝑡superscript𝑍𝑊𝛿subscript~𝑧𝑡1superscript𝑍𝑣\displaystyle\operatorname*{\mathbb{E}}Z^{v}Z^{\delta\tilde{x}_{t+1}}=-\mathring{\chi}_{t}\operatorname*{\mathbb{E}}\big{[}Z^{z_{t}}Z^{\tilde{z}}\big{]}\operatorname*{\mathbb{E}}\big{[}\phi^{\prime}(Z^{h_{t}})\phi^{\prime}(Z^{\tilde{h}_{t}})\big{]}-\operatorname*{\mathbb{E}}\big{[}\phi^{\prime}(Z^{h_{t}})Z^{W\delta\tilde{z}_{t+1}}Z^{v}\big{]}. (57) To compute the second term of the RHS of Eq. 57, we use Section 3.2, reproduced below. \claimZdot*
Applying Section 3.2 to get the expression for :
𝑡1\displaystyle Z^{W\delta\tilde{z}_{t+1}} (58) \displaystyle=G-\phi^{\prime}(Z^{h_{t}})Z^{v}\mathring{\chi}_{t}\xi_{t}^{\top}\tilde{\xi}\operatorname*{\mathbb{E}}\big{[}\phi^{\prime}(Z^{g_{t}})\phi^{\prime}(Z^{\tilde{g_{t}}})\big{]} (59) As before, for it holds that , and . Plugging Eq. 58 into Eq. 57 yields Section 3.
Appendix B Tensor Programs: the Formal Version
We briefly review the formal definition of Tensor Programs below, but readers needing more explanation and intuition should see (Yang, 2020b). We will directly describe programs, which generalizes .
A program is a sequence of -vectors and -scalars inductively generated via one of the following ways from an initial set of random scalars, of random vectors, and a set of random matrices (which will be sampled with iid Gaussian entries in B.2)
Given , previous scalars and vectors , we can generate a new vector
where applies coordinatewise to each “-slice” .
Given same setup as above, we can also generate a new scalar
A program is just a program without scalars, without the usage of Moment, and without parameters in Nonlin+.
We will typically randomly sample the initial matrices, vectors, and scalars of the program as follows.
1) For each initial , we sample iid for some variance associated to , independent of other ; 2) for some multivariate Gaussian , we sample the initial set of vectors like iid for each . 3) For each initial scalar , we require for some deterministic .
The following constructs a random variable for every vector and a deterministic scalar for every scalar in the program. The interpretation is that will have iid coordinates distributed like , and will converge to as .
Given a program, we recursively define for each vector and for each scalar as follows.
If , then is defined as in B.2. We also set \hat{Z}^{h}\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}Z^{h} and \dot{Z}^{h}\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}0.
Given , previous scalars and vectors , we have
Given same setup as above and scalar , then
Here are deterministic, so the expectation is taken over .
Z^{Wx}\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\hat{Z}^{Wx}+\dot{Z}^{Wx} for every matrix (with entries) and vector , where
is a Gaussian variable with zero mean. Let denote the set of all vectors in the program of the form for some . Then is defined to be jointly Gaussian with zero mean and covariance
Furthermore, is mutually independent from , where ranges over .
We can always unwind , for some arguments , (where is defined in ZHat), and deterministic function . Define \partial Z^{x}/\partial\hat{Z}^{W^{\top}y^{i}}\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\partial_{i}\Phi(\cdots). Then we set
superscript𝑍𝑥superscript^𝑍superscript𝑊topsuperscript𝑦𝑖\dot{Z}^{Wx}\mathrel{\raisebox{-1.29167pt}{}}\sigma_{W}^{2}\sum_{i=1}^{k}Z^{y^{i}}\operatorname*{\mathbb{E}}\frac{\partial Z^{x}}{\partial\hat{Z}^{W^{\top}y^{i}}}, (65) There is some nuance in this definition, so see B.5 and B.6.
The following theorem ties the symbolic nature of the s to the analytic nature of a Tensor Program.
\textsc{Netsor}\top^{+} Master Theorem, c.f. Theorem E.15 of (Yang, 2020b)). Fix a Tensor Program initialized accordingly to B.2. Adopt B.8. Then
For any fixed and any pseudo-Lipschitz , as ,
for any vectors in the program, where are as defined in LABEL:{defn:netsortplusKeyIntuit}.
Any scalar in the program tends to almost surely, where is as defined in LABEL:{defn:netsortplusKeyIntuit}.
The partial derivative in ZDot should be interpreted as follows. By a simple inductive argument, for every vector in the program is defined uniquely as a deterministic function of some in or introduced by MatMul (notationally, we are suppressing the possible dependence on limit scalars ). For instance, if in a program we have , , then , so is given by . Then
superscript𝑍𝑥superscript^𝑍superscript𝑥𝑖subscript𝑖𝜑superscript^𝑍superscript𝑥1…superscript^𝑍superscript𝑥𝑘anddefsuperscript𝑍𝑥superscript^𝑍𝑧0 for any 𝑧superscript𝑥1…superscript𝑥𝑘\partial Z^{x}/\partial\hat{Z}^{x^{i}}\mathrel{\raisebox{-1.29167pt}{}}\partial_{i}\varphi(\hat{Z}^{x^{1}},\ldots,\hat{Z}^{x^{k}}),\quad\text{and}\quad\partial Z^{x}/\partial\hat{Z}^{z}\mathrel{\raisebox{-1.29167pt}{}}0\text{ for any }z\not\in\{x^{1},\ldots,x^{k}\}. Note this definition depends on the precise way the program is written, not just on the underlying mathematics. For example, if and , then so that . If instead, we have , then so that . However, in both cases, .
The quantity is well defined if is differentiable in . However, even if this is not the case, e.g. if where is the Heavyside step function, we can still define this expectation by leveraging Stein’s lemma:
In ZDot, suppose are all elements of introduced before . Define the matrix by C_{ij}\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\operatorname*{\mathbb{E}}Z^{y^{i}}Z^{y^{j}} and define the vector by b_{i}\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\operatorname*{\mathbb{E}}\hat{Z}^{W^{\top}y^{i}}Z^{x}. If (where denotes the pseudoinverse of ), then in ZDot we may set
superscript𝑍𝑥superscript^𝑍superscript𝑊topsuperscript𝑦𝑖subscript𝑎𝑖\sigma_{W}^{2}\operatorname*{\mathbb{E}}\frac{\partial Z^{x}}{\partial\hat{Z}^{W^{\top}y^{i}}}=a_{i}. (67) This definition agrees with the partial derivative expectation by Stein’s lemma when the latter is well defined. B.4 holds with this broader definition of partial derivative expectation.
are, roughly speaking, functions whose weak derivatives are polynomially bounded.
A function is called pseudo-Lipschitz of degree if for some . We say is pseudo-Lipschitz if it is so for any degree.
Here are some basic properties of pseudo-Lipschitz functions:
The norm in B.7 can be any norm equivalent to the norm, e.g. norms. Similarly, can be replaced by , for any .
A pseudo-Lipschitz function is polynomially bounded.
A composition of pseudo-Lipschitz functions of degrees and is pseudo-Lipschitz of degree .
A pseudo-Lipschitz function is Lipschitz on any compact set.
We adopt the following assumption for the Master Theorem B.4.
If a function with only parameter arguments is used in Moment, then is continuous in those arguments.
Any other function with parameters (where ) used in Nonlin or Moment is pseudo-Lipschitz in all of its arguments (both inputs and parameters).
Statement 1 in B.8 essentially says that if we have scalars in the program, then we can produce a new scalar by applying a continuous function (a weaker restriction than a pseudo-Lipschitz function) to them. Indeed, if converge almost surely, then this new scalar does too. In our setting, statement 1 is used to allow any loss function whose derivative is continuous.
Other versions of the Master Theorem can be found in (Yang, 2020b), for example, versions where the we do not assume any smoothness condition at all on the nonlinearities beyond that they be polynomially bounded, in exchange for assuming what’s called a rank stability condition. This rank stability should be generically true, but checking it rigorously is subtle, so we are content with the pseudo-Lipschitz condition in this paper.
Appendix C More Diagrams
We can augment the graphical form of to accomodate the Moment instruction in . See Fig. 4 for an example for layernorm and attention. In short, we denote scalar variables with a square, in contrast to the circle for vector variables, and we use a “bar-gate” to denote the Moment, where the function in the gate corresponds to in Moment.
In addition, for more examples of the expressivity of , Figs. 5 and 6 demonstrate convolution and MLP backpropagation in .
Appendix D Proof of Main Result
We dedicate the following section to prove 5.3. We will begin by proving a simplified version under the same assumptions as Section 5.1, as reproduced below:
Suppose a neural network is represented by a program (in the sense of 4.2) whose Nonlin all have polynomially bounded derivatives.272727More generally, we can allow any pseudo-Lipschitz function here, but for simplicity we go with the statement in the main text. Adopt the NTK parametrization: for every matrix parameter of , we factor where is the trainable parameter; likewise, for each input layer matrix , we factor , and likewise the output matrix . We randomly initialize all trainable parameters iid as . Furthermore, we assume the following:
Input and output layers , as well as biases are not trained (only weight matrices are trained).
The forward pass does not use both a matrix and its transpose (in different MatMuls).
We assume the last layer embedding is a G-var.
Our main result is to show that the SGD training of such a neural network described in 5.2 reduces to kernel gradient descent with kernel in the infinite-width limit.
Consider training a network described in D.1 via SGD with batch-size 1 and (WLOG) learning rate 1. Let be the input and be the loss function (absorbing the label) at time . Suppose is continuous for all . Then, for any and , converges almost surely to a random variable as width , such that
𝑡1𝜉subscript̊𝑓𝑡𝜉̊𝒦𝜉subscript𝜉𝑡superscriptsubscriptℒ𝑡′subscript̊𝑓𝑡subscript𝜉𝑡\displaystyle\mathring{f}_{t+1}(\xi)-\mathring{f}_{t}(\xi)=-\mathring{\mathcal{K}}(\xi,\xi_{t})\mathcal{L}_{t}^{\prime}(\mathring{f}_{t}(\xi_{t})) (68) where is the infinite-width NTK (at initialization) of the neural network.
\textsc{Netsor}\top^{+} Program SGD is comprised of a sequence of forward and backward passes computed on some architecture. WLOG, let denote the reduced program implementing the body of network , and let denote the final embedding such that , we will now show how the SGD procedure on can be implemented by a program.
While implements the embeddings by definition, the outputs cannot be implemented trivially in a program since that at initialization is not deterministic, and converges non-trivially to a GP, violating the requirements of a scalar type in a program which require all scalar types to converge to a deterministic limit as . Nevertheless, we can still easily express evolution of conditioned on (i.e. fixing) the values of at initialization. More formally, let denote a fixed vector of outputs, and let denote a fixed embedding matrix such that . The distribution of when conditioned on and is given by (see e.g. (Yang, 2020b, Sec K.2))
𝑣subscriptd𝚏𝑋𝑛superscript𝑋𝚏Π𝚟v\overset{\mathrm{d}}{=}_{\mathtt{f},X}\sqrt{n}X^{+}\mathtt{f}+\Pi\mathtt{v} (69) where is the pseudo-inverse of , is an independent copy of and is the projection operator projecting unto the orthogonal complement of the space spanned by . Namely:
1𝑛𝑋superscriptsuperscript𝑋top𝑋𝑛Π𝐼superscript𝑋superscript𝑋top\displaystyle X^{+}=\frac{1}{n}X(\frac{X^{\top}X}{n})^{+},\quad\Pi=I-X^{+}X^{\top} (70) Denote . Define
𝑋superscriptΣ𝚏𝑛𝚟𝑋superscriptΣ𝜇\displaystyle\hat{v}\mathrel{\raisebox{-1.29167pt}{}}X(\frac{\Sigma^{+}\mathtt{f}}{\sqrt{n}})+\mathtt{v}-X\Sigma^{+}\mu. (71) Then we see via Eq. 69 that
Given and (the columns of) as vectors and as scalars in a program, may be defined in the same program via Nonlin, where and (both finite-dimensional) provide coefficients for the linear combination over (columns of) . Formally, to express the evolution of conditioned on at initialization, the program will calculate the first forward pass up to , calculate the loss derivatives assuming , and then proceed with the backward pass and later forward/backward passes with replaced by .
However, since and (by rank stability, c.f. (Yang, 2020b, Lemma L.11)), these coefficients of the linear combination converge to 0, so that . Intuitively, this means that the distribution of conditioned on the equality is asymptotically the same as no conditioning as . Thus, for the limit calculation of and other quantities, it ends up not mattering whether we use or .
The loss derivative after the first forward pass given can be implemented with Moment instructions using .
D.1.2 Implementing SGD
Under SGD, the update at step to any weight is given by:
𝑡1subscript𝑤𝑡\displaystyle w_{t+1}-w_{t} (73) where the summation in Eq. 73 is over all pairs of vectors in program satisfying (there can be multiple such pairs since may reuse the same matrix ).
To write the full unrolled SGD as a program, we will need to implement the error signal dg_{t}\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\sqrt{n}\frac{\partial f_{t}}{\partial g_{t}} for each G-var at time . To accomplish this, we recall the notion of paths in program : \Paths* Note that a path represents a series of nodes independent of an input, and can be instantiated as by an input , resulting in a series of instantiated G-vars and X-vars .
For any G-var , we can write the error term as the summation of errors signals over paths :
superscript𝑝2superscript𝑝1topsuperscriptsuperscript𝑝2superscript𝑝3top…superscriptsuperscript𝑝1superscript𝑝2top𝑣\displaystyle=(\frac{\partial p^{2}}{\partial p^{1}})^{\top}(\frac{\partial p^{-2}}{\partial p^{-3}})^{\top}...(\frac{\partial p^{-1}}{\partial p^{-2}})^{\top}v (75) (Here again, represents a symbolic computation that can be instantiated with an input ). Note that can be defined recursively:
superscript𝑝2superscript𝑝1topsuperscriptsuperscript𝑝3superscript𝑝2topsuperscript𝐽:𝑝3\displaystyle J^{p}=(\frac{\partial p^{2}}{\partial p^{1}})^{\top}(\frac{\partial p^{3}}{\partial p^{2}})^{\top}J^{p:3} (76) where is defined as:
superscript𝑝𝑘1superscript𝑝𝑘topsuperscriptsuperscript𝑝2superscript𝑝3top…superscriptsuperscript𝑝1superscript𝑝2top𝑣𝑘𝑝𝑣𝑘𝑝\displaystyle J^{p:k}\mathrel{\raisebox{-1.29167pt}{}}\begin{cases}(\frac{\partial p^{k+1}}{\partial p^{k}})^{\top}(\frac{\partial p^{-2}}{\partial p^{-3}})^{\top}...(\frac{\partial p^{-1}}{\partial p^{-2}})^{\top}v&k<|p|\\ v&k=|p|\end{cases} (77) Recall that each path starts with an X-var , and alternates between G and X vars. Let denote the defining weight matrix of G-var (i.e ), and let . Then we can re-write Eq. 76 as:
Note that Eq. 78 can be written in language using MatMul instructions using the transposed weights, and Nonlin instructions using , which is pseudo-Lipschitz by D.1.
Recall that is the program defining the network architecture. We now write the unrolled SGD of this network in a new program . Below, recall that lack of time subscript means (e.g. means , the initialized value). In addition, feel free to revisit the notations explained before Appendix A.
𝑡1\displaystyle\delta\tilde{g}_{t+1} (79) where, using Eq. 73, we have
𝑡1subscript𝑊𝑡subscript~ℎ𝑡\displaystyle\sqrt{n}(W_{t+1}-W_{t})\tilde{h}_{t} (80) (81) Tensor Program implementation Eqs. 79, 80 and 81 may be easily implemented using instructions. For instance, Eq. 80 (assuming the sum sums over a single pair ) may be implemented using Moment and Nonlin+ instructions as follows: the term may be implemented by a Moment instruction with . The full term is then a Nonlin+instructions with scalars and vector .
Eq. 82 may be implemented as a Nonlin+ instruction:
𝑡1superscript𝜓⋆superscriptsubscriptsuperscriptsubscript~ℎ𝑡𝑖𝑖1𝑘superscriptsubscript𝛿superscriptsubscript~ℎ𝑡1𝑖𝑖1𝑘1𝑛\displaystyle\delta\tilde{g}_{t+1}:=\psi^{\star}\left(\{\tilde{h}_{t}^{i}\}_{i=1}^{k}\cup\{\delta\tilde{h}_{t+1}^{i}\}_{i=1}^{k};\frac{1}{\sqrt{n}}\right) (83) for a set of vectors and a scalar , where:
superscript𝜇1𝜃superscript𝜈1…superscript𝜇𝑘𝜃superscript𝜈𝑘𝛼𝜓subscriptsuperscript𝜇1…superscript𝜇𝑘𝛼𝜃𝜃0superscriptsubscript𝑖1𝑘𝜓subscriptsuperscript𝜇1…superscript𝜇𝑘𝛼subscriptsuperscript𝜇𝑖𝛼subscriptsuperscript𝜈𝑖𝛼𝜃0\displaystyle\psi^{\star}(\{\mu^{i}\}_{i=1}^{k}\cup\{\nu^{i}\}_{i=1}^{k};\theta)_{\alpha}\mathrel{\raisebox{-1.29167pt}{}}\begin{cases}\frac{\psi(\mu^{1}+\theta\nu^{1},...,\mu^{k}+\theta\nu^{k})_{\alpha}-\psi(\mu^{1},...,\mu^{k})_{\alpha}}{\theta}&\theta>0\\ \sum_{i=1}^{k}\frac{\partial\psi(\mu^{1},...,\mu^{k})_{\alpha}}{\partial\mu^{i}_{\alpha}}\nu^{i}_{\alpha}&\theta=0.\end{cases} (84) Since is pseudo-Lipschitz by D.1, is pseudo-Lipschitz in all of its inputs as well.
the scalar type outputs at for any input can be implemented using the Moment instruction. The loss derivative given can be implemented with Moment instructions using \psi(-;f(\xi))=\mathcal{L}^{\prime}\big{(}f(\xi)\big{)} where is treated as a scalar type as in the first forward pass.
is implemented using a Nonlin+ instruction g_{t+1}(\xi)=\psi\big{(}g(\xi)\cup\{\delta g_{s}(\xi)\}_{s=1}^{t+1};\frac{1}{\sqrt{n}}\big{)} with .
is implemented using Moment and Nonlin+ instructions.
According to the rules as specified in B.3, we have the following identities:
If , then using Eqs. 79, 80 and 81: (Here , , and )
𝑡1\displaystyle Z^{\delta\tilde{g}_{t+1}} \displaystyle=Z^{W\delta\tilde{h}_{t+1}}-\chi_{t}\sum_{\mathtt{g},\mathtt{h}:\mathtt{g}=W\mathtt{h}}Z^{d\mathtt{g}_{t}}\operatorname*{\mathbb{E}}\big{[}Z^{\mathtt{h}_{t}}Z^{\tilde{h}}\big{]} (91) (92) • If , then using Eqs. 82 and 84, taking the limit ,
𝑡1superscriptsubscript𝑖1𝑘𝜓superscript𝑍superscript~ℎ1…superscript𝑍superscript~ℎ𝑘superscript𝑍superscript~ℎ𝑖superscript𝑍𝛿subscriptsuperscript~ℎ𝑖𝑡1\displaystyle Z^{\delta\tilde{g}_{t+1}}=\sum_{i=1}^{k}\frac{\partial\psi(Z^{\tilde{h}^{1}},...,Z^{\tilde{h}^{k}})}{\partial Z^{\tilde{h}^{i}}}Z^{\delta\tilde{h}^{i}_{t+1}}. (93) • Using Eqs. 86, 87, 89 and 90 and taking , we have by ZNonlin+:
D.2 Deriving The NTK
Instantiate paths and on two inputs by (abusing notation slightly). We define an inner product between them as follows:
superscript𝑍superscript𝑝𝑖superscript𝑍superscript𝑝𝑖1superscript𝑍superscript𝑞𝑖superscript𝑍superscript𝑞𝑖1\displaystyle\big{\langle}p,q\big{\rangle}\mathrel{\raisebox{-1.29167pt}{}}\operatorname*{\mathbb{E}}\big{[}Z^{p^{0}}Z^{q^{0}}\big{]}\prod_{i=2,even}^{|p|-2}\operatorname*{\mathbb{E}}\big{[}\frac{\partial Z^{p^{i}}}{\partial Z^{p^{i-1}}}\frac{\partial Z^{q^{i}}}{\partial Z^{q^{i-1}}}\big{]}. (95) where are X-vars for all even . Note that for even , is always of the form for some . So the partial derivatives in Eq. 95 are just .
For each weight , the gradient of the output with respect to is given by:
Here, represent nodes in program that can be instantiated by an input . The NTK of can be expressed as:
Using Eqs. 74 and 78, for any G-var , we can write the error term as the summation of errors signals over paths :
superscript𝑍superscript𝑝2superscript𝑍superscript𝑝1\displaystyle=Z^{(W^{p^{3}})^{\top}J^{p:3}}\frac{\partial{Z^{p^{2}}}}{\partial Z^{p^{1}}} (103) where denotes the derivative w.r.t. . By Simple GIA Check (Yang, 2020a), we have that (see ZMatMul). Hence, with abuse of notation , we have
superscript𝑍superscript𝑝2superscript𝑍superscript𝑝1superscript𝑍superscript𝑞2superscript𝑍superscript𝑞1\displaystyle\operatorname*{\mathbb{E}}\big{[}Z^{J^{p}}Z^{J^{q}}\big{]}=\operatorname*{\mathbb{E}}[\hat{Z}^{(W^{p^{3}})^{\top}J^{p:3}}\hat{Z}^{(W^{q^{3}})^{\top}J^{q:3}}]\operatorname*{\mathbb{E}}[\frac{\partial{Z^{p^{2}}}}{\partial Z^{p^{1}}}\frac{\partial{Z^{q^{2}}}}{\partial Z^{q^{1}}}]. (104) From the definition of ZHat, the expectation vanishes if the weights and are not symbolically the same (i.e ). Then by ZHat,
superscript𝑍superscript𝑝𝑖superscript𝑍superscript𝑝𝑖1superscript𝑍superscript𝑞𝑖superscript𝑍superscript𝑞𝑖1if p≅q0otherwise.\displaystyle\operatorname*{\mathbb{E}}\big{[}Z^{J^{p}}Z^{J^{q}}\big{]}=\begin{cases}\prod_{i=2,even}^{|p|-2}\operatorname*{\mathbb{E}}\big{[}\frac{\partial Z^{p^{i}}}{\partial Z^{p^{i-1}}}\frac{\partial Z^{q^{i}}}{\partial Z^{q^{i-1}}}\big{]}&\text{if }\\ 0&\text{otherwise.}\end{cases} (106) Combining with Eqs. 100, 74 and 102 proves D.3.
D.3 Getting Section 3
For the remainder of the proof we abbreviate , (i.e path is always evaluated on , while path is always evaluated on ). We prove Section 3 by inducting on all G-vars in the network. We begin by proving the following induction hypothesis.
We write to denote that is a linear combination of for various vectors .
At any time and G-var , the following holds:
𝑡1modulosubscript𝜒𝑡subscript:𝑝superscript𝑝1𝑔subscript:𝑞𝑞𝑝superscript𝑍𝑑superscript𝑞1𝑝𝑞superscript^𝑍𝑊∙Z^{\delta\tilde{g}_{t+1}}\equiv-\chi_{t}\sum_{p:p^{-1}=g}\sum_{q:q\cong p}Z^{dq^{-1}}\langle p,q\rangle\mod\hat{Z}^{W\bullet} (107) Here, the sum is over all paths with endpoint and all paths isomorphic to . Recall that is the (scaled) gradient where is the endpoint of .
D.3.1 Base Case
For initial G-vars , since we are not training the input layers (Assumption A1.). This proves the base case since the sum in Eq. 107 has no terms and thus is 0.
D.3.2 Inductive case
Suppose , where , we then have using Eq. 91:
𝑡1\displaystyle Z^{\delta\tilde{g}_{t+1}} \displaystyle\equiv\dot{Z}^{W\delta\tilde{h}_{t+1}}+\chi_{t}\sum_{\mathtt{g}=W\mathtt{h}}Z^{d\mathtt{g}_{t}}\operatorname*{\mathbb{E}}\big{[}Z^{\mathtt{h}_{t}}Z^{\tilde{h}}\big{]}\mod\hat{Z}^{W\bullet} (108) (109) Note \sum_{\mathtt{g}=W\mathtt{h}}Z^{d\mathtt{g}_{t}}\operatorname*{\mathbb{E}}\big{[}Z^{\mathtt{h}_{t}}Z^{\tilde{h}}\big{]} in Eq. 108 can be written as . Therefore, it suffices to show that
𝑡1subscript𝜒𝑡subscript:𝚙formulae-sequencesuperscript𝚙1𝑔𝚙4subscript𝚚𝚙superscript𝑍𝑑superscript𝚚1𝚙𝚚\dot{Z}^{W\delta\tilde{h}_{t+1}}=-\chi_{t}\sum_{\mathtt{p}:\mathtt{p}^{-1}=g,|\mathtt{p}|\geq 4}\sum_{\mathtt{q}\cong\mathtt{p}}Z^{d\mathtt{q}^{-1}}\langle\mathtt{p},\mathtt{q}\rangle. (110) Showing Eq. 110 By Eq. 82:
𝑡1superscriptsubscript𝑖1𝑘superscript𝑍~ℎsuperscript𝑍superscript~ℎ𝑖superscript𝑍𝛿subscriptsuperscript~ℎ𝑖𝑡1\displaystyle Z^{\delta\tilde{h}_{t+1}}=\sum_{i=1}^{k}\frac{\partial Z^{\tilde{h}}}{\partial Z^{\tilde{h}^{i}}}Z^{\delta\tilde{h}^{i}_{t+1}}. (111) Since do not depend on for any (by the assumption that we don’t use both a matrix and its transpose in the forward pass), from Eq. 111 we have for any :
superscript𝑍𝛿subscript~ℎ𝑡1superscript^𝑍superscript𝑊top𝑦superscriptsubscript𝑖1𝑘𝔼superscript𝑍~ℎsuperscript𝑍superscript~ℎ𝑖superscript𝑍𝛿subscriptsuperscript~ℎ𝑖𝑡1superscript^𝑍superscript𝑊top𝑦\displaystyle\operatorname*{\mathbb{E}}\big{[}\frac{\partial Z^{\delta\tilde{h}_{t+1}}}{\partial\hat{Z}^{W^{\top}y}}\big{]}=\sum_{i=1}^{k}\operatorname*{\mathbb{E}}\big{[}\frac{\partial Z^{\tilde{h}}}{\partial Z^{\tilde{h}^{i}}}\frac{\partial Z^{\delta\tilde{h}^{i}_{t+1}}}{\partial\hat{Z}^{W^{\top}y}}\big{]}. (112) Applying the induction hypothesis Eq. 107 to each G-var , we get
superscript𝑍𝛿subscriptsuperscript~ℎ𝑖𝑡1superscript^𝑍superscript𝑊top𝑦modulosubscript𝜒𝑡subscript:𝑝superscript𝑝1superscriptℎ𝑖subscript:𝑞𝑞𝑝superscript𝑍𝑑superscript𝑞1superscript^𝑍superscript𝑊top𝑦𝑝𝑞superscript^𝑍𝑊∙\frac{\partial Z^{\delta\tilde{h}^{i}_{t+1}}}{\partial\hat{Z}^{W^{\top}y}}=-\chi_{t}\sum_{p:p^{-1}=h^{i}}\sum_{q:q\cong p}\frac{\partial Z^{dq^{-1}}}{\partial\hat{Z}^{W^{\top}y}}\langle p,q\rangle\mod\hat{Z}^{W\bullet} (113) Plugging this back into (Eq. 109), we get
𝑡1subscript𝜒𝑡subscript𝑦superscript𝑍𝑦superscriptsubscript𝑖1𝑘subscript:𝑝superscript𝑝1superscriptℎ𝑖subscript:𝑞𝑞𝑝𝑝𝑞𝔼superscript𝑍~ℎsuperscript𝑍superscript~ℎ𝑖superscript𝑍𝑑superscript𝑞1superscript^𝑍superscript𝑊top𝑦\displaystyle\dot{Z}^{W\delta\tilde{h}_{t+1}}=-\chi_{t}\sum_{y}Z^{y}\sum_{i=1}^{k}\sum_{p:p^{-1}=h^{i}}\sum_{q:q\cong p}\langle p,q\rangle\operatorname*{\mathbb{E}}\big{[}\frac{\partial Z^{\tilde{h}}}{\partial Z^{\tilde{h}^{i}}}\frac{\partial Z^{dq^{-1}}}{\partial\hat{Z}^{W^{\top}y}}\big{]}. (114) Note that for any path with , we may extend by vectors (recall and ). Let denote this extension. If is a path such that
superscript𝑍𝑑superscript𝑞1superscript^𝑍superscript𝑊top𝑦superscript𝑍superscript𝚚2superscript𝑍superscript𝚚3\mathtt{q}\cong\mathtt{p}\quad\text{and}\quad\frac{\partial Z^{dq^{-1}}}{\partial\hat{Z}^{W^{\top}y}}=\frac{\partial Z^{\mathtt{q}^{-2}}}{\partial Z^{\mathtt{q}^{-3}}}, (115) then
superscript𝑍~ℎsuperscript𝑍superscript~ℎ𝑖superscript𝑍𝑑superscript𝑞1superscript^𝑍superscript𝑊top𝑦𝚙𝚚\displaystyle\langle p,q\rangle\operatorname*{\mathbb{E}}\big{[}\frac{\partial Z^{\tilde{h}}}{\partial Z^{\tilde{h}^{i}}}\frac{\partial Z^{dq^{-1}}}{\partial\hat{Z}^{W^{\top}y}}\big{]}=\langle\mathtt{p},\mathtt{q}\rangle. (116) Our goal now is to show in Eq. 114 can be extended appropriately such that we may rewrite Eq. 114 as Eq. 110. This will be done through explicitly computing the term in Eq. 114.
superscript𝑍𝑑superscript𝑞1superscript^𝑍superscript𝑊top𝑦\frac{\partial Z^{dq^{-1}}}{\partial\hat{Z}^{W^{\top}y}} Suppose are all G-vars in the program that depend on i.e for all we have where and where can be same or different matrices for different . Note that it follows that:
superscript𝑍superscript𝑧𝑗superscript𝑍superscript𝑞1superscript𝑍superscriptsuperscript𝑊𝑗top𝑑superscript𝑔𝑗\displaystyle=\sum_{j=1}^{r}\frac{\partial Z^{z^{j}}}{\partial Z^{q^{-1}}}Z^{(W^{j})^{\top}dg^{j}} (118) (119) Note that in Eq. 119 from the gradient independence assumption (GIA) because we pass the Simple GIA Check. This may also be easily verified by explicitly computing , and noticing that the expectation vanishes from the dependency of on (i.e for some vector which does not depend on ). Since does not depend on and the last layer is not trained, we have . Since we assumed that the forward propagation does not contain both , it follows from differentiating Eq. 118 that
superscript𝑍𝑑superscript𝑞1superscript^𝑍superscript𝑊top𝑦superscriptsubscript𝑗1𝑟superscript𝑍superscript𝑧𝑗superscript𝑍superscript𝑞1superscript^𝑍superscriptsuperscript𝑊𝑗top𝑑superscript𝑔𝑗superscript^𝑍superscript𝑊top𝑦subscript:𝑗formulae-sequence≜superscript𝑊𝑗𝑊𝑑superscript𝑔𝑗𝑦superscript𝑍superscript𝑧𝑗superscript𝑍superscript𝑞1\displaystyle\frac{\partial Z^{dq^{-1}}}{\partial\hat{Z}^{W^{\top}y}}=\sum_{j=1}^{r}\frac{\partial Z^{z^{j}}}{\partial Z^{q^{-1}}}\frac{\partial\hat{Z}^{(W^{j})^{\top}dg^{j}}}{\partial\hat{Z}^{W^{\top}y}}=\sum_{j:W^{j}\triangleq W,dg^{j}=y}\frac{\partial Z^{z^{j}}}{\partial Z^{q^{-1}}}. (120) If this sum over is nonempty, then there is a unique such that and . In such a case, we may extend the path with to form satisfying Eq. 115. Plugging back into Eq. 114 we obtain Eq. 110 as desired.
Hence, we have proven the induction hypothesis.
D.3.3 Proving Section 3 using the induction hypothesis
WLOG assume for some G-var . Using the induction hypothesis and the Master Theorem (B.4), we have that:
𝑡1subscript~𝑓𝑡𝔼superscript𝑍𝑣superscript𝑍𝛿subscript~𝑥𝑡1subscript𝜒𝑡subscript:𝑝superscript𝑝1𝑥subscript𝑞𝑞𝔼superscript𝑍𝑣superscript𝑍𝑑superscript𝑞1𝑝𝑞\displaystyle\lim_{n\to\infty}\tilde{f}_{t+1}-\tilde{f}_{t}=\operatorname*{\mathbb{E}}Z^{v}Z^{\delta\tilde{x}_{t+1}}=-\chi_{t}\sum_{p:p^{-1}=x}\sum_{q\cong q}\operatorname*{\mathbb{E}}\big{[}Z^{v}Z^{dq^{-1}}\big{]}\langle p,q\rangle. (121) Note that for any path . Hence, with Eq. 96, we have
𝑡1subscript~𝑓𝑡subscript𝜒𝑡subscript:𝑝superscript𝑝1𝑥subscript𝑞𝑝𝑝𝑞subscript𝜒𝑡̊𝒦subscript𝜉𝑡~𝜉\displaystyle\lim_{n\to\infty}\tilde{f}_{t+1}-\tilde{f}_{t}=-\chi_{t}\sum_{p:p^{-1}=x}\sum_{q\cong p}\langle p,q\rangle=-\chi_{t}\mathring{\mathcal{K}}(\xi_{t},\tilde{\xi}) (122) as desired.
D.4 Relaxing A1., A2., A3. and A4.
We now briefly discuss the case where A1., A2., A3. and A4. are relaxed, as well as the case where is represented by a program. As the proof of the general case follows roughly the same logic as in D.1, we only discuss the meaningful differences in each case.
Recall the input and output layers are parameterized by which now depend on . The output evolution is now given by:
𝑡1subscript~𝑓𝑡superscriptsubscript𝑉𝑡1topsubscript~𝑥𝑡1superscriptsubscript𝑉𝑡topsubscript~𝑥𝑡superscript𝑣top𝛿subscript~𝑥𝑡1𝑛𝑛superscriptsubscript𝑣𝑡1subscript𝑣𝑡topsubscript~𝑥𝑡𝑛superscriptsubscript𝑠0𝑡superscriptsubscript𝑣𝑠1subscript𝑣𝑠top𝛿subscript~𝑥𝑡1𝑛\displaystyle\tilde{f}_{t+1}-\tilde{f}_{t}=V_{t+1}^{\top}\tilde{x}_{t+1}-V_{t}^{\top}\tilde{x}_{t}=\frac{v^{\top}\delta\tilde{x}_{t+1}}{n}+\frac{\sqrt{n}(v_{t+1}-v_{t})^{\top}\tilde{x}_{t}}{n}+\sum_{s=0}^{t}\frac{(v_{s+1}-v_{s})^{\top}\delta\tilde{x}_{t+1}}{n}. (123) Plugging into Eq. 123 and taking the limit (using rules):
𝑡1subscript~𝑓𝑡𝔼superscript𝑍𝑣superscript𝑍𝛿subscript~𝑥𝑡1subscript𝜒𝑡𝔼superscript𝑍𝑥subscript𝜉𝑡superscript𝑍~𝑥\displaystyle\lim_{n\to\infty}\tilde{f}_{t+1}-\tilde{f}_{t}=\operatorname*{\mathbb{E}}\big{[}Z^{v}Z^{\delta\tilde{x}_{t+1}}\big{]}-\chi_{t}\operatorname*{\mathbb{E}}\big{[}Z^{x(\xi_{t})}Z^{\tilde{x}}\big{]}. (124) Evaluating \operatorname*{\mathbb{E}}\big{[}Z^{v}Z^{\delta\tilde{x}_{t+1}}\big{]} by induction requires altering the path definition so that each path may start with an input , and ends with a G-var (that is, a path either starts with an X-var or an input). We reuse the definition of inner product between in Eq. 95, only when both start with inputs respectively then \operatorname*{\mathbb{E}}\big{[}Z^{p^{0}}Z^{q^{0}}\big{]} implies . The remainder of the proof follows the same logic as with D.1. Note that the NTK in this case would yield:
D.4.2 W,W⊤𝑊superscript𝑊topW,W^{\top} in the forward pass
When both are allowed in the forward pass, the update equations for each take the form:
𝑡1subscript𝑤𝑡subscript𝜒𝑡subscript:𝑔ℎ𝑔𝑊ℎ𝑑subscript𝑔𝑡superscriptsubscriptℎ𝑡top𝑛subscript𝜒𝑡subscript:𝑔ℎ𝑔superscript𝑊topℎsubscriptℎ𝑡𝑑superscriptsubscript𝑔𝑡top𝑛\displaystyle w_{t+1}-w_{t}=-\chi_{t}\sum_{g,h:g=Wh}\frac{dg_{t}h_{t}^{\top}}{n}-\chi_{t}\sum_{g,h:g=W^{\top}h}\frac{h_{t}dg_{t}^{\top}}{n} (126) Some quick calculations using rules show that for G-vars:
It is straightforward to show using GIA (Yang, 2020a) that \operatorname*{\mathbb{E}}\big{[}Z^{d\mathtt{g}(\xi_{t})}Z^{\tilde{h}}\big{]}=0 in both cases, leaving us with a similar expression as with D.1. The induction hypothesis for G-vars in this case takes one of two forms:
If then Eq. 107 holds with replacing .
Some additional complications need to be resolved. Specifically, with setup D.1 we have used in two places the fact that no transpose is used in the forward pass to prove the induction hypothesis (see Eqs. 112 and 120). To prove the induction, and assuming , we now have instead of Eq. 112 (using Eq. 111):
superscript𝑍𝛿subscript~ℎ𝑡1superscript^𝑍superscript𝑊top𝑦superscriptsubscript𝑖1𝑘𝔼superscript𝑍~ℎsuperscript𝑍superscript~ℎ𝑖superscript𝑍𝛿subscriptsuperscript~ℎ𝑖𝑡1superscript^𝑍superscript𝑊top𝑦superscriptsubscript𝑖1𝑘𝔼superscript2superscript𝑍~ℎsuperscript𝑍superscript~ℎ𝑖superscript^𝑍superscript𝑊top𝑦superscript𝑍𝛿subscriptsuperscript~ℎ𝑖𝑡1\displaystyle\operatorname*{\mathbb{E}}\frac{\partial Z^{\delta\tilde{h}_{t+1}}}{\partial\hat{Z}^{W^{\top}y}}=\sum_{i=1}^{k}\operatorname*{\mathbb{E}}\big{[}\frac{\partial Z^{\tilde{h}}}{\partial Z^{\tilde{h}^{i}}}\frac{\partial Z^{\delta\tilde{h}^{i}_{t+1}}}{\partial\hat{Z}^{W^{\top}y}}\big{]}+\sum_{i=1}^{k}\operatorname*{\mathbb{E}}\big{[}\frac{\partial^{2}Z^{\tilde{h}}}{\partial Z^{\tilde{h}^{i}}\partial\hat{Z}^{W^{\top}y}}Z^{\delta\tilde{h}^{i}_{t+1}}\big{]}. (129) where are G-vars. To evaluate the additional term on the RHS of Eq. 129, we use the induction hypothesis to express :
2superscript𝑍~ℎsuperscript𝑍superscript~ℎ𝑖superscript^𝑍superscript𝑊top𝑦superscript𝑍𝛿subscriptsuperscript~ℎ𝑖𝑡1subscript𝜒𝑡subscript:𝑝superscript𝑝1superscriptℎ𝑖subscript𝑞𝑝𝑝𝑞𝔼superscript2superscript𝑍~ℎsuperscript𝑍superscript~ℎ𝑖superscript^𝑍superscript𝑊top𝑦superscript𝑍𝑑superscript𝑞1\displaystyle\operatorname*{\mathbb{E}}\big{[}\frac{\partial^{2}Z^{\tilde{h}}}{\partial Z^{\tilde{h}^{i}}\partial\hat{Z}^{W^{\top}y}}Z^{\delta\tilde{h}^{i}_{t+1}}\big{]}=-\chi_{t}\sum_{p:p^{-1}=h^{i}}\sum_{q\cong p}\langle p,q\rangle\operatorname*{\mathbb{E}}\big{[}\frac{\partial^{2}Z^{\tilde{h}}}{\partial Z^{\tilde{h}^{i}}\partial\hat{Z}^{W^{\top}y}}Z^{dq^{-1}}\big{]}. (130) Using GIA (Yang, 2020a), it is straight forward to show that the expectation on the RHS of Eq. 130 vanishes, leaving us with the first term on the RHS of Eq. 129, as with D.1. Note that the same logic may be applied in Eq. 120, concluding the proof.
D.4.3 Multiple outputs and arbitrary batchsize
We have used a scalar output and a batchsize of 1 throughout this paper. However, extending to multiple (finite) outputs and an arbitrary batchsize requires no additional arguments besides some additional notations. For example, the definition of path should now be altered to express dependency on multiple samples (if batchnorm is used for example). The proof however follows roughly the same logic in D.1.
D.4.4 X-var embedding
We assumed in our proof that , which represents the final embedding of is a G-var. However, extending the proof to the case where is an X-var is straightforward. Let where and are G-vars. Using the induction hypothesis, along with Eq. 93 yields:
𝑡1subscript~𝑓𝑡\displaystyle\lim_{n\to\infty}\tilde{f}_{t+1}-\tilde{f}_{t} \displaystyle=-\chi_{t}\operatorname*{\mathbb{E}}\big{[}Z^{v}Z^{\delta\tilde{x}_{t+1}}\big{]}=-\chi_{t}\sum_{i=1}^{k}\operatorname*{\mathbb{E}}\big{[}Z^{v}\frac{\partial Z^{\tilde{x}}}{\partial Z^{\tilde{h}^{i}}}Z^{\delta\tilde{h}_{t+1}}\big{]} (131) \displaystyle=-\chi_{t}\sum_{i=1}^{k}\operatorname*{\mathbb{E}}\big{[}Z^{v}\frac{\partial Z^{\tilde{x}}}{\partial Z^{\tilde{h}^{i}}}\sum_{p:p^{-1}=h^{i}}\sum_{q\cong p}Z^{dh^{i}(\xi_{t})}\langle p,q\rangle\big{]} (132) \displaystyle=-\chi_{t}\sum_{i=1}^{k}\sum_{p:p^{-1}=h^{i}}\sum_{q\cong p}\langle p,q\rangle\operatorname*{\mathbb{E}}\big{[}\frac{\partial Z^{\tilde{x}}}{\partial Z^{\tilde{h}^{i}}}\frac{\partial Z^{x(\xi_{t})}}{\partial Z^{h^{i}(\xi_{t})}}\big{]} (133) (134) It is straightforward to show that the expression for in Eq. 133 represents the NTK if this case.
D.4.5 Network specified by Netsor⊤+limit-fromNetsorsuperscripttop\textsc{Netsor}\top^{+}
\textsc{Netsor}\top^{+} If the network is more generally represented by a program instead of just a program, then our proof can be very simply modified to accommodate as follows: The new operation allowed in such a network is the production of a scalar through Moment, say . By a similar inductive argument as before, we will see that 1) for all and for all , so that ; 2) in the backward pass, any backpropagation through will zero out: For example, if is only used later in a Nonlin , then will converge to 0 because of GIA (as is linear in the final layer), and the error signal at times is the constant vector with entries , which is .
Therefore, we can treat any scalar produced through Moment as a constant fixed at initialization, and the notion of path from before carries over here without change (by assuming all nonlinearities with scalar parameters to be parameterless nonlinearities where the parameters are fixed). Then the same reasoning follows.