Tensor Programs IIb: Architectural Universality of Neural Tangent Kernel Training Dynamics

Greg Yang, Etai Littwin

Introduction

(Jacot et al., 2018)’s pioneering work showed that a multi-layer perceptron (MLP) trained by gradient descent (GD) evolves like a linear model. This spurred a flurry of research papers using this insight to tackle the core questions in deep learning theory, from optimization to generalization in both finite and infinite width regimes. (Jacot et al., 2018)’s argument consists of two observations:

For the output of a network f(ξ;w)f(\xi;w) with parameters ww given example ξ\xi, (Jacot et al., 2018) identified the kernel K(ξ,ξˉ)=f(ξ;w),f(ξˉ,w)\mathcal{K}(\xi,\bar{\xi})=\langle\nabla f(\xi;w),\nabla f(\bar{\xi},w)\rangle, known as the Neural Tangent Kernel (NTK). They showed that if ff is parametrized and initialized appropriately, then K\mathcal{K} converges to a deterministic kernel K˚\mathring{\mathcal{K}} as the width of the network tends to infinity.

As the infinitely wide network is trained by gradient descent, the NTK remains frozen in its initial state, and the network evolves as by kernel gradient descent with kernel K˚\mathring{\mathcal{K}}

In (Yang, 2020a), the ntkInit property was proven to hold for standard architectures, meaning any composition of MLPs, recurrent neural networks (RNN), LSTMs (Hochreiter & Schmidhuber, 1997), gated recurrent unit (GRU) (Cho et al., 2014), convolutions (Fukushima, 1980, 1975; Lecun et al., 1998, 2000; Rumelhart et al., 1986), residual connections (He et al., 2016; Huang et al., 2017), batch normalization (Ioffe & Szegedy, 2015), graph neural networks (Bruna et al., 2014; Defferrard et al., 2016; Duvenaud et al., 2015; Henaff et al., 2015; Kipf & Welling, 2017) and attention (Bahdanau et al., 2015; Vaswani et al., 2017), along with arbitrary weight sharing between components. More generally, it holds for any architecture expressible in a so-called Tensor Program (Yang, 2019b, a, 2020a, 2020b), of which the standard architectures are a subset. However, their reasoning is limited to initialization only.

A statement is architecturally universal if it holds for any reasonable neural architecture. This is an informal property, but here we will formalize it by taking reasonable to be “expressable in Tensor Programs.” By the expressiveness of such programs (Yang, 2019a, 2020a), architectural universality is a fairly robust notion that covers present (and, we expect, future) architectures comprehensively. In this terminology, (Yang, 2020a) showed that ntkInit is architecturally universal.

We show the architectural universality of the entire NTK theory by proving ntkTrain for the same architectures discussed above, including all standard architectures. In the process, we introduce a new graphical form of Tensor Programs that is both required in our proofs and useful for the pedagogy of Tensor Programs.

This paper follows (Yang, 2019b, a, 2020a, 2020b; Yang & Hu, 2020) in the series. While we number this paper “IIb” right after (Yang, 2020a), we actually need the complete theoretical foundation developed in III (Yang, 2020b). See Footnote 22 for more details.

Background

Let f(ξ;w)Rf(\xi;w)\in\mathbb{R} denote the (scalar) output of a neural network parameterized by ww, given example ξ\xi. To understand how the output changes with a slight change in the network parameters w0δww_{0}-\delta w, we may naively expand the network function using the first order Taylor expansion around a base point w0w_{0}:

Under the SGD algorithm, the weight update δw\delta w is given by the gradient δw=ηχ(ξ^)wf(ξ^;w0)\delta w=-\eta\chi(\hat{\xi})\nabla_{w}f(\hat{\xi};w_{0}) where χ(ξ^)\chi(\hat{\xi}) is the loss derivative, ξ^\hat{\xi} is a sample from the training set, and η\eta is the learning rate. Plugging into Eq. 1, we get:

where K(ξ,ξ^)=wf(ξ;w0),wf(ξ^;w0)\mathcal{K}(\xi,\hat{\xi})=\langle\nabla_{w}f(\xi;w_{0}),\nabla_{w}f(\hat{\xi};w_{0})\rangle is the NTK. The NTK theory of infinitely wide neural networks as first proposed by (Jacot et al., 2018) boils down to the the following observations: When the width of ff tend to infinity, the NTK K\mathcal{K} converges to a fixed kernel K˚\mathring{\mathcal{K}} at random initialization, independent of the specific instantiation of the weights, and remains frozen during the optimization process. Eq. 2 then gives an accurate description of the output evolution with if we substitue K\mathcal{K} with K˚\mathring{\mathcal{K}}. The seemingly complex optimization trajectory of SGD therefore reduce to the convex trajectory of kernel gradient descent with a time-independent kernel K˚\mathring{\mathcal{K}}. Consider the output of the network fRDf\in\mathbb{R}^{D} on the full training dataset. As shown in (Jacot et al., 2018), when the L2L2 loss is used the evolution of the output ftf_{t} at time tt under continuous time GD (i.e. gradient flow) takes a simple form:

where K˚RD×D\mathring{\mathcal{K}}\in\mathbb{R}^{D\times D} is the full NTK matrix evaluated on the training data, ff^{\star} is the label function, and f0f_{0} is the output at initialization. Hence, provided K˚\mathring{\mathcal{K}} is full rank, as tt\to\infty we have that ftff_{t}\to f^{\star}, and the network can fit the training data perfectly.

A common theme in showing ntkTrain for MLP is to derive high-probability bounds on the deviation of the NTK K\mathcal{K} from its initial value after training (e.g. Allen-Zhu et al. (2018); Du et al. (2018); Zou et al. (2018)).111 In the original NTK paper (Jacot et al., 2018), the limit is taken as each layer width goes to infinity sequentially, which already doesn’t make sense for weight-tied architectures like RNNs. Obtaining these bounds usually requires developing ad hoc methods on a per-architecture basis, hindering the scalability of the method to other settings. In the present work we take a more holistic approach, leveraging the recently developed Tensor Programs framework (Yang, 2019b, a, 2020a, 2020b). It consists of two layers of arguments: 1) The bottom layer analyzes how the distribution of (pre-)activations change throughout the course of training; this crucially leverages the mathematical machinery of the Tensor Programs Master Theorem.222In particular, we need to use the Master Theorem in (Yang, 2020b), so (Yang, 2020a) could not have obtained ntkTrain at the same time as ntkInit. 2) The top layer packages these insights systematically via the notion of paths so as to apply to any architecture expressible by a Tensor Program. We will illustrate 1) through examples in Section 3 and 2) through figures in Section 5.1.

In this paper, we will consider the architecture (including depth), data, and training time to be fixed as width nn\to\infty.333They will affect the rate of convergence to the infinite-width limit, but since we are only concerned with whether convergence occurs, they do not appear in our theorem statements here. We describe common notations used in the remainder of the paper. For simplicity, we will consider SGD with batch size 1 and learning rate η\eta (often set to 1 WLOG).444This generalizes readily to any batch size and learning rate. We use ξt\xi_{t} to denote the input and Lt\mathcal{L}_{t} to denote the loss function (absorbing the label) at step tt. More generally, subscript tt on any symbol means time tt. However, for brevity, we abuse notation and shorthand ftf_{t} for ft(ξt)f_{t}(\xi_{t}), and, for any (pre-)activation xx, xtx_{t} for xt(ξt)x_{t}(\xi_{t}).555We will not refer to the function xt:ξxt(ξ)x_{t}:\xi\to x_{t}(\xi) (likewise for ft,χtf_{t},\chi_{t}), so this abuse of notation should cause no confusion. We will also write χt\chi_{t} for the loss derivative Lt(ft)\mathcal{L}_{t}^{\prime}(f_{t}). For any vector x(ξ)x(\xi) we define \delta x_{t+1}(\xi)\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\sqrt{n}\big{(}x_{t+1}(\xi)-x_{t}(\xi)\big{)} and dx(\xi)\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\sqrt{n}\frac{\partial f(\xi)}{\partial x(\xi)}. We will track the evolution of ff on an arbitrary input ξ~\tilde{\xi}.666It might help to think of ξ~\tilde{\xi} as some test sample, but it can also fall in the training set. Similar to above, we shorthand x~t,f~t\tilde{x}_{t},\tilde{f}_{t} for xt(ξ~),ft(x~)x_{t}(\tilde{\xi}),f_{t}(\tilde{x}).

Motivating Examples

The purpose of this section is to illustrate our key ideas via simple, intuitive examples without diving into the specifics of Tensor Programs. In the process, we will gain insight into how randomness from initialization propagates over the course of training. As these examples intend to provide the reader with the proper intuition, we use informal arguments alone and relegate all formal statements to the appendix. For brevity, we will gloss over minor details or routine calculations, but interested readers can see Appendix A for these omissions.

It turns out that the random initialization and the overparametrization of weights cause each (pre-)activation vector xt(ξ)Rnx_{t}(\xi)\in\mathbb{R}^{n}, its gradient dxt(ξ)Rndx_{t}(\xi)\in\mathbb{R}^{n}, and its (scaled) change δxt(ξ)Rn\delta x_{t}(\xi)\in\mathbb{R}^{n} every time step tt to have roughly iid coordinates, not just initially but throughout training.777This is a consequence of the Tensor Program Master Theorem. Then, as we shall demonstrate through the examples below, to track the evolution of the neural network function, it suffices to track the evolution of the coordinate distributions of x(ξ)x(\xi), dx(ξ)dx(\xi), δx(ξ)\delta x(\xi). We write Zx(ξ)Z^{x(\xi)}, Zdx(ξ)Z^{dx(\xi)}, Zδx(ξ)RZ^{\delta x(\xi)}\in\mathbb{R} for the random variables corresponding to such coordinate distributions.888As we will explain below, different ZZs may correlate, reflecting correlations between corresponding vectors.

Our goal is to derive, from these insights, {restatable}claimclaimNTK In the large width limit, f~t=ft(ξ~)\tilde{f}_{t}=f_{t}(\tilde{\xi}) changes by

𝑡1subscript~𝑓𝑡subscript̊𝜒𝑡̊𝒦~𝜉subscript𝜉𝑡\displaystyle\lim_{n\to\infty}\tilde{f}_{t+1}-\tilde{f}_{t}=-\mathring{\chi}_{t}\mathring{\mathcal{K}}(\tilde{\xi},\xi_{t}) (3) at step tt, where K˚\mathring{\mathcal{K}} is the limiting NTK of the architecture and χ˚t=Lt(limnft)\mathring{\chi}_{t}=\mathcal{L}_{t}^{\prime}(\lim_{n}f_{t}) is the loss derivative.

We start with an example derivation for 1-hidden-layer MLP, before moving on to 2-hidden-layers, where the mathematics quickly become much more involved.

1 1 Hidden Layer

Consider a 1-hidden-layer network with nonlinearity ϕ\phi:

where ξRd, U=udRn×d, V=vnRn×1\xi\in\mathbb{R}^{d},~{}U=\frac{u}{\sqrt{d}}\in\mathbb{R}^{n\times d},~{}V=\frac{v}{\sqrt{n}}\in\mathbb{R}^{n\times 1}, for trainable parameter tensor uu, initialized iid from N(0,1)\mathcal{N}(0,1). In the interest of clarity we assume the output layer is not trained, and d=1d=1.

For a vector vRnv\in\mathbb{R}^{n}, let v=Θ(na)v=\Theta(n^{a}) mean that “vv has coordinates of order nan^{a} when nn is large”999 More rigorously, we mean that v2/n=Θ(n2a)\|v\|^{2}/n=\Theta(n^{2a}). Note this is different from the common interpretation that v=Θ(na)\|v\|=\Theta(n^{a}). ; likewise for o(na)o(n^{a}), etc. Recall the notations xt=xt(ξt),x~t=xt(ξ~),δxt=n(xtxt1)x_{t}=x_{t}(\xi_{t}),\tilde{x}_{t}=x_{t}(\tilde{\xi}),\delta x_{t}=\sqrt{n}(x_{t}-x_{t-1}) and likewise for ht,h~t,δhth_{t},\tilde{h}_{t},\delta h_{t}. The key insights are as follows:

It turns out x~t+1x~t=Θ(1/n)=o(1)\tilde{x}_{t+1}-\tilde{x}_{t}=\Theta(1/\sqrt{n})=o(1) so δx~t+1=n(x~t+1x~t))\delta\tilde{x}_{t+1}=\sqrt{n}(\tilde{x}_{t+1}-\tilde{x}_{t})) has Θ(1)\Theta(1) coordinates. Likewise for δh~t+1\delta\tilde{h}_{t+1}. Consequently, for any tt and input ξ\xi, by telescoping,

subscriptℎ0𝜉𝑜1\displaystyle=h_{0}(\xi)+o(1). (4) Using uf=1ndhtξt\nabla_{u}f=\frac{1}{\sqrt{n}}dh_{t}\xi_{t}^{\top} and dht=ϕ(ht)vdh_{t}=\phi^{\prime}(h_{t})\odot v, we have:

𝑡1direct-productsubscript𝜒𝑡superscriptsubscript𝜉𝑡top~𝜉superscriptitalic-ϕ′subscriptℎ𝑡𝑣\displaystyle\delta\tilde{h}_{t+1}=-\chi_{t}\xi_{t}^{\top}\tilde{\xi}\phi^{\prime}(h_{t})\odot v. (5) Also, \delta\tilde{x}_{t+1}=\sqrt{n}\big{(}\phi(\tilde{h}_{t}+\frac{\delta\tilde{h}_{t+1}}{\sqrt{n}})-\phi(\tilde{h}_{t})\big{)}. Since δh~t+1=Θ(1)\delta\tilde{h}_{t+1}=\Theta(1), by intuitive Taylor expansion, we have

𝑡1\displaystyle\delta\tilde{x}_{t+1} ϕ(h~t)δh~t+1.\displaystyle\approx\phi^{\prime}(\tilde{h}_{t})\odot\delta\tilde{h}_{t+1}. (6) The change in the output on example ξ~\tilde{\xi} from step tt to step t+1t+1 is given by:

By definition vv has iid coordinates. It turns out ht(ξ),δht(ξ)h_{t}(\xi),\delta h_{t}(\xi) (likewise for xx) all have approx. iid coordinates of size Θ(1)\Theta(1) as well.101010Technically, they have iid coordinates only after conditioning on the initial function (GP) f0f_{0}. Likewise, when we take expectation in this example (e.g. Eqs. 9 and LABEL:{eqn:dot}), it’s a conditional expectation of this kind. See Section D.1.1 for a rigorous treatment. However, to convey the main intuition, we gloss over this technicality here. Let Zδx~t,ZvZ^{\delta\tilde{x}_{t}},Z^{v} denote the random variables encoding the corresponding coordinate distributions; likewise for the other vectors. Note that Zδx~t,ZvZ^{\delta\tilde{x}_{t}},Z^{v} will in general be correlated, reflecting the coordinatewise correlation between vv and δx~t\delta\tilde{x}_{t}.

𝑡1subscript~𝑓𝑡𝔼superscript𝑍𝑣superscript𝑍𝛿subscript~𝑥𝑡1\displaystyle\phantom{{}={}}\lim_{n\to\infty}\tilde{f}_{t+1}-\tilde{f}_{t}=\operatorname*{\mathbb{E}}Z^{v}Z^{\delta\tilde{x}_{t+1}} (8) =χ˚tξtξ~Eϕ(Zh~t)ϕ(Zht)(Zv)2.\displaystyle=-\mathring{\chi}_{t}\xi_{t}^{\top}\tilde{\xi}\operatorname*{\mathbb{E}}\phi^{\prime}(Z^{\tilde{h}_{t}})\phi^{\prime}(Z^{h_{t}})(Z^{v})^{2}. (9) where χ˚t=Lt(limnft)\mathring{\chi}_{t}=\mathcal{L}_{t}^{\prime}(\lim_{n}f_{t}) as in Section 3.

By Eq. 4, in the nn\to\infty limit, Zh~t=Zh0(ξ~)Z^{\tilde{h}_{t}}=Z^{h_{0}(\tilde{\xi})} and Zht=Zh0(ξt)Z^{h_{t}}=Z^{h_{0}(\xi_{t})}. They are independent from ZvZ^{v} and jointly Gaussian with variances ξ~2,ξt2\|\tilde{\xi}\|^{2},\|\xi_{t}\|^{2} and covariance ξ~ξt\tilde{\xi}^{\top}\xi_{t}. So (using the initialization of vv to simplify E(Zv)2=1\operatorname*{\mathbb{E}}(Z^{v})^{2}=1),

𝑡1subscript~𝑓𝑡\displaystyle\lim_{n\to\infty}\tilde{f}_{t+1}-\tilde{f}_{t} =χ˚tξtξ~Eϕ(Zht)ϕ(Zh~)\displaystyle=-\mathring{\chi}_{t}\xi_{t}^{\top}\tilde{\xi}\operatorname*{\mathbb{E}}\phi^{\prime}(Z^{h_{t}})\phi^{\prime}(Z^{\tilde{h}}) (10) This can easily be seen to be Eq. 3 (recall we assumed for simplicity the output layer VV is not trained).

Our strategy so far has been computing the form of Zδx~tZ^{\delta\tilde{x}_{t}} and plugging it into Eq. 8 to compute the dynamics of the output in the limit. Note that our approach differs from previous work which mainly focus on proving a bound on the change of the NTK post training. As the architecture gets more complex, bounding the NTK movement becomes quite complex, but our approach easily scales due to the automation provided by the Tensor Programs framework (see Section 4).

In the previous example, the coordinate distribution Zx~tZ^{\tilde{x}_{t}} took a fairly simple form, which allowed us to intuitively compute the expectation EZx~tZv\operatorname*{\mathbb{E}}Z^{\tilde{x}_{t}}Z^{v}. Before introducing a method for computing coordinate distributions in a general architecture, we move on to a slightly more involved architecture, with the intention of highlighting the intuition behind the general case.

