Can SGD Learn Recurrent Neural Networks with Provable Generalization?

Zeyuan Allen-Zhu, Yuanzhi Li

Introduction

Recurrent neural networks (RNNs) is one of the most popular models in sequential data analysis . When processing an input sequence, RNNs repeatedly and sequentially apply the same operation to each input token. The recurrent structure of RNNs allows it to capture the dependencies among different tokens inside each sequence, which is empirically shown to be effective in many applications such as natural language processing , speech recognition and so on.

The recurrent structure in RNNs shows great power in practice, however, it also imposes great challenge in theory. Until now, RNNs remains to be one of the least theoretical understood models in deep learning. Many fundamental open questions are still largely unsolved in RNNs, including

(Optimization). When can RNNs be trained efficiently?

(Generalization). When do the results learned by RNNs generalize to test data?

Question 1 is technically challenging due to the notorious question of vanishing/exploding gradients, and the non-convexity of the training objective induced by non-linear activation functions.

Question 2 requires even deeper understanding of RNNs. For example, in natural language processing, “Juventus beats Bacerlona” and “Bacerlona beats Juventus” have completely different meanings. How can the same operation in RNN encode a different rule for “Juventus” at token 1 vs. “Juventus” at token 3, instead of merely memorizing each training example?

There have been some recent progress towards obtaining more principled understandings of these questions.

On the optimization side, Hardt, Ma, and Recht show that over-parameterization can help in the training process of a linear dynamic system, which is a special case of RNNs with linear activation functions. Allen-Zhu, Li, and Song show that over-parameterization also helps in training RNNs with ReLU activations. This latter result gives no generalization guarantee.

Indeed, bridging the gap between optimization (question 1) and generalization (question 2) can be quite challenging in neural networks. The case of RNN is particularly so due to the (potentially) exponential blowup in input length.

Generalization \nrightarrow Optimization. One could imagine adding a strong regularizer to ensure β1\beta\leq 1 for generalization purpose; however, it is unclear how an optimization algorithm such as stochastic gradient descent (SGD) finds a network that both minimizes training loss and maintains β1\beta\leq 1. One could also use a very small network so the number of parameters is limited; however, it is not clear how SGD finds a small network with small training loss.

Optimization \nrightarrow Generalization. One could try to train RNNs without any regularization; however, it is then quite possible that the number of parameters need to be large and β>1\beta>1 after the training. This is so both in practice (since “memory implies larger spectral radius” ) and in theory . All known generalization bounds fail to apply in this regime.

In this paper, we give arguably the first theoretical analysis of RNNs that captures optimization and generalization simultaneously. Given any set of input sequences, as long as the outputs are (approximately) realizable by some smooth function in a certain concept class, then after training a vanilla RNN with ReLU activations, SGD provably finds a solution that has both small training and generalization error. Our result allows β\beta to be larger than 11 by a constant, but is still efficient: meaning that the iteration complexity of the SGD, the sample complexity, and the time complexity scale only polynomially (or almost polynomially) with the length of the input.

Notations

Note that in the occasion that j=1i1(Iv^jv^j)vi\prod_{j=1}^{i-1}(I-\widehat{v}_{j}\widehat{v}_{j}^{\top})v_{i} is the zero vector, we let v^i\widehat{v}_{i} be an arbitrary unit vector that is orthogonal to v^1,,v^i1\widehat{v}_{1},\dots,\widehat{v}_{i-1}.

where CC^{*} is a sufficiently large constant (e.g., 10410^{4}). It holds Cs(ϕ,R)Cε(ϕ,R)Cs(ϕ,O(R))×\poly(1/ε)\mathfrak{C}_{\mathfrak{s}}(\phi,R)\leq\mathfrak{C}_{\varepsilon}(\phi,R)\leq\mathfrak{C}_{\mathfrak{s}}(\phi,O(R))\times\poly(1/\varepsilon), and for sinz,ez\sin z,e^{z} or low degree polynomials, they only differ by o(1/ε)o(1/\varepsilon).

Problem Formulation

where εx(0,1)\varepsilon_{x}\in(0,1) is a parameter to be chosen later. We then feed this actual sequence xx into RNN.

In this way we have ensured that the actual input sequence is normalized:

We say that W,A,BW,A,B are at random initialization, if the entries of WW and AA are i.i.d. generated from N(0,2m)\mathcal{N}(0,\frac{2}{m}), and the entries of BB are i.i.d. generated from N(0,1d)\mathcal{N}(0,\frac{1}{d}).

Since we only update WW, the label sequence y3,,yLy^{\star}_{3},\dots,y^{\star}_{L} is off from the input sequence x2,,xL1x^{\star}_{2},\dots,x^{\star}_{L-1} by one. The last xLx_{L} can be made zero, but we keep it normalized for notational simplicity. The first x1x_{1} gives a random seed fed into the RNN (one can equivalently put it into h0h_{0}). We have scaled down the input signals by εx\varepsilon_{x}, which can be equivalently thought as scaling down AA.

2 Concept Class

For proof simplicity, we assume Φij,r,s(0)=0\Phi_{i\to j,r,s}(0)=0. We also use

Agnostic PAC-learning language. Our concept class consists of all functions FF^{*} in the form of (3.1) with complexity bounded by threshold CC and parameter pp bounded by threshold p0p_{0}. Let OPT\mathsf{OPT} be the population risk achieved by the best target function in this concept class. Then, our goal is to learn this concept class with population risk OPT+ε\mathsf{OPT}+\varepsilon using sample and time complexity polynomial in CC, p0p_{0} and 1/ε1/\varepsilon. In the remainder of this paper, to simplify notations, we do not explicitly define this concept class parameterized by CC and pp. Instead, we equivalently state our theorem with respect to any (unknown) target function FF^{*} with specific parameters CC and pp.

Our concept class is general enough and contains functions where the output at each token is generated from inputs of previous tokens using any two-layer neural network. Indeed, one can verify that our general form (3.1) includes functions of the following:

Our Result: RNN Provably Learns the Concept Class

Suppose the distribution D\mathcal{D} is generated by some (unknown) target function FF^{*} of the form (3.1) in the concept class with population risk OPT\mathsf{OPT}, namely,

and suppose we are given training dataset Z\mathcal{Z} consisting of NN i.i.d. samples from D\mathcal{D}. We consider the following stochastic training objective

For every 0<\varepsilon<\widetilde{O}\big{(}\frac{1}{\poly(L,d)\cdot p\cdot\mathfrak{C}_{\mathfrak{s}}(\Phi,O(\sqrt{L}))}\big{)}, define complexity C=Cε(Φ,L)C=\mathfrak{C}_{\varepsilon}(\Phi,\sqrt{L}) and \lambda=\widetilde{\Theta}\big{(}\frac{\varepsilon}{L^{2}d}\big{)}, if the number of neurons m\poly(C,ε1)m\geq\poly(C,\varepsilon^{-1}) and the number of samples is N=Z\poly(C,ε1,logm)N=|\mathcal{Z}|\geq\poly(C,\varepsilon^{-1},\log m), then SGD with \eta=\widetilde{\Theta}\big{(}\frac{1}{\varepsilon L^{2}d^{2}m}\big{)} and

satisfies that, with probability at least 1eΩ(ρ2)1-e^{-\Omega(\rho^{2})} over the random initialization

Sample complexity. Our sample complexity only scales with log(m)\log(m), making the result applicable to over-parameterized RNNs that have mNm\gg N. Following Example 2.1, if ϕ(z)\phi(z) is constant degree polynomial we have C=\poly(L,logε1)C=\poly(L,\log\varepsilon^{-1}) so Theorem 1 says that RNN learns such concept class

Non-linear measurements. Our result shows that vanilla RNNs can efficiently learn a weighted average of non-linear measurements of the input. As we argued in Example 3.5, this at least includes functions where the output at each token is generated from inputs of previous tokens using any two-layer neural networks. Average of non-linear measurements can be quite powerful, achieving the state-of-the-art performance in some sequential applications such as sentence embedding and many others , and acts as the base of attention mechanism in RNNs .

In our result, the function Φij,r,s\Phi_{i\to j,r,s} only adapts with the positions of the input tokens, but in many applications, we would like the function to adapt with the values of the past tokens x1,,xi1x^{\star}_{1},\dots,x^{\star}_{i-1} as well. We believe a study on other models (such as LSTM ) can potentially settle these questions.

Comparison to feed-forward networks. Recently there are many interesting results on analyzing the learning process of feed-forward neural networks . Most of them either assume that the input is structured (e.g. Gaussian or separable) or only consider linear networks. Allen-Zhu, Li, and Liang show a result in the same flavor as this paper but for two and three-layer networks. Since RNNs apply the same unit repeatedly to each input token in a sequence, our analysis is significantly different from and creates lots of difficulties in the analysis.

