Learning ReLUs via Gradient Descent

Mahdi Soltanolkotabi

Introduction

A natural approach to fitting ReLUs to data is via minimizing the least-squares misfit aggregated over the data. This optimization problem takes the form

Fitting nonlinear models such as ReLUs have a rich history in statistics and learning theory with interesting new developments emerging (we shall discuss all these results in greater detail in Section 4). Most recently, nonlinear data fitting problems in the form of neural networks (a.k.a. deep learning) have emerged as powerful tools for automatically extracting interpretable and actionable information from raw forms of data, leading to striking breakthroughs in a multitude of applications . In these and many other empirical domains it is common to use local search heuristics such as gradient or stochastic gradient descent for nonlinear data fitting. These local search heuristics are surprisingly effective on real or randomly generated data. However, despite their empirical success the reasons for their effectiveness remains mysterious.

Focusing on fitting ReLUs, a-priori it is completely unclear why local search heuristics such as gradient descent should converge for problems of the form (1.1), as not only the regularization function maybe nonconvex but also the loss function! Efficient fitting of ReLUs in this high-dimensional setting poses new challenges: When are the iterates able to escape local optima and saddle points and converge to global optima? How many samples do we need? How does the number of samples depend on the a-priori prior knowledge available about the weights? What regularizer is best suited to utilizing a particular form of prior knowledge? How many passes (or iterations) of the algorithm is required to get to an accurate solution? At the heart of answering these questions is the ability to predict convergence behavior/rate of (non)convex constrained optimization algorithms. In this paper we build up on a new framework developed by the author in for analyzing nonconvex optimization problems to address such challenges.

Precise measures for statistical resources

The set of descent of a function R\mathcal{R} at a point w\bm{w}^{*} is defined as

The cone of descent is defined as a closed cone CR(w)\mathcal{C}_{\mathcal{R}}(\bm{w}^{*}) that contains the descent set, i.e. DR(w)CR(w)\mathcal{D}_{\mathcal{R}}(\bm{w}^{*})\subset\mathcal{C}_{\mathcal{R}}(\bm{w}^{*}). The tangent cone is the conic hull of the descent set. That is, the smallest closed cone CR(w)\mathcal{C}_{\mathcal{R}}(\bm{w}^{*}) obeying DR(w)CR(w)\mathcal{D}_{\mathcal{R}}(\bm{w}^{*})\subset\mathcal{C}_{\mathcal{R}}(\bm{w}^{*}).

We note that the capability of the regularizer R\mathcal{R} in capturing the properties of the unknown weight vector w\bm{w}^{*} depends on the size of the descent cone CR(w)\mathcal{C}_{\mathcal{R}}(\bm{w}^{*}). The smaller this cone is the more suited the function R\mathcal{R} is at capturing the properties of w\bm{w}^{*}. To quantify the size of this set we shall use the notion of mean width.

We now have all the definitions in place to quantify the capability of the function R\mathcal{R} in capturing the properties of the unknown parameter w\bm{w}^{*}. This naturally leads us to the definition of the minimum required number of samples.

Let CR(w)\mathcal{C}_{\mathcal{R}}(\bm{w}^{*}) be a cone of descent of R\mathcal{R} at w\bm{w}^{*}. We define the minimal sample function as

We shall often use the short hand n0=M(R,w)n_{0}=\mathcal{M}(\mathcal{R},\bm{w}^{*}) with the dependence on R,w\mathcal{R},\bm{w}^{*} implied.

We note that n0n_{0} is exactly the minimum number of samples required for structured signal recovery from linear measurements when using convex regularizers . Specifically, the optimization problem