2 2 Hidden Layers

In this example we consider a model of the form:

where U=udRn×d, W=wnRn×n, V=vnRn×1U=\frac{u}{\sqrt{d}}\in\mathbb{R}^{n\times d},~{}W=\frac{w}{\sqrt{n}}\in\mathbb{R}^{n\times n},~{}V=\frac{v}{\sqrt{n}}\in\mathbb{R}^{n\times 1}, for trainable parameters u,wu,w, initialized iid from a normal distribution. As before we assume the last layer is not trained, and d=1d=1.

Again, we want to establish Section 3 for this model. As in the 1-hidden-layer example, the dynamics of the output in the limit is still given by Eq. 8. This time around, the second hidden layer adds nontrivial complexity when evaluating the expectation EZvZδx~t+1\operatorname*{\mathbb{E}}Z^{v}Z^{\delta\tilde{x}_{t+1}}. As we shall see, this complexity arises from the dependency of x~\tilde{x} on the n×nn\times n matrices ww and ww^{\top}, which will make it wrong to naively apply LLN arguments. Resolving this complexity will pave the way to a general strategy which we will then be able to apply in any arbitrary architecture. We now apply the same insights as in the 1-hidden-layer MLP. Namely:

Eq. 4 continues to hold with hh replaced by any of {x,h,z,g}\{x,h,z,g\}.

After some brief calculations, with dhtdh_{t} denoting the scaled gradient nhtf\sqrt{n}\nabla_{h_{t}}f,

𝑡1\displaystyle\delta\tilde{g}_{t+1} \displaystyle\approx-\chi_{t}\xi_{t}^{\top}\tilde{\xi}\phi^{\prime}(g_{t})\odot\big{(}W^{\top}dh_{t}\big{)} (11) δh~t+1\displaystyle\delta\tilde{h}_{t+1} Wδz~t+1χtztz~tnϕ(ht)v.\displaystyle\approx W\delta\tilde{z}_{t+1}-\chi_{t}\frac{z_{t}^{\top}\tilde{z}_{t}}{n}\phi^{\prime}(h_{t})\odot v. (12) • As in the 1-hidden-layer case, for all x{g,z,h,x}x\in\{g,z,h,x\}, xt(ξ),δxt(ξ)x_{t}(\xi),\delta x_{t}(\xi) have iid coordinates of size Θ(1)\Theta(1), as does vv by definition.111111Again, this is technically true only after conditioning on f0f_{0}; see Footnote 10. Let Zδx~t,ZvZ^{\delta\tilde{x}_{t}},Z^{v} denote the (generally correlated) random variables encoding the corresponding coordinate distributions.

As in Eq. 7, by naive Taylor expansion we have:

𝑡1\displaystyle\delta\tilde{z}_{t+1} ϕ(g~t)δg~t+1.\displaystyle\approx\phi^{\prime}(\tilde{g}_{t})\odot\delta\tilde{g}_{t+1}. (13) • Eqs. 7 and 6 in the 1-hidden-layer case continue to hold here. Then by Eq. 12 and Law of Large Numbers,

𝑡1subscript~𝑓𝑡𝔼superscript𝑍𝑣superscript𝑍𝛿subscript~𝑥𝑡1\displaystyle\phantom{{}={}}\lim_{n\to\infty}\tilde{f}_{t+1}-\tilde{f}_{t}=\operatorname*{\mathbb{E}}Z^{v}Z^{\delta\tilde{x}_{t+1}} (14) =χ˚tξtξ~E[ZztZz~]E[ϕ(Zht)ϕ(Zh~)]\displaystyle=-\mathring{\chi}_{t}\xi_{t}^{\top}\tilde{\xi}\operatorname*{\mathbb{E}}[Z^{z_{t}}Z^{\tilde{z}}]\operatorname*{\mathbb{E}}[\phi^{\prime}(Z^{h_{t}})\phi^{\prime}(Z^{\tilde{h}})] (15) =χ˚tξtξ~E[ϕ(Zh~t)ZWδz~t+1Zv].\displaystyle\phantom{{}={}}-\mathring{\chi}_{t}\xi_{t}^{\top}\tilde{\xi}\operatorname*{\mathbb{E}}[\phi^{\prime}(Z^{\tilde{h}_{t}})Z^{W\delta\tilde{z}_{t+1}}Z^{v}]. (16) where χ˚t=Lt(limnft)\mathring{\chi}_{t}=\mathcal{L}_{t}^{\prime}(\lim_{n}f_{t}) as in Section 3.

In this expression, the first term (Eq. 15) can easily be seen to correspond to the contribution from ww to the NTK. It remains to show that the second (Eq. 16) corresponds to the contribution from uu.

To do this, we must reason about the coordinate distribution of Wδz~t+1W\delta\tilde{z}_{t+1} (encoded by random variable ZWδz~t+1Z^{W\delta\tilde{z}_{t+1}}) and compute the expectation in Eq. 16. To understand why this represents a greater challenge than it might first appear, note that from δz~t+1ϕ(g~t)δg~t+1\delta\tilde{z}_{t+1}\approx\phi^{\prime}(\tilde{g}_{t})\odot\delta\tilde{g}_{t+1} (Eq. 13), the term Wδz~t+1{W\delta\tilde{z}_{t+1}} hides within itself a dependency on WdhtW^{\top}dh_{t} through δg~\delta\tilde{g} (Eq. 11). While at t=0t=0, we may assume WW^{\top} be independent of WW and obtain the correct results (Gradient Independent Assumption (Yang & Schoenholz, 2017; Yang, 2020a)), this is no longer the case for t>0t>0: ZWδz~t+1Z^{W\delta\tilde{z}_{t+1}} will be nontrivially correlated with ϕ(Zh~t)\phi^{\prime}(Z^{\tilde{h}_{t}}) and ZvZ^{v} (which would be false if WW^{\top} can be assumed independent of WW). We will give some intuition why later in Eq. 20. Now, what is this dependence exactly? {restatable}claimclaimZdot Based on the above discussion and some easy calculations, δz~t+1\delta\tilde{z}_{t+1} can be written as Φ(Wdht)\Phi(W^{\top}dh_{t}) for some Φ:RR\Phi:\mathbb{R}\to\mathbb{R} applied coordinatewise (which will depend on other vectors not of the form WW^{\top}\bullet). Then it turns out††footnotemark:

𝑡1𝐺superscript𝑍𝑑subscriptℎ𝑡𝔼superscript𝑍𝛿subscript~𝑧𝑡1superscript𝑍superscript𝑊top𝑑subscriptℎ𝑡\displaystyle Z^{W\delta\tilde{z}_{t+1}}=G+Z^{dh_{t}}\operatorname*{\mathbb{E}}\frac{\partial Z^{\delta\tilde{z}_{t+1}}}{\partial Z^{W^{\top}dh_{t}}}, (17) where GG is some Gaussian variable independent from ZvZ^{v}, and \frac{\partial Z^{\delta\tilde{z}_{t+1}}}{\partial Z^{W^{\top}dh_{t}}}\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\Phi^{\prime}(Z^{W^{\top}dh_{t}}).

Thus, from Eqs. 13 and 11, it follows that:

𝑡1\displaystyle Z^{\delta\tilde{z}_{t+1}} =χ˚tξtξ~ϕ(Zg~t)ϕ(Zgt)ZWdht\displaystyle=-\mathring{\chi}_{t}\xi_{t}^{\top}\tilde{\xi}\phi^{\prime}(Z^{\tilde{g}_{t}})\phi^{\prime}(Z^{g_{t}})Z^{W^{\top}dh_{t}} (18) EZδz~t+1ZWdht\displaystyle\operatorname*{\mathbb{E}}\frac{\partial Z^{\delta\tilde{z}_{t+1}}}{\partial Z^{W^{\top}dh_{t}}} =χ˚tξtξ~E[ϕ(Zg~t)ϕ(Zgt)].\displaystyle=-\mathring{\chi}_{t}\xi_{t}^{\top}\tilde{\xi}\operatorname*{\mathbb{E}}[\phi^{\prime}(Z^{\tilde{g}_{t}})\phi^{\prime}(Z^{g_{t}})]. (19) Plugging into Eqs. 14 and LABEL:{eqn:dot}, followed by some straightforward calculation, then yields Section 3.

Eq. 17 may appear cryptic at first, so let’s give some intuition using an example. Suppose Φ\Phi in Section 3.2 is actually identity. For brevity, we set x=dht,y=δz~t+1Rn\mathbf{x}=dh_{t},\mathbf{y}=\delta\tilde{z}_{t+1}\in\mathbb{R}^{n}. Then, following straightforward calculation, Wy=WWxW\mathbf{y}=WW^{\top}\mathbf{x} has coordinates

subscript𝛾𝛼subscript𝐱𝛾superscriptsubscript𝛽1𝑛subscript𝑊𝛼𝛽subscript𝑊𝛾𝛽superscriptsubscript𝛽1𝑛superscriptsubscript𝑊𝛼𝛽2subscript𝐱𝛼\displaystyle(W\mathbf{y})_{\alpha}=\sum_{\gamma\neq\alpha}\mathbf{x}_{\gamma}\sum_{\beta=1}^{n}W_{\alpha\beta}W_{\gamma\beta}+\sum_{\beta=1}^{n}(W_{\alpha\beta})^{2}\mathbf{x}_{\alpha} (20) Now, the second sum converges via LLN to xα\mathbf{x}_{\alpha} as nn\to\infty. On the other hand, the first sum will converge via CLT to N(0,limx2/n)\mathcal{N}(0,\lim\|\mathbf{x}\|^{2}/n). Thus, in terms of ZZs, we have

𝐺superscript𝑍𝐱𝐺superscript𝑍𝐱𝔼superscriptΦ′\displaystyle Z^{W\mathbf{y}}=G+Z^{\mathbf{x}}=G+Z^{\mathbf{x}}\operatorname*{\mathbb{E}}\Phi^{\prime} (21) for some Gaussian GG; this corresponds directly to Eq. 17.121212 This example was worked out in (Yang, 2020a, b) as well, though in different contexts. Readers needing more explanation may see those works. For general Φ\Phi, a similar intuition applies after Taylor expansion of Φ\Phi.

This 2-hidden-layer example proceeded much the same as the 1-hidden-layer case, with the main exception of analyzing the interaction of the n×nn\times n Gaussian matrix WW and WW^{\top} (Eq. 16) that occurs after taking at least 1 step of SGD. This was absent in the 1-hidden-layer case because each weight matrix has at most one side tending to infinity. Such analysis is crucial to obtaining the right results, as assuming WW^{\top} be independent from WW would imply ff does not move from initialization.131313One can see this easily by modifying our calculations above.

It turns out these two examples have essentially covered all of the core ideas needed to extend the analysis into arbitrary architectures. To formalize and scale up our calculations, we now turn to the Tensor Programs framework.

Tensor Programs

So far, our results have been obtained by unrolling the SGD updates on toy models with specific architectures, and using informal arguments. Obviously, these computations quickly become unmanageable when the architecture becomes more complex. The sheer amount of architectural innovations that have sprung up in recent years requires us to adopt a much more general formulation of our results. To that end, we adopt the Tensor Programs (TP) framework developed in (Yang, 2019a, 2020a, 2020b). In a nutshell, it provides a language for describing typical computations done in the context of neural networks, such as forward and backward propagation. It is simultaneously simple and expressive, covering all standard architectures (Yang, 2019a, 2020a). Here we review two basic forms of Tensor Programs, \textscNetsor\textsc{Netsor}\top and \textscNetsor+\textsc{Netsor}\top^{+}.

A \textscNetsor\textsc{Netsor}\top program is just a sequence of Rn\mathbb{R}^{n} vectors inductively generated via one of the following instructions from an initial set V\mathcal{V} of random Rn\mathbb{R}^{n} vectors and a set W\mathcal{W} of random n×nn\times n matrices

For x1,,xkRnx^{1},\ldots,x^{k}\in\mathbb{R}^{n} in the program and any ψ:RkR\psi:\mathbb{R}^{k}\to\mathbb{R}, we can generate ψ(x1,,xk)Rn\psi(x^{1},\ldots,x^{k})\in\mathbb{R}^{n}

Given WRn×nW\in\mathbb{R}^{n\times n} and xRnx\in\mathbb{R}^{n}, we can generate WxRnWx\in\mathbb{R}^{n} or WxRnW^{\top}x\in\mathbb{R}^{n}

We propose to represent a \textscNetsor\textsc{Netsor}\top program as a computational graph, where each node in the graph represents vectors (initial or generated), each (dashed) edge represents a MatMul, and each gate represents a Nonlin. For example, Fig. 1 shows the computation graphs expressing (the forward passes of) an MLP and an RNN. We can also express the backpropagation as well (see Fig. 6). Graphically, the initial vectors are the empty nodes with only one edge coming out, toward the direction of computation. The matrices correspond to (the labels of) the dashed edges. We can also define the output vectors to correspond to the nodes that have only one edge coming out, against the direction of computation.

Each \textscNetsor\textsc{Netsor}\top program can be thought of as computing a function (Rn)V×(Rn×n)WRY(\mathbb{R}^{n})^{\mathcal{V}}\times(\mathbb{R}^{n\times n})^{\mathcal{W}}\to\mathbb{R}^{\mathcal{Y}} taking an instantiation of the initial vectors V\mathcal{V} and matrices W\mathcal{W} and computing the values of all output vectors Y\mathcal{Y}. We can say a program represents a network if it computes the body of it (without the input and output layers), as exemplified by Fig. 1. This is formalized below.

Consider a neural network f:(Rd)kRf:(\mathbb{R}^{d})^{k}\to\mathbb{R} with input embedding matrices U1,,UkRn×dU^{1},\ldots,U^{k}\in\mathbb{R}^{n\times d} (not necessarily distinct) and readout matrix VRnV\in\mathbb{R}^{n}, so that f(ξ1,,ξk)=VΦ(U1ξ1,,Ukξk;Θ)f(\xi^{1},\ldots,\xi^{k})=V^{\top}\Phi(U^{1}\xi^{1},\ldots,U^{k}\xi^{k};\Theta) for some function Φ(x1,,xk;Θ)\Phi(x^{1},\ldots,x^{k};\Theta) with parameters Θ\Theta. We say a \textscNetsor\textsc{Netsor}\top program represents ff if it computes Φ\Phi (under some correspondence of VW\mathcal{V}\cup\mathcal{W} to {x1,,xk}Θ\{x^{1},\ldots,x^{k}\}\cup\Theta).141414 We only consider ff with scalar output for simplicity but generalization to multi-dimensional output is straightforward.

For example, the programs in Fig. 1 resp. represent a 3-hidden-layer MLP and an RNN running for 3 steps. Note that the initial vectors correspond to a combination of input embeddings (e.g. W1ξW^{1}\xi) and vector parameters (e.g. biases) and the matrices correspond to matrix parameters (e.g. weights).

Typically, the vectors (resp. matrices) in a program will be sampled iid like N(0,1)\mathcal{N}(0,1) (resp. N(0,1/n)\mathcal{N}(0,1/n)), corresponding to the “standard” initialization of neural networks.151515 In the original definition of (Yang, 2019a, 2020a, 2020b), the vectors can have correlations between them, but we can always rewrite these vectors as linear image of another set of uncorrelated vectors. In such cases, when nn\to\infty, a program behaves as follows, in a gist:

Any vector xRnx\in\mathbb{R}^{n} in the program has roughly iid coordinates. We write ZxZ^{x} for the random variable encoding this coordinate distribution. This ZxZ^{x} may be correlated with ZyZ^{y} for other vector yy in the program, such that, for example, limnxy/n=EZxZy\lim_{n\to\infty}x^{\top}y/n=\operatorname*{\mathbb{E}}Z^{x}Z^{y}.

Zψ(x1,,xk)=ψ(Zx1,,Zxk)Z^{\psi(x^{1},\ldots,x^{k})}=\psi(Z^{x^{1}},\ldots,Z^{x^{k}}).

Consider a matrix WRn×nW\in\mathbb{R}^{n\times n} in the program and any set of Rn\mathbb{R}^{n} vectors X\mathcal{\mathbf{X}} not dependent on vectors of the form WW^{\top}\bullet. Then the set of random variables {ZWx:xX}\{Z^{Wx}:x\in\mathcal{\mathbf{X}}\} are jointly Gaussian with mean zero and covariance Cov(ZWx,ZWx~)=EZxZx~\mathrm{Cov}(Z^{Wx},Z^{W\tilde{x}})=\operatorname*{\mathbb{E}}Z^{x}Z^{\tilde{x}} for any x,x~Xx,\tilde{x}\in\mathcal{\mathbf{X}}. If WˉW\bar{W}\neq W is another matrix in the program and Y\mathcal{\mathbf{Y}} is a set of such vectors w.r.t. Wˉ\bar{W}, then the set {ZWx:xX}\{Z^{Wx}:x\in\mathcal{\mathbf{X}}\} is independent from {ZW~y:yY}\{Z^{\tilde{W}y}:y\in\mathcal{\mathbf{Y}}\}.

For general xx, ZWxZ^{Wx} decomposes into a sum of a Gaussian part, identical to ZWxZ^{Wx} in the above case, and a correction term. This decomposition is a generalization of Eq. 17.