2 Conclusion

We show RNN can actually learn some notable concept class efficiently, using simple SGD method with sample complexity polynomial or almost-polynomial in input length. This concept class at least includes functions where each output token is generated from inputs of earlier tokens using a smooth neural network. To the best of our knowledge, this is the first proof that some non-trivial concept class is efficiently learnable by RNN. Our sample complexity is almost independent of mm, making the result applicable to over-parameterized settings. On a separate note, our proof explains why the same recurrent unit is capable of learning various functions from different input tokens to different output tokens.

Our proof of Theorem 1 divides into four conceptual steps.

We obtain first-order approximation of how much the outputs of the RNN change if we move from WW to W+WW+W^{\prime}. This change (up to small error) is a linear function in WW^{\prime}. (See Section 6).

(This step can be derived from prior work without much difficulty.)

(This step is the most interesting part of this paper.)

We argue that the SGD method moves in a direction nearly as good as WW^{\divideontimes} and thus efficiently decreases the training objective (see Section 7).

(This is a routine analysis of SGD in the non-convex setting given Steps 1&2.)

We use the first-order linear approximation to derive a Rademacher complexity bound that does not grow exponentially in LL (see Section 8). By feeding the output of SGD into this Rademacher complexity, we finish the proof of Theorem 1 (see Section 9).

(This is a one-paged proof given the Steps 1&2&3.)

Although our proofs are technical, to help the readers, we write 7 pages of sketch proofs for Steps 1 through 4. This can be found in Section 5 through 9. Our final proofs reply on many other technical properties of RNN that may be of independent interests: such as properties of RNN at random initialization (which we include in Section B and C), and properties of RNN stability (which we include in Section D, E, F). Some of these properties are simple modifications from prior work, but some are completely new and require new proof techniques (namely, Section C, D and E). We introduce some notations for analysis purpose.

Throughout the proofs, to simplify notations when specifying polynomial factors, we introduce

We assume m\poly(ϱ)m\geq\poly(\varrho) for some sufficiently large polynomial factor.

Existence of Good Network Through Backward

Furthermore, WW^{\divideontimes} is appropriately bounded in Frobenius norm. In our sketched proof below, it shall become clear how this same matrix WW^{\divideontimes} can simultaneously represent functions Φij\Phi_{i\to j^{\prime}} that come from different input tokens ii. Since SGD can be shown to descend in a direction “comparable” to WW^{\divideontimes}, it converges to a matrix WW with similar guarantees.

In order to show (5.2), we first show a variant of the “indicator to function” lemma from .

Above, Lemma lem:fit_fun_olda says that we can use a bounded function \mathds1a,x+n0H(a)\mathds{1}_{\langle a,x^{\star}\rangle+n\geq 0}H\left(a\right) to fit a target function Φ(w,x)\Phi(\langle w^{*},x^{\star}\rangle), and Lemma lem:fit_fun_oldb says that if the magnitude of nn is large then this function is close to being constant. For such reason, we can view nn as “noise.” While the proof of lem:fit_fun_olda is from prior work , our new property lem:fit_fun_oldb is completely new and it requires some technical challenge to simultaneously guarantee lem:fit_fun_olda and lem:fit_fun_oldb. The proof is in Appendix G.1

2 Fitting a Single Function

We now try to apply Lemma 5.1 to approximate a single function Φij,r,s(wij,r,s,xi)\Phi_{i\to j,r,s}(\langle w^{*}_{i\to j,r,s},x^{\star}_{i}\rangle). For this purpose, let us consider two (normalized) input sequences. The first (null) sequence x(0)x^{(0)} is given as

The second sequence xx is generated from an input xx^{\star} in the support of D\mathcal{D} (recall Definition 3.1). Let

For every 2i<jL2\leq i<j\leq L, r[p],s[d]r\in[p],s\in[d] and every constant \varepsilon_{e}\in\big{(}0,\frac{1}{\mathfrak{C}_{\mathfrak{s}}(\Phi_{i\to j,r,s},O(\sqrt{L}))}\big{)}, there exists C=Cεe(Φij,r,s,L)C^{\prime}=\mathfrak{C}_{\varepsilon_{e}}(\Phi_{i\to j,r,s},\sqrt{L}) so that, for every

xx be a fixed input sequence defined by some xx^{\star} in the support of D\mathcal{D} (see Definition 3.1),

w~k,a~kN(0,2Im)\widetilde{w}_{k},\widetilde{a}_{k}\sim\mathcal{N}\left(0,\frac{2\mathbf{I}}{m}\right) be freshly new random vectors,

with probability at least 1eΩ(ρ2)1-e^{-\Omega(\rho^{2})} over WW and AA,

(off target), for every iii^{\prime}\not=i

Lemma 5.2 implies there is a quantity \mathds1w~k,hi1(0)εcmHij,r,s(a~k)\mathds{1}_{|\langle\widetilde{w}_{k},h_{i-1}^{(0)}\rangle|\leq\frac{\varepsilon_{c}}{\sqrt{m}}}H_{i\to j,r,s}(\widetilde{a}_{k}) that only depends on the target function and the random initialization (namely, w~k,a~k\widetilde{w}_{k},\widetilde{a}_{k}) such that,