succeeds at recovering an unknown weight vector w\bm{w}^{*} with high probability from nn observations of the form yi=ai,w\bm{y}_{i}=\langle\bm{a}_{i},\bm{w}^{*}\rangle if and only if nn0n\geq n_{0}.We would like to note that n0n_{0} only approximately characterizes the minimum number of samples required. A more precise characterization is ϕ1(ω2(CR(w)Bd))ω2(CR(w)Bd)\phi^{-1}(\omega^{2}(\mathcal{C}_{\mathcal{R}}(\bm{w}^{*})\cap\mathcal{B}^{d}))\approx\omega^{2}(\mathcal{C}_{\mathcal{R}}(\bm{w}^{*})\cap\mathcal{B}^{d}) where ϕ(t)=2Γ(t+12)Γ(t2)t\phi(t)=\sqrt{2}\frac{\Gamma\left(\frac{t+1}{2}\right)}{\Gamma\left(\frac{t}{2}\right)}\approx\sqrt{t}. However, since our results have unspecified constants we avoid this more accurate characterization. While this result is only known to be true for convex regularization functions we believe that n0n_{0} also characterizes the minimal number of samples even for nonconvex regularizers in (2.1). See for some results in the nonconvex case as well as the role this quantity plays in the computational complexity of projected gradient schemes for linear inverse problems. Given that with nonlinear samples we have less information (we loose some information compared to linear observations) we can not hope to recover the weight vector from nn0n\leq n_{0} when using (1.1). Therefore, we can use n0n_{0} as a lower-bound on the minimum number of observations required for projected gradient descent iterations (3.2) to succeed at finding the right model.

Theoretical results for learning ReLUs

A simple heuristic for optimizing (1.1) is to use gradient descent. One challenging aspect of the above loss function is that it is not differentiable and it is not clear how to run projected gradient descent. However, this does not pose a fundamental challenge as the loss function is differentiable except for isolated points and we can use the notion of generalized gradients to define the gradient at a non-differentiable point as one of the limit points of the gradient in a local neighborhood of the non-differentiable point. For the loss in (1.1) the generalized gradient takes the form

Therefore, projected gradient descent takes the form

To estimate w\bm{w}^{*}, we start from the initial point w0=0\bm{w}_{0}=\bm{0} and apply the Projected Gradient (PGD) updates of the form

holds for a fixed numerical constant cc. Then there is an event of probability at least 19eγn1-9e^{-\gamma n} such that on this event the updates (3.3) obey

Here γ\gamma is a fixed numerical constant.

The first interesting and perhaps surprising aspect of this result is its generality: it applies not only to convex regularization functions but also nonconvex ones! As we mentioned earlier the optimization problem in (1.1) is not known to be tractable even for convex regularizers. Despite the nonconvexity of both the objective and regularizer, the theorem above shows that with a near minimal number of data samples, projected gradient descent provably learns the original weight vector w\bm{w}^{*} without getting trapped in any local optima.

Another interesting aspect of the above result is that the convergence rate is linear. Therefore, to achieve a relative error of ϵ\epsilon the total number of iterations is on the order of O(log(1/ϵ))\mathcal{O}(\log(1/\epsilon)). Thus the overall computational complexity is on the order of O(ndlog(1/ϵ))\mathcal{O}\left(nd\log(1/\epsilon)\right) (in general the cost is the total number of iterations multiplied by the cost of applying the feature matrix X\bm{X} and its transpose). As a result, the computational complexity is also now optimal in terms of dependence on the matrix dimensions. Indeed, for a dense matrix even verifying that a good solution has been achieved requires one matrix-vector multiplication which takes O(nd)\mathcal{O}(nd) time.

Discussions and prior art

Proofs

In this section we gather some useful results on concentration of stochastic processes which will be crucial in our proofs. These results are mostly adapted from . We begin with a lemma which is a direct consequence of Gordon’s escape from the mesh lemma .

for a fixed numerical constant cc. Then for all hC\bm{h}\in\mathcal{C}

holds with probability at least 12eδ2360n1-2e^{-\frac{\delta^{2}}{360}n}.

We also need a generalization of the above lemma stated below.

for a fixed numerical constant cc. Then for all u,hC\bm{u},\bm{h}\in\mathcal{C}