\textsc{Netsor}\top^{+} Programs (Yang, 2019a, 2020a) showed that \textscNetsor\textsc{Netsor}\top suffices to express the forward and backward passes of most architectures such as ResNet (with Batchnorm), but Transformer and other standard architectures require adding to \textscNetsor\textsc{Netsor}\top a new “averaging” instruction161616For readers familiar with Tensor Programs, this is equivalent in expressivity to Moment, which is a composition of a Nonlin and this averaging instruction. that returns the “empirical average” 1nα=1nxα\frac{1}{n}\sum_{\alpha=1}^{n}x_{\alpha} of a vector xRnx\in\mathbb{R}^{n}. In the nn\to\infty limit, this scalar converges to EZx\operatorname*{\mathbb{E}}Z^{x} as would be expected from the intuitions above. This extension of \textscNetsor\textsc{Netsor}\top (called \textscNetsor+\textsc{Netsor}\top^{+}) also allows us to express the network output and loss derivative (e.g. in contrast to Fig. 1), which will be a technical requirement for unrolling the entirety of SGD training inside a single program, a key step in the proof of our main result. See discussion of proof formalization in Section 5. We can say a \textscNetsor+\textsc{Netsor}\top^{+} program represents a network ff if it computes the body of ff.

Universality of Kernel Dynamics

(Yang, 2019a, 2020a) showed that any neural network of standard architecture is represented by a \textscNetsor+\textsc{Netsor}\top^{+} program. Moreover,

For a neural network as in 5.2 below, its Neural Tangent Kernel at initialization has a well-defined infinite-width limit K˚\mathring{\mathcal{K}}.

Suppose a neural network ff is represented by a \textscNetsor+\textsc{Netsor}\top^{+} program (in the sense of 4.2) whose Nonlin all have polynomially bounded derivatives.171717More generally, we can allow any pseudo-Lipschitz function here, but for simplicity we go with the statement in the main text. Adopt the NTK parametrization: for every matrix parameter WRn×nW\in\mathbb{R}^{n\times n} of ff, we factor W=1nwW=\frac{1}{\sqrt{n}}w where ww is the trainable parameter; likewise, for each input layer matrix UiRn×dU^{i}\in\mathbb{R}^{n\times d}, we factor Ui=1duiU^{i}=\frac{1}{\sqrt{d}}u^{i}, and likewise the output matrix V=1nvV=\frac{1}{\sqrt{n}}v, such that ui,vu^{i},v are trainable. Finally, we randomly initialize all trainable parameters iid as N(0,1)\mathcal{N}(0,1).

Our main result is to show that the SGD training of such a neural network described in 5.2 reduces to kernel gradient descent with kernel K˚\mathring{\mathcal{K}} in the infinite-width limit.

Consider training a network ff described in 5.2 via SGD with batch-size 1 and (WLOG) learning rate 1. Let ξt\xi_{t} be the input and Lt:RR\mathcal{L}_{t}:\mathbb{R}\to\mathbb{R} be the loss function (absorbing the label) at time tt. Suppose Lt\mathcal{L}_{t} is continuous for all tt. Then, for any ξ\xi and tt, ft(ξ)f_{t}(\xi) converges almost surely to a random variable f˚t(ξ)\mathring{f}_{t}(\xi) as width \to\infty, such that

𝑡1𝜉subscript̊𝑓𝑡𝜉̊𝒦𝜉subscript𝜉𝑡superscriptsubscriptℒ𝑡′subscript̊𝑓𝑡subscript𝜉𝑡\displaystyle\mathring{f}_{t+1}(\xi)-\mathring{f}_{t}(\xi)=-\mathring{\mathcal{K}}(\xi,\xi_{t})\mathcal{L}_{t}^{\prime}(\mathring{f}_{t}(\xi_{t})) (22) where K˚\mathring{\mathcal{K}} is the infinite-width NTK (at initialization) of the neural network.

The full proof of 5.3 is given Appendix D.

We briefly mention several ways our result can be easily extended. 0) Different batch sizes, learning rate schedules, and nonscalar outputs. 1) Variants of NTK parametrization. We can deal with any parametrization that scales the same way as NTK parametrization, e.g. weights are sampled like N(0,σ2)N(0,\sigma^{2}) for any σ\sigma, with the multipliers γ/fanin\gamma/fanin for any γ\gamma. 2) Variable width. In real networks, the width of different layers can often be different (e.g. in ResNet). Our result can be extended to the case where the widths tend to infinity at a fixed ratio, using the variable-width version of Tensor Programs (Yang, 2020b). 3) Unsupervised and other learning settings can be covered because their training and testing computation can be written into Tensor Programs. 4) Weight decay, momentum, and other optimizer tricks can be covered as well as they can be straightforwardly written into Tensor Programs, but in general the kernel will change from step to step in contrast to 5.3.

1 Proof Sketch of Special Case

To convey the main idea, we give a proof sketch of a simplified problem: we assume 1) the input, output layers and biases are not trained (the network has only Rn×n\mathbb{R}^{n\times n} matrices as trainable parameters); 2) the forward pass does not contain both a weight matrix WW and its transpose WW^{\top} (but a single matrix WW can still be used multiple times without being transposed); 3) input space is Rd\mathbb{R}^{d} (with k=1k=1), and f=VxRf=V^{\top}x\in\mathbb{R}; 4) the output vector xx is a G-var; 5) the network is represented by a \textscNetsor\textsc{Netsor}\top (instead of \textscNetsor+\textsc{Netsor}\top^{+}) program; 181818This is not common in deep learning, but appears in some weight-tied autoencoders (Li & Nguyen, 2019). In the appendix, we prove the general case with these simplifications lifted.

It turns out, every \textscNetsor\textsc{Netsor}\top can be simplified into a standard form of sorts, which greatly facilitates our proof.

In a \textscNetsor\textsc{Netsor}\top program, a G-var191919“G” because G-vars often are roughly Gaussian vectors is an initial vector or a vector created by MatMul, while an X-var is a vector created by Nonlin.202020Var is short for variable, as the vectors are considered variables in the program. In previous works, H-var refers to any vector in the program; we will not use this terminology. We define a reduced \textscNetsor\textsc{Netsor}\top program as a program in which only G-vars are allowed as inputs to a Nonlin, while only an X-var is allowed as input to a MatMul.

Observe that any \textscNetsor\textsc{Netsor}\top program may be trivially expressed as a reduced \textscNetsor\textsc{Netsor}\top program by: 1) collapsing chains of non-linearities which appear consecutively, and 2) insert a Nonlin operation with ψ(x)=x\psi(x)=x in between consecutive G-vars. Hence, we may safely assume that ff is representable by a reduced \textscNetsor\textsc{Netsor}\top program.

The examples of Sections 3.1 and 3.2 exposed several insights, such as the iid-coordinates intuition, important for proving 5.3. Now we discuss the one remaining key idea for scaling up to general architectures:

[Paths]defnPaths In a \textscNetsor\textsc{Netsor}\top program, a path pp starts with an X-var and ends with a G-var, alternating between X- and G-vars along the path. We write p0p^{0} for the starting X-var, p1p^{1} for the following G-var, and so on, as well as p1p^{-1} for the ending G-var (see Fig. 2 for a graphical illustration). For odd ii, let WpiW^{p^{i}} denote the defining matrix of G-var pip^{i}. For two equal length paths p,qp,q, we write pqp\cong q (path pp is isomorphic to path qq) if for all odd ii, WpiW^{p^{i}} is the same matrix as WqiW^{q^{i}}.212121Here we are talking about equality of symbols rather than equality of values of those symbols. In other words, we say path pp is isomorphic to path qq if their sequences of MatMul matrices are identical, (but the Nonlin don’t have to be, see Fig. 3 for a graphical illustration). Let p|p| denote the number of vectors in pp (this is always an even number).

The collection of paths pp starting with an X-var p0=xp^{0}=x and ending with a G-var hh describes all possible pathways of backpropagating an error signal dhdh at hh to an error signal dxdx at xx. Simultaneously, it also describes all possible pathways of forward propagating a change in xx to a change in hh.

Because the gradient Wf\nabla_{W}f of a weight WRn×nW\in\mathbb{R}^{n\times n} is the sum of outer products h,x:h=Wxdhx\sum_{h,x:h=Wx}dh\otimes x, summing over all G-vars hh and X-vars xx in the program with h=Wxh=Wx (where dhdh denotes hf\nabla_{h}f), we also have

𝑓superscript𝑝1superscript𝑝1superscript𝑝2superscript𝑝2superscript𝑝3…superscript𝑝2ℎ\displaystyle(J^{p})^{\top}=\frac{\partial f}{\partial p^{-1}}\times\frac{\partial p^{-1}}{\partial p^{-2}}\times\frac{\partial p^{-2}}{\partial p^{-3}}\times...\times\frac{\partial p^{2}}{\partial h} (24) i.e, JpJ^{p} denotes the error signal at hh from backpropagation through path pp, and pp ranges over all paths starting with p0=x,p1=hp^{0}=x,p^{1}=h and ending with the output node of the underlying program. Recall WW factors as 1nw\frac{1}{\sqrt{n}}w where ww is the trainable parameter, not WW. By the discussion above, updating ww with wf=1nWf\nabla_{w}f=\frac{1}{\sqrt{n}}\nabla_{W}f causes ff to change by

When every parameter ww is randomly initialized iid as N(0,1)\mathcal{N}(0,1), it turns out that Jp,Jpˉ\langle J^{p},J_{\bar{p}}\rangle will go to 0 as nn\to\infty unless ppˉp\cong\bar{p} (Fig. 3). If one think of JpJ^{p} as a product of random Gaussian matrices (interleaved with other matrices), then this is akin to the fact that, for a mixed moment M\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\operatorname*{\mathbb{E}}\prod_{i=1}^{r}Z_{\gamma(i)},\gamma:[r]\to[k] of standard iid Gaussians Z1,,ZkZ_{1},\ldots,Z_{k}, MM is nonzero iff every ZiZ_{i} appears an even number of times in the product. This means we can replace the 4 nested sums in Eq. 25 with the single sum ppˉ\sum_{p\cong\bar{p}} and rewrite x=p0,xˉ=pˉ0x=p^{0},\bar{x}=\bar{p}^{0}.

We have suppressed dependence on input in Eq. 25. Being more explicit about it and performing updates on all weights, we have

𝑡1𝜉subscript𝑓𝑡𝜉superscriptsubscriptℒ𝑡′subscript𝑓𝑡subscript𝜉𝑡subscript𝑝¯𝑝superscript𝐽𝑝𝜉superscript𝐽¯𝑝subscript𝜉𝑡superscript𝑝0𝜉superscript¯𝑝0subscript𝜉𝑡\begin{gathered}f_{t+1}(\xi)-f_{t}(\xi)\\ \approx-\mathcal{L}_{t}^{\prime}(f_{t}(\xi_{t}))\sum_{p\cong\bar{p}}\langle J^{p}(\xi),J^{\bar{p}}(\xi_{t})\rangle\langle p^{0}(\xi),\bar{p}^{0}(\xi_{t})\rangle.\end{gathered} (26) after taking a gradient step on input ξt\xi_{t} at initialization t=0t=0. (Here p0(ξ)p^{0}(\xi) denotes the vector p0p^{0} as a function of ξ\xi at initialization). However, Eq. 25 holds for general tt as well: The key insight is similar to Eq. 4 in the 1-hidden-layer example, that vectors p0(ξ)p^{0}(\xi), Jp(ξ)J^{p}(\xi), etc change vanishingly from their initial values as nn\to\infty, after any number of SGD steps. Because our arguments above only depend on inner products between vectors, this means that the error in Eq. 26 for t>0t>0 vanishes as nn\to\infty. Finally, at least heuristically, it is straightforward to show the RHS of Eq. 26 is the NTK via a series of calculations exemplified by those in Sections 3.1 and 3.2, to get Eq. 22.

While the core ideas discussed above are intuitive, making them rigorous at face value would be quite challenging. Instead we use the machinery offered by the Tensor Programs framework. The mechanics of the proof then goes as follows: 1) First we unroll SGD of ff into a \textscNetsor+\textsc{Netsor}\top^{+} program.222222 We note that this formalization crucially relies on \textscNetsor+\textsc{Netsor}\top^{+} and its Master Theorem from (Yang, 2020b) because the SGD unrolling cannot be done in \textscNetsor\textsc{Netsor}\top. The reason is that we need to express the output and loss derivatives of the network, which are scalars (or at least finite dimensional), and that cannot be done in a \textscNetsor\textsc{Netsor}\top program. Furthermore, the Master Theorem from (Yang, 2020a) only pertains to a specific type of programs that look like the first backpropagation after initialization. Thus, it cannot deal with the complete unrolling of SGD as we do here, which requires the more advanced Master Theorem from (Yang, 2020b). This is similar to the equations in Sections 3.1 and 3.2; the key here is to express δxt+1=n(xt+1xt)\delta x_{t+1}=\sqrt{n}(x_{t+1}-x_{t}) as a vector in the program, for any (pre-)activation xx. 2) We apply the \textscNetsor+\textsc{Netsor}\top^{+} Master Theorem (Yang, 2020b) to this program. This yields the coordinate distribution of each vector. The core insights here are demonstrated by the calculation with the ZZ random variables in Sections 3.1 and 3.2. 3) Finally, we need to show that (the rigorous version of) Eq. 26 indeed recovers the NTK and agrees with Eq. 22. This is done via an inductive (symbolic) computation, and the path concept in this section plays a key role here.

Related Works

The connection between kernel methods and neural networks has had a long history before its recent resurgence. The Gaussian Process (NNGP) view of wide neural networks, which characterizes the behaviour of training only the last layer of a wide neural network, has been studied in (Daniely et al., 2016; Hazan & Jaakkola, 2015; Roux & Bengio, 2007; Lee et al., 2018; Matthews et al., 2018; Hinton & Neal, 1995; Novak et al., 2019). Since the original NTK paper (Jacot et al., 2018), many works have informally derived the infinite-width NTK for various architectures such as CNNs (Arora et al., 2019), RNN (Alemohammad et al., 2020), attention (Hron et al., 2020), ensembles (Littwin et al., 2020b) and graph neural networks (Du et al., 2019), but none of them formally proved ntkInit or ntkTrain for those architectures. Finite width corrections to the NTK were derived for fully connected networks in (Hanin & Nica, 2019; Littwin et al., 2020a). The validity of the NTK theory was empirically studied in (Lee et al., 2019) for a variety of architectures.

The Tensor Program framework (Yang, 2019a, 2020a, 2020b) was introduced in an attempt to unify and generalize the NNGP/NTK theory to a broad range of architectures, eliminating the need to re-develop the theory for each new architecture. For example, (Yang, 2019a) proved the architectural universality of NNGP correspondence, while (Yang, 2020a) proved that of ntkInit. On the other hand, (Yang, 2020b) developed the most general machinery for Tensor Programs and as a corollary constructed a comprehensive theory of nonlinear random matrix theory, that, for example, can calculate the singular value distribution of a wide neural network of any architecture. Our proofs depend on the machinery of (Yang, 2020b) crucially, as discussed in Section 5.

Conclusion

New theories of deep learning almost always start with MLPs, rightly so as they are the simplest case and can often reveal the key insights more clearly. Of course, as deep learning itself is an applied field, one should always ask whether insights on MLPs extend to more general architectures, i.e. whether there is an architecturally universal extension of a proposed theory of MLPs. This is not always easy to answer.

In this paper, we showed that the NTK theory is architecturally universal, but more importantly, we showed that the Tensor Programs technique is a very powerful tool for answering the above question as a matter of routine. Looking forward, we hope to apply it to generate more novel and general insights.

References

The appendix is organized as follows: In Appendix A we expands upon the examples given in Section 3, while adding some additional details. In Appendix B we introduce the formal version of the \textscNetsor\textsc{Netsor}\top,\textscNetsor+\textsc{Netsor}\top^{+} programs. In Appendix C we introduce the graphical notation of \textscNetsor+\textsc{Netsor}\top^{+} and demonstrate other examples of architectures or computations expressible in Tensor Programs. In Appendix D we prove our main result.

For the readers convenience we restate the notations described in Section 2, along with some additional ones which will be used throughout the appendix. We will consider SGD with batch size 1 and learning rate of 1 (WLOG).We use ξt\xi_{t} to denote the input and Lt\mathcal{L}_{t} to denote the loss function (absorbing the label) at step tt. More generally, subscript tt on any symbol means time tt. However, for brevity, we abuse notation and shorthand ftf_{t} for ft(ξt)f_{t}(\xi_{t}), and, for any (pre-)activation xx, xtx_{t} for xt(ξt)x_{t}(\xi_{t}). We will also write χt\chi_{t} for the loss derivative Lt(ft)\mathcal{L}_{t}^{\prime}(f_{t}). For any vector x(ξ)x(\xi) we define \delta x_{t+1}(\xi)\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\sqrt{n}\big{(}x_{t+1}(\xi)-x_{t}(\xi)\big{)} and dx(\xi)\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\sqrt{n}\frac{\partial f(\xi)}{\partial x(\xi)}. We will track the evolution of ff on an arbitrary input ξ~\tilde{\xi}.232323It might help to think of ξ~\tilde{\xi} as some test sample, but it can also fall in the training set. Similar to above, we shorthand x~t,f~t\tilde{x}_{t},\tilde{f}_{t} for xt(ξ~),ft(x~)x_{t}(\tilde{\xi}),f_{t}(\tilde{x}). In general, omitting the time index tt for any time dependent quantity implies its value at initialization. (i.e x(ξ)=x0(ξ),f(ξ)=f0(ξ)x(\xi)=x_{0}(\xi),f(\xi)=f_{0}(\xi)). Finally, we use \triangleq to imply equality of symbols (i.e W1W2W^{1}\triangleq W^{2} iff W1,W2W^{1},W^{2} represent the same variable, as opposed to equality in value).

Appendix A Additional Examples

In this section we flesh out the examples given in Section 3 of the main text with the purpose of adding additional clarity, while maintaining the intuitive arguments as presented in each example to perform these calculations. The rigorous justification for these calculations will be given in the following section with the formal introduction of the Tensor Program framework.

Recall that our objective is to derive Section 3 by tracking the coordinate distribution of each (pre-)activations vector x(\xi),dx(\xi)\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\sqrt{n}\frac{\partial f(\xi)}{\partial x(\xi)},\delta x(\xi)\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\sqrt{n}\big{(}x_{t+1}(\xi)-x_{t}(\xi)\big{)}. \claimNTK*

In our calculations we will rely on the following rules relating to the coordinates of any Rn\mathbb{R}^{n} (pre-)activation vector x(ξ)x(\xi) in the large width regime, which we will later formalize:

xt+1(ξ)xt(ξ)x_{t+1}(\xi)-x_{t}(\xi) has Θ(1n)\Theta(\frac{1}{\sqrt{n}}) coordinates.

δxt+1(ξ)\delta x_{t+1}(\xi) has Θ(1)\Theta(1) coordinates.

xt(ξ)=x(ξ)+o(1)x_{t}(\xi)=x(\xi)+o(1). Consequently Zxt(ξ)=Zx(ξ)Z^{x_{t}(\xi)}=Z^{x(\xi)}.

If x(\xi)=\phi\big{(}y(\xi)\big{)} for some vector y(ξ)Rny(\xi)\in\mathbb{R}^{n}, then by Taylor approximation \delta x_{t+1}(\xi)=\sqrt{n}\big{(}\phi(y_{t}(\xi)+\frac{\delta y_{t+1}(\xi)}{\sqrt{n}})-\phi(y_{t}(\xi))\big{)}\approx\phi^{\prime}\big{(}y_{t}(\xi)\big{)}\odot\delta y_{t+1}(\xi). Consequently Zδxt+1(ξ)=ϕ(Zyt(ξ))Zδyt+1(ξ)Z^{\delta x_{t+1}(\xi)}=\phi^{\prime}(Z^{y_{t}(\xi)})Z^{\delta y_{t+1}(\xi)}.

We write ZxZ^{x} to denote the limit coordinate distribution of xRnx\in\mathbb{R}^{n} conditioned on the output function ff at initialization. Consequently we write EXxZy\operatorname*{\mathbb{E}}X^{x}Z^{y} to express a conditional expectation given the output function ff. See Appendix B for the formal statement.

where ξRd, U=udRn×d, V=vnRn×1\xi\in\mathbb{R}^{d},~{}U=\frac{u}{\sqrt{d}}\in\mathbb{R}^{n\times d},~{}V=\frac{v}{\sqrt{n}}\in\mathbb{R}^{n\times 1}, for trainable parameter tensor uu, initialized iid from N(0,1)\mathcal{N}(0,1), and d=1d=1.

The infinite width NTK of this architecture is given by:

𝑓𝜉ℎ𝜉direct-productsuperscriptitalic-ϕ′ℎ𝜉𝑣\displaystyle\mathrel{\raisebox{-1.29167pt}{=def\mathbin{\overset{\text{\tiny{def}}}{=}}}}\sqrt{n}\frac{\partial f(\xi)}{\partial h(\xi)}=\phi^{\prime}(h(\xi))\odot v. (28) Therefore, by LLN it follows:242424While in A.1, we said E\operatorname*{\mathbb{E}} denotes expectation conditioned on limf0\lim f_{0}, the NTK here does not actually depend on limf0\lim f_{0}.

To show that Section 3 holds with the kernel in Eq. 29, we track the coordinate distribution Zδx~t+1Z^{\delta\tilde{x}_{t+1}} at each step of SGD. At step tt, the update to the weights ut+1utu_{t+1}-u_{t} is given by the gradient of the loss with respect to utu_{t}:

𝑡1subscript𝑢𝑡subscript𝜒𝑡𝑑subscriptℎ𝑡superscriptsubscript𝜉𝑡top𝑛𝑑subscriptℎ𝑡direct-productsuperscriptitalic-ϕ′subscriptℎ𝑡𝑣\displaystyle u_{t+1}-u_{t}=-\chi_{t}\frac{dh_{t}\xi_{t}^{\top}}{\sqrt{n}},~{}~{}~{}dh_{t}=\phi^{\prime}(h_{t})\odot v (30) Recall that δh~t+1=n(h~t+1h~t)=n(ut+1ξ~utξ~)\delta\tilde{h}_{t+1}=\sqrt{n}(\tilde{h}_{t+1}-\tilde{h}_{t})=\sqrt{n}(u_{t+1}\tilde{\xi}-u_{t}\tilde{\xi}) and δx~t+1=n(x~t+1x~t)\delta\tilde{x}_{t+1}=\sqrt{n}(\tilde{x}_{t+1}-\tilde{x}_{t}). It therefore follows:

𝑡1direct-productsubscript𝜒𝑡superscriptsubscript𝜉𝑡top~𝜉superscriptitalic-ϕ′subscriptℎ𝑡𝑣𝛿subscript~𝑥𝑡1𝑛italic-ϕsubscript~ℎ𝑡𝛿subscript~ℎ𝑡1𝑛italic-ϕsubscript~ℎ𝑡\displaystyle\delta\tilde{h}_{t+1}=-\chi_{t}\xi_{t}^{\top}\tilde{\xi}\phi^{\prime}(h_{t})\odot v,~{}~{}~{}\delta\tilde{x}_{t+1}=\sqrt{n}\big{(}\phi(\tilde{h}_{t}+\frac{\delta\tilde{h}_{t+1}}{\sqrt{n}})-\phi(\tilde{h}_{t})\big{)}. (31) Since ht=Θ(1)h_{t}=\Theta(1) and δh~t+1=Θ(1)\delta\tilde{h}_{t+1}=\Theta(1), for large nn we may Taylor expand ϕ\phi to first order around h~t\tilde{h}_{t}:

𝑡1\displaystyle\delta\tilde{x}_{t+1} \displaystyle\approx\sqrt{n}\big{(}\phi(\tilde{h}_{t})+\frac{1}{\sqrt{n}}\phi^{\prime}(\tilde{h}_{t})\odot\delta\tilde{h}_{t+1}-\phi(\tilde{h}_{t})\big{)} (32) =ϕ(h~t)δh~t+1\displaystyle=\phi^{\prime}(\tilde{h}_{t})\odot\delta\tilde{h}_{t+1} (33) =χtξtξ~ϕ(h~t)ϕ(ht)v.\displaystyle=-\chi_{t}\xi_{t}^{\top}\tilde{\xi}\phi^{\prime}(\tilde{h}_{t})\odot\phi^{\prime}(h_{t})\odot v. (34) Again since δht(ξ)=Θ(1)\delta h_{t}(\xi)=\Theta(1), it follows that ht(ξ)=h(ξ)+s=1tδhs(ξ)n=h(ξ)+o(1)h_{t}(\xi)=h(\xi)+\sum_{s=1}^{t}\frac{\delta h_{s}(\xi)}{\sqrt{n}}=h(\xi)+o(1). Hence, in the infinite width limit the coordinate distribution of ht(ξ)h_{t}(\xi) is identical to the coordinate distribution of h(ξ)h(\xi) (i.e Zht(ξ)=Zh(ξ)Z^{h_{t}(\xi)}=Z^{h(\xi)}). Using Eq. 34, the coordinate distribution of δx~t+1\delta\tilde{x}_{t+1} is given by:

𝑡1subscript̊𝜒𝑡superscriptsubscript𝜉𝑡top~𝜉superscriptitalic-ϕ′superscript𝑍subscript~ℎ𝑡superscriptitalic-ϕ′superscript𝑍subscriptℎ𝑡superscript𝑍𝑣\displaystyle Z^{\delta\tilde{x}_{t+1}}=-\mathring{\chi}_{t}\xi_{t}^{\top}\tilde{\xi}\phi^{\prime}(Z^{\tilde{h}_{t}})\phi^{\prime}(Z^{h_{t}})Z^{v}. (35) In the large width limit, the change in the output is simply given by f~t+1f~t=vδx~t+1n=EZδx~t+1Zv\tilde{f}_{t+1}-\tilde{f}_{t}=\frac{v^{\top}\delta\tilde{x}_{t+1}}{n}=\operatorname*{\mathbb{E}}Z^{\delta\tilde{x}_{t+1}}Z^{v}. Using Eq. 35 and the independence of ZvZ^{v} from the other random variables,252525 Again, as in Footnote 10, the expectations in Sections A.1 and A.2 should more rigorously be interpreted as expectation conditional on the function values of limf0\lim f_{0}.

A.2 2 hidden layers

where U=udRn×d, W=wnRn×n, V=vnRn×1U=\frac{u}{\sqrt{d}}\in\mathbb{R}^{n\times d},~{}W=\frac{w}{\sqrt{n}}\in\mathbb{R}^{n\times n},~{}V=\frac{v}{\sqrt{n}}\in\mathbb{R}^{n\times 1}, for trainable parameters u,wu,w, initialized iid from a normal distribution. As before we assume the last layer is not trained, and d=1d=1.

subscript∇𝑢𝑓𝜉subscript∇𝑢𝑓~𝜉subscript∇𝑤𝑓𝜉subscript∇𝑤𝑓~𝜉\displaystyle=\langle\nabla_{u}f(\xi),\nabla_{u}f(\tilde{\xi})\rangle+\langle\nabla_{w}f(\xi),\nabla_{w}f(\tilde{\xi})\rangle (39) =ξξ~dg(ξ)dg(ξ~)n+z(ξ)z(ξ~)ndh(ξ)dh(ξ~)n\displaystyle=\xi^{\top}\tilde{\xi}\frac{dg(\xi)^{\top}dg(\tilde{\xi})}{n}+\frac{z(\xi)^{\top}z(\tilde{\xi})}{n}\frac{dh(\xi)^{\top}dh(\tilde{\xi})}{n} (40) dh(ξ)\displaystyle dh(\xi) \displaystyle\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\sqrt{n}\frac{\partial f(\xi)}{\partial h(\xi)}=\phi^{\prime}\big{(}h(\xi)\big{)}\odot v (41) dg(ξ)\displaystyle dg(\xi) \displaystyle\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\sqrt{n}\frac{\partial f(\xi)}{\partial g(\xi)}=\phi^{\prime}\big{(}g(\xi)\big{)}\odot\big{(}W^{\top}dh(\xi)\big{)}. (42) Naively using LLN on Eq. 39 (and ZvZ^{v} being independent from everything else) should result in:

superscript𝜉top~𝜉𝔼superscript𝑍𝑑𝑔𝜉superscript𝑍𝑑𝑔~𝜉𝔼superscript𝑍𝑧𝜉superscript𝑍𝑧~𝜉𝔼superscriptitalic-ϕ′superscript𝑍ℎ𝜉superscriptitalic-ϕ′superscript𝑍ℎ~𝜉\displaystyle\mathcal{K}(\xi,\tilde{\xi})=\xi^{\top}\tilde{\xi}\operatorname*{\mathbb{E}}\big{[}Z^{dg(\xi)}Z^{dg(\tilde{\xi})}\big{]}+\operatorname*{\mathbb{E}}\big{[}Z^{z(\xi)}Z^{z(\tilde{\xi})}\big{]}\operatorname*{\mathbb{E}}\big{[}\phi^{\prime}(Z^{h(\xi)})\phi^{\prime}(Z^{h(\tilde{\xi})})\big{]}. (43) Evaluating the term \operatorname*{\mathbb{E}}\big{[}Z^{dg(\xi)}Z^{dg(\tilde{\xi})}\big{]} however presents a challenge since dh(ξ)dh(\xi) depends on both WW and WW^{\top}. As it turns out, at initialization we may naively assume that W,WW^{\top},W are independent (formally known in the literature as gradient independence assumption, or GIA) 262626For a rigorous justification of the GIA assumption see (Yang, 2020a), we arrive using simple LLN arguments to:

Plugging Eq. 44 into Eq. 43 we arrive at the correct expression for the infinite width NTK.

To show that Section 3 holds at any step tt (where we may not assume that GIA holds), we track the distributions of the vectors g(ξ),z(ξ),h(ξ),x(ξ)g(\xi),z(\xi),h(\xi),x(\xi) throughout training.

At any step tt the weights are updated according to:

𝑡1subscript𝑢𝑡subscript𝜒𝑡𝑑subscript𝑔𝑡superscriptsubscript𝜉𝑡top𝑛subscript𝑤𝑡1subscript𝑤𝑡subscript𝜒𝑡𝑑subscriptℎ𝑡superscriptsubscript𝑧𝑡top𝑛\displaystyle u_{t+1}-u_{t}=-\chi_{t}\frac{dg_{t}\xi_{t}^{\top}}{\sqrt{n}},~{}~{}~{}w_{t+1}-w_{t}=-\chi_{t}\frac{dh_{t}z_{t}^{\top}}{n}. (45) The update \delta\tilde{g}_{t+1}\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\sqrt{n}(\tilde{g}_{t+1}-\tilde{g}_{t}),\delta\tilde{z}_{t+1}\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\sqrt{n}(\tilde{z}_{t+1}-\tilde{z}_{t}) are given by:

As before, with large nn we have that t+1(ξ)t(ξ)Θ(1n)\bullet_{t+1}(\xi)-\bullet_{t}(\xi)\sim\Theta(\frac{1}{\sqrt{n}}) and δt+1(ξ)Θ(1)\delta\bullet_{t+1}(\xi)\sim\Theta(1) coordinates for \bullet replaced by {g,z,h,x}\{g,z,h,x\}. And so after Taylor expanding ϕ(g~t+δg~t+1n)\phi(\tilde{g}_{t}+\frac{\delta\tilde{g}_{t+1}}{\sqrt{n}}) around g~t\tilde{g}_{t}:

𝑡1direct-productsuperscriptitalic-ϕ′subscript~𝑔𝑡𝛿subscript~𝑔𝑡1\displaystyle\delta\tilde{z}_{t+1}\approx\phi^{\prime}(\tilde{g}_{t})\odot\delta\tilde{g}_{t+1}. (47) In a similar fashion, using Eqs. 41 and 46 the updates δz~t+1,δx~t+1,δx~t+1\delta\tilde{z}_{t+1},\delta\tilde{x}_{t+1},\delta\tilde{x}_{t+1} take the form:

𝑡1\displaystyle\delta\tilde{z}_{t+1} \displaystyle\approx\phi^{\prime}(g_{t})\odot\delta\tilde{g}_{t+1}=-\chi_{t}\xi_{t}^{\top}\tilde{\xi}\phi^{\prime}(g_{t})\odot\phi^{\prime}(\tilde{g}_{t})\odot\big{(}W^{\top}dh_{t}\big{)} (48) δh~t+1\displaystyle\delta\tilde{h}_{t+1} Wδz~t+1χtztz~tnϕ(ht)v1ns=0tχszsδz~t+1nϕ(hs)v\displaystyle\approx W\delta\tilde{z}_{t+1}-\chi_{t}\frac{z_{t}^{\top}\tilde{z}_{t}}{n}\phi^{\prime}(h_{t})\odot v-\frac{1}{\sqrt{n}}\sum_{s=0}^{t}\chi_{s}\frac{z_{s}^{\top}\delta\tilde{z}_{t+1}}{n}\phi^{\prime}(h_{s})\odot v (49) δx~t+1\displaystyle\delta\tilde{x}_{t+1} ϕ(h~t)δh~t+1.\displaystyle\approx\phi^{\prime}(\tilde{h}_{t})\odot\delta\tilde{h}_{t+1}. (50) where we used Eq. 45 and

𝑡1\displaystyle\delta\tilde{h}_{t+1} =Wt+1δz~t+1+n(Wt+1Wt)z~t\displaystyle=W_{t+1}\delta\tilde{z}_{t+1}+\sqrt{n}(W_{t+1}-W_{t})\tilde{z}_{t} (51) =Wδz~t+1+s=0t(Ws+1Ws)δz~t+1+n(Wt+1Wt)z~t\displaystyle=W\delta\tilde{z}_{t+1}+\sum_{s=0}^{t}(W_{s+1}-W_{s})\delta\tilde{z}_{t+1}+\sqrt{n}(W_{t+1}-W_{t})\tilde{z}_{t} (52) to get Eq. 49. Based on Eqs. 41, 48, 49 and 50, the corresponding coordinate distributions take the form:

𝑡1\displaystyle Z^{\delta\tilde{z}_{t+1}} =χ˚tξtξ~ϕ(Zgt)ϕ(Zg~t)ZWdht\displaystyle=-\mathring{\chi}_{t}\xi_{t}^{\top}\tilde{\xi}\phi^{\prime}(Z^{g_{t}})\phi^{\prime}(Z^{\tilde{g}_{t}})Z^{W^{\top}dh_{t}} (54) Zδh~t+1\displaystyle Z^{\delta\tilde{h}_{t+1}} =ZWδz~t+1χ˚tE[ZztZz~]ϕ(Zht)Zv\displaystyle=Z^{W\delta\tilde{z}_{t+1}}-\mathring{\chi}_{t}\operatorname*{\mathbb{E}}[Z^{z_{t}}Z^{\tilde{z}}]\phi^{\prime}(Z^{h_{t}})Z^{v} (55) Zδx~t+1\displaystyle Z^{\delta\tilde{x}_{t+1}} =ϕ(Zh~t)Zδh~t+1.\displaystyle=\phi^{\prime}(Z^{\tilde{h}_{t}})Z^{\delta\tilde{h}_{t+1}}. (56) As before, the functional update is given by limnf~t+1f~t=limnvδx~t+1n=EZvZδx~t+1\lim_{n\to\infty}\tilde{f}_{t+1}-\tilde{f}_{t}=\lim_{n\to\infty}\frac{v^{\top}\delta\tilde{x}_{t+1}}{n}=\operatorname*{\mathbb{E}}Z^{v}Z^{\delta\tilde{x}_{t+1}}. Plugging Eqs. 56 and 55:

