Learning ReLUs via Gradient Descent
Mahdi Soltanolkotabi
Introduction
A natural approach to fitting ReLUs to data is via minimizing the least-squares misfit aggregated over the data. This optimization problem takes the form
Fitting nonlinear models such as ReLUs have a rich history in statistics and learning theory with interesting new developments emerging (we shall discuss all these results in greater detail in Section 4). Most recently, nonlinear data fitting problems in the form of neural networks (a.k.a. deep learning) have emerged as powerful tools for automatically extracting interpretable and actionable information from raw forms of data, leading to striking breakthroughs in a multitude of applications . In these and many other empirical domains it is common to use local search heuristics such as gradient or stochastic gradient descent for nonlinear data fitting. These local search heuristics are surprisingly effective on real or randomly generated data. However, despite their empirical success the reasons for their effectiveness remains mysterious.
Focusing on fitting ReLUs, a-priori it is completely unclear why local search heuristics such as gradient descent should converge for problems of the form (1.1), as not only the regularization function maybe nonconvex but also the loss function! Efficient fitting of ReLUs in this high-dimensional setting poses new challenges: When are the iterates able to escape local optima and saddle points and converge to global optima? How many samples do we need? How does the number of samples depend on the a-priori prior knowledge available about the weights? What regularizer is best suited to utilizing a particular form of prior knowledge? How many passes (or iterations) of the algorithm is required to get to an accurate solution? At the heart of answering these questions is the ability to predict convergence behavior/rate of (non)convex constrained optimization algorithms. In this paper we build up on a new framework developed by the author in for analyzing nonconvex optimization problems to address such challenges.
Precise measures for statistical resources
The set of descent of a function at a point is defined as
The cone of descent is defined as a closed cone that contains the descent set, i.e. . The tangent cone is the conic hull of the descent set. That is, the smallest closed cone obeying .
We note that the capability of the regularizer in capturing the properties of the unknown weight vector depends on the size of the descent cone . The smaller this cone is the more suited the function is at capturing the properties of . To quantify the size of this set we shall use the notion of mean width.
We now have all the definitions in place to quantify the capability of the function in capturing the properties of the unknown parameter . This naturally leads us to the definition of the minimum required number of samples.
Let be a cone of descent of at . We define the minimal sample function as
We shall often use the short hand with the dependence on implied.
We note that is exactly the minimum number of samples required for structured signal recovery from linear measurements when using convex regularizers . Specifically, the optimization problem
succeeds at recovering an unknown weight vector with high probability from observations of the form if and only if .We would like to note that only approximately characterizes the minimum number of samples required. A more precise characterization is where . However, since our results have unspecified constants we avoid this more accurate characterization. While this result is only known to be true for convex regularization functions we believe that also characterizes the minimal number of samples even for nonconvex regularizers in (2.1). See for some results in the nonconvex case as well as the role this quantity plays in the computational complexity of projected gradient schemes for linear inverse problems. Given that with nonlinear samples we have less information (we loose some information compared to linear observations) we can not hope to recover the weight vector from when using (1.1). Therefore, we can use as a lower-bound on the minimum number of observations required for projected gradient descent iterations (3.2) to succeed at finding the right model.
Theoretical results for learning ReLUs
A simple heuristic for optimizing (1.1) is to use gradient descent. One challenging aspect of the above loss function is that it is not differentiable and it is not clear how to run projected gradient descent. However, this does not pose a fundamental challenge as the loss function is differentiable except for isolated points and we can use the notion of generalized gradients to define the gradient at a non-differentiable point as one of the limit points of the gradient in a local neighborhood of the non-differentiable point. For the loss in (1.1) the generalized gradient takes the form
Therefore, projected gradient descent takes the form
To estimate , we start from the initial point and apply the Projected Gradient (PGD) updates of the form
holds for a fixed numerical constant . Then there is an event of probability at least such that on this event the updates (3.3) obey
Here is a fixed numerical constant.
The first interesting and perhaps surprising aspect of this result is its generality: it applies not only to convex regularization functions but also nonconvex ones! As we mentioned earlier the optimization problem in (1.1) is not known to be tractable even for convex regularizers. Despite the nonconvexity of both the objective and regularizer, the theorem above shows that with a near minimal number of data samples, projected gradient descent provably learns the original weight vector without getting trapped in any local optima.
Another interesting aspect of the above result is that the convergence rate is linear. Therefore, to achieve a relative error of the total number of iterations is on the order of . Thus the overall computational complexity is on the order of (in general the cost is the total number of iterations multiplied by the cost of applying the feature matrix and its transpose). As a result, the computational complexity is also now optimal in terms of dependence on the matrix dimensions. Indeed, for a dense matrix even verifying that a good solution has been achieved requires one matrix-vector multiplication which takes time.
Discussions and prior art
Proofs
In this section we gather some useful results on concentration of stochastic processes which will be crucial in our proofs. These results are mostly adapted from . We begin with a lemma which is a direct consequence of Gordon’s escape from the mesh lemma .
for a fixed numerical constant . Then for all
holds with probability at least .
We also need a generalization of the above lemma stated below.
for a fixed numerical constant . Then for all
holds with probability at least .
We next state a generalization of Gordon’s escape through the mesh lemma also from .
The previous lemma leads to the following Corollary.
2 Convergence proof (Proof of Theorem 3.1)
In this section we shall prove Theorem 3.1. Throughout, we use the shorthand to denote the descent cone of at , i.e. . We begin by analyzing the first iteration. Using we have
We use the argument of [Page 25, inequality (7.34)] which shows that
Using ReLU we have
We proceed by bounding the first term in the above equality. To this aim we decompose in the direction parallel/perpendicular to that of and arrive at
holds with probability at least . Also,
holds with probability at least . Plugging (5.4) with and (5.5) with into (5.2), as long as
holds with probability at least .
We now focus on bounding the second term in (5.2). To this aim we decompose in the direction parallel/perpendicular to that of and arrive at
with fixed numerical constant. Thus by Bernstein’s type inequality ([Proposition 5.16])
holds with probability at least with a fixed numerical constant. Also note that
holds with probability at least and
holds with probability at least . Combining the last two inequalities we conclude that
holds with probability at least . Plugging (5.8) and (5.9) with , , and into (5.2)
holds with probability at least as long as
Thus pluggin (5.6) and (5.10) into (5.1) we conclude that for
holds with probability at least as long as
To introduce our general convergence analysis we begin by defining
To prove Theorem 3.1 we use the argument of [Page 25, inequality (7.34)] which shows that if we apply the projected gradient descent update
the error obeys
To complete the convergence analysis it is then sufficient to prove
We will instead prove that the following stronger result holds for all and
The equation (5.13) above implies (5.12) which when combined with (5.11) proves the convergence result of the Theorem (specifically equation (3.5)).
The rest of this section is dedicated to proving (5.13). To this aim note that . Therefore, the loss function can alternatively be written as
Now defining we conclude that
We now proceed by bounding each of the four terms in (5.14) and then combine them in Section 5.2.5.
To bound the first term we use Lemma 5.2, which implies that as long as
then for all and
holds with probability at least .
2.2 Bounding the second term in (5.14)
We now proceed by bound the four terms on the right-hand side of the above inequality. To bound the first term we use (5.8) to conclude that
holds with probability at least . To bound the second and third terms in (5.2.2) we use (5.9) to conclude that
holds with probability at least . To bound the last term let be i.i.d. random variables independent from . Then,
Now to bound the term in the parenthesis note that by concentration of sums of sub-Gaussian random variables [18, Proposition 5.10]
holds with probability at least . Now note that
holds with probability at least . Thus using the union bound
holds with probability at least . Plugging this into (5.20) and using Lemma 5.3 with , we conclude that
holds with probability at least , completing the bound of the last term of (5.2.2). Combining (5.17), (5.2.2), and (5.21) with , we conclude that
holds with probability at least as long as
This completes the bound on the second term in (5.14).
2.3 Bounding the third term in (5.14)
where the last inequality follows from Cauchy Schwarz. Note that by Lemma 5.1 as long as , then
We now proceed by using the following result from .
with a fixed numerical constant. Then
holds with probability at least for all vectors obeying
Proof This lemma follows from the argument on pages 27-30 of .
Combining the lemma with equations (5.2.3), (5.2.3), and (5.2.3) we conclude that as long as , then
2.4 Bounding the fourth term in (5.14)
To bound the fourth term of (5.14) note that by the argument leading to (5.2.3)
with a fixed numerical constant. Then
holds with probability at least for all vectors obeying
Proof This lemma follows from the argument on pages 27-30 of .
Combining the lemma with (5.2.4) we conclude that as long as , then
2.5 Putting the bounds together
In this Section we put together the bounds of the previous sections. Combining (5.15), (5.22), (5.26), and (5.28) we conclude that
Acknowledgements
M.S. would like to thank Adam Klivans and Matus Telgarsky for discussions related to and the Isotron algorithm. This work was done in part while the author was visiting the Simons Institute for the Theory of Computing.