holds with probability at least 16eδ21440n1-6e^{-\frac{\delta^{2}}{1440}n}.

We next state a generalization of Gordon’s escape through the mesh lemma also from .

The previous lemma leads to the following Corollary.

2 Convergence proof (Proof of Theorem 3.1)

In this section we shall prove Theorem 3.1. Throughout, we use the shorthand C\mathcal{C} to denote the descent cone of R\mathcal{R} at w\bm{w}^{*}, i.e. C=CR(w)\mathcal{C}=\mathcal{C}_{\mathcal{R}}(\bm{w}^{*}). We begin by analyzing the first iteration. Using w0=0\bm{w}_{0}=\bm{0} we have

We use the argument of [Page 25, inequality (7.34)] which shows that

Using ReLU(z)=z+z2(z)=\frac{z+\left|z\right|}{2} we have

We proceed by bounding the first term in the above equality. To this aim we decompose u\bm{u} in the direction parallel/perpendicular to that of w\bm{w}^{*} and arrive at

holds with probability at least 12enΔ281-2e^{-n\frac{\Delta^{2}}{8}}. Also,

holds with probability at least 1eη221-e^{-\frac{\eta^{2}}{2}}. Plugging (5.4) with Δ=δ6\Delta=\frac{\delta}{6} and (5.5) with η=δ6n\eta=\frac{\delta}{6}\sqrt{n} into (5.2), as long as

holds with probability at least 13enδ22881-3e^{-n\frac{\delta^{2}}{288}}.

We now focus on bounding the second term in (5.2). To this aim we decompose u\bm{u} in the direction parallel/perpendicular to that of w\bm{w}^{*} and arrive at

with fixed numerical constant. Thus by Bernstein’s type inequality ([Proposition 5.16])

holds with probability at least 12eγnmin(t2,t)1-2e^{-\gamma n\min\left(t^{2},t\right)} with γ\gamma a fixed numerical constant. Also note that

holds with probability at least 12enΔ281-2e^{-n\frac{\Delta^{2}}{8}} and

holds with probability at least 1eη221-e^{-\frac{\eta^{2}}{2}}. Combining the last two inequalities we conclude that

holds with probability at least 12enΔ28eη221-2e^{-n\frac{\Delta^{2}}{8}}-e^{-\frac{\eta^{2}}{2}}. Plugging (5.8) and (5.9) with t=δ6t=\frac{\delta}{6}, Δ=1\Delta=1, and η=δ62n\eta=\frac{\delta}{6\sqrt{2}}\sqrt{n} into (5.2)

holds with probability at least 13eγnδ22en81-3e^{-\gamma n\delta^{2}}-2e^{-\frac{n}{8}} as long as

Thus pluggin (5.6) and (5.10) into (5.1) we conclude that for δ=7/400\delta=7/400

holds with probability at least 18eγn1-8e^{-\gamma n} as long as

To introduce our general convergence analysis we begin by defining

To prove Theorem 3.1 we use the argument of [Page 25, inequality (7.34)] which shows that if we apply the projected gradient descent update

the error hτ=wτw\bm{h}_{\tau}=\bm{w}_{\tau}-\bm{w}^{*} obeys

To complete the convergence analysis it is then sufficient to prove

We will instead prove that the following stronger result holds for all uCBn\bm{u}\in\mathcal{C}\cap\mathcal{B}^{n} and wE(ϵ)\bm{w}\in E(\epsilon)

The equation (5.13) above implies (5.12) which when combined with (5.11) proves the convergence result of the Theorem (specifically equation (3.5)).

The rest of this section is dedicated to proving (5.13). To this aim note that ReLU(xi,w)=xi,w+xi,w2\text{ReLU}(\langle\bm{x}_{i},\bm{w}\rangle)=\frac{\langle\bm{x}_{i},\bm{w}\rangle+\left|\langle\bm{x}_{i},\bm{w}\rangle\right|}{2}. Therefore, the loss function can alternatively be written as