𝑡1subscript̊𝜒𝑡𝔼superscript𝑍subscript𝑧𝑡superscript𝑍~𝑧𝔼superscriptitalic-ϕ′superscript𝑍subscriptℎ𝑡superscriptitalic-ϕ′superscript𝑍subscript~ℎ𝑡𝔼superscriptitalic-ϕ′superscript𝑍subscriptℎ𝑡superscript𝑍𝑊𝛿subscript~𝑧𝑡1superscript𝑍𝑣\displaystyle\operatorname*{\mathbb{E}}Z^{v}Z^{\delta\tilde{x}_{t+1}}=-\mathring{\chi}_{t}\operatorname*{\mathbb{E}}\big{[}Z^{z_{t}}Z^{\tilde{z}}\big{]}\operatorname*{\mathbb{E}}\big{[}\phi^{\prime}(Z^{h_{t}})\phi^{\prime}(Z^{\tilde{h}_{t}})\big{]}-\operatorname*{\mathbb{E}}\big{[}\phi^{\prime}(Z^{h_{t}})Z^{W\delta\tilde{z}_{t+1}}Z^{v}\big{]}. (57) To compute the second term of the RHS of Eq. 57, we use Section 3.2, reproduced below. \claimZdot*

Applying Section 3.2 to get the expression for ZWδz~t+1Z^{W\delta\tilde{z}_{t+1}}:

𝑡1\displaystyle Z^{W\delta\tilde{z}_{t+1}} =G+ZdhtEZδz~t+1ZWdht\displaystyle=G+Z^{dh_{t}}\operatorname*{\mathbb{E}}\frac{\partial Z^{\delta\tilde{z}_{t+1}}}{\partial Z^{W^{\top}dh_{t}}} (58) \displaystyle=G-\phi^{\prime}(Z^{h_{t}})Z^{v}\mathring{\chi}_{t}\xi_{t}^{\top}\tilde{\xi}\operatorname*{\mathbb{E}}\big{[}\phi^{\prime}(Z^{g_{t}})\phi^{\prime}(Z^{\tilde{g_{t}}})\big{]} (59) As before, for h{g,z,h,x}h\in\{g,z,h,x\} it holds that ht(ξ)=h(ξ)+s=1tδhs(ξ)n=h(ξ)+o(1)h_{t}(\xi)=h(\xi)+\sum_{s=1}^{t}\frac{\delta h_{s}(\xi)}{\sqrt{n}}=h(\xi)+o(1), and Zht(ξ)=Zh(ξ)Z^{h_{t}(\xi)}=Z^{h(\xi)}. Plugging Eq. 58 into Eq. 57 yields Section 3.

Appendix B Tensor Programs: the Formal Version

We briefly review the formal definition of Tensor Programs below, but readers needing more explanation and intuition should see (Yang, 2020b). We will directly describe \textscNetsor+\textsc{Netsor}\top^{+} programs, which generalizes \textscNetsor\textsc{Netsor}\top.

A \textscNetsor+\textsc{Netsor}\top^{+} program is a sequence of Rn\mathbb{R}^{n}-vectors and R\mathbb{R}-scalars inductively generated via one of the following ways from an initial set C\mathcal{C} of random scalars, V\mathcal{V} of random Rn\mathbb{R}^{n} vectors, and a set W\mathcal{W} of random Rn×n\mathbb{R}^{n\times n} matrices (which will be sampled with iid Gaussian entries in B.2)

Given ϕ:Rk×RlR\phi:\mathbb{R}^{k}\times\mathbb{R}^{l}\to\mathbb{R}, previous scalars θ1,,θlR\theta_{1},\ldots,\theta_{l}\in\mathbb{R} and vectors x1,,xkRnx^{1},\ldots,x^{k}\in\mathbb{R}^{n}, we can generate a new vector

where ψ(;θ1,,θl)\psi(-;\theta_{1},\ldots,\theta_{l}) applies coordinatewise to each “α\alpha-slice” (xα1,,xαk)(x_{\alpha}^{1},\ldots,x_{\alpha}^{k}).

Given same setup as above, we can also generate a new scalar

A \textscNetsor\textsc{Netsor}\top program is just a \textscNetsor+\textsc{Netsor}\top^{+} program without scalars, without the usage of Moment, and without parameters θ1,,θl\theta_{1},\ldots,\theta_{l} in Nonlin+.

We will typically randomly sample the initial matrices, vectors, and scalars of the program as follows.

1) For each initial WWW\in\mathcal{W}, we sample iid WαβN(0,σW2/n)W_{\alpha\beta}\sim\mathcal{N}(0,\sigma_{W}^{2}/n) for some variance σW2\sigma_{W}^{2} associated to WW, independent of other WWW^{\prime}\in\mathcal{W}; 2) for some multivariate Gaussian ZV={Zh:hV}RVZ^{\mathcal{V}}=\left\{Z^{h}:h\in\mathcal{V}\right\}\in\mathbb{R}^{\mathcal{V}}, we sample the initial set of vectors V\mathcal{V} like {hα:hV}ZV\left\{h_{\alpha}:h\in\mathcal{V}\right\}\sim Z^{\mathcal{V}} iid for each α[n]\alpha\in[n]. 3) For each initial scalar θC\theta\in\mathcal{C}, we require θa.s.θ˚\theta\xrightarrow{\mathrm{\mathrm{a.s.}}}\mathring{\theta} for some deterministic θ˚R\mathring{\theta}\in\mathbb{R}.

The following constructs a random variable ZhZ^{h} for every vector hh and a deterministic scalar θ˚\mathring{\theta} for every scalar θ\theta in the program. The interpretation is that hh will have iid coordinates distributed like ZhZ^{h}, and θ\theta will converge to θ˚\mathring{\theta} as nn\to\infty.

Given a \textscNetsor+\textsc{Netsor}\top^{+} program, we recursively define ZhZ^{h} for each vector hh and θ˚\mathring{\theta} for each scalar θ\theta as follows.

If hVh\in\mathcal{V}, then ZhZ^{h} is defined as in B.2. We also set \hat{Z}^{h}\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}Z^{h} and \dot{Z}^{h}\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}0.

Given ψ:Rk×RlR\psi:\mathbb{R}^{k}\times\mathbb{R}^{l}\to\mathbb{R}, previous scalars θ1,,θlR\theta_{1},\ldots,\theta_{l}\in\mathbb{R} and vectors x1,,xkRnx^{1},\ldots,x^{k}\in\mathbb{R}^{n}, we have

Given same setup as above and scalar θ=1nα=1nψ(xα1,,xαk;θ1,,θl)\theta=\frac{1}{n}\sum_{\alpha=1}^{n}\psi(x_{\alpha}^{1},\ldots,x_{\alpha}^{k};\theta_{1},\ldots,\theta_{l}), then

Here θ˚1,,θ˚l\mathring{\theta}_{1},\ldots,\mathring{\theta}_{l} are deterministic, so the expectation is taken over Zx1,,ZxkZ^{x^{1}},\ldots,Z^{x^{k}}.

Z^{Wx}\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\hat{Z}^{Wx}+\dot{Z}^{Wx} for every matrix WW (with N(0,σW2/n)\mathcal{N}(0,\sigma_{W}^{2}/n) entries) and vector xx, where

Z^Wx\hat{Z}^{Wx} is a Gaussian variable with zero mean. Let VW\mathcal{V}_{W} denote the set of all vectors in the program of the form WyWy for some yy. Then {Z^Wy:WyVW}\{\hat{Z}^{Wy}:Wy\in\mathcal{V}_{W}\} is defined to be jointly Gaussian with zero mean and covariance

Furthermore, {Z^Wy:WyVW}\{\hat{Z}^{Wy}:Wy\in\mathcal{V}_{W}\} is mutually independent from {Z^v:vVWˉWVWˉ}\{\hat{Z}^{v}:v\in\mathcal{V}\cup\bigcup_{\bar{W}\neq W}\mathcal{V}_{\bar{W}}\}, where Wˉ\bar{W} ranges over W{A:AW}\mathcal{W}\cup\{A^{\top}:A\in\mathcal{W}\}.

We can always unwind Zx=Φ()Z^{x}=\Phi(\cdots), for some arguments ()=({Z^Wyi}i=1k,{Z^zi}i=1j;{θ˚i}i=1l)(\cdots)=(\{\hat{Z}^{W^{\top}y^{i}}\}_{i=1}^{k},\{\hat{Z}^{z^{i}}\}_{i=1}^{j};\{\mathring{\theta}_{i}\}_{i=1}^{l}), zi∉VWz^{i}\not\in\mathcal{V}_{W^{\top}} (where VW\mathcal{V}_{W^{\top}} is defined in ZHat), and deterministic function Φ:Rk+j+lR\Phi:\mathbb{R}^{k+j+l}\to\mathbb{R}. Define \partial Z^{x}/\partial\hat{Z}^{W^{\top}y^{i}}\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\partial_{i}\Phi(\cdots). Then we set

superscript𝑍𝑥superscript^𝑍superscript𝑊topsuperscript𝑦𝑖\dot{Z}^{Wx}\mathrel{\raisebox{-1.29167pt}{=def\mathbin{\overset{\text{\tiny{def}}}{=}}}}\sigma_{W}^{2}\sum_{i=1}^{k}Z^{y^{i}}\operatorname*{\mathbb{E}}\frac{\partial Z^{x}}{\partial\hat{Z}^{W^{\top}y^{i}}}, (65) There is some nuance in this definition, so see B.5 and B.6.

The following theorem ties the symbolic nature of the ZZs to the analytic nature of a Tensor Program.

\textsc{Netsor}\top^{+} Master Theorem, c.f. Theorem E.15 of (Yang, 2020b)). Fix a Tensor Program initialized accordingly to B.2. Adopt B.8. Then

For any fixed kk and any pseudo-Lipschitz ψ:RkR\psi:\mathbb{R}^{k}\to\mathbb{R}, as nn\to\infty,

for any vectors h1,,hkh^{1},\ldots,h^{k} in the program, where ZhiZ^{h^{i}} are as defined in LABEL:{defn:netsortplusKeyIntuit}.

Any scalar θ\theta in the program tends to θ˚\mathring{\theta} almost surely, where θ˚\mathring{\theta} is as defined in LABEL:{defn:netsortplusKeyIntuit}.

The partial derivative in ZDot should be interpreted as follows. By a simple inductive argument, ZxZ^{x} for every vector xx in the program is defined uniquely as a deterministic function φ(Z^x1,,Z^xk)\varphi(\hat{Z}^{x^{1}},\ldots,\hat{Z}^{x^{k}}) of some x1,,xkx^{1},\ldots,x^{k} in V\mathcal{V} or introduced by MatMul (notationally, we are suppressing the possible dependence on limit scalars θ˚1,,θ˚l\mathring{\theta}_{1},\ldots,\mathring{\theta}_{l}). For instance, if in a program we have AW,vVA\in\mathcal{W},v\in\mathcal{V}, y=Av,x=Ayy=Av,x=A^{\top}y, then Zx=Z^x+Z^vZ^{x}=\hat{Z}^{x}+\hat{Z}^{v}, so φ\varphi is given by φ(a,b)=a+b\varphi(a,b)=a+b. Then

superscript𝑍𝑥superscript^𝑍superscript𝑥𝑖subscript𝑖𝜑superscript^𝑍superscript𝑥1…superscript^𝑍superscript𝑥𝑘anddefsuperscript𝑍𝑥superscript^𝑍𝑧0 for any 𝑧superscript𝑥1…superscript𝑥𝑘\partial Z^{x}/\partial\hat{Z}^{x^{i}}\mathrel{\raisebox{-1.29167pt}{=def\mathbin{\overset{\text{\tiny{def}}}{=}}}}\partial_{i}\varphi(\hat{Z}^{x^{1}},\ldots,\hat{Z}^{x^{k}}),\quad\text{and}\quad\partial Z^{x}/\partial\hat{Z}^{z}\mathrel{\raisebox{-1.29167pt}{=def\mathbin{\overset{\text{\tiny{def}}}{=}}}}0\text{ for any }z\not\in\{x^{1},\ldots,x^{k}\}. Note this definition depends on the precise way the program is written, not just on the underlying mathematics. For example, if y,zVy,z\in\mathcal{V} and x=ϕ(W(y+z))x=\phi(W(y+z)), then Zx=ϕ(Z^W(y+z))Z^{x}=\phi(\hat{Z}^{W(y+z)}) so that Zx/Z^Wy=Zx/Z^Wz=0\partial Z^{x}/\partial\hat{Z}^{Wy}=\partial Z^{x}/\partial\hat{Z}^{Wz}=0. If instead, we have x=ϕ(Wy+Wz)x=\phi(Wy+Wz), then Zx=ϕ(Z^Wy+Z^Wz)Z^{x}=\phi(\hat{Z}^{Wy}+\hat{Z}^{Wz}) so that Zx/Z^W(x+y)=0\partial Z^{x}/\partial\hat{Z}^{W(x+y)}=0. However, in both cases, Z˙Wx=(Zy+Zz)Eϕ(Z^W(y+z))\dot{Z}^{W^{\top}x}=(Z^{y}+Z^{z})\operatorname*{\mathbb{E}}\phi^{\prime}(\hat{Z}^{W(y+z)}).

The quantity EZxZ^Wy\operatorname*{\mathbb{E}}\frac{\partial Z^{x}}{\partial\hat{Z}^{W^{\top}y}} is well defined if ZxZ^{x} is differentiable in Z^Wy\hat{Z}^{W^{\top}y}. However, even if this is not the case, e.g. if x=θ(Wy)x=\theta(W^{\top}y) where θ\theta is the Heavyside step function, we can still define this expectation by leveraging Stein’s lemma:

In ZDot, suppose {Wyi}i=1k\{W^{\top}y^{i}\}_{i=1}^{k} are all elements of VW\mathcal{V}_{W^{\top}} introduced before xx. Define the matrix CRk×kC\in\mathbb{R}^{k\times k} by C_{ij}\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\operatorname*{\mathbb{E}}Z^{y^{i}}Z^{y^{j}} and define the vector bRkb\in\mathbb{R}^{k} by b_{i}\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\operatorname*{\mathbb{E}}\hat{Z}^{W^{\top}y^{i}}Z^{x}. If a=C+ba=C^{+}b (where C+C^{+} denotes the pseudoinverse of CC), then in ZDot we may set

superscript𝑍𝑥superscript^𝑍superscript𝑊topsuperscript𝑦𝑖subscript𝑎𝑖\sigma_{W}^{2}\operatorname*{\mathbb{E}}\frac{\partial Z^{x}}{\partial\hat{Z}^{W^{\top}y^{i}}}=a_{i}. (67) This definition agrees with the partial derivative expectation by Stein’s lemma when the latter is well defined. B.4 holds with this broader definition of partial derivative expectation.

are, roughly speaking, functions whose weak derivatives are polynomially bounded.

A function f:RkRf:\mathbb{R}^{k}\to\mathbb{R} is called pseudo-Lipschitz of degree dd if f(x)f(y)Cxy(1+i=1kxid+yid)|f(x)-f(y)|\leq C\|x-y\|(1+\sum_{i=1}^{k}|x_{i}|^{d}+|y_{i}|^{d}) for some CC. We say ff is pseudo-Lipschitz if it is so for any degree.

Here are some basic properties of pseudo-Lipschitz functions:

The norm \|\cdot\| in B.7 can be any norm equivalent to the 2\ell_{2} norm, e.g. p,p1,\ell_{p},p\geq 1, norms. Similarly, i=1kxid+yid\sum_{i=1}^{k}|x_{i}|^{d}+|y_{i}|^{d} can be replaced by xpd+ypd\|x\|^{d}_{p}+\|y\|^{d}_{p}, for any p1p\geq 1.

A pseudo-Lipschitz function is polynomially bounded.

A composition of pseudo-Lipschitz functions of degrees d1d_{1} and d2d_{2} is pseudo-Lipschitz of degree d1+d2d_{1}+d_{2}.

A pseudo-Lipschitz function is Lipschitz on any compact set.

We adopt the following assumption for the Master Theorem B.4.

If a function ϕ(;):R0+lR\phi(;-):\mathbb{R}^{0+l}\to\mathbb{R} with only parameter arguments is used in Moment, then ϕ\phi is continuous in those arguments.

Any other function ϕ(;):Rk+lR\phi(-;-):\mathbb{R}^{k+l}\to\mathbb{R} with parameters (where k>0k>0) used in Nonlin or Moment is pseudo-Lipschitz in all of its arguments (both inputs and parameters).

Statement 1 in B.8 essentially says that if we have scalars θ1,,θl\theta_{1},\ldots,\theta_{l} in the program, then we can produce a new scalar by applying a continuous function (a weaker restriction than a pseudo-Lipschitz function) to them. Indeed, if θ1,,θl\theta_{1},\ldots,\theta_{l} converge almost surely, then this new scalar does too. In our setting, statement 1 is used to allow any loss function whose derivative is continuous.

Other versions of the Master Theorem can be found in (Yang, 2020b), for example, versions where the we do not assume any smoothness condition at all on the nonlinearities beyond that they be polynomially bounded, in exchange for assuming what’s called a rank stability condition. This rank stability should be generically true, but checking it rigorously is subtle, so we are content with the pseudo-Lipschitz condition in this paper.

Appendix C More Diagrams

We can augment the graphical form of \textscNetsor\textsc{Netsor}\top to accomodate the Moment instruction in \textscNetsor+\textsc{Netsor}\top^{+}. See Fig. 4 for an example for layernorm and attention. In short, we denote scalar variables with a square, in contrast to the circle for vector variables, and we use a “bar-gate” to denote the Moment, where the function in the gate corresponds to ψ\psi in Moment.

In addition, for more examples of the expressivity of \textscNetsor\textsc{Netsor}\top, Figs. 5 and 6 demonstrate convolution and MLP backpropagation in \textscNetsor\textsc{Netsor}\top.

Appendix D Proof of Main Result

We dedicate the following section to prove 5.3. We will begin by proving a simplified version under the same assumptions as Section 5.1, as reproduced below:

