Can SGD Learn Recurrent Neural Networks with Provable Generalization?
Zeyuan Allen-Zhu, Yuanzhi Li
Introduction
Recurrent neural networks (RNNs) is one of the most popular models in sequential data analysis . When processing an input sequence, RNNs repeatedly and sequentially apply the same operation to each input token. The recurrent structure of RNNs allows it to capture the dependencies among different tokens inside each sequence, which is empirically shown to be effective in many applications such as natural language processing , speech recognition and so on.
The recurrent structure in RNNs shows great power in practice, however, it also imposes great challenge in theory. Until now, RNNs remains to be one of the least theoretical understood models in deep learning. Many fundamental open questions are still largely unsolved in RNNs, including
(Optimization). When can RNNs be trained efficiently?
(Generalization). When do the results learned by RNNs generalize to test data?
Question 1 is technically challenging due to the notorious question of vanishing/exploding gradients, and the non-convexity of the training objective induced by non-linear activation functions.
Question 2 requires even deeper understanding of RNNs. For example, in natural language processing, “Juventus beats Bacerlona” and “Bacerlona beats Juventus” have completely different meanings. How can the same operation in RNN encode a different rule for “Juventus” at token 1 vs. “Juventus” at token 3, instead of merely memorizing each training example?
There have been some recent progress towards obtaining more principled understandings of these questions.
On the optimization side, Hardt, Ma, and Recht show that over-parameterization can help in the training process of a linear dynamic system, which is a special case of RNNs with linear activation functions. Allen-Zhu, Li, and Song show that over-parameterization also helps in training RNNs with ReLU activations. This latter result gives no generalization guarantee.
Indeed, bridging the gap between optimization (question 1) and generalization (question 2) can be quite challenging in neural networks. The case of RNN is particularly so due to the (potentially) exponential blowup in input length.
Generalization Optimization. One could imagine adding a strong regularizer to ensure for generalization purpose; however, it is unclear how an optimization algorithm such as stochastic gradient descent (SGD) finds a network that both minimizes training loss and maintains . One could also use a very small network so the number of parameters is limited; however, it is not clear how SGD finds a small network with small training loss.
Optimization Generalization. One could try to train RNNs without any regularization; however, it is then quite possible that the number of parameters need to be large and after the training. This is so both in practice (since “memory implies larger spectral radius” ) and in theory . All known generalization bounds fail to apply in this regime.
In this paper, we give arguably the first theoretical analysis of RNNs that captures optimization and generalization simultaneously. Given any set of input sequences, as long as the outputs are (approximately) realizable by some smooth function in a certain concept class, then after training a vanilla RNN with ReLU activations, SGD provably finds a solution that has both small training and generalization error. Our result allows to be larger than by a constant, but is still efficient: meaning that the iteration complexity of the SGD, the sample complexity, and the time complexity scale only polynomially (or almost polynomially) with the length of the input.
Notations
Note that in the occasion that is the zero vector, we let be an arbitrary unit vector that is orthogonal to .
where is a sufficiently large constant (e.g., ). It holds , and for or low degree polynomials, they only differ by .
Problem Formulation
where is a parameter to be chosen later. We then feed this actual sequence into RNN.
In this way we have ensured that the actual input sequence is normalized:
We say that are at random initialization, if the entries of and are i.i.d. generated from , and the entries of are i.i.d. generated from .
Since we only update , the label sequence is off from the input sequence by one. The last can be made zero, but we keep it normalized for notational simplicity. The first gives a random seed fed into the RNN (one can equivalently put it into ). We have scaled down the input signals by , which can be equivalently thought as scaling down .
2 Concept Class
For proof simplicity, we assume . We also use
Agnostic PAC-learning language. Our concept class consists of all functions in the form of (3.1) with complexity bounded by threshold and parameter bounded by threshold . Let be the population risk achieved by the best target function in this concept class. Then, our goal is to learn this concept class with population risk using sample and time complexity polynomial in , and . In the remainder of this paper, to simplify notations, we do not explicitly define this concept class parameterized by and . Instead, we equivalently state our theorem with respect to any (unknown) target function with specific parameters and .
Our concept class is general enough and contains functions where the output at each token is generated from inputs of previous tokens using any two-layer neural network. Indeed, one can verify that our general form (3.1) includes functions of the following:
Our Result: RNN Provably Learns the Concept Class
Suppose the distribution is generated by some (unknown) target function of the form (3.1) in the concept class with population risk , namely,
and suppose we are given training dataset consisting of i.i.d. samples from . We consider the following stochastic training objective
For every 0<\varepsilon<\widetilde{O}\big{(}\frac{1}{\poly(L,d)\cdot p\cdot\mathfrak{C}_{\mathfrak{s}}(\Phi,O(\sqrt{L}))}\big{)}, define complexity and \lambda=\widetilde{\Theta}\big{(}\frac{\varepsilon}{L^{2}d}\big{)}, if the number of neurons and the number of samples is , then SGD with \eta=\widetilde{\Theta}\big{(}\frac{1}{\varepsilon L^{2}d^{2}m}\big{)} and
satisfies that, with probability at least over the random initialization
Sample complexity. Our sample complexity only scales with , making the result applicable to over-parameterized RNNs that have . Following Example 2.1, if is constant degree polynomial we have so Theorem 1 says that RNN learns such concept class
Non-linear measurements. Our result shows that vanilla RNNs can efficiently learn a weighted average of non-linear measurements of the input. As we argued in Example 3.5, this at least includes functions where the output at each token is generated from inputs of previous tokens using any two-layer neural networks. Average of non-linear measurements can be quite powerful, achieving the state-of-the-art performance in some sequential applications such as sentence embedding and many others , and acts as the base of attention mechanism in RNNs .
In our result, the function only adapts with the positions of the input tokens, but in many applications, we would like the function to adapt with the values of the past tokens as well. We believe a study on other models (such as LSTM ) can potentially settle these questions.
Comparison to feed-forward networks. Recently there are many interesting results on analyzing the learning process of feed-forward neural networks . Most of them either assume that the input is structured (e.g. Gaussian or separable) or only consider linear networks. Allen-Zhu, Li, and Liang show a result in the same flavor as this paper but for two and three-layer networks. Since RNNs apply the same unit repeatedly to each input token in a sequence, our analysis is significantly different from and creates lots of difficulties in the analysis.
2 Conclusion
We show RNN can actually learn some notable concept class efficiently, using simple SGD method with sample complexity polynomial or almost-polynomial in input length. This concept class at least includes functions where each output token is generated from inputs of earlier tokens using a smooth neural network. To the best of our knowledge, this is the first proof that some non-trivial concept class is efficiently learnable by RNN. Our sample complexity is almost independent of , making the result applicable to over-parameterized settings. On a separate note, our proof explains why the same recurrent unit is capable of learning various functions from different input tokens to different output tokens.
Our proof of Theorem 1 divides into four conceptual steps.
We obtain first-order approximation of how much the outputs of the RNN change if we move from to . This change (up to small error) is a linear function in . (See Section 6).
(This step can be derived from prior work without much difficulty.)
(This step is the most interesting part of this paper.)
We argue that the SGD method moves in a direction nearly as good as and thus efficiently decreases the training objective (see Section 7).
(This is a routine analysis of SGD in the non-convex setting given Steps 1&2.)
We use the first-order linear approximation to derive a Rademacher complexity bound that does not grow exponentially in (see Section 8). By feeding the output of SGD into this Rademacher complexity, we finish the proof of Theorem 1 (see Section 9).
(This is a one-paged proof given the Steps 1&2&3.)
Although our proofs are technical, to help the readers, we write 7 pages of sketch proofs for Steps 1 through 4. This can be found in Section 5 through 9. Our final proofs reply on many other technical properties of RNN that may be of independent interests: such as properties of RNN at random initialization (which we include in Section B and C), and properties of RNN stability (which we include in Section D, E, F). Some of these properties are simple modifications from prior work, but some are completely new and require new proof techniques (namely, Section C, D and E). We introduce some notations for analysis purpose.
Throughout the proofs, to simplify notations when specifying polynomial factors, we introduce
We assume for some sufficiently large polynomial factor.
Existence of Good Network Through Backward
Furthermore, is appropriately bounded in Frobenius norm. In our sketched proof below, it shall become clear how this same matrix can simultaneously represent functions that come from different input tokens . Since SGD can be shown to descend in a direction “comparable” to , it converges to a matrix with similar guarantees.
In order to show (5.2), we first show a variant of the “indicator to function” lemma from .
Above, Lemma lem:fit_fun_olda says that we can use a bounded function to fit a target function , and Lemma lem:fit_fun_oldb says that if the magnitude of is large then this function is close to being constant. For such reason, we can view as “noise.” While the proof of lem:fit_fun_olda is from prior work , our new property lem:fit_fun_oldb is completely new and it requires some technical challenge to simultaneously guarantee lem:fit_fun_olda and lem:fit_fun_oldb. The proof is in Appendix G.1
2 Fitting a Single Function
We now try to apply Lemma 5.1 to approximate a single function . For this purpose, let us consider two (normalized) input sequences. The first (null) sequence is given as
The second sequence is generated from an input in the support of (recall Definition 3.1). Let
For every , and every constant \varepsilon_{e}\in\big{(}0,\frac{1}{\mathfrak{C}_{\mathfrak{s}}(\Phi_{i\to j,r,s},O(\sqrt{L}))}\big{)}, there exists so that, for every
be a fixed input sequence defined by some in the support of (see Definition 3.1),
be freshly new random vectors,
with probability at least over and ,
(off target), for every
Lemma 5.2 implies there is a quantity that only depends on the target function and the random initialization (namely, ) such that,
when multiplying gives the target , but
when multiplying gives near zero.
The full proof is in Appendix G.2 but we sketch why Lemma 5.2 can be derived from Lemma 5.1.
Let us focus on indicator :
is distributed like because \langle\widetilde{a}_{k},x_{i^{\prime}}\rangle=\big{\langle}(\widetilde{a}_{k},(\varepsilon_{x}x^{\star}_{i^{\prime}},0)\big{\rangle}; but
is roughly because by random init. (see Lemma lem:done1a).
Thus, if we treat as the “noise ” in Lemma 5.1 it can be times larger than .
To show Lemma lem:fit_fun_plusa, we only need to focus on because . Since can be shown close to (see Lemma D.1), this is almost equivalent to . Conditioning on this happens, the “noise ” must be small so we can apply Lemma lem:fit_fun_olda.
To show Lemma lem:fit_fun_plusa, we can show when , the indicator on gives little information about the true noise . This is so because and are somewhat uncorrelated (details in Lemma lem:done1k). As a result, the “noise ” is still large and thus Lemma lem:fit_fun_oldb applies with . ∎
3 Fitting the Target Function
Suppose \varepsilon_{e}\in\big{(}0,\frac{1}{\mathfrak{C}_{\mathfrak{s}}(\Phi,O(\sqrt{L}))}\big{)}, , \varepsilon_{x}\in(0,\frac{1}{\rho^{4}C^{\prime}}\big{)}, we choose
The following lemma that says is close to the target function .
The construction of in Definition 5.3 satisfies the following. For every normalized input sequence generated from in the support of , we have with probability at least over , it holds for every and
Using definition of in (5.1) and , one can write down
The summands in (5.3) with are negligible owing to Lemma lem:fit_fun_plusb.
The summands in (5.3) with but are negligible, after proving that and are very uncorrelated (details in Lemma C.1).
The summands in (5.3) with are negligible using the randomness of .
One can also prove and (details in Lemma D.1).
Applying Lemma lem:fit_fun_plusa and using our choice of , this gives (in expectation)
Proving concentration (with respect to ) is a lot more challenging due to the sophisticated correlations across different indices . To achieve this, we replace some of the pairs with fresh new samples for all and apply concentration only with respect to . Here, is a random subset of with cardinality . We show that the network stabilizes (details in Section E) against such re-randomization. Full proof is in Section G.3. ∎
Finally, one can show \|W^{\divideontimes}\|_{F}\leq O\big{(}\frac{p\rho^{3}C}{\sqrt{m}}\big{)} (see Claim G.1). Crucially, this Frobenius norm scales in so standard SGD analysis shall ensure that our sample complexity does not depend on (up to log factors).
Coupling and First-Order Approximation
Consider now the scenario when the random initialization matrix is perturbed to with being small in spectral norm. Intuitively, this will later capture how much SGD has moved away from the random initialization, so it may depend on the randomness of . To untangle this possibly complicated correlation, all lemmas in this section hold for all being small.
The first lemma below states that the -th layer output difference can be approximated by a linear function in , that is . We remind the reader that this linear function in is exactly the same as our notation of from (5.2).
Let be at random initialization, be a fixed normalized input sequence, and . With probability at least over the following holds. Given any perturbation matrix with , letting
The proof of Lemma 6.1 is similar to the semi-smoothness theorem of and can be found in Section H.1.
The next lemma says that, for this linear function over , one can replace with without changing much in its output. It is a direct consequence of the adversarial stability properties of RNN from prior work (see Section F).
Let be at random initialization, be a fixed normalized input sequence, and . With probability at least over the following holds. Given any matrix with , and any with , letting
A direct corollary of Lemma 6.2 is that, for our matrix constructed in Definition 5.3 satisfies the same property of Lemma 5.4 after perturbation. Namely,
in Definition 5.3 satisfies the following. Let be at random initialization, be a fixed normalized input sequence generated by in the support of , and . With probability at least over the following holds. Given any matrix with , any , and any :
Combining Lemma 5.4 and Lemma 6.2 gives the proof. ∎
Optimization and Convergence
Our main convergence lemma for SGD on the training objective is as follows.
For every constant \varepsilon\in\big{(}0,\frac{1}{p\cdot\poly(\rho)\cdot\mathfrak{C}_{\mathfrak{s}}(\Phi,\sqrt{L})}\big{)}, there exists and parameters
so that, as long as and , setting learning rate \eta=\Theta\big{(}\frac{1}{\varepsilon\rho^{2}m}\big{)} and T=\Theta\big{(}\frac{p^{2}C^{2}\poly(\rho)}{\varepsilon^{2}}\big{)}, we have
and for .
The full proof is in Section I and we sketch the main idea here. Recall the training objective
Let be a normalized input sequence generated by some in the support of . Consider an iteration where the current weight matrix is . Let
which is a linear function over . Let us define a loss function as:
Let be defined in Definition 5.3. By Lemma 6.3, we know that as long as is small (which we shall ensure towards the end),
Thus, by the 1-Lipschitz continuity of , one can derive that
By Lemma 6.1 and Lemma 6.2 together, we know that
Using the linearity of and the 1-Lipschitz continuity of , we have
where ① is by our choice of which implies by Lemma lem:done1h.
Together, we have . Thus, by the convexity of (composing convex function with linear function is convex), we know
Suppose in this high-level sketch that we apply gradient descent as opposed to SGD. Then, and we have
Putting (7.1) into this formula, we know that as long as , then is a very negative term and thus, when is sufficiently small, it guarantees to decrease . This cannot happen for too many iterations, and thus we arrive at a convergence statement. ∎
Rademacher Complexity Through Coupling
We have the following simple lemma about the Rademacher complexity of RNNs. It first uses the coupling Lemma 6.1 to reduce the network to a linear function, and then calculates the Rademacher complexity for this linear function class.
where we use and to denote that calculated from sample . Since this function is linear in , we can write it as
We have from Lemma B.1. We bound the Rademacher complexity of this linear function using Proposition A.3 as follows.
Proof of Theorem 1
For every \varepsilon\in\big{(}0,\frac{1}{\poly(\rho)\cdot p\cdot\mathfrak{C}_{\mathfrak{s}}(\Phi,O(\sqrt{L}))}\big{)}, define complexity and , if the number of neurons and the number of samples is , then SGD with \eta=\Theta\big{(}\frac{1}{\varepsilon\rho^{2}m}\big{)} and
satisfies that, with probability at least over the random initialization
One can first apply Lemma 7.1 to obtain for satisfying (recall (I.1))
We can also apply Lemma lem:done1h together with Lemma lem:stability:adva to derive that for each fixed , with probability at least , it satisfies for every ,
and therefore by the 1-Lipschitz continuity of ,
Plugging in the Rademacher complexity Lemma 8.1 together with the choice into standard generalization argument (see Corollary A.2), we have with probability at least , for all
where the additional factor is because we have scaled with factor . In sum, it suffices to choose N\geq\Omega\big{(}\frac{\lambda^{2}\rho^{8}\Delta^{2}}{\varepsilon^{2}}\big{)}=\Omega(\rho^{6}\Delta^{2}) and N\geq\Omega\big{(}\frac{\rho^{4}b^{2}}{\varepsilon^{2}}\big{)}=\Omega(\poly(\rho)\Delta^{2}). ∎
Appendix A Rademacher Complexity Review
Suppose where each is generated i.i.d. from a distribution . If every satisfies , for every with probability at least over the randomness of , it satisfies
Let be the class of functions by composing with , that is, . By the (vector version) of the contraction lemma of Rademacher complexity There are slightly different versions of the contraction lemma in the literature. For the scalar case without absolute value, see [21, Section 3.8]; for the scalar case with absolute value, see [6, Theorem 12]; and for the vector case without absolute value, see . it satisfies . ∎
We recall the simple calculation of the Rademacher complexity for linear function class.
Suppose for all . The class has Rademacher complexity .
Appendix B Random Initialization: Basic Properties
We first note some important properties about the random initialization of our RNNs. Some of them have already appeared in , and the remaining ones can be easily derived from .
Let be at random initialization and be any fixed normalized input sequence (see Definition 3.2). Recall and
\left|\mathbf{e}_{r}^{\top}\operatorname{\mathsf{Back}}_{i\to j}\mathbf{e}_{k}\right|\leq O\big{(}\frac{\rho}{\sqrt{d}}\big{)} for every (backward signal)
\left\|\mathbf{e}_{r}^{\top}\operatorname{\mathsf{Back}}_{i\to j}\right\|\geq\Omega\big{(}\frac{\sqrt{m}}{\sqrt{d}}\big{)} for every (backward signal)
Finally, letting finishes the proof.
This is similar to the proof of Lemma lem:done1a, except noticing the factor: .
has at most coordinates with .
Since is of dimension at most , we can apply a standard -net argument over all possible with (with ) and then apply union bound. Since , with probability at least , for all in this range, it satisfies
has at most coordinates with .
The “intermediate layers” part of “basic properties at random initialization” of in fact shows (e.g. their Claim B.12 of version 3) that, for a fixed unit vector , with probability at least , it satisfies
Further using the randomness of , we have with probability at least ,
This finishes the proof after plugging in .
Using the same as above, we have with probability at least ,
This finishes the proof after plugging in for all and taking Chernoff bound.
This is similar to the proof of Lemma lem:done1f but requires a careful -net argument. It is already included in the “intermediate layers: spectral norm” part of the “Basic Properties at Random Initialization” of (e.g. Lemma B.11 in version 3).
This can be proved in the same way as Lemma lem:done1f. It is already included in the “intermediate layers: sparse spectral norm” part of the “Basic Properties at Random Initialization” of (e.g. Lemma B.14 in version 3).
See the “forward correlation” part of (e.g. Lemma B.6 of version 3).
Appendix C Random Initialization: Backward Correlation
For every , every fixed normalized input sequence , with probability at least over : for every :
Since is on the magnitude of , the above Lemma C.1 says that the two vectors and are very uncorrelated whenever .
In fact, one can prove the same Lemma C.1 for the un-correlation between and whenever . We do not need that stronger version in this paper.
so it suffices to bound the absolute value of \sum_{p\in[m]}\Xi_{p}=\big{\langle}u^{\top}\operatorname{\mathsf{Back}}_{i\to j},v^{\top}\operatorname{\mathsf{Back}}_{i\to j^{\prime}}\big{\rangle}.
Let us fix coordinates (without loss of generality the first coordinates) and calculate only over by induction. Define
where the first is due to Lemma lem:done1e. Let
Above, ① is by the definition , and ② is because for each fixed unit vector , we have with probability at least .
We consider the two terms on the right hand side separately:
Now, suppose for a moment that we view and as fixed. Then, it is a simple exercise to verify that with probability (over the randomness of ),See for instance Claim B.13 of version 3 of .
Combining the above two properties and using induction, we have
Above, ① is because for every fixed vector , with probability at least over it satisfies , and therefore we have (similarly if we replace with )
To bound , we note that the following vectors
are pairwise orthogonal, and therefore, when left-multiplied with matrix , their randomness (over ) are independent. This means, |\clubsuit|\leq O\big{(}\sqrt{N}\rho^{2}\big{)} with probability at least over . Choosing we have
Finally, we can divide all the coordinates into chunks each of size . Performing the above calculation for each of them gives the desired bound. ∎
In this section we consider two (normalized) input sequences. The first sequence is given as
The second sequence is generated from an arbitrary in the support of :
We study the following two executions of RNNs under input and respectively:
We emphasize that Lemma lem:dropping_xsa is technically the most involved, and the remaining two properties Lemma lem:dropping_xsb and Lemma lem:dropping_xsc are simple corollaries.
Let , then we can write
Using the concentration, we know with probability at least , for any fixed satisfying and ,
Using again and the Lipschitz continuity property of from Proposition D.2, we have
Recall from Proposition D.2. This means
D.2 Proof of Lemma lem:dropping_xsb and lem:dropping_xsc
Euclidean norm from Lemma lem:done1f.
Spectral norm from Lemma lem:done1j.
Spectral norm from Lemma lem:done1i.
D.3 Mathematical Tools
Let be two independent standard Gaussian random variable , and let parameters and \alpha=\sqrt{1-\beta^{2}}\in\big{[}\frac{3}{4},\frac{5}{4}\big{]}. Define
For we have .
Over , the function is -Lipschitz continuous over .
At least for the range of \alpha\in\big{[}\frac{3}{4},\frac{5}{4}\big{]} and , one can exactly integrate out this squared difference over two Gaussian variables. For instance, the “-separateness” part of “properties at random initialization” of (e.g. Claim A.6 of their version 3) has already done this for us:
It is easy to see that, as long as , we always have . Therefore
For the value approximation, we have from the above formula, and thus plugging , we have
Taking derivative with respect to , we have
It is easy to verify that for all :
This proves the -Lipschitz continuity over .
If can be optimally chosen to minimize the above right hand side,
Appendix E Stability: After Re-Randomization
In this section we study a scenario where we re-randomize a fixed set of rows in the random initialization matrices and . Formally, consider a fixed set with cardinality . Define
There are many terms in this difference, and we treat them separately below.
We have that , , , and .
Finally, we choose and to satisfy (using )
and compute its Euclidean norm. One can in fact expand out all the (exponentially many) difference terms and bound them separately.
If shows up once and never shows up, then the term is
We have by Lemma lem:done1i and by Lemma lem:done1j. Therefore, its absolute value is at most , and there are at most such terms.
If shows up once and never shows up, then the term is
We have by Lemma lem:done1i and by Lemma lem:done1j. (We also have but this is much easier to prove because is fresh new random.) Therefore, its absolute value is at most , and there are at most such terms.
If the total number of times and show up is , then the occurance of and divides the difference term into three consecutive parts. As before, the norm of the first and the last parts are at most and respectively, so it suffices to bound the matrix spectral norm of the middle part. There are four possibilities for this middle part:
.
.
All of such matrices have spectral norm at most by Lemma lem:done1j because and are both -sparse. Therefore, although there are at most such difference terms, each of them is at most in magnitude. Therefore, their total contribution is negligible when comparing to cases (1) and (2).
If the total number of times and show up is , then there are at most such terms and each of them is at most in magnitude.
and finally using the randomness of we have with probability at least
Appendix F Stability: After Adversarial Perturbation
In this section we study a scenario where the random initialization matrix is perturbed to with being small in spectral norm. Intuitively, this will later capture how much SGD has moved away from the random initialization, so it may depend on the randomness of . To untangle this possibly complicated correlation, we consider stability with respect to all being small. The following lemma has appeared already in the “Stability After Adversarial Perturbation” section of .
Let be at random initialization, be a fixed normalized input sequence, and . With probability at least over the randomness of , the following holds. Given any perturbation matrix with , letting
for every (forward stability)
for every (sign change)
\|\operatorname{\mathsf{Back}}_{i\to j}^{\prime}\|_{2}\leq O\big{(}\Delta^{1/3}\rho^{6}m^{1/3}\big{)} for every (backward stability)
Specifically, Lemma lem:stability:adva and lem:stability:advb can be found in Lemma C.2 and Lemma lem:stability:advc can be found in Lemma C.9 of [2, ver.3].
Appendix G Proof for Section 5
We now revise (G.1) in two ways without changing much of its original proof.
First, the above is quite an arbitrary choice in their proof, and can be replaced with any other for constant . They have constructed by first expanding into its Taylor expansions, and then approximating each term with — probabilists’ Hermite polynomial of degree — with . Only the coefficient in front of each , namely in their Eq. (117) depends on the choice of , and decreases from its original value as decreases from . Therefore, their final construction of will only have a smaller magnitude in these coefficients so (G.1) remains unchanged if we choose .
Parameter Choices. Having restated (G.2) and (G.3) from the prior work, let us choose parameters to apply them. Let us separate out the last coordinate for these vectors. Suppose
, ,
\langle v,y\rangle+b_{0}=\frac{1}{\sqrt{\sigma^{2}+3/4}}\big{(}\langle a_{\triangleleft},x^{\star}_{\triangleleft}\rangle+n+\frac{a_{\triangleright}}{2}\big{)}=\frac{1}{\sqrt{\sigma^{2}+3/4}}\big{(}\langle a,x^{\star}\rangle+n\big{)}
The first statement above finishes the proof of Lemma lem:fit_fun_olda. We note that because is re-scaled from by .
Off Target. For Lemma lem:fit_fun_oldb, we can derive the following
Above, ① is because with probability at most O\big{(}\frac{1}{\gamma\sigma}\big{)} and ; ② uses . ∎
G.2 Missing Proof of Lemma 5.2
Using Lemma lem:done1a and Lemma lem:dropping_xsa, we can write (abbreviating by and ),
Note that by Lemma lem:done1a and Lemma lem:dropping_xsa we have
to get the with , and let us define
A standard property of Gaussian random variable shows that:
We can decompose into the projection on and on the perpendicular space of . We can write as:
By the fact that from Lemma lem:done1a, we know that \big{|}\frac{\langle h^{(0)}_{i-1},h_{i-1}\rangle}{\|h^{(0)}_{i-1}\|^{2}}\big{|}\leq 3. As a result, conditioning on , we have
Since \langle\widetilde{a}_{k},x_{i}\rangle\sim\mathcal{N}\big{(}0,\frac{2\varepsilon_{x}^{2}}{m}\big{)} is independent of the randomness of , equation (G.5) implies
Above, ① uses (G.6) and .
Next, recall from (G.4) that |\|(I-\widehat{h}\widehat{h}^{\top})h_{i-1}\|-\tau|=O\big{(}\frac{\rho^{2}}{\sqrt{m}}\big{)}. As a result, with probability at least over ,
Above, ① is because \big{\langle}\widetilde{w}_{k},\frac{(I-\widehat{h}\widehat{h}^{\top})h_{i-1}}{\|(I-\widehat{h}\widehat{h}^{\top})h_{i-1}\|}\big{\rangle}\sim\mathcal{N}\big{(}0,\frac{2}{m}\big{)}, so with probability at least over it is at most . Using (G.8), together with a similar argument to (G.6), we have
Above, ① is because of (G.9) and the definition of ; and ② is because of Lemma 5.1 with and and re-scaling. Putting together (G.7) and (G.10) finish the proof for the on target part.
We again decompose into the projection on and on the perpendicular space of . We can write as:
By the fact that from Lemma lem:done1a, we again know \big{|}\frac{\langle h^{(0)}_{i-1},h_{i^{\prime}-1}\rangle}{\|h^{(0)}_{i-1}\|^{2}}\big{|}\leq 3. Therefore, the same derivation of (G.7) implies
Using Lemma lem:fit_fun_oldb, we have for every ,
Above, ① is by applying Lemma lem:fit_fun_oldb after re-scaling and .
Using Lemma lem:done1k, we know for with high probability \|(I-\widehat{h}\widehat{h}^{\top})h^{(0)}_{i^{\prime}-1}\|\geq\Omega\big{(}\frac{1}{L^{2}\log^{3}m}\big{)}. (One can argue similarly with the help of Lemma B.1 for .) Thus, by the closeness property from Lemma lem:dropping_xsa, we also have with probability at least over the randomness of ,
Putting this into (G.12), we have with probability at least ,
Combining this with (G.11), and using our choice of finishes the proof of for the off target part.
G.3 Missing Proof of Lemma 5.4
Let us choose , and by Lemma lem:done1d, we have with probability over the randomness of ,
Using the randomness of , we have with probability over the randomness of ,
Let us define be the above subset. Since and since from (G.15), we have
For a similar reason, using from Lemma lem:done1a and the independence between and , we know with probability over the randomness of ,
Let us define be the above subset. Using from (G.15) again, we have
Together, we have . Let us choose
Let us now fix . Summing over , we have with probability ,
Thus, putting together the two summations we have
Now, we consider fixed and only use the randomness of to analyze . It is a summation of independent random variables. By Chernoff bound, with probability ,
By Lemma 5.2 we know that the expectation is given by:
Using (G.14) and (G.16) again, we can bound
Therefore, so far we have calculated that with probability at least over the randomness of :
Summing it up, and using from Claim claim:Ws-constructb, we can write
with .
When , using the randomness of , it is easy to see with probability at least
where the last equality is due to Lemma lem:done1i and . Therefore, we can write
Above, ① uses our choice .
With probability at least over the random initialization , we have
\|W^{\divideontimes}\|_{2,\infty}\leq O\big{(}\frac{p\rho^{3}C}{m}\big{)} so \|W^{\divideontimes}\|_{F}\leq O\big{(}\frac{p\rho^{3}C}{\sqrt{m}}\big{)}.
for every and .
where the notion comes from Definition 5.3.
For the norm bound on , note that for any unit vector , we have
and as a result \|W^{\divideontimes}\|_{F}\leq O\big{(}\frac{p\rho^{3}C}{\sqrt{m}}\big{)}\enspace. Lemma lem:done1a and Lemma lem:done1g together imply . Lemma lem:done1a and Lemma lem:done1f together imply . ∎
Appendix H Proof for Section 6
by Lemma lem:done1f
by Lemma lem:done1j.
\|(WD\cdots WD)W^{\prime}(h+h^{\prime})\|\leq\|WD\cdots WD\|_{2}\|W^{\prime}\|_{2}\|h+h^{\prime}\|\leq O\big{(}\frac{L^{3}\Delta}{\sqrt{m}}\big{)} by Lemma lem:done1i (and using Lemma lem:done1a with Lemma lem:stability:adva).
Putting (H.2) and (H.3) back to (H.1) finishes the proof. ∎
H.2 Proof of Lemma 6.2
We analyze the three error terms separately.
\|\clubsuit_{i^{\prime}}\|\leq\|\operatorname{\mathsf{Back}}^{\prime}_{i^{\prime}\to j^{\prime}}\|_{2}\cdot\|\widetilde{W}\|_{2}\cdot O(1)\leq O\big{(}\frac{\omega\rho^{6}\Delta^{1/3}}{m^{1/6}}\big{)} using Lemma lem:stability:advc.
\|\spadesuit_{i^{\prime}}\|\leq\|\operatorname{\mathsf{Back}}_{i^{\prime}\to j^{\prime}}\|_{2}\cdot\|\widetilde{W}\|_{2}\cdot\|h^{\prime}_{i^{\prime}}\|\leq O\big{(}\frac{\omega\rho^{7}\Delta}{\sqrt{m}}\big{)} using Lemma lem:done1f and Lemma lem:stability:adva.
\|\diamondsuit_{i^{\prime}}\|\leq\|\operatorname{\mathsf{Back}}_{i^{\prime}\to j^{\prime}}D^{\prime}_{i^{\prime}}\|_{2}\cdot\|\widetilde{W}\|_{2}\cdot O(1)\leq O\big{(}\frac{\omega\rho^{2}\Delta^{1/3}}{m^{1/6}}\big{)} using Lemma lem:done1f and the bound from Lemma lem:stability:advb. ∎
Appendix I Proof for Lemma 7.1
Since it satisfies for all , by the 1-Lipschitz continuity of , we also have . Therefore, by Chernoff bound, as long as we also have
Let be a normalized input sequence generated by some in the support of . Consider an iteration where the current weight matrix is . Let
which is a linear function over . Let us define a loss function as:
Let be defined in Definition 5.3. By Lemma 6.3, we know that as long as (for some parameter we shall choose at the end),
where the last inequality is by choosing and and sufficiently large . Taking union bound over all samples in , by the 1-Lipschitz continuity of , we have
Using 1-Lipschitz continuity of , we have
where ① is by our choice of which implies by Lemma lem:done1h.
Thus, by the convexity of (composing convex function with linear function is convex), we know
Finally, recall that SGD takes a stochastic gradient so
On one hand, we have (for any )
Above, ① uses the -Lipschitz continuity of , and ② uses Lemma B.1 and the choice of .
Telescoping over all , we have
so we can choose \Delta=\Theta\big{(}\frac{C^{2}\rho^{11}p^{2}}{\varepsilon^{2}}\big{)}.
Finally, we can replace the notation with because . ∎