Now defining h=ww\bm{h}=\bm{w}-\bm{w}^{*} we conclude that

We now proceed by bounding each of the four terms in (5.14) and then combine them in Section 5.2.5.

To bound the first term we use Lemma 5.2, which implies that as long as

then for all uCBn\bm{u}\in\mathcal{C}\cap\mathcal{B}^{n} and hE(ϵ)\bm{h}\in E(\epsilon)

holds with probability at least 16eδ21440n1-6e^{-\frac{\delta^{2}}{1440}n}.

2.2 Bounding the second term in (5.14)

We now proceed by bound the four terms on the right-hand side of the above inequality. To bound the first term we use (5.8) to conclude that

holds with probability at least 12eγnmin(t2,t)1-2e^{-\gamma n\min\left(t^{2},t\right)}. To bound the second and third terms in (5.2.2) we use (5.9) to conclude that

holds with probability at least 12enΔ28eη22n1-2e^{-n\frac{\Delta^{2}}{8}}-e^{-\frac{\eta^{2}}{2}n}. To bound the last term let ϵi\epsilon_{i} be i.i.d. ±1\pm 1 random variables independent from xi\bm{x}_{i}. Then,

Now to bound the term in the parenthesis note that by concentration of sums of sub-Gaussian random variables [18, Proposition 5.10]

holds with probability at least 12eγnΔ21-2e^{-\gamma n\Delta^{2}}. Now note that

holds with probability at least 1eη22n1-e^{-\frac{\eta^{2}}{2}n}. Thus using the union bound

holds with probability at least 1neη22n1-ne^{-\frac{\eta^{2}}{2}n}. Plugging this into (5.20) and using Lemma 5.3 with S=I\bm{S}=\bm{I}, we conclude that

holds with probability at least 12eγnΔ2neη22n6eη28n1-2e^{-\gamma n\Delta^{2}}-ne^{-\frac{\eta^{2}}{2}n}-6e^{-\frac{\eta^{2}}{8}n}, completing the bound of the last term of (5.2.2). Combining (5.17), (5.2.2), and (5.21) with η=Δ=1\eta=\Delta=1, we conclude that

holds with probability at least 12eγδ2n(n+10)eγn1-2e^{-\gamma\delta^{2}n}-(n+10)e^{-\gamma n} as long as

This completes the bound on the second term in (5.14).

2.3 Bounding the third term in (5.14)

where the last inequality follows from Cauchy Schwarz. Note that by Lemma 5.1 as long as nmax(80n0δ2,2δ1)n\geq\max\left(80\frac{n_{0}}{\delta^{2}},\frac{2}{\delta}-1\right), then

We now proceed by using the following result from .

with cc a fixed numerical constant. Then

holds with probability at least 12eγδ2n1-2e^{-\gamma\delta^{2}n} for all vectors hC\bm{h}\in\mathcal{C} obeying

Proof This lemma follows from the argument on pages 27-30 of .

Combining the lemma with equations (5.2.3), (5.2.3), and (5.2.3) we conclude that as long as ncn0n\geq cn_{0}, then

2.4 Bounding the fourth term in (5.14)

To bound the fourth term of (5.14) note that by the argument leading to (5.2.3)

with cc a fixed numerical constant. Then

holds with probability at least 12eγδ2n1-2e^{-\gamma\delta^{2}n} for all vectors hC\bm{h}\in\mathcal{C} obeying

Proof This lemma follows from the argument on pages 27-30 of .

Combining the lemma with (5.2.4) we conclude that as long as ncn0n\geq cn_{0}, then

2.5 Putting the bounds together

In this Section we put together the bounds of the previous sections. Combining (5.15), (5.22), (5.26), and (5.28) we conclude that

Acknowledgements

M.S. would like to thank Adam Klivans and Matus Telgarsky for discussions related to and the Isotron algorithm. This work was done in part while the author was visiting the Simons Institute for the Theory of Computing.

References