Suppose a neural network fRf\in\mathbb{R} is represented by a \textscNetsor\textsc{Netsor}\top program (in the sense of 4.2) whose Nonlin all have polynomially bounded derivatives.272727More generally, we can allow any pseudo-Lipschitz function here, but for simplicity we go with the statement in the main text. Adopt the NTK parametrization: for every matrix parameter WRn×nW\in\mathbb{R}^{n\times n} of ff, we factor W=1nwW=\frac{1}{\sqrt{n}}w where ww is the trainable parameter; likewise, for each input layer matrix UiRn×dU^{i}\in\mathbb{R}^{n\times d}, we factor Ui=1duiU^{i}=\frac{1}{\sqrt{d}}u^{i}, and likewise the output matrix V=1nvV=\frac{1}{\sqrt{n}}v. We randomly initialize all trainable parameters iid as N(0,1)\mathcal{N}(0,1). Furthermore, we assume the following:

Input and output layers {ui},v\{u^{i}\},v, as well as biases are not trained (only Rn×n\mathbb{R}^{n\times n} weight matrices are trained).

The forward pass does not use both a matrix and its transpose (in different MatMuls).

We assume the last layer embedding is a G-var.

Our main result is to show that the SGD training of such a neural network described in 5.2 reduces to kernel gradient descent with kernel K˚\mathring{\mathcal{K}} in the infinite-width limit.

Consider training a network ff described in D.1 via SGD with batch-size 1 and (WLOG) learning rate 1. Let ξt\xi_{t} be the input and Lt:RR\mathcal{L}_{t}:\mathbb{R}\to\mathbb{R} be the loss function (absorbing the label) at time tt. Suppose Lt\mathcal{L}_{t} is continuous for all tt. Then, for any ξ\xi and tt, ft(ξ)f_{t}(\xi) converges almost surely to a random variable f˚t(ξ)\mathring{f}_{t}(\xi) as width \to\infty, such that

𝑡1𝜉subscript̊𝑓𝑡𝜉̊𝒦𝜉subscript𝜉𝑡superscriptsubscriptℒ𝑡′subscript̊𝑓𝑡subscript𝜉𝑡\displaystyle\mathring{f}_{t+1}(\xi)-\mathring{f}_{t}(\xi)=-\mathring{\mathcal{K}}(\xi,\xi_{t})\mathcal{L}_{t}^{\prime}(\mathring{f}_{t}(\xi_{t})) (68) where K˚\mathring{\mathcal{K}} is the infinite-width NTK (at initialization) of the neural network.

\textsc{Netsor}\top^{+} Program SGD is comprised of a sequence of forward and backward passes computed on some architecture. WLOG, let π0\pi_{0} denote the reduced program implementing the body of network ff, and let x(ξ)x(\xi) denote the final embedding such that f(ξ)=Vx(ξ)f(\xi)=V^{\top}x(\xi), we will now show how the SGD procedure on π0\pi_{0} can be implemented by a \textscNetsor+\textsc{Netsor}\top^{+} program.

While π0\pi_{0} implements the embeddings x(ξ)x(\xi) by definition, the outputs f(ξ)f(\xi) cannot be implemented trivially in a program since that at initialization f(ξ)=vx(ξ)nf(\xi)=\frac{v^{\top}x(\xi)}{\sqrt{n}} is not deterministic, and converges non-trivially to a GP, violating the requirements of a scalar type in a \textscNetsor+\textsc{Netsor}\top^{+} program which require all scalar types to converge to a deterministic limit as nn\to\infty. Nevertheless, we can still easily express evolution of ff conditioned on (i.e. fixing) the values of ff at initialization. More formally, let f=[f(ξ0),f(ξ1),...,f(ξD1)]RD\mathtt{f}=[f(\xi_{0}),f(\xi_{1}),...,f(\xi_{D-1})]^{\top}\in\mathbb{R}^{D} denote a fixed vector of outputs, and let X=[x(ξ0),x(ξ1),...,x(ξD1)]Rn×DX=[x(\xi_{0}),x(\xi_{1}),...,x(\xi_{D-1})]^{\top}\in\mathbb{R}^{n\times D} denote a fixed embedding matrix such that f=Xvn\mathtt{f}=\frac{X^{\top}v}{\sqrt{n}}. The distribution of vv when conditioned on f\mathtt{f} and XX is given by (see e.g. (Yang, 2020b, Sec K.2))

𝑣subscriptd𝚏𝑋𝑛superscript𝑋𝚏Π𝚟v\overset{\mathrm{d}}{=}_{\mathtt{f},X}\sqrt{n}X^{+}\mathtt{f}+\Pi\mathtt{v} (69) where X+X^{+} is the pseudo-inverse of XX, v\mathtt{v} is an independent copy of vv and Π\Pi is the projection operator projecting unto the orthogonal complement of the space spanned by XX. Namely:

1𝑛𝑋superscriptsuperscript𝑋top𝑋𝑛Π𝐼superscript𝑋superscript𝑋top\displaystyle X^{+}=\frac{1}{n}X(\frac{X^{\top}X}{n})^{+},\quad\Pi=I-X^{+}X^{\top} (70) Denote Σ=XXnRD×D,μ=XvnRD\Sigma=\frac{X^{\top}X}{n}\in\mathbb{R}^{D\times D},\mu=\frac{X^{\top}v}{n}\in\mathbb{R}^{D}. Define

𝑋superscriptΣ𝚏𝑛𝚟𝑋superscriptΣ𝜇\displaystyle\hat{v}\mathrel{\raisebox{-1.29167pt}{=def\mathbin{\overset{\text{\tiny{def}}}{=}}}}X(\frac{\Sigma^{+}\mathtt{f}}{\sqrt{n}})+\mathtt{v}-X\Sigma^{+}\mu. (71) Then we see via Eq. 69 that

Given v\mathtt{v} and (the columns of) XX as vectors and f\mathtt{f} as scalars in a program, v^\hat{v} may be defined in the same program via Nonlin, where Σ+fn\frac{\Sigma^{+}\mathtt{f}}{\sqrt{n}} and Σ+μ\Sigma^{+}\mu (both finite-dimensional) provide coefficients for the linear combination over (columns of) XX. Formally, to express the evolution of ff conditioned on f0=ff_{0}=\mathtt{f} at initialization, the program will calculate the first forward pass up to XX, calculate the loss derivatives χ\chi assuming f0=ff_{0}=\mathtt{f}, and then proceed with the backward pass and later forward/backward passes with vv replaced by v^\hat{v}.

However, since Σ+fn,μa.s.0\frac{\Sigma^{+}\mathtt{f}}{\sqrt{n}},\mu\xrightarrow{\mathrm{\mathrm{a.s.}}}0 and Σ+a.s.Σ˚+\Sigma^{+}\xrightarrow{\mathrm{\mathrm{a.s.}}}\mathring{\Sigma}^{+} (by rank stability, c.f. (Yang, 2020b, Lemma L.11)), these coefficients of the linear combination converge to 0, so that Zv^=ZvZ^{\hat{v}}=Z^{\mathtt{v}}. Intuitively, this means that the distribution of vv conditioned on the equality f=Xv/n\mathtt{f}=X^{\top}v/\sqrt{n} is asymptotically the same as no conditioning as nn\to\infty. Thus, for the limit calculation of δft\delta f_{t} and other quantities, it ends up not mattering whether we use v^\hat{v} or vv.

The loss derivative χ(ξ)=L(f(ξ))(f(ξ))\chi(\xi)=\frac{\partial\mathcal{L}(f(\xi))}{\partial(f(\xi))} after the first forward pass given f(ξ)f(\xi) can be implemented with Moment instructions using ψ(;f(ξ))=L(f(ξ))\psi(;f(\xi))=\mathcal{L}^{\prime}(f(\xi)).

D.1.2 Implementing SGD

Under SGD, the update at step t+1t+1 to any weight wRn×nw\in\mathbb{R}^{n\times n} is given by:

𝑡1subscript𝑤𝑡\displaystyle w_{t+1}-w_{t} =χtg,h:g=Whdgthtn.\displaystyle=-\chi_{t}\sum_{\mathtt{g},\mathtt{h}:\mathtt{g}=W\mathtt{h}}\frac{d\mathtt{g}_{t}\mathtt{h}_{t}^{\top}}{n}. (73) where the summation in Eq. 73 is over all pairs of vectors g,h\mathtt{g},\mathtt{h} in program π0\pi_{0} satisfying g=Wh\mathtt{g}=W\mathtt{h} (there can be multiple such pairs since π0\pi_{0} may reuse the same matrix WW).

To write the full unrolled SGD as a \textscNetsor+\textsc{Netsor}\top^{+} program, we will need to implement the error signal dg_{t}\mathrel{\raisebox{-1.29167pt}{\mathbin{\overset{\text{\tiny{def}}}{=}}}}\sqrt{n}\frac{\partial f_{t}}{\partial g_{t}} for each G-var gg at time tt. To accomplish this, we recall the notion of paths in program π0\pi_{0}: \Paths* Note that a path pp represents a series of nodes independent of an input, and can be instantiated as p(ξ)p(\xi) by an input ξ\xi, resulting in a series of instantiated G-vars and X-vars pi(ξ)p^{i}(\xi).

For any G-var g=Whg=Wh, we can write the error term dgdg as the summation of errors signals over paths pp:

superscript𝑝2superscript𝑝1topsuperscriptsuperscript𝑝2superscript𝑝3top…superscriptsuperscript𝑝1superscript𝑝2top𝑣\displaystyle=(\frac{\partial p^{2}}{\partial p^{1}})^{\top}(\frac{\partial p^{-2}}{\partial p^{-3}})^{\top}...(\frac{\partial p^{-1}}{\partial p^{-2}})^{\top}v (75) (Here again, JpJ^{p} represents a symbolic computation that can be instantiated with an input Jp(ξ)J^{p}(\xi)). Note that JpJ^{p} can be defined recursively:

superscript𝑝2superscript𝑝1topsuperscriptsuperscript𝑝3superscript𝑝2topsuperscript𝐽:𝑝3\displaystyle J^{p}=(\frac{\partial p^{2}}{\partial p^{1}})^{\top}(\frac{\partial p^{3}}{\partial p^{2}})^{\top}J^{p:3} (76) where Jp:k,kpJ_{p:k},k\leq|p| is defined as:

superscript𝑝𝑘1superscript𝑝𝑘topsuperscriptsuperscript𝑝2superscript𝑝3top…superscriptsuperscript𝑝1superscript𝑝2top𝑣𝑘𝑝𝑣𝑘𝑝\displaystyle J^{p:k}\mathrel{\raisebox{-1.29167pt}{=def\mathbin{\overset{\text{\tiny{def}}}{=}}}}\begin{cases}(\frac{\partial p^{k+1}}{\partial p^{k}})^{\top}(\frac{\partial p^{-2}}{\partial p^{-3}})^{\top}...(\frac{\partial p^{-1}}{\partial p^{-2}})^{\top}v&k<|p|\\ v&k=|p|\end{cases} (77) Recall that each path pp starts with an X-var p0p^{0}, and alternates between G and X vars. Let Wp3W^{p^{3}} denote the defining weight matrix of G-var p3p^{3} (i.e p3=Wp3p2p^{3}=W^{p^{3}}p^{2}), and let p2=ψ(...,p1,...)p^{2}=\psi(...,p^{1},...). Then we can re-write Eq. 76 as:

Note that Eq. 78 can be written in \textscNetsor\textsc{Netsor}\top language using MatMul instructions using the transposed weights, and Nonlin instructions using ψ\psi^{\prime}, which is pseudo-Lipschitz by D.1.

Recall that π0\pi_{0} is the program defining the network architecture. We now write the unrolled SGD of this network in a new program π\pi. Below, recall that lack of time subscript means t=0t=0 (e.g. WW means W0W_{0}, the initialized value). In addition, feel free to revisit the notations explained before Appendix A.

𝑡1\displaystyle\delta\tilde{g}_{t+1} =n(Wt+1h~t+1Wth~t)=Wδh~t+1+n(Wt+1Wt)h~t+s=0t(Ws+1Ws)δh~t+1.\displaystyle=\sqrt{n}(W_{t+1}\tilde{h}_{t+1}-W_{t}\tilde{h}_{t})=W\delta\tilde{h}_{t+1}+\sqrt{n}(W_{t+1}-W_{t})\tilde{h}_{t}+\sum_{s=0}^{t}(W_{s+1}-W_{s})\delta\tilde{h}_{t+1}. (79) where, using Eq. 73, we have

𝑡1subscript𝑊𝑡subscript~ℎ𝑡\displaystyle\sqrt{n}(W_{t+1}-W_{t})\tilde{h}_{t} =χtg,h:g=Whdgthth~tn\displaystyle=-\chi_{t}\sum_{\mathtt{g},\mathtt{h}:\mathtt{g}=W\mathtt{h}}d\mathtt{g}_{t}\frac{\mathtt{h}_{t}^{\top}\tilde{h}_{t}}{n} (80) s=0t(Ws+1Ws)δh~t+1\displaystyle\sum_{s=0}^{t}(W_{s+1}-W_{s})\delta\tilde{h}_{t+1} =s=0tχsng,h:g=Whdgshsδh~tn\displaystyle=-\sum_{s=0}^{t}\frac{\chi_{s}}{\sqrt{n}}\sum_{\mathtt{g},\mathtt{h}:\mathtt{g}=W\mathtt{h}}d\mathtt{g}_{s}\frac{\mathtt{h}_{s}^{\top}\delta\tilde{h}_{t}}{n} (81) Tensor Program implementation Eqs. 79, 80 and 81 may be easily implemented using \textscNetsor+\textsc{Netsor}\top^{+} instructions. For instance, Eq. 80 (assuming the sum sums over a single pair {h,g}\{\mathtt{h},\mathtt{g}\}) may be implemented using Moment and Nonlin+ instructions as follows: the term hth~tn\frac{\mathtt{h}_{t}^{\top}\tilde{h}_{t}}{n} may be implemented by a Moment instruction with ψ(h~t,ht)=1nα(h~t)α(ht)α\psi(\tilde{h}_{t},\mathtt{h}_{t})=\frac{1}{n}\sum_{\alpha}(\tilde{h}_{t})_{\alpha}(\mathtt{h}_{t})_{\alpha}. The full term is then a Nonlin+instructions ψ(dgt;χt,{hth~tn}h)\psi(d\mathtt{g}_{t};\chi_{t},\{\frac{\mathtt{h}_{t}^{\top}\tilde{h}_{t}}{n}\}_{\mathtt{h}}) with scalars χt,{hth~tn}h\chi_{t},\{\frac{\mathtt{h}_{t}^{\top}\tilde{h}_{t}}{n}\}_{\mathtt{h}} and vector dgtd\mathtt{g}_{t}.

Eq. 82 may be implemented as a Nonlin+ instruction:

𝑡1superscript𝜓⋆superscriptsubscriptsuperscriptsubscript~ℎ𝑡𝑖𝑖1𝑘superscriptsubscript𝛿superscriptsubscript~ℎ𝑡1𝑖𝑖1𝑘1𝑛\displaystyle\delta\tilde{g}_{t+1}:=\psi^{\star}\left(\{\tilde{h}_{t}^{i}\}_{i=1}^{k}\cup\{\delta\tilde{h}_{t+1}^{i}\}_{i=1}^{k};\frac{1}{\sqrt{n}}\right) (83) for a set of vectors {h~ti}i=1k,{δh~t+1i}i=1k\{\tilde{h}_{t}^{i}\}_{i=1}^{k},\{\delta\tilde{h}_{t+1}^{i}\}_{i=1}^{k} and a scalar 1n\frac{1}{\sqrt{n}}, where:

superscript𝜇1𝜃superscript𝜈1…superscript𝜇𝑘𝜃superscript𝜈𝑘𝛼𝜓subscriptsuperscript𝜇1…superscript𝜇𝑘𝛼𝜃𝜃0superscriptsubscript𝑖1𝑘𝜓subscriptsuperscript𝜇1…superscript𝜇𝑘𝛼subscriptsuperscript𝜇𝑖𝛼subscriptsuperscript𝜈𝑖𝛼𝜃0\displaystyle\psi^{\star}(\{\mu^{i}\}_{i=1}^{k}\cup\{\nu^{i}\}_{i=1}^{k};\theta)_{\alpha}\mathrel{\raisebox{-1.29167pt}{=def\mathbin{\overset{\text{\tiny{def}}}{=}}}}\begin{cases}\frac{\psi(\mu^{1}+\theta\nu^{1},...,\mu^{k}+\theta\nu^{k})_{\alpha}-\psi(\mu^{1},...,\mu^{k})_{\alpha}}{\theta}&\theta>0\\ \sum_{i=1}^{k}\frac{\partial\psi(\mu^{1},...,\mu^{k})_{\alpha}}{\partial\mu^{i}_{\alpha}}\nu^{i}_{\alpha}&\theta=0.\end{cases} (84) Since ψ\psi^{\prime} is pseudo-Lipschitz by D.1, ψ\psi^{\star} is pseudo-Lipschitz in all of its inputs as well.

the scalar type outputs ft(ξ)f_{t}(\xi) at t>0t>0 for any input ξ\xi can be implemented using the Moment instruction. The loss derivative χt,t>0\chi_{t},t>0 given ftf_{t} can be implemented with Moment instructions using \psi(-;f(\xi))=\mathcal{L}^{\prime}\big{(}f(\xi)\big{)} where f(ξ)f(\xi) is treated as a scalar type as in the first forward pass.

g~t+1\tilde{g}_{t+1} is implemented using a Nonlin+ instruction g_{t+1}(\xi)=\psi\big{(}g(\xi)\cup\{\delta g_{s}(\xi)\}_{s=1}^{t+1};\frac{1}{\sqrt{n}}\big{)} with ψ(μ{νs}s=1t+1;θ)=μ+θsνs\psi(\mu\cup\{\nu_{s}\}_{s=1}^{t+1};\theta)=\mu+\theta\sum_{s}\nu_{s}.