when multiplying \mathds1w~k,hi1+a~k,xi0\mathds{1}_{\langle\widetilde{w}_{k},h_{i-1}\rangle+\langle\widetilde{a}_{k},x_{i}\rangle\geq 0} gives the target Φij,r,s(wij,r,s,xi\Phi_{i\to j,r,s}(\langle w^{*}_{i\to j,r,s},x^{\star}_{i}\rangle, but

when multiplying \mathds1w~k,hi1+a~k,xi0\mathds{1}_{\langle\widetilde{w}_{k},h_{i^{\prime}-1}\rangle+\langle\widetilde{a}_{k},x_{i^{\prime}}\rangle\geq 0} gives near zero.

The full proof is in Appendix G.2 but we sketch why Lemma 5.2 can be derived from Lemma 5.1.

Let us focus on indicator \mathds1w~k,hi1+a~k,xi0\mathds{1}_{\langle\widetilde{w}_{k},h_{i^{\prime}-1}\rangle+\langle\widetilde{a}_{k},x_{i^{\prime}}\rangle\geq 0}:

a~k,xi\langle\widetilde{a}_{k},x_{i^{\prime}}\rangle is distributed like N(0,2εx2m)\mathcal{N}(0,\frac{2\varepsilon_{x}^{2}}{m}) because \langle\widetilde{a}_{k},x_{i^{\prime}}\rangle=\big{\langle}(\widetilde{a}_{k},(\varepsilon_{x}x^{\star}_{i^{\prime}},0)\big{\rangle}; but

w~k,hi1\langle\widetilde{w}_{k},h_{i^{\prime}-1}\rangle is roughly N(0,2m)\mathcal{N}(0,\frac{2}{m}) because hi11\|h_{i^{\prime}-1}\|\approx 1 by random init. (see Lemma lem:done1a).

Thus, if we treat w~k,hi1\langle\widetilde{w}_{k},h_{i^{\prime}-1}\rangle as the “noise nn” in Lemma 5.1 it can be 1εx\frac{1}{\varepsilon_{x}} times larger than a~k,xi\langle\widetilde{a}_{k},x_{i^{\prime}}\rangle.

To show Lemma lem:fit_fun_plusa, we only need to focus on w~k,hi1(0)εcm|\langle\widetilde{w}_{k},h_{i^{\prime}-1}^{(0)}\rangle|\leq\frac{\varepsilon_{c}}{\sqrt{m}} because i=ii=i^{\prime}. Since h(0)h^{(0)} can be shown close to hh (see Lemma D.1), this is almost equivalent to w~k,hi1εcm|\langle\widetilde{w}_{k},h_{i^{\prime}-1}\rangle|\leq\frac{\varepsilon_{c}}{\sqrt{m}}. Conditioning on this happens, the “noise nn” must be small so we can apply Lemma lem:fit_fun_olda.

To show Lemma lem:fit_fun_plusa, we can show when iii^{\prime}\neq i, the indicator on w~k,hi1εcm|\langle\widetilde{w}_{k},h_{i-1}\rangle|\leq\frac{\varepsilon_{c}}{\sqrt{m}} gives little information about the true noise w~k,hi1\langle\widetilde{w}_{k},h_{i^{\prime}-1}\rangle. This is so because hi1h_{i-1} and hi1h_{i^{\prime}-1} are somewhat uncorrelated (details in Lemma lem:done1k). As a result, the “noise nn” is still large and thus Lemma lem:fit_fun_oldb applies with Φij,r,s(0)=0\Phi_{i\to j,r,s}(0)=0. ∎

3 Fitting the Target Function

Suppose \varepsilon_{e}\in\big{(}0,\frac{1}{\mathfrak{C}_{\mathfrak{s}}(\Phi,O(\sqrt{L}))}\big{)}, C=Cεe(Φ,L)C^{\prime}=\mathfrak{C}_{\varepsilon_{e}}(\Phi,\sqrt{L}), \varepsilon_{x}\in(0,\frac{1}{\rho^{4}C^{\prime}}\big{)}, we choose

The following lemma that says fj,sf_{j^{\prime},s^{\prime}} is close to the target function Fj,sF^{*}_{j^{\prime},s^{\prime}}.

The construction of WW^{\divideontimes} in Definition 5.3 satisfies the following. For every normalized input sequence xx generated from xx^{\star} in the support of D\mathcal{D}, we have with probability at least 1eΩ(ρ2)1-e^{-\Omega(\rho^{2})} over W,A,BW,A,B, it holds for every 3jL3\leq j^{\prime}\leq L and s[d]s^{\prime}\in[d]

Using definition of fj,sf_{j^{\prime},s^{\prime}} in (5.1) and WW^{\divideontimes}, one can write down

The summands in (5.3) with iii\neq i^{\prime} are negligible owing to Lemma lem:fit_fun_plusb.

The summands in (5.3) with i=ii=i^{\prime} but jjj\neq j^{\prime} are negligible, after proving that Backij\operatorname{\mathsf{Back}}_{i\to j} and Backij\operatorname{\mathsf{Back}}_{i\to j^{\prime}} are very uncorrelated (details in Lemma C.1).

The summands in (5.3) with sss\neq s^{\prime} are negligible using the randomness of BB.

One can also prove BackijBackij(0)\operatorname{\mathsf{Back}}_{i^{\prime}\to j^{\prime}}\approx\operatorname{\mathsf{Back}}^{(0)}_{i^{\prime}\to j^{\prime}} and hi1hi1(0)h_{i^{\prime}-1}\approx h^{(0)}_{i^{\prime}-1} (details in Lemma D.1).

Applying Lemma lem:fit_fun_plusa and using our choice of Cij,sC_{i^{\prime}\to j^{\prime},s^{\prime}}, this gives (in expectation)

Proving concentration (with respect to k[m]k\in[m]) is a lot more challenging due to the sophisticated correlations across different indices kk. To achieve this, we replace some of the pairs wk,akw_{k},a_{k} with fresh new samples w~k,a~k\widetilde{w}_{k},\widetilde{a}_{k} for all kNk\in\mathcal{N} and apply concentration only with respect to kNk\in\mathcal{N}. Here, N\mathcal{N} is a random subset of [m][m] with cardinality m0.1m^{0.1}. We show that the network stabilizes (details in Section E) against such re-randomization. Full proof is in Section G.3. ∎

Finally, one can show \|W^{\divideontimes}\|_{F}\leq O\big{(}\frac{p\rho^{3}C}{\sqrt{m}}\big{)} (see Claim G.1). Crucially, this Frobenius norm scales in m1/2m^{-1/2} so standard SGD analysis shall ensure that our sample complexity does not depend on mm (up to log factors).

Coupling and First-Order Approximation

Consider now the scenario when the random initialization matrix WW is perturbed to W+WW+W^{\prime} with WW^{\prime} being small in spectral norm. Intuitively, this WW^{\prime} will later capture how much SGD has moved away from the random initialization, so it may depend on the randomness of W,A,BW,A,B. To untangle this possibly complicated correlation, all lemmas in this section hold for all WW^{\prime} being small.

The first lemma below states that the jj-th layer output difference B(hj+hj)BhjB(h_{j}+h^{\prime}_{j})-Bh_{j} can be approximated by a linear function in WW^{\prime}, that is i=1j1BackijDi+1Whi\sum_{i=1}^{j-1}\operatorname{\mathsf{Back}}_{i\to j}D_{i+1}W^{\prime}h_{i}. We remind the reader that this linear function in WW^{\prime} is exactly the same as our notation of fj,sf_{j^{\prime},s^{\prime}} from (5.2).

Let W,A,BW,A,B be at random initialization, xx be a fixed normalized input sequence, and Δ[ϱ100,ϱ100]\Delta\in[\varrho^{-100},\varrho^{100}]. With probability at least 1eΩ(ρ)1-e^{-\Omega(\rho)} over W,A,BW,A,B the following holds. Given any perturbation matrix WW^{\prime} with W2Δm\|W^{\prime}\|_{2}\leq\frac{\Delta}{\sqrt{m}}, letting

The proof of Lemma 6.1 is similar to the semi-smoothness theorem of and can be found in Section H.1.

The next lemma says that, for this linear function i=1j1BackijDi+1W~hi\sum_{i=1}^{j-1}\operatorname{\mathsf{Back}}_{i\to j}D_{i+1}\widetilde{W}h_{i} over W~\widetilde{W}, one can replace h,D,Backh,D,\operatorname{\mathsf{Back}} with h+h,D+D,Back+Backh+h^{\prime},D+D^{\prime},\operatorname{\mathsf{Back}}+\operatorname{\mathsf{Back}}^{\prime} without changing much in its output. It is a direct consequence of the adversarial stability properties of RNN from prior work (see Section F).

Let W,A,BW,A,B be at random initialization, xx be a fixed normalized input sequence, and Δ[ϱ100,ϱ100]\Delta\in[\varrho^{-100},\varrho^{100}]. With probability at least 1eΩ(ρ)1-e^{-\Omega(\rho)} over W,A,BW,A,B the following holds. Given any matrix WW^{\prime} with W2Δm\|W^{\prime}\|_{2}\leq\frac{\Delta}{\sqrt{m}}, and any W~\widetilde{W} with W~2ωm\|\widetilde{W}\|_{2}\leq\frac{\omega}{\sqrt{m}}, letting

A direct corollary of Lemma 6.2 is that, for our matrix WW^{\divideontimes} constructed in Definition 5.3 satisfies the same property of Lemma 5.4 after perturbation. Namely,

WW^{\divideontimes} in Definition 5.3 satisfies the following. Let W,A,BW,A,B be at random initialization, xx be a fixed normalized input sequence generated by xx^{\star} in the support of D\mathcal{D}, and Δ[ϱ100,ϱ100]\Delta\in[\varrho^{-100},\varrho^{100}]. With probability at least 1eΩ(ρ)1-e^{-\Omega(\rho)} over W,A,BW,A,B the following holds. Given any matrix WW^{\prime} with W2Δm\|W^{\prime}\|_{2}\leq\frac{\Delta}{\sqrt{m}}, any 3jL3\leq j^{\prime}\leq L, and any s[d]s^{\prime}\in[d]:

Combining Lemma 5.4 and Lemma 6.2 gives the proof. ∎

Optimization and Convergence

Our main convergence lemma for SGD on the training objective is as follows.

For every constant \varepsilon\in\big{(}0,\frac{1}{p\cdot\poly(\rho)\cdot\mathfrak{C}_{\mathfrak{s}}(\Phi,\sqrt{L})}\big{)}, there exists C=Cε(Φ,L)C^{\prime}=\mathfrak{C}_{\varepsilon}(\Phi,\sqrt{L}) and parameters

so that, as long as m\poly(ϱ)m\geq\poly(\varrho) and NΩ(ρ3pCs2(Φ,1)ε2)N\geq\Omega(\frac{\rho^{3}p\mathfrak{C}_{\mathfrak{s}}^{2}(\Phi,1)}{\varepsilon^{2}}), setting learning rate \eta=\Theta\big{(}\frac{1}{\varepsilon\rho^{2}m}\big{)} and T=\Theta\big{(}\frac{p^{2}C^{2}\poly(\rho)}{\varepsilon^{2}}\big{)}, we have

and WtFΔm\|W_{t}\|_{F}\leq\frac{\Delta}{\sqrt{m}} for Δ=C2p2\poly(ρ)ε2\Delta=\frac{C^{2}p^{2}\poly(\rho)}{\varepsilon^{2}}.

The full proof is in Section I and we sketch the main idea here. Recall the training objective

Let xx be a normalized input sequence generated by some xx^{\star} in the support of D\mathcal{D}. Consider an iteration tt where the current weight matrix is W+WtW+W_{t}. Let

which is a linear function over W~\widetilde{W}. Let us define a loss function G~\widetilde{G} as:

Let WW^{\divideontimes} be defined in Definition 5.3. By Lemma 6.3, we know that as long as Wt2\|W_{t}\|_{2} is small (which we shall ensure towards the end),

Thus, by the 1-Lipschitz continuity of GG, one can derive that

By Lemma 6.1 and Lemma 6.2 together, we know that

Using the linearity of RjR_{j} and the 1-Lipschitz continuity of GG, we have

where ① is by our choice of λ\lambda which implies λFj(x;W)λO(ρ)ε10L\lambda\|F_{j}(x^{\star};W)\|\leq\lambda\cdot O(\rho)\leq\frac{\varepsilon}{10L} by Lemma lem:done1h.

Together, we have G~(1λWWt)OPT+ε5\widetilde{G}\left(\frac{1}{\lambda}W^{\divideontimes}-W_{t}\right)\leq\mathsf{OPT}+\frac{\varepsilon}{5}. Thus, by the convexity of G~(W~)\widetilde{G}(\widetilde{W}) (composing convex function with linear function is convex), we know

Suppose in this high-level sketch that we apply gradient descent as opposed to SGD. Then, Wt+1=WtηG~(0)W_{t+1}=W_{t}-\eta\nabla\widetilde{G}(0) and we have

Putting (7.1) into this formula, we know that as long as G~(0)>OPT+ε5\widetilde{G}(0)>\mathsf{OPT}+\frac{\varepsilon}{5}, then \diamondsuit is a very negative term and thus, when η\eta is sufficiently small, it guarantees to decrease λ1WWt+1F\left\|\lambda^{-1}W^{\divideontimes}-W_{t+1}\right\|_{F}. This cannot happen for too many iterations, and thus we arrive at a convergence statement. ∎

Rademacher Complexity Through Coupling

We have the following simple lemma about the Rademacher complexity of RNNs. It first uses the coupling Lemma 6.1 to reduce the network to a linear function, and then calculates the Rademacher complexity for this linear function class.

where we use Backq,ij,hq,i\operatorname{\mathsf{Back}}_{q,i\to j},h_{q,i} and Dq,iD_{q,i} to denote that calculated from sample xqx^{\star}_{q}. Since this function is linear in WW^{\prime}, we can write it as

We have GqFO(Lρm/d)\|G_{q}\|_{F}\leq O(L\rho\sqrt{m/d}) from Lemma B.1. We bound the Rademacher complexity of this linear function using Proposition A.3 as follows.

Proof of Theorem 1

For every \varepsilon\in\big{(}0,\frac{1}{\poly(\rho)\cdot p\cdot\mathfrak{C}_{\mathfrak{s}}(\Phi,O(\sqrt{L}))}\big{)}, define complexity C=Cε(Φ,L)C=\mathfrak{C}_{\varepsilon}(\Phi,\sqrt{L}) and λ=ε10Lρ\lambda=\frac{\varepsilon}{10L\rho}, if the number of neurons m\poly(C,p,L,d,ε1)m\geq\poly(C,p,L,d,\varepsilon^{-1}) and the number of samples is N\poly(C,p,L,d,ε1)N\geq\poly(C,p,L,d,\varepsilon^{-1}), then SGD with \eta=\Theta\big{(}\frac{1}{\varepsilon\rho^{2}m}\big{)} and

satisfies that, with probability at least 1eΩ(ρ2)1-e^{-\Omega(\rho^{2})} over the random initialization

One can first apply Lemma 7.1 to obtain WtW_{t} for t=0,1,,T1t=0,1,\dots,T-1 satisfying (recall (I.1))

We can also apply Lemma lem:done1h together with Lemma lem:stability:adva to derive that for each fixed (x,y)D(x^{\star},y^{\star})\sim\mathcal{D}, with probability at least 1eΩ(ρ2)1-e^{-\Omega(\rho^{2})}, it satisfies for every j=3,4,,Lj=3,4,\dots,L,

and therefore by the 1-Lipschitz continuity of G(,y)G(\cdot,y^{\star}),

Plugging in the Rademacher complexity Lemma 8.1 together with the choice b=O(ερ6Δ)b=O(\varepsilon\rho^{6}\Delta) into standard generalization argument (see Corollary A.2), we have with probability at least 1eΩ(ρ2)1-e^{-\Omega(\rho^{2})}, for all tt

where the additional factor λ\lambda is because we have scaled FjF_{j} with factor λ\lambda. In sum, it suffices to choose N\geq\Omega\big{(}\frac{\lambda^{2}\rho^{8}\Delta^{2}}{\varepsilon^{2}}\big{)}=\Omega(\rho^{6}\Delta^{2}) and N\geq\Omega\big{(}\frac{\rho^{4}b^{2}}{\varepsilon^{2}}\big{)}=\Omega(\poly(\rho)\Delta^{2}). ∎

Appendix A Rademacher Complexity Review

Suppose X=(x1,,xN)\mathcal{X}=(x_{1},\dots,x_{N}) where each xix_{i} is generated i.i.d. from a distribution D\mathcal{D}. If every fFf\in\mathcal{F} satisfies fb|f|\leq b, for every δ(0,1)\delta\in(0,1) with probability at least 1δ1-\delta over the randomness of Z\mathcal{Z}, it satisfies

Let F\mathcal{F}^{\prime} be the class of functions by composing LL with F1,,Fk\mathcal{F}_{1},\dots,\mathcal{F}_{k}, that is, F={Lx(f1,,fk)f1F1fkFk}\mathcal{F}^{\prime}=\{L_{x}\circ(f_{1},\dots,f_{k})\mid f_{1}\in\mathcal{F}_{1}\cdots f_{k}\in\mathcal{F}_{k}\}. By the (vector version) of the contraction lemma of Rademacher complexity There are slightly different versions of the contraction lemma in the literature. For the scalar case without absolute value, see [21, Section 3.8]; for the scalar case with absolute value, see [6, Theorem 12]; and for the vector case without absolute value, see . it satisfies R^(Z;F)O(1)r=1kR^(Z;Fr)\widehat{\mathfrak{R}}(\mathcal{Z};\mathcal{F}^{\prime})\leq O(1)\cdot\sum_{r=1}^{k}\widehat{\mathfrak{R}}(\mathcal{Z};\mathcal{F}_{r}). ∎

We recall the simple calculation of the Rademacher complexity for linear function class.

Suppose x2=1\|x\|_{2}=1 for all xXx\in\mathcal{X}. The class F={xw,xw2B}\mathcal{F}=\{x\mapsto\langle w,x\rangle\mid\|w\|_{2}\leq B\} has Rademacher complexity R^(X;F)O(BN)\widehat{R}(\mathcal{X};\mathcal{F})\leq O(\frac{B}{\sqrt{N}}).

Appendix B Random Initialization: Basic Properties

We first note some important properties about the random initialization of our RNNs. Some of them have already appeared in , and the remaining ones can be easily derived from .

Let W,A,BW,A,B be at random initialization and x1,,xLx_{1},\dots,x_{L} be any fixed normalized input sequence (see Definition 3.2). Recall h0=0h_{0}=0 and

\left|\mathbf{e}_{r}^{\top}\operatorname{\mathsf{Back}}_{i\to j}\mathbf{e}_{k}\right|\leq O\big{(}\frac{\rho}{\sqrt{d}}\big{)} for every k[m],r[d],1ijLk\in[m],r\in[d],1\leq i\leq j\leq L (backward signal)

\left\|\mathbf{e}_{r}^{\top}\operatorname{\mathsf{Back}}_{i\to j}\right\|\geq\Omega\big{(}\frac{\sqrt{m}}{\sqrt{d}}\big{)} for every r[d],1ijLr\in[d],1\leq i\leq j\leq L (backward signal)

Finally, letting δ=ρm\delta=\frac{\rho}{\sqrt{m}} finishes the proof.

This is similar to the proof of Lemma lem:done1a, except noticing the 2\sqrt{2} factor: 2(1δ)1+εx2g12(1+δ)1+εx2\sqrt{2}(1-\delta)\sqrt{1+\varepsilon_{x}^{2}}\leq\|g_{1}\|\leq\sqrt{2}(1+\delta)\sqrt{1+\varepsilon_{x}^{2}}.

yy has at most O(sm)O(sm) coordinates kk with yks2m|y_{k}|\leq\frac{s}{2\sqrt{m}}.

Since zz is of dimension at most LL, we can apply a standard ϵ\epsilon-net argument over all possible zz with z[0.5,3]\|z\|\in[0.5,3] (with ϵ=O(s/m)\epsilon=O(s/\sqrt{m})) and then apply union bound. Since smρ2sm\geq\rho^{2}, with probability at least 1eΩ(ρ2)1-e^{-\Omega(\rho^{2})}, for all zz in this range, it satisfies

yy has at most O(sm)O(sm) coordinates kk with yksm|y_{k}|\leq\frac{s}{\sqrt{m}}.

The “intermediate layers” part of “basic properties at random initialization” of in fact shows (e.g. their Claim B.12 of version 3) that, for a fixed unit vector uu, with probability at least 1eΩ(m/L2)1-e^{-\Omega(m/L^{2})}, it satisfies

Further using the randomness of BB, we have with probability at least 1eΩ(ρ2)1-e^{-\Omega(\rho^{2})},

This finishes the proof after plugging in u=eku=\mathbf{e}_{k}.

Using the same as above, we have with probability at least 1/21/2,

This finishes the proof after plugging in u=eku=\mathbf{e}_{k} for all k[m]k\in[m] and taking Chernoff bound.

This is similar to the proof of Lemma lem:done1f but requires a careful ϵ\epsilon-net argument. It is already included in the “intermediate layers: spectral norm” part of the “Basic Properties at Random Initialization” of (e.g. Lemma B.11 in version 3).

This can be proved in the same way as Lemma lem:done1f. It is already included in the “intermediate layers: sparse spectral norm” part of the “Basic Properties at Random Initialization” of (e.g. Lemma B.14 in version 3).

See the “forward correlation” part of (e.g. Lemma B.6 of version 3).

Appendix C Random Initialization: Backward Correlation

For every εx[0,1/L]\varepsilon_{x}\in[0,1/L], every fixed normalized input sequence x1,,xLx_{1},\dots,x_{L}, with probability at least 1eΩ(ρ2)1-e^{-\Omega(\rho^{2})} over W,A,BW,A,B: for every 1ij<jL1\leq i\leq j<j^{\prime}\leq L:

Since Backij2\|\operatorname{\mathsf{Back}}_{i\to j}\|_{2} is on the magnitude of m\sqrt{m}, the above Lemma C.1 says that the two vectors uBackiju^{\top}\operatorname{\mathsf{Back}}_{i\to j} and vBackijv^{\top}\operatorname{\mathsf{Back}}_{i\to j^{\prime}} are very uncorrelated whenever jjj\neq j^{\prime}.

In fact, one can prove the same Lemma C.1 for the un-correlation between uBackiju^{\top}\operatorname{\mathsf{Back}}_{i\to j} and vBackijv^{\top}\operatorname{\mathsf{Back}}_{i^{\prime}\to j^{\prime}} whenever jijij-i\neq j^{\prime}-i^{\prime}. We do not need that stronger version in this paper.

so it suffices to bound the absolute value of \sum_{p\in[m]}\Xi_{p}=\big{\langle}u^{\top}\operatorname{\mathsf{Back}}_{i\to j},v^{\top}\operatorname{\mathsf{Back}}_{i\to j^{\prime}}\big{\rangle}.

Let us fix NN coordinates (without loss of generality the first NN coordinates) and calculate p[N]Ξp\sum_{p\in[N]}\Xi_{p} only over [N][N] by induction. Define

where the first is due to Lemma lem:done1e. Let

Above, ① is by the definition Ui=GS(h1,,hi)U_{i}=\mathsf{GS}(h_{1},\dots,h_{i}), and ② is because for each fixed unit vector uu, we have u,ξpO(ρ/m)|\langle u,\xi_{p}\rangle|\leq O(\rho/\sqrt{m}) with probability at least 1eΩ(ρ2)1-e^{-\Omega(\rho^{2})}.

We consider the two terms on the right hand side separately:

Now, suppose for a moment that we view aa and bb as fixed. Then, it is a simple exercise to verify that with probability 1exp(Ω(m/L2))1-\exp(-\Omega(m/L^{2})) (over the randomness of M,AM,A),See for instance Claim B.13 of version 3 of .

Combining the above two properties and using induction, we have

Above, ① is because for every fixed vector xx, with probability at least 1eΩ(ρ2)1-e^{-\Omega(\rho^{2})} over BB it satisfies uBxO(ρx)|u^{\top}Bx|\leq O(\rho\|x\|), and therefore we have (similarly if we replace uu with vv)

To bound \clubsuit, we note that the following 2N2N vectors

are pairwise orthogonal, and therefore, when left-multiplied with matrix BB, their randomness (over BB) are independent. This means, |\clubsuit|\leq O\big{(}\sqrt{N}\rho^{2}\big{)} with probability at least 1eΩ(ρ2)1-e^{-\Omega(\rho^{2})} over BB. Choosing N=m1/2N=m^{1/2} we have

Finally, we can divide all the mm coordinates into m\sqrt{m} chunks each of size m\sqrt{m}. Performing the above calculation for each of them gives the desired bound. ∎

In this section we consider two (normalized) input sequences. The first sequence x(0)x^{(0)} is given as

The second sequence xx is generated from an arbitrary x=(x2,,xL1)x^{\star}=(x^{\star}_{2},\dots,x^{\star}_{L-1}) in the support of D\mathcal{D}:

We study the following two executions of RNNs under input x(0)x^{(0)} and xx respectively:

We emphasize that Lemma lem:dropping_xsa is technically the most involved, and the remaining two properties Lemma lem:dropping_xsb and Lemma lem:dropping_xsc are simple corollaries.

Let z^0=z0/z0\widehat{z}_{0}=z_{0}/\|z_{0}\|, then we can write

Using the concentration, we know with probability at least 1eΩ(ρ2)1-e^{-\Omega(\rho^{2})}, for any fixed z,z0z,z_{0} satisfying z,z0[0.5,3]\|z\|,\|z_{0}\|\in[0.5,3] and zz00.1\|z-z_{0}\|\leq 0.1,

Using again z=ζn±O(ρ2m)\|z\|=\zeta_{n}\pm O(\frac{\rho^{2}}{\sqrt{m}}) and the Lipschitz continuity property of ζc(x)\zeta_{c}(\sqrt{x}) from Proposition D.2, we have

Recall β3ζc(β)β22β44-|\beta|^{3}\leq\zeta_{c}(\beta)-\frac{\beta^{2}}{2}\leq\frac{\beta^{4}}{4} from Proposition D.2. This means

D.2 Proof of Lemma lem:dropping_xsb and lem:dropping_xsc

Euclidean norm uB(WDW)DO(ρs)uO(ρ7/6εx1/3m1/2)u\|u^{\top}B(WD\cdots W)D^{\prime}\|\leq O(\rho\sqrt{s})\|u\|\leq O(\rho^{7/6}\varepsilon_{x}^{1/3}m^{1/2})\|u\| from Lemma lem:done1f.

Spectral norm D(WDW)D2O(ρs/m)1100L\|D^{\prime}(WD\cdots W)D^{\prime}\|_{2}\leq O(\rho\sqrt{s/m})\leq\frac{1}{100L} from Lemma lem:done1j.

Spectral norm D(WDW)2O(L3)\|D^{\prime}(WD\cdots W)\|_{2}\leq O(L^{3}) from Lemma lem:done1i.

D.3 Mathematical Tools

Let g1,g2g_{1},g_{2} be two independent standard Gaussian random variable N(0,1)\mathcal{N}(0,1), and let parameters β[34,34]\beta\in[-\frac{3}{4},\frac{3}{4}] and \alpha=\sqrt{1-\beta^{2}}\in\big{[}\frac{3}{4},\frac{5}{4}\big{]}. Define

For β[34,34]\beta\in[-\frac{3}{4},\frac{3}{4}] we have β3ζc(β)β22β44-|\beta|^{3}\leq\zeta_{c}(\beta)-\frac{\beta^{2}}{2}\leq\frac{\beta^{4}}{4}.

Over x[0.05,0.05]x\in[-0.05,0.05], the function ζc(x)\zeta_{c}(\sqrt{x}) is 12\frac{1}{2}-Lipschitz continuous over xx.

At least for the range of \alpha\in\big{[}\frac{3}{4},\frac{5}{4}\big{]} and β0\beta\geq 0, one can exactly integrate out this squared difference over two Gaussian variables. For instance, the “δ\delta-separateness” part of “properties at random initialization” of (e.g. Claim A.6 of their version 3) has already done this for us:

It is easy to see that, as long as βα\beta\leq\alpha, we always have (α+k)β2k+1(2k+1)α2k+1(α+k+1)β2k+3(2k+3)α2k+3\frac{(\alpha+k)\beta^{2k+1}}{(2k+1)\alpha^{2k+1}}\geq\frac{(\alpha+k+1)\beta^{2k+3}}{(2k+3)\alpha^{2k+3}}. Therefore

For the value approximation, we have 1αζc(α,β)1αβ31-\alpha\geq\zeta_{c}(\alpha,\beta)\geq 1-\alpha-\beta^{3} from the above formula, and thus plugging α=1β2\alpha=\sqrt{1-\beta^{2}}, we have

Taking derivative with respect to xx, we have

It is easy to verify that for all β[0,0.9]\beta\in[0,0.9]:

This proves the 0.50.5-Lipschitz continuity over xx.

If ss can be optimally chosen to minimize the above right hand side,

Appendix E Stability: After Re-Randomization

In this section we study a scenario where we re-randomize a fixed set of rows in the random initialization matrices WW and AA. Formally, consider a fixed set N[m]{\mathcal{N}}\subseteq[m] with cardinality N=NN=|\mathcal{N}|. Define

There are many terms in this difference, and we treat them separately below.

We have that β00N\|\beta^{\prime}_{0}\|_{0}\leq N, β0O(ρNm+ρ2smτ1)\|\beta^{\prime}_{0}\|\leq O(\frac{\rho\sqrt{N}}{\sqrt{m}}+\frac{\rho^{2}\sqrt{s}}{\sqrt{m}}\tau_{1}), β10s+N\|\beta^{\prime}_{1}\|_{0}\leq s+N, and β123τ0+O(ρ2smτ1)\|\beta^{\prime}_{1}\|_{2}\leq 3\tau_{0}+O(\frac{\rho^{2}\sqrt{s}}{\sqrt{m}}\tau_{1}).

Finally, we choose τ0=Θ(ρNm+ρ2smτ1)\tau_{0}=\Theta(\frac{\rho\sqrt{N}}{\sqrt{m}}+\frac{\rho^{2}\sqrt{s}}{\sqrt{m}}\tau_{1}) and τ1=Θ(ρNm)\tau_{1}=\Theta(\frac{\rho\sqrt{N}}{\sqrt{m}}) to satisfy (using Nm/ρ23N\leq m/\rho^{23})

and compute its Euclidean norm. One can in fact expand out all the (exponentially many) difference terms and bound them separately.

If DD^{\prime} shows up once and WW^{\prime} never shows up, then the term is

We have (DW)a2O(L3)\|(DW)^{a}\|_{2}\leq O(L^{3}) by Lemma lem:done1i and D(WD)bWekO(ρs/m)\|D^{\prime}(WD)^{b}W\mathbf{e}_{k}\|\leq O(\rho\sqrt{s/m}) by Lemma lem:done1j. Therefore, its absolute value is at most O(ρ4s/m)O(\rho^{4}\sqrt{s/m}), and there are at most LL such terms.

If W=DN(W~W)W^{\prime}=D_{\mathcal{N}}(\widetilde{W}-W) shows up once and DD^{\prime} never shows up, then the term is

We have (DW)aDO(L3)\|(DW)^{a}D\|\leq O(L^{3}) by Lemma lem:done1i and DNW(DW)bekO(ρs/m)\|D_{\mathcal{N}}W^{\prime}(DW)^{b}\mathbf{e}_{k}\|\leq O(\rho\sqrt{s/m}) by Lemma lem:done1j. (We also have DNW~(DW)bekO(ρs/m)\|D_{\mathcal{N}}\widetilde{W}(DW)^{b}\mathbf{e}_{k}\|\leq O(\rho\sqrt{s/m}) but this is much easier to prove because W~\widetilde{W} is fresh new random.) Therefore, its absolute value is at most O(ρ4s/m)O(\rho^{4}\sqrt{s/m}), and there are at most LL such terms.

If the total number of times DD^{\prime} and WW^{\prime} show up is 22, then the occurance of DD^{\prime} and WW^{\prime} divides the difference term into three consecutive parts. As before, the norm of the first and the last parts are at most O(L3)O(L^{3}) and O(ρs/m)O(\rho\sqrt{s/m}) respectively, so it suffices to bound the matrix spectral norm of the middle part. There are four possibilities for this middle part:

DN(W~W)DWWDD_{\mathcal{N}}(\widetilde{W}-W)DW\cdots WD^{\prime}.

DN(W~W)DWWDND_{\mathcal{N}}(\widetilde{W}-W)DW\cdots WD_{\mathcal{N}}.

All of such matrices have spectral norm at most O(ρs/m)1100LO(\rho\sqrt{s/m})\leq\frac{1}{100L} by Lemma lem:done1j because DD^{\prime} and DND_{\mathcal{N}} are both ss-sparse. Therefore, although there are at most (2L)2(2L)^{2} such difference terms, each of them is at most 1100LO(ρ4s/m)\frac{1}{100L}\cdot O(\rho^{4}\sqrt{s/m}) in magnitude. Therefore, their total contribution is negligible when comparing to cases (1) and (2).

If the total number of times DD^{\prime} and WW^{\prime} show up is 33, then there are at most (2L)3(2L)^{3} such terms and each of them is at most 1(100L)2O(ρ4s/m)\frac{1}{(100L)^{2}}\cdot O(\rho^{4}\sqrt{s/m}) in magnitude.

and finally using the randomness of BB we have with probability at least 1eΩ(ρ2)1-e^{-\Omega(\rho^{2})}

Appendix F Stability: After Adversarial Perturbation

In this section we study a scenario where the random initialization matrix WW is perturbed to W+WW+W^{\prime} with WW^{\prime} being small in spectral norm. Intuitively, this WW^{\prime} will later capture how much SGD has moved away from the random initialization, so it may depend on the randomness of W,A,BW,A,B. To untangle this possibly complicated correlation, we consider stability with respect to all WW^{\prime} being small. The following lemma has appeared already in the “Stability After Adversarial Perturbation” section of .

Let W,A,BW,A,B be at random initialization, xx be a fixed normalized input sequence, and Δ[ϱ100,ϱ100]\Delta\in[\varrho^{-100},\varrho^{100}]. With probability at least 1eΩ(ρ)1-e^{-\Omega(\rho)} over the randomness of W,A,BW,A,B, the following holds. Given any perturbation matrix WW^{\prime} with W2Δm\|W^{\prime}\|_{2}\leq\frac{\Delta}{\sqrt{m}}, letting

hiO(ρ6Δ/m)\|h_{i}^{\prime}\|\leq O(\rho^{6}\Delta/\sqrt{m}) for every i[L]i\in[L] (forward stability)

Di0O(ρ4Δ2/3m2/3)\|D^{\prime}_{i}\|_{0}\leq O(\rho^{4}\Delta^{2/3}m^{2/3}) for every i[L]i\in[L] (sign change)

\|\operatorname{\mathsf{Back}}_{i\to j}^{\prime}\|_{2}\leq O\big{(}\Delta^{1/3}\rho^{6}m^{1/3}\big{)} for every 1ijL1\leq i\leq j\leq L (backward stability)

Specifically, Lemma lem:stability:adva and lem:stability:advb can be found in Lemma C.2 and Lemma lem:stability:advc can be found in Lemma C.9 of [2, ver.3].

Appendix G Proof for Section 5

We now revise (G.1) in two ways without changing much of its original proof.

First, the above b0N(0,1)b_{0}\sim\mathcal{N}(0,1) is quite an arbitrary choice in their proof, and can be replaced with any other b0N(0,τ2)b_{0}\sim\mathcal{N}(0,\tau^{2}) for constant τ(0,1]\tau\in(0,1]. They have constructed HΨH^{\Psi} by first expanding Ψ\Psi into its Taylor expansions, and then approximating each term xix^{i} with hi(z)h_{i}(z)— probabilists’ Hermite polynomial of degree ii— with z=v,vz=\langle v,v^{*}\rangle. Only the coefficient in front of each hi(z)h_{i}(z), namely cic^{\prime}_{i} in their Eq. (117) depends on the choice of τ\tau, and cic^{\prime}_{i} decreases from its original value as τ\tau decreases from 11. Therefore, their final construction of HΨH^{\Psi} will only have a smaller magnitude in these coefficients so (G.1) remains unchanged if we choose τ=13+4σ2\tau=\frac{1}{\sqrt{3+4\sigma^{2}}}.

Parameter Choices. Having restated (G.2) and (G.3) from the prior work, let us choose parameters to apply them. Let us separate out the last coordinate for these vectors. Suppose

vN(0,I)v\sim\mathcal{N}(0,\mathbf{I}), y=1\|y\|=1, b0N(0,13+4σ2)b_{0}\sim\mathcal{N}(0,\frac{1}{3+4\sigma^{2}})

\langle v,y\rangle+b_{0}=\frac{1}{\sqrt{\sigma^{2}+3/4}}\big{(}\langle a_{\triangleleft},x^{\star}_{\triangleleft}\rangle+n+\frac{a_{\triangleright}}{2}\big{)}=\frac{1}{\sqrt{\sigma^{2}+3/4}}\big{(}\langle a,x^{\star}\rangle+n\big{)}

Φ(v,y)=Ψ(σ2+3/4v,y)=Φ(w,x)=Φ(w,x)\Phi(\langle v^{*},y\rangle)=\Psi(\sqrt{\sigma^{2}+3/4}\langle v^{*},y\rangle)=\Phi(\langle w^{*}_{\triangleleft},x^{\star}_{\triangleleft}\rangle)=\Phi(\langle w^{*},x^{\star}\rangle)

v,v=w,a\langle v,v^{*}\rangle=\langle w^{*}_{\triangleleft},a_{\triangleleft}\rangle

The first statement above finishes the proof of Lemma lem:fit_fun_olda. We note that Cεe(Ψ,1)=Cεe(Φ,σ)\mathfrak{C}_{\varepsilon_{e}}(\Psi,1)=\mathfrak{C}_{\varepsilon_{e}}(\Phi,\sigma) because Φ\Phi is re-scaled from Ψ\Psi by σ2+3/4O(σ)\sqrt{\sigma^{2}+3/4}\leq O(\sigma).

Off Target. For Lemma lem:fit_fun_oldb, we can derive the following

Above, ① is because a,x>log(γσ)|\langle a,x^{\star}\rangle|>\sqrt{\log(\gamma\sigma)} with probability at most O\big{(}\frac{1}{\gamma\sigma}\big{)} and HC|H|\leq C^{\prime}; ② uses HC|H|\leq C^{\prime}. ∎

G.2 Missing Proof of Lemma 5.2

Using Lemma lem:done1a and Lemma lem:dropping_xsa, we can write (abbreviating by ζn=ζn(εx,i1)\zeta_{n}=\zeta_{n}(\varepsilon_{x},i-1) and ζd=ζd(εx,i1)\zeta_{d}=\zeta_{d}(\varepsilon_{x},i-1)),

Note that by Lemma lem:done1a and Lemma lem:dropping_xsa we have

to get the HH with HC|H|\leq C^{\prime}, and let us define

A standard property of Gaussian random variable shows that:

We can decompose hi1h_{i-1} into the projection on hi1(0)h_{i-1}^{(0)} and on the perpendicular space of hi1(0)h_{i-1}^{(0)}. We can write (g~i)k(\widetilde{g}_{i})_{k} as:

By the fact that hi1(0),hi1[0.9,2]\|h^{(0)}_{i-1}\|,\|h_{i-1}\|\in[0.9,2] from Lemma lem:done1a, we know that \big{|}\frac{\langle h^{(0)}_{i-1},h_{i-1}\rangle}{\|h^{(0)}_{i-1}\|^{2}}\big{|}\leq 3. As a result, conditioning on w~k,hi1(0)εcm|\langle\widetilde{w}_{k},h_{i-1}^{(0)}\rangle|\leq\frac{\varepsilon_{c}}{\sqrt{m}}, we have

Since \langle\widetilde{a}_{k},x_{i}\rangle\sim\mathcal{N}\big{(}0,\frac{2\varepsilon_{x}^{2}}{m}\big{)} is independent of the randomness of w~k\widetilde{w}_{k}, equation (G.5) implies

Above, ① uses (G.6) and HC|H|\leq C^{\prime}.

Next, recall from (G.4) that |\|(I-\widehat{h}\widehat{h}^{\top})h_{i-1}\|-\tau|=O\big{(}\frac{\rho^{2}}{\sqrt{m}}\big{)}. As a result, with probability at least 1eΩ(ρ2)1-e^{-\Omega(\rho^{2})} over w~k\widetilde{w}_{k},

Above, ① is because \big{\langle}\widetilde{w}_{k},\frac{(I-\widehat{h}\widehat{h}^{\top})h_{i-1}}{\|(I-\widehat{h}\widehat{h}^{\top})h_{i-1}\|}\big{\rangle}\sim\mathcal{N}\big{(}0,\frac{2}{m}\big{)}, so with probability at least 1eΩ(ρ2)1-e^{-\Omega(\rho^{2})} over w~k\widetilde{w}_{k} it is at most O(ρ/m)O(\rho/\sqrt{m}). Using (G.8), together with a similar argument to (G.6), we have

Above, ① is because of (G.9) and the definition of nkn_{k}; and ② is because of Lemma 5.1 with xi=(εxxi,0)x_{i}=(\varepsilon_{x}x^{\star}_{i},0) and σ=τεx=O(L)\sigma=\frac{\tau}{\varepsilon_{x}}=O(\sqrt{L}) and re-scaling. Putting together (G.7) and (G.10) finish the proof for the on target part.

We again decompose hi1h_{i^{\prime}-1} into the projection on hi1(0)h_{i-1}^{(0)} and on the perpendicular space of hi1(0)h_{i-1}^{(0)}. We can write (g~i)k(\widetilde{g}_{i^{\prime}})_{k} as:

By the fact that hi1(0),hi1[0.9,2]\|h^{(0)}_{i-1}\|,\|h_{i^{\prime}-1}\|\in[0.9,2] from Lemma lem:done1a, we again know \big{|}\frac{\langle h^{(0)}_{i-1},h_{i^{\prime}-1}\rangle}{\|h^{(0)}_{i-1}\|^{2}}\big{|}\leq 3. Therefore, the same derivation of (G.7) implies

Using Lemma lem:fit_fun_oldb, we have for every γ>1\gamma>1,

Above, ① is by applying Lemma lem:fit_fun_oldb after re-scaling and Φij,r,s(0)=0\Phi_{i\to j,r,s}(0)=0.

Using Lemma lem:done1k, we know for i>ii^{\prime}>i with high probability \|(I-\widehat{h}\widehat{h}^{\top})h^{(0)}_{i^{\prime}-1}\|\geq\Omega\big{(}\frac{1}{L^{2}\log^{3}m}\big{)}. (One can argue similarly with the help of Lemma B.1 for i<ii^{\prime}<i.) Thus, by the closeness property hi1hi1(0)O(Lεx)\|h_{i^{\prime}-1}-h^{(0)}_{i^{\prime}-1}\|\leq O(\sqrt{L}\varepsilon_{x}) from Lemma lem:dropping_xsa, we also have with probability at least 1eΩ(ρ2)1-e^{-\Omega(\rho^{2})} over the randomness of W,AW,A,

Putting this into (G.12), we have with probability at least 1eΩ(ρ2)1-e^{-\Omega(\rho^{2})},

Combining this with (G.11), and using our choice of εc\varepsilon_{c} finishes the proof of for the off target part.

G.3 Missing Proof of Lemma 5.4

Let us choose s=O(ρ5N2/3/m1/6)s=O(\rho^{5}N^{2/3}/m^{1/6}), and by Lemma lem:done1d, we have with probability 1eΩ(ρ2)1-e^{-\Omega(\rho^{2})} over the randomness of W,AW,A,

Using the randomness of N\mathcal{N}, we have with probability 1eΩ(ρ2)1-e^{-\Omega(\rho^{2})} over the randomness of W,A,NW,A,\mathcal{N},

Let us define N1\mathcal{N}_{1} be the above subset. Since [gi]k=wk,hi1+ak,xi[g_{i^{\prime}}]_{k}=\langle w_{k},h_{i^{\prime}-1}\rangle+\langle a_{k},x_{i^{\prime}}\rangle and since wk,hi1wk,h~i1sm|\langle w_{k},h_{i^{\prime}-1}\rangle-\langle w_{k},\widetilde{h}_{i^{\prime}-1}\rangle|\leq\frac{s}{\sqrt{m}} from (G.15), we have

For a similar reason, using h~i1(0)[0.5,3]\|\widetilde{h}_{i-1}^{(0)}\|\in[0.5,3] from Lemma lem:done1a and the independence between wkw_{k} and hi1(0)h^{(0)}_{i-1}, we know with probability 1eΩ(ρ2)1-e^{-\Omega(\rho^{2})} over the randomness of W,A,W~,A~,NW,A,\widetilde{W},\widetilde{A},\mathcal{N},

Let us define N2\mathcal{N}_{2} be the above subset. Using wk,hi1(0)h~i1(0)sm|\langle w_{k},h^{(0)}_{i-1}-\widetilde{h}^{(0)}_{i-1}\rangle|\leq\frac{s}{\sqrt{m}} from (G.15) again, we have

Together, we have N1N2O(sN+ρ2)=O(ρ5N5/3/m1/6+ρ2)|\mathcal{N}_{1}\cup\mathcal{N}_{2}|\leq O(sN+\rho^{2})=O(\rho^{5}N^{5/3}/m^{1/6}+\rho^{2}). Let us choose

Let us now fix i,i,j,j,s,s,ri,i^{\prime},j,j^{\prime},s,s^{\prime},r. Summing over kNN1N2k\in\mathcal{N}\setminus\mathcal{N}_{1}\cup\mathcal{N}_{2}, we have with probability 1eΩ(ρ2)1-e^{-\Omega(\rho^{2})},

Thus, putting together the two summations we have

Now, we consider fixed W~,A~,N\widetilde{W},\widetilde{A},\mathcal{N} and only use the randomness of {wk,ak}kN\{w_{k},a_{k}\}_{k\in\mathcal{N}} to analyze G~i,i,j,j,r,s,s\widetilde{G}_{i,i^{\prime},j,j^{\prime},r,s,s^{\prime}}. It is a summation of NN independent random variables. By Chernoff bound, with probability 1eΩ(ρ2)1-e^{-\Omega(\rho^{2})},

By Lemma 5.2 we know that the expectation is given by:

Using (G.14) and (G.16) again, we can bound

Therefore, so far we have calculated that with probability at least 1eΩ(ρ2)1-e^{-\Omega(\rho^{2})} over the randomness of W,A,NW,A,\mathcal{N}:

Summing it up, and using Cij,sΩ(1d)C_{i\to j,s}\geq\Omega(\frac{1}{d}) from Claim claim:Ws-constructb, we can write

with error=pmO(εeρ5m+Cmρ11(N/m)1/6+CmNρ7+Cρ5mN)error=\frac{p}{m}\cdot O(\varepsilon_{e}\rho^{5}m+Cm\rho^{11}(N/m)^{1/6}+C\frac{m}{N}\rho^{7}+C\rho^{5}\frac{m}{\sqrt{N}}).

When sss\neq s^{\prime}, using the randomness of BB, it is easy to see with probability at least 1eΩ(ρ2)1-e^{-\Omega(\rho^{2})}

where the last equality is due to Lemma lem:done1i and B2O(ρm)\|B\|_{2}\leq O(\rho\sqrt{m}). Therefore, we can write

Above, ① uses our choice N=Θ(m0.1/ρ2)N=\Theta(m^{0.1}/\rho^{2}).

With probability at least 1eΩ(ρ2)1-e^{-\Omega(\rho^{2})} over the random initialization W,A,BW,A,B, we have

\|W^{\divideontimes}\|_{2,\infty}\leq O\big{(}\frac{p\rho^{3}C}{m}\big{)} so \|W^{\divideontimes}\|_{F}\leq O\big{(}\frac{p\rho^{3}C}{\sqrt{m}}\big{)}.

Ω(1d)Cij,sO(ρ2d)\Omega(\frac{1}{d})\leq C_{i\to j,s}\leq O(\frac{\rho^{2}}{d}) for every 1ijL1\leq i\leq j\leq L and s[d]s\in[d].

where the notion CC comes from Definition 5.3.

For the norm bound on WW^{\divideontimes}, note that for any unit vector zz, we have

and as a result \|W^{\divideontimes}\|_{F}\leq O\big{(}\frac{p\rho^{3}C}{\sqrt{m}}\big{)}\enspace. Lemma lem:done1a and Lemma lem:done1g together imply Cij,sΩ(1d)C_{i\to j,s}\geq\Omega(\frac{1}{d}). Lemma lem:done1a and Lemma lem:done1f together imply Cij,sO(ρ2d)C_{i\to j,s}\leq O(\frac{\rho^{2}}{d}). ∎

Appendix H Proof for Section 6

BDWDWD2O(ρD0)O(ρ3Δ1/3m1/3)\|BDW\cdots DWD^{\prime}\|_{2}\leq O(\rho\sqrt{\|D^{\prime}\|_{0}})\leq O(\rho^{3}\Delta^{1/3}m^{1/3}) by Lemma lem:done1f

DWDDWD2O(ρD0m)1100L\|D^{\prime}WD\cdots DWD^{\prime}\|_{2}\leq O(\rho\frac{\sqrt{\|D^{\prime}\|_{0}}}{\sqrt{m}})\leq\frac{1}{100L} by Lemma lem:done1j.

\|(WD\cdots WD)W^{\prime}(h+h^{\prime})\|\leq\|WD\cdots WD\|_{2}\|W^{\prime}\|_{2}\|h+h^{\prime}\|\leq O\big{(}\frac{L^{3}\Delta}{\sqrt{m}}\big{)} by Lemma lem:done1i (and h+hO(1)\|h+h^{\prime}\|\leq O(1) using Lemma lem:done1a with Lemma lem:stability:adva).

Putting (H.2) and (H.3) back to (H.1) finishes the proof. ∎

H.2 Proof of Lemma 6.2

We analyze the three error terms separately.

\|\clubsuit_{i^{\prime}}\|\leq\|\operatorname{\mathsf{Back}}^{\prime}_{i^{\prime}\to j^{\prime}}\|_{2}\cdot\|\widetilde{W}\|_{2}\cdot O(1)\leq O\big{(}\frac{\omega\rho^{6}\Delta^{1/3}}{m^{1/6}}\big{)} using Lemma lem:stability:advc.

\|\spadesuit_{i^{\prime}}\|\leq\|\operatorname{\mathsf{Back}}_{i^{\prime}\to j^{\prime}}\|_{2}\cdot\|\widetilde{W}\|_{2}\cdot\|h^{\prime}_{i^{\prime}}\|\leq O\big{(}\frac{\omega\rho^{7}\Delta}{\sqrt{m}}\big{)} using Lemma lem:done1f and Lemma lem:stability:adva.

\|\diamondsuit_{i^{\prime}}\|\leq\|\operatorname{\mathsf{Back}}_{i^{\prime}\to j^{\prime}}D^{\prime}_{i^{\prime}}\|_{2}\cdot\|\widetilde{W}\|_{2}\cdot O(1)\leq O\big{(}\frac{\omega\rho^{2}\Delta^{1/3}}{m^{1/6}}\big{)} using Lemma lem:done1f and the bound Di0\|D^{\prime}_{i^{\prime}}\|_{0} from Lemma lem:stability:advb. ∎

Appendix I Proof for Lemma 7.1

Since it satisfies Fj(x)O(pLdCs(Φ,1))\|F^{*}_{j}(x^{\star})\|\leq O(\sqrt{pLd}\mathfrak{C}_{\mathfrak{s}}(\Phi,1)) for all xx^{\star}, by the 1-Lipschitz continuity of G(,y)G(\cdot,y^{\star}), we also have G(Fj(x),yj)O(pLdCs(Φ,1))|G(F^{*}_{j}(x^{\star}),y^{\star}_{j})|\leq O(\sqrt{pLd}\mathfrak{C}_{\mathfrak{s}}(\Phi,1)). Therefore, by Chernoff bound, as long as NΩ(ρ3pCs2(Φ,1)ε2)N\geq\Omega(\frac{\rho^{3}\cdot p\cdot\mathfrak{C}_{\mathfrak{s}}^{2}(\Phi,1)}{\varepsilon^{2}}) we also have

Let xx be a normalized input sequence generated by some xx^{\star} in the support of D\mathcal{D}. Consider an iteration tt where the current weight matrix is W+WtW+W_{t}. Let

which is a linear function over W~\widetilde{W}. Let us define a loss function G~\widetilde{G} as:

Let WW^{\divideontimes} be defined in Definition 5.3. By Lemma 6.3, we know that as long as Wt2Δm\|W_{t}\|_{2}\leq\frac{\Delta}{\sqrt{m}} (for some parameter Δ[ϱ100,ϱ100]\Delta\in[\varrho^{-100},\varrho^{100}] we shall choose at the end),

where the last inequality is by choosing εe=Θ(εpρ13)\varepsilon_{e}=\Theta(\frac{\varepsilon}{p\rho^{13}}) and εx=1\poly(ρ,p,ε1,C)\varepsilon_{x}=\frac{1}{\poly(\rho,p,\varepsilon^{-1},C^{\prime})} and sufficiently large mm. Taking union bound over all samples in Z\mathcal{Z}, by the 1-Lipschitz continuity of GG, we have

Using 1-Lipschitz continuity of GG, we have

where ① is by our choice of λ\lambda which implies λFj(x;W)λO(ρ)ε10L\lambda\|F_{j}(x^{\star};W)\|\leq\lambda\cdot O(\rho)\leq\frac{\varepsilon}{10L} by Lemma lem:done1h.

Thus, by the convexity of G~(W~)\widetilde{G}(\widetilde{W}) (composing convex function with linear function is convex), we know

Finally, recall that SGD takes a stochastic gradient so

On one hand, we have (for any (x,y)D(x^{\star},y^{\star})\sim\mathcal{D})

Above, ① uses the 11-Lipschitz continuity of GG, and ② uses Lemma B.1 and the choice of λεLρ\lambda\leq\frac{\varepsilon}{L\rho}.

Telescoping over all t=0,1,,T1t=0,1,\dots,T-1, we have

so we can choose \Delta=\Theta\big{(}\frac{C^{2}\rho^{11}p^{2}}{\varepsilon^{2}}\big{)}.

Finally, we can replace the notation Cεe\mathfrak{C}_{\varepsilon_{e}} with Cε\mathfrak{C}_{\varepsilon} because log(1/εe)=O(log(1/ε))\log(1/\varepsilon_{e})=O(\log(1/\varepsilon)). ∎

References