dg(ξ)t+1dg(\xi)_{t+1} is implemented using Moment and Nonlin+ instructions.

According to the \textscNetsor+\textsc{Netsor}\top^{+} rules as specified in B.3, we have the following identities:

If g=Whg=Wh, then using Eqs. 79, 80 and 81: (Here Zdgt=Zdgt(ξt)Z^{d\mathtt{g}_{t}}=Z^{d\mathtt{g}_{t}(\xi_{t})}, Zht=Zht(ξt)Z^{\mathtt{h}_{t}}=Z^{\mathtt{h}_{t}(\xi_{t})}, and Zh~=Zht(ξ~)Z^{\tilde{h}}=Z^{h_{t}(\tilde{\xi})})

𝑡1\displaystyle Z^{\delta\tilde{g}_{t+1}} \displaystyle=Z^{W\delta\tilde{h}_{t+1}}-\chi_{t}\sum_{\mathtt{g},\mathtt{h}:\mathtt{g}=W\mathtt{h}}Z^{d\mathtt{g}_{t}}\operatorname*{\mathbb{E}}\big{[}Z^{\mathtt{h}_{t}}Z^{\tilde{h}}\big{]} (91) whereZWδh~t+1\displaystyle\text{where}\quad Z^{W\delta\tilde{h}_{t+1}} =Z^Wδh~t+1+yZy~EZδh~t+1Z^Wy~.\displaystyle=\hat{Z}^{W\delta\tilde{h}_{t+1}}+\sum_{y}Z^{\tilde{y}}\operatorname*{\mathbb{E}}\frac{\partial Z^{\delta\tilde{h}_{t+1}}}{\partial\hat{Z}^{W^{\top}\tilde{y}}}. (92) • If g=ψ(h1,...,hk)g=\psi(h^{1},...,h^{k}), then using Eqs. 82 and 84, taking the limit 1/n01/\sqrt{n}\to 0,

𝑡1superscriptsubscript𝑖1𝑘𝜓superscript𝑍superscript~ℎ1…superscript𝑍superscript~ℎ𝑘superscript𝑍superscript~ℎ𝑖superscript𝑍𝛿subscriptsuperscript~ℎ𝑖𝑡1\displaystyle Z^{\delta\tilde{g}_{t+1}}=\sum_{i=1}^{k}\frac{\partial\psi(Z^{\tilde{h}^{1}},...,Z^{\tilde{h}^{k}})}{\partial Z^{\tilde{h}^{i}}}Z^{\delta\tilde{h}^{i}_{t+1}}. (93) • Using Eqs. 86, 87, 89 and 90 and taking 1n0\frac{1}{\sqrt{n}}\to 0, we have by ZNonlin+:

D.2 Deriving The NTK

Instantiate paths pp and qq on two inputs ξ,ξ\xi,\xi^{\prime} by p=p(ξ),q=q(ξ)p=p(\xi),q=q(\xi^{\prime}) (abusing notation slightly). We define an inner product between them as follows:

superscript𝑍superscript𝑝𝑖superscript𝑍superscript𝑝𝑖1superscript𝑍superscript𝑞𝑖superscript𝑍superscript𝑞𝑖1\displaystyle\big{\langle}p,q\big{\rangle}\mathrel{\raisebox{-1.29167pt}{=def\mathbin{\overset{\text{\tiny{def}}}{=}}}}\operatorname*{\mathbb{E}}\big{[}Z^{p^{0}}Z^{q^{0}}\big{]}\prod_{i=2,even}^{|p|-2}\operatorname*{\mathbb{E}}\big{[}\frac{\partial Z^{p^{i}}}{\partial Z^{p^{i-1}}}\frac{\partial Z^{q^{i}}}{\partial Z^{q^{i-1}}}\big{]}. (95) where pi,qip^{i},q^{i} are X-vars for all even ii. Note that for even ii, pip^{i} is always of the form pi=ψ(....,pi1,...)p^{i}=\psi(....,p^{i-1},...) for some ψ\psi. So the partial derivatives in Eq. 95 are just ZpiZpi1(ξ~)=ψ(...,Zpi1,...)\frac{\partial Z^{p^{i}}}{\partial Z^{p^{i-1}(\tilde{\xi})}}=\psi^{\prime}(...,Z^{p^{i-1}},...).

For each weight WWW\in\mathcal{W}, the gradient of the output with respect to ww is given by:

Here, g,hg,h represent nodes in program π0\pi_{0} that can be instantiated by an input ξ\xi. The NTK of ff can be expressed as:

Using Eqs. 74 and 78, for any G-var g=Whg=Wh, we can write the error term dgdg as the summation of errors signals over paths pp:

superscript𝑍superscript𝑝2superscript𝑍superscript𝑝1\displaystyle=Z^{(W^{p^{3}})^{\top}J^{p:3}}\frac{\partial{Z^{p^{2}}}}{\partial Z^{p^{1}}} (103) where ψ\psi^{\prime} denotes the derivative w.r.t. p1p^{1}. By Simple GIA Check (Yang, 2020a), we have that Z(Wp3)Jp:3=Z^(Wp3)Jp:3Z^{(W^{p^{3}})^{\top}J^{p:3}}=\hat{Z}^{(W^{p^{3}})^{\top}J^{p:3}} (see ZMatMul). Hence, with abuse of notation Jp=Jp(ξ),Jq=Jq(ξ~),p=p(ξ),q=q(ξ~)J^{p}=J^{p}(\xi),J^{q}=J^{q}(\tilde{\xi}),p=p(\xi),q=q(\tilde{\xi}), we have

superscript𝑍superscript𝑝2superscript𝑍superscript𝑝1superscript𝑍superscript𝑞2superscript𝑍superscript𝑞1\displaystyle\operatorname*{\mathbb{E}}\big{[}Z^{J^{p}}Z^{J^{q}}\big{]}=\operatorname*{\mathbb{E}}[\hat{Z}^{(W^{p^{3}})^{\top}J^{p:3}}\hat{Z}^{(W^{q^{3}})^{\top}J^{q:3}}]\operatorname*{\mathbb{E}}[\frac{\partial{Z^{p^{2}}}}{\partial Z^{p^{1}}}\frac{\partial{Z^{q^{2}}}}{\partial Z^{q^{1}}}]. (104) From the definition of ZHat, the expectation E[Z^(Wp3)Jp:3Z^(Wq3)Jq:3]\operatorname*{\mathbb{E}}[\hat{Z}^{(W^{p^{3}})^{\top}J^{p:3}}\hat{Z}^{(W^{q^{3}})^{\top}J^{q:3}}] vanishes if the weights Wp3W^{p^{3}} and Wq3W^{q^{3}} are not symbolically the same (i.e Wp3Wq3W^{p^{3}}\triangleq W^{q^{3}}). Then by ZHat,

superscript𝑍superscript𝑝𝑖superscript𝑍superscript𝑝𝑖1superscript𝑍superscript𝑞𝑖superscript𝑍superscript𝑞𝑖1if p≅q0otherwise.\displaystyle\operatorname*{\mathbb{E}}\big{[}Z^{J^{p}}Z^{J^{q}}\big{]}=\begin{cases}\prod_{i=2,even}^{|p|-2}\operatorname*{\mathbb{E}}\big{[}\frac{\partial Z^{p^{i}}}{\partial Z^{p^{i-1}}}\frac{\partial Z^{q^{i}}}{\partial Z^{q^{i-1}}}\big{]}&\text{if pqp\cong q}\\ 0&\text{otherwise.}\end{cases} (106) Combining with Eqs. 100, 74 and 102 proves D.3.

D.3 Getting Section 3

For the remainder of the proof we abbreviate p=p(ξ~),q=q(ξt)p=p(\tilde{\xi}),q=q(\xi_{t}), pi=pi(ξ~),qi=qi(ξt)p^{i}=p^{i}(\tilde{\xi}),q^{i}=q^{i}(\xi_{t}) (i.e path pp is always evaluated on ξ~\tilde{\xi}, while path qq is always evaluated on ξt\xi_{t}). We prove Section 3 by inducting on all G-vars in the network. We begin by proving the following induction hypothesis.

We write ZxZymodZ^WZ^{x}\equiv Z^{y}\mod\hat{Z}^{W\bullet} to denote that ZxZyZ^{x}-Z^{y} is a linear combination of Z^Wu\hat{Z}^{Wu} for various vectors uu.

At any time tt and G-var g=Whg=Wh, the following holds:

𝑡1modulosubscript𝜒𝑡subscript:𝑝superscript𝑝1𝑔subscript:𝑞𝑞𝑝superscript𝑍𝑑superscript𝑞1𝑝𝑞superscript^𝑍𝑊∙Z^{\delta\tilde{g}_{t+1}}\equiv-\chi_{t}\sum_{p:p^{-1}=g}\sum_{q:q\cong p}Z^{dq^{-1}}\langle p,q\rangle\mod\hat{Z}^{W\bullet} (107) Here, the sum is over all paths pp with endpoint gg and all paths qq isomorphic to pp. Recall that dq1dq^{-1} is the (scaled) gradient dydy where y=q1y=q^{-1} is the endpoint of qq.

D.3.1 Base Case

For initial G-vars gg, δg~t=0\delta\tilde{g}_{t}=0 since we are not training the input layers (Assumption A1.). This proves the base case since the sum in Eq. 107 has no terms and thus is 0.

D.3.2 Inductive case

Suppose g=Whg=Wh, where h=ψ(h1,...,hk)h=\psi(h^{1},...,h^{k}), we then have using Eq. 91:

𝑡1\displaystyle Z^{\delta\tilde{g}_{t+1}} \displaystyle\equiv\dot{Z}^{W\delta\tilde{h}_{t+1}}+\chi_{t}\sum_{\mathtt{g}=W\mathtt{h}}Z^{d\mathtt{g}_{t}}\operatorname*{\mathbb{E}}\big{[}Z^{\mathtt{h}_{t}}Z^{\tilde{h}}\big{]}\mod\hat{Z}^{W\bullet} (108) whereZ˙Wδh~t+1\displaystyle\text{where}\quad\dot{Z}^{W\delta\tilde{h}_{t+1}} =yZyEZδh~t+1Z^Wy.\displaystyle=\sum_{y}Z^{y}\operatorname*{\mathbb{E}}\frac{\partial Z^{\delta\tilde{h}_{t+1}}}{\partial\hat{Z}^{W^{\top}y}}. (109) Note \sum_{\mathtt{g}=W\mathtt{h}}Z^{d\mathtt{g}_{t}}\operatorname*{\mathbb{E}}\big{[}Z^{\mathtt{h}_{t}}Z^{\tilde{h}}\big{]} in Eq. 108 can be written as p:p1=g,p=2qpZdq1p,q\sum_{\mathtt{p}:\mathtt{p}^{-1}=g,|\mathtt{p}|=2}\sum_{\mathtt{q}\cong\mathtt{p}}Z^{d\mathtt{q}^{-1}}\langle\mathtt{p},\mathtt{q}\rangle. Therefore, it suffices to show that

𝑡1subscript𝜒𝑡subscript:𝚙formulae-sequencesuperscript𝚙1𝑔𝚙4subscript𝚚𝚙superscript𝑍𝑑superscript𝚚1𝚙𝚚\dot{Z}^{W\delta\tilde{h}_{t+1}}=-\chi_{t}\sum_{\mathtt{p}:\mathtt{p}^{-1}=g,|\mathtt{p}|\geq 4}\sum_{\mathtt{q}\cong\mathtt{p}}Z^{d\mathtt{q}^{-1}}\langle\mathtt{p},\mathtt{q}\rangle. (110) Showing Eq. 110 By Eq. 82:

𝑡1superscriptsubscript𝑖1𝑘superscript𝑍~ℎsuperscript𝑍superscript~ℎ𝑖superscript𝑍𝛿subscriptsuperscript~ℎ𝑖𝑡1\displaystyle Z^{\delta\tilde{h}_{t+1}}=\sum_{i=1}^{k}\frac{\partial Z^{\tilde{h}}}{\partial Z^{\tilde{h}^{i}}}Z^{\delta\tilde{h}^{i}_{t+1}}. (111) Since Zh~1,...,Zh~kZ^{\tilde{h}^{1}},...,Z^{\tilde{h}^{k}} do not depend on ZWyZ^{W^{\top}y} for any yy (by the assumption that we don’t use both a matrix and its transpose in the forward pass), from Eq. 111 we have for any yy:

superscript𝑍𝛿subscript~ℎ𝑡1superscript^𝑍superscript𝑊top𝑦superscriptsubscript𝑖1𝑘𝔼superscript𝑍~ℎsuperscript𝑍superscript~ℎ𝑖superscript𝑍𝛿subscriptsuperscript~ℎ𝑖𝑡1superscript^𝑍superscript𝑊top𝑦\displaystyle\operatorname*{\mathbb{E}}\big{[}\frac{\partial Z^{\delta\tilde{h}_{t+1}}}{\partial\hat{Z}^{W^{\top}y}}\big{]}=\sum_{i=1}^{k}\operatorname*{\mathbb{E}}\big{[}\frac{\partial Z^{\tilde{h}}}{\partial Z^{\tilde{h}^{i}}}\frac{\partial Z^{\delta\tilde{h}^{i}_{t+1}}}{\partial\hat{Z}^{W^{\top}y}}\big{]}. (112) Applying the induction hypothesis Eq. 107 to each G-var hih^{i}, we get

superscript𝑍𝛿subscriptsuperscript~ℎ𝑖𝑡1superscript^𝑍superscript𝑊top𝑦modulosubscript𝜒𝑡subscript:𝑝superscript𝑝1superscriptℎ𝑖subscript:𝑞𝑞𝑝superscript𝑍𝑑superscript𝑞1superscript^𝑍superscript𝑊top𝑦𝑝𝑞superscript^𝑍𝑊∙\frac{\partial Z^{\delta\tilde{h}^{i}_{t+1}}}{\partial\hat{Z}^{W^{\top}y}}=-\chi_{t}\sum_{p:p^{-1}=h^{i}}\sum_{q:q\cong p}\frac{\partial Z^{dq^{-1}}}{\partial\hat{Z}^{W^{\top}y}}\langle p,q\rangle\mod\hat{Z}^{W\bullet} (113) Plugging this back into Z˙Wδh~t+1\dot{Z}^{W\delta\tilde{h}_{t+1}} (Eq. 109), we get

𝑡1subscript𝜒𝑡subscript𝑦superscript𝑍𝑦superscriptsubscript𝑖1𝑘subscript:𝑝superscript𝑝1superscriptℎ𝑖subscript:𝑞𝑞𝑝𝑝𝑞𝔼superscript𝑍~ℎsuperscript𝑍superscript~ℎ𝑖superscript𝑍𝑑superscript𝑞1superscript^𝑍superscript𝑊top𝑦\displaystyle\dot{Z}^{W\delta\tilde{h}_{t+1}}=-\chi_{t}\sum_{y}Z^{y}\sum_{i=1}^{k}\sum_{p:p^{-1}=h^{i}}\sum_{q:q\cong p}\langle p,q\rangle\operatorname*{\mathbb{E}}\big{[}\frac{\partial Z^{\tilde{h}}}{\partial Z^{\tilde{h}^{i}}}\frac{\partial Z^{dq^{-1}}}{\partial\hat{Z}^{W^{\top}y}}\big{]}. (114) Note that for any path pp with p1=hip^{-1}=h^{i}, we may extend pp by vectors g,hg,h (recall g=Whg=Wh and h=ψ(h1,...,hk)h=\psi(h^{1},...,h^{k})). Let p\mathtt{p} denote this extension. If q\mathtt{q} is a path such that

superscript𝑍𝑑superscript𝑞1superscript^𝑍superscript𝑊top𝑦superscript𝑍superscript𝚚2superscript𝑍superscript𝚚3\mathtt{q}\cong\mathtt{p}\quad\text{and}\quad\frac{\partial Z^{dq^{-1}}}{\partial\hat{Z}^{W^{\top}y}}=\frac{\partial Z^{\mathtt{q}^{-2}}}{\partial Z^{\mathtt{q}^{-3}}}, (115) then

superscript𝑍~ℎsuperscript𝑍superscript~ℎ𝑖superscript𝑍𝑑superscript𝑞1superscript^𝑍superscript𝑊top𝑦𝚙𝚚\displaystyle\langle p,q\rangle\operatorname*{\mathbb{E}}\big{[}\frac{\partial Z^{\tilde{h}}}{\partial Z^{\tilde{h}^{i}}}\frac{\partial Z^{dq^{-1}}}{\partial\hat{Z}^{W^{\top}y}}\big{]}=\langle\mathtt{p},\mathtt{q}\rangle. (116) Our goal now is to show qq in Eq. 114 can be extended appropriately such that we may rewrite Eq. 114 as Eq. 110. This will be done through explicitly computing the term Zdq1Z^Wy\frac{\partial Z^{dq^{-1}}}{\partial\hat{Z}^{W^{\top}y}} in Eq. 114.

superscript𝑍𝑑superscript𝑞1superscript^𝑍superscript𝑊top𝑦\frac{\partial Z^{dq^{-1}}}{\partial\hat{Z}^{W^{\top}y}} Suppose {g1,...,gr}\{g^{1},...,g^{r}\} are all G-vars in the program π0\pi_{0} that depend on q1q^{-1} i.e for all 1jr,1\leq j\leq r, we have gj=Wjzjg^{j}=W^{j}z^{j} where zj=ψj(...,q1,...)z^{j}=\psi_{j}(...,q^{-1},...) and where WjW^{j} can be same or different matrices for different jj. Note that it follows that:

superscript𝑍superscript𝑧𝑗superscript𝑍superscript𝑞1superscript𝑍superscriptsuperscript𝑊𝑗top𝑑superscript𝑔𝑗\displaystyle=\sum_{j=1}^{r}\frac{\partial Z^{z^{j}}}{\partial Z^{q^{-1}}}Z^{(W^{j})^{\top}dg^{j}} (118) whereZ(Wj)dgj\displaystyle\text{where}\quad Z^{(W^{j})^{\top}dg^{j}} =Z^(Wj)dgj+Z˙(Wj)dgj=Z^(Wj)dgj.\displaystyle=\hat{Z}^{(W^{j})^{\top}dg^{j}}+\dot{Z}^{(W^{j})^{\top}dg^{j}}=\hat{Z}^{(W^{j})^{\top}dg^{j}}. (119) Note that Z˙(Wj)dgj=0\dot{Z}^{(W^{j})^{\top}dg^{j}}=0 in Eq. 119 from the gradient independence assumption (GIA) because we pass the Simple GIA Check. This may also be easily verified by explicitly computing Z˙(Wj)dgj=yZyEZdgjZ^Wy\dot{Z}^{(W^{j})^{\top}dg^{j}}=\sum_{y}Z^{y}\operatorname*{\mathbb{E}}\frac{\partial Z^{dg^{j}}}{\partial\hat{Z}^{Wy}}, and noticing that the expectation vanishes from the dependency of ZdgjZ^{dg^{j}} on ZvZ^{v} (i.e Zdgj=ZvZμZ^{dg^{j}}=Z^{v}Z^{\mu} for some vector μ\mu which does not depend on vv). Since yy does not depend on vv and the last layer vv is not trained, we have EZdgjZ^Wy=E[Zv]E[..]=0\operatorname*{\mathbb{E}}\frac{\partial Z^{dg^{j}}}{\partial\hat{Z}^{Wy}}=\operatorname*{\mathbb{E}}[Z^{v}]\operatorname*{\mathbb{E}}[..]=0. Since we assumed that the forward propagation does not contain both W,WW,W^{\top}, it follows from differentiating Eq. 118 that

superscript𝑍𝑑superscript𝑞1superscript^𝑍superscript𝑊top𝑦superscriptsubscript𝑗1𝑟superscript𝑍superscript𝑧𝑗superscript𝑍superscript𝑞1superscript^𝑍superscriptsuperscript𝑊𝑗top𝑑superscript𝑔𝑗superscript^𝑍superscript𝑊top𝑦subscript:𝑗formulae-sequence≜superscript𝑊𝑗𝑊𝑑superscript𝑔𝑗𝑦superscript𝑍superscript𝑧𝑗superscript𝑍superscript𝑞1\displaystyle\frac{\partial Z^{dq^{-1}}}{\partial\hat{Z}^{W^{\top}y}}=\sum_{j=1}^{r}\frac{\partial Z^{z^{j}}}{\partial Z^{q^{-1}}}\frac{\partial\hat{Z}^{(W^{j})^{\top}dg^{j}}}{\partial\hat{Z}^{W^{\top}y}}=\sum_{j:W^{j}\triangleq W,dg^{j}=y}\frac{\partial Z^{z^{j}}}{\partial Z^{q^{-1}}}. (120) If this sum over jj is nonempty, then there is a unique jj such that WjWW^{j}\triangleq W and dgj=ydg^{j}=y. In such a case, we may extend the path qq with gj,zjg^{j},z^{j} to form q\mathtt{q} satisfying Eq. 115. Plugging back into Eq. 114 we obtain Eq. 110 as desired.

Hence, we have proven the induction hypothesis.

D.3.3 Proving Section 3 using the induction hypothesis

WLOG assume f(ξ)=Vx(ξ)f(\xi)=V^{\top}x(\xi) for some G-var x(ξ)x(\xi). Using the induction hypothesis and the Master Theorem (B.4), we have that:

𝑡1subscript~𝑓𝑡𝔼superscript𝑍𝑣superscript𝑍𝛿subscript~𝑥𝑡1subscript𝜒𝑡subscript:𝑝superscript𝑝1𝑥subscript𝑞𝑞𝔼superscript𝑍𝑣superscript𝑍𝑑superscript𝑞1𝑝𝑞\displaystyle\lim_{n\to\infty}\tilde{f}_{t+1}-\tilde{f}_{t}=\operatorname*{\mathbb{E}}Z^{v}Z^{\delta\tilde{x}_{t+1}}=-\chi_{t}\sum_{p:p^{-1}=x}\sum_{q\cong q}\operatorname*{\mathbb{E}}\big{[}Z^{v}Z^{dq^{-1}}\big{]}\langle p,q\rangle. (121) Note that Zdq1=ZvZ^{dq^{-1}}=Z^{v} for any path q:q1=xq:q^{-1}=x. Hence, with Eq. 96, we have

𝑡1subscript~𝑓𝑡subscript𝜒𝑡subscript:𝑝superscript𝑝1𝑥subscript𝑞𝑝𝑝𝑞subscript𝜒𝑡̊𝒦subscript𝜉𝑡~𝜉\displaystyle\lim_{n\to\infty}\tilde{f}_{t+1}-\tilde{f}_{t}=-\chi_{t}\sum_{p:p^{-1}=x}\sum_{q\cong p}\langle p,q\rangle=-\chi_{t}\mathring{\mathcal{K}}(\xi_{t},\tilde{\xi}) (122) as desired.

D.4 Relaxing A1., A2., A3. and A4.

We now briefly discuss the case where A1., A2., A3. and A4. are relaxed, as well as the case where ff is represented by a \textscNetsor+\textsc{Netsor}\top^{+} program. As the proof of the general case follows roughly the same logic as in D.1, we only discuss the meaningful differences in each case.

Recall the input and output layers are parameterized by {ui},v\{u^{i}\},v which now depend on tt. The output evolution is now given by:

𝑡1subscript~𝑓𝑡superscriptsubscript𝑉𝑡1topsubscript~𝑥𝑡1superscriptsubscript𝑉𝑡topsubscript~𝑥𝑡superscript𝑣top𝛿subscript~𝑥𝑡1𝑛𝑛superscriptsubscript𝑣𝑡1subscript𝑣𝑡topsubscript~𝑥𝑡𝑛superscriptsubscript𝑠0𝑡superscriptsubscript𝑣𝑠1subscript𝑣𝑠top𝛿subscript~𝑥𝑡1𝑛\displaystyle\tilde{f}_{t+1}-\tilde{f}_{t}=V_{t+1}^{\top}\tilde{x}_{t+1}-V_{t}^{\top}\tilde{x}_{t}=\frac{v^{\top}\delta\tilde{x}_{t+1}}{n}+\frac{\sqrt{n}(v_{t+1}-v_{t})^{\top}\tilde{x}_{t}}{n}+\sum_{s=0}^{t}\frac{(v_{s+1}-v_{s})^{\top}\delta\tilde{x}_{t+1}}{n}. (123) Plugging vt+1vt=χt1nxtv_{t+1}-v_{t}=-\chi_{t}\frac{1}{\sqrt{n}}x_{t} into Eq. 123 and taking the limit (using \textscNetsor+\textsc{Netsor}\top^{+} rules):

𝑡1subscript~𝑓𝑡𝔼superscript𝑍𝑣superscript𝑍𝛿subscript~𝑥𝑡1subscript𝜒𝑡𝔼superscript𝑍𝑥subscript𝜉𝑡superscript𝑍~𝑥\displaystyle\lim_{n\to\infty}\tilde{f}_{t+1}-\tilde{f}_{t}=\operatorname*{\mathbb{E}}\big{[}Z^{v}Z^{\delta\tilde{x}_{t+1}}\big{]}-\chi_{t}\operatorname*{\mathbb{E}}\big{[}Z^{x(\xi_{t})}Z^{\tilde{x}}\big{]}. (124) Evaluating \operatorname*{\mathbb{E}}\big{[}Z^{v}Z^{\delta\tilde{x}_{t+1}}\big{]} by induction requires altering the path definition so that each path pp may start with an input ξ\xi, and ends with a G-var (that is, a path either starts with an X-var or an input). We reuse the definition of inner product between pqp\cong q in Eq. 95, only when both start with inputs ξ,ξ~\xi,\tilde{\xi} respectively then \operatorname*{\mathbb{E}}\big{[}Z^{p^{0}}Z^{q^{0}}\big{]} implies ξξ~\xi^{\top}\tilde{\xi}. The remainder of the proof follows the same logic as with D.1. Note that the NTK in this case would yield:

D.4.2 W,W⊤𝑊superscript𝑊topW,W^{\top} in the forward pass

When both W,WW,W^{\top} are allowed in the forward pass, the update equations for each wtw_{t} take the form:

𝑡1subscript𝑤𝑡subscript𝜒𝑡subscript:𝑔ℎ𝑔𝑊ℎ𝑑subscript𝑔𝑡superscriptsubscriptℎ𝑡top𝑛subscript𝜒𝑡subscript:𝑔ℎ𝑔superscript𝑊topℎsubscriptℎ𝑡𝑑superscriptsubscript𝑔𝑡top𝑛\displaystyle w_{t+1}-w_{t}=-\chi_{t}\sum_{g,h:g=Wh}\frac{dg_{t}h_{t}^{\top}}{n}-\chi_{t}\sum_{g,h:g=W^{\top}h}\frac{h_{t}dg_{t}^{\top}}{n} (126) Some quick calculations using \textscNetsor+\textsc{Netsor}\top^{+}rules show that for G-vars:

It is straightforward to show using GIA (Yang, 2020a) that \operatorname*{\mathbb{E}}\big{[}Z^{d\mathtt{g}(\xi_{t})}Z^{\tilde{h}}\big{]}=0 in both cases, leaving us with a similar expression as with D.1. The induction hypothesis for G-vars in this case takes one of two forms:

If g=Whg=W^{\top}h then Eq. 107 holds with modZ^W\mod\hat{Z}^{W^{\top}\bullet} replacing modZ^W\mod\hat{Z}^{W\bullet}.

Some additional complications need to be resolved. Specifically, with setup D.1 we have used in two places the fact that no transpose is used in the forward pass to prove the induction hypothesis (see Eqs. 112 and 120). To prove the induction, and assuming g=Whg=Wh, we now have instead of Eq. 112 (using Eq. 111):

superscript𝑍𝛿subscript~ℎ𝑡1superscript^𝑍superscript𝑊top𝑦superscriptsubscript𝑖1𝑘𝔼superscript𝑍~ℎsuperscript𝑍superscript~ℎ𝑖superscript𝑍𝛿subscriptsuperscript~ℎ𝑖𝑡1superscript^𝑍superscript𝑊top𝑦superscriptsubscript𝑖1𝑘𝔼superscript2superscript𝑍~ℎsuperscript𝑍superscript~ℎ𝑖superscript^𝑍superscript𝑊top𝑦superscript𝑍𝛿subscriptsuperscript~ℎ𝑖𝑡1\displaystyle\operatorname*{\mathbb{E}}\frac{\partial Z^{\delta\tilde{h}_{t+1}}}{\partial\hat{Z}^{W^{\top}y}}=\sum_{i=1}^{k}\operatorname*{\mathbb{E}}\big{[}\frac{\partial Z^{\tilde{h}}}{\partial Z^{\tilde{h}^{i}}}\frac{\partial Z^{\delta\tilde{h}^{i}_{t+1}}}{\partial\hat{Z}^{W^{\top}y}}\big{]}+\sum_{i=1}^{k}\operatorname*{\mathbb{E}}\big{[}\frac{\partial^{2}Z^{\tilde{h}}}{\partial Z^{\tilde{h}^{i}}\partial\hat{Z}^{W^{\top}y}}Z^{\delta\tilde{h}^{i}_{t+1}}\big{]}. (129) where {hi}\{h^{i}\} are G-vars. To evaluate the additional term on the RHS of Eq. 129, we use the induction hypothesis to express Zδh~t+1iZ^{\delta\tilde{h}^{i}_{t+1}}:

2superscript𝑍~ℎsuperscript𝑍superscript~ℎ𝑖superscript^𝑍superscript𝑊top𝑦superscript𝑍𝛿subscriptsuperscript~ℎ𝑖𝑡1subscript𝜒𝑡subscript:𝑝superscript𝑝1superscriptℎ𝑖subscript𝑞𝑝𝑝𝑞𝔼superscript2superscript𝑍~ℎsuperscript𝑍superscript~ℎ𝑖superscript^𝑍superscript𝑊top𝑦superscript𝑍𝑑superscript𝑞1\displaystyle\operatorname*{\mathbb{E}}\big{[}\frac{\partial^{2}Z^{\tilde{h}}}{\partial Z^{\tilde{h}^{i}}\partial\hat{Z}^{W^{\top}y}}Z^{\delta\tilde{h}^{i}_{t+1}}\big{]}=-\chi_{t}\sum_{p:p^{-1}=h^{i}}\sum_{q\cong p}\langle p,q\rangle\operatorname*{\mathbb{E}}\big{[}\frac{\partial^{2}Z^{\tilde{h}}}{\partial Z^{\tilde{h}^{i}}\partial\hat{Z}^{W^{\top}y}}Z^{dq^{-1}}\big{]}. (130) Using GIA (Yang, 2020a), it is straight forward to show that the expectation on the RHS of Eq. 130 vanishes, leaving us with the first term on the RHS of Eq. 129, as with D.1. Note that the same logic may be applied in Eq. 120, concluding the proof.

D.4.3 Multiple outputs and arbitrary batchsize

We have used a scalar output and a batchsize of 1 throughout this paper. However, extending to multiple (finite) outputs and an arbitrary batchsize requires no additional arguments besides some additional notations. For example, the definition of path should now be altered to express dependency on multiple samples (if batchnorm is used for example). The proof however follows roughly the same logic in D.1.

D.4.4 X-var embedding

We assumed in our proof that xx, which represents the final embedding of ff is a G-var. However, extending the proof to the case where xx is an X-var is straightforward. Let f(ξ)=Vx(ξ)f(\xi)=V^{\top}x(\xi) where x=ψ(h1,...,hk)x=\psi(h^{1},...,h^{k}) and h1,...,hkh^{1},...,h^{k} are G-vars. Using the induction hypothesis, along with Eq. 93 yields:

𝑡1subscript~𝑓𝑡\displaystyle\lim_{n\to\infty}\tilde{f}_{t+1}-\tilde{f}_{t} \displaystyle=-\chi_{t}\operatorname*{\mathbb{E}}\big{[}Z^{v}Z^{\delta\tilde{x}_{t+1}}\big{]}=-\chi_{t}\sum_{i=1}^{k}\operatorname*{\mathbb{E}}\big{[}Z^{v}\frac{\partial Z^{\tilde{x}}}{\partial Z^{\tilde{h}^{i}}}Z^{\delta\tilde{h}_{t+1}}\big{]} (131) \displaystyle=-\chi_{t}\sum_{i=1}^{k}\operatorname*{\mathbb{E}}\big{[}Z^{v}\frac{\partial Z^{\tilde{x}}}{\partial Z^{\tilde{h}^{i}}}\sum_{p:p^{-1}=h^{i}}\sum_{q\cong p}Z^{dh^{i}(\xi_{t})}\langle p,q\rangle\big{]} (132) \displaystyle=-\chi_{t}\sum_{i=1}^{k}\sum_{p:p^{-1}=h^{i}}\sum_{q\cong p}\langle p,q\rangle\operatorname*{\mathbb{E}}\big{[}\frac{\partial Z^{\tilde{x}}}{\partial Z^{\tilde{h}^{i}}}\frac{\partial Z^{x(\xi_{t})}}{\partial Z^{h^{i}(\xi_{t})}}\big{]} (133) =χtK˚(ξt,ξ~)\displaystyle=-\chi_{t}\mathring{\mathcal{K}}(\xi_{t},\tilde{\xi}) (134) It is straightforward to show that the expression for K˚(ξt,ξ~)\mathring{\mathcal{K}}(\xi_{t},\tilde{\xi}) in Eq. 133 represents the NTK if this case.

D.4.5 Network specified by Netsor⊤+limit-fromNetsorsuperscripttop\textsc{Netsor}\top^{+}

\textsc{Netsor}\top^{+} If the network is more generally represented by a \textscNetsor+\textsc{Netsor}\top^{+} program instead of just a \textscNetsor\textsc{Netsor}\top program, then our proof can be very simply modified to accommodate as follows: The new operation allowed in such a network is the production of a scalar through Moment, say a=1nα=1nψ(xα1,,xαk;θ1,,θl)a=\frac{1}{n}\sum_{\alpha=1}^{n}\psi(x^{1}_{\alpha},\ldots,x^{k}_{\alpha};\theta^{1},\ldots,\theta^{l}). By a similar inductive argument as before, we will see that 1) xti=x0i+o(1)x^{i}_{t}=x^{i}_{0}+o(1) for all i[k]i\in[k] and θtj=θ0j+o(1)\theta^{j}_{t}=\theta^{j}_{0}+o(1) for all j[l]j\in[l], so that at=a0+o(1)a_{t}=a_{0}+o(1); 2) in the backward pass, any backpropagation through aa will zero out: For example, if aa is only used later in a Nonlin z=ψ(y;a)z=\mathtt{\psi}(y;a), then 1naf=dz,aψ(y;a)/n\frac{1}{\sqrt{n}}\nabla_{a}f=\langle dz,\partial_{a}\mathtt{\psi}(y;a)\rangle/n will converge to 0 because of GIA (as dzdz is linear in the final layer), and the error signal at xix^{i} times n\sqrt{n} is the constant vector with entries 1naf\frac{1}{\sqrt{n}}\nabla_{a}f, which is o(1)o(1).

Therefore, we can treat any scalar produced through Moment as a constant fixed at initialization, and the notion of path from before carries over here without change (by assuming all nonlinearities with scalar parameters to be parameterless nonlinearities where the parameters are fixed). Then the same reasoning follows.