Exponential expressivity in deep neural networks through transient chaos
Ben Poole, Subhaneil Lahiri, Maithra Raghu, Jascha Sohl-Dickstein, Surya Ganguli
Introduction
Deep feedforward neural networks, with multiple hidden layers, have achieved remarkable performance across many domains . A key factor thought to underlie their success is their high expressivity. This informal notion has manifested itself primarily in two forms of intuition. The first is that deep networks can compactly express highly complex functions over input space in a way that shallow networks with one hidden layer and the same number of neurons cannot. The second piece of intuition, which has captured the imagination of machine learning and neuroscience alike, is that deep neural networks can disentangle highly curved manifolds in input space into flattened manifolds in hidden space, to aid the performance of simple linear readouts. These intuitions, while attractive, have been difficult to formalize mathematically, and thereby rigorously test.
For the first intuition, seminal works have exhibited examples of particular functions that can be computed with a polynomial number of neurons (in the input dimension) in a deep network but require an exponential number of neurons in a shallow network . This raises a central open question: are such functions merely rare curiosities, or is any function computed by a generic deep network not efficiently computable by a shallow network? The theoretical techniques employed in prior work both limited the applicability of theory to specific nonlinearities and dictated the particular measure of deep functional complexity involved. For example focused on ReLu nonlinearities and number of linear regions as a complexity measure, while focused on sum-product networks and the number of monomials as complexity measure, and focused on Pfaffian nonlinearities and topological measures of complexity, like the sum of Betti numbers of a decision boundary. However, see for an interesting analysis of a general class of compositional functions. The limits of prior theoretical techniques raise another central question: is there a unifying theoretical framework for deep neural expressivity that is simultaneously applicable to arbitrary nonlinearities, generic networks, and a natural, general measure of functional complexity?
Here we attack both central problems of deep neural expressivity by combining a very different set of tools, namely Riemannian geometry and dynamical mean field theory . This novel combination enables us to show that for very broad classes of nonlinearities, even random deep neural networks can construct hidden internal representations whose global extrinsic curvature grows exponentially with depth but not width. Our geometric framework enables us to quantitatively define a notion of disentangling and verify this notion even in deep random networks. Furthermore, our methods yield insights into the emergent, deterministic nature of signal propagation through large random feedforward networks, revealing the existence of an order to chaos transition as a function of the statistics of weights and biases. We find that the transient, finite depth evolution in the chaotic regime underlies the origins of exponential expressivity in deep random networks.
In our companion paper , we study several related measures of expressivity in deep random neural networks with piecewise linear activations.
A mean field theory of deep nonlinear signal propagation
where is a vector of biases, is the pattern of inputs to neurons at layer , and is a single neuron scalar nonlinearity that acts component-wise to transform inputs to activities . We wish to understand the nature of typical functions computable by such networks, as a consequence of their depth. We therefore study ensembles of random networks in which each of the synaptic weights are drawn i.i.d. from a zero mean Gaussian with variance , while the biases are drawn i.i.d. from a zero mean Gaussian with variance . This weight scaling ensures that the input contribution to each individual neuron at layer from activities in layer remains , independent of the layer width . This ensemble constitutes a maximum entropy distribution over deep neural networks, subject to constraints on the means and variances of weights and biases. This ensemble induces no further structure in the resulting set of deep functions, so its analysis provides an opportunity to understand the specific contribution of depth alone to the nature of typical functions computed by deep networks.
In the limit of large layer widths, , certain aspects of signal propagation through deep random neural networks take on an essentially deterministic character. This emergent determinism in large random neural networks enables us to understand how the Riemannian geometry of simple manifolds in the input layer is typically modified as the manifold propagates into the deep layers. For example, consider the simplest case of a single input vector . As it propagates through the network, its length in downstream layers will change. We track this changing length by computing the normalized squared length of the input vector at each layer:
This length is the second moment of the empirical distribution of inputs across all neurons in layer . For large , this empirical distribution converges to a zero mean Gaussian since each
is a weighted sum of a large number of uncorrelated random variables - i.e. the weights and biases , which are independent of the activity in previous layers. By propagating this Gaussian distribution across one layer, we obtain an iterative map for in (2):
where is the standard Gaussian measure, and the initial condition is , where is the length in the initial activity layer. See Supplementary Material (SM) for a derivation of (3). Intuitively, the integral over in (3) replaces an average over the empirical distribution of across neurons in layer at large layer width .
The function in (3) is an iterative variance, or length, map that predicts how the length of an input in (2) changes as it propagates through the network. This length map is plotted in Fig. 1A for the special case of a sigmoidal nonlinearity, . For monotonic nonlinearities, this length map is a monotonically increasing, concave function whose intersections with the unity line determine its fixed points . For and , the only intersection is at . In this bias-free, small weight regime, the network shrinks all inputs to the origin. For and , the fixed point becomes unstable and the length map acquires a second nonzero fixed point, which is stable. In this bias-free, large weight regime, the network expands small inputs and contracts large inputs. Also, for any nonzero bias , the length map has a single stable non-zero fixed point. In such a regime, even with small weights, the injected biases at each layer prevent signals from decaying to . The dynamics of the length map leads to rapid convergence of length to its fixed point with depth (Fig. 1B,D), often within only layers. The fixed points are shown in Fig. 1C.
Transient chaos in deep networks
Now consider the layer-wise propagation of two inputs and . The geometry of these two inputs as they propagate through the network is captured by the by matrix of inner products:
The dynamics of the two diagonal terms are each theoretically predicted by the length map in (3). We derive (see SM) a correlation map that predicts the layer-wise dynamics of :
where is the correlation coefficient. Here and are independent standard Gaussian variables, while and are correlated Gaussian variables with covariance matrix . Together, (3) and (5) constitute a theoretical prediction for the typical evolution of the geometry of points in (4) in a fixed large network.
Analysis of these equations reveals an interesting order to chaos transition in the and plane. In particular, what happens to two nearby points as they propagate through the layers? Their relation to each other can be tracked by the correlation coefficient between the two points, which approaches a fixed point at large depth. Since the length of each point rapidly converges to , as shown in Fig. 1BD, we can compute by simply setting in (5) and dividing by to obtain an iterative correlation coefficient map, or -map, for :
This -map is shown in Fig. 2A. It always has a fixed point at as can be checked by direct calculation. However, the stability of this fixed point depends on the slope of the map at , which is
See SM for a derivation of (7). If the slope is less than , then the -map is above the unity line, the fixed point at under the -map in (6) is stable, and
nearby points become more similar over time. Conversely, if then this fixed point is unstable, and nearby points separate as they propagate through the layers. Thus we can intuitively understand as a multiplicative stretch factor. This intuition can be made precise by considering the Jacobian at a point with length . is a linear approximation of the network map from layer to in the vicinity of . Therefore a small random perturbation will map to . The growth of the perturbation, becomes after averaging over the random perturbation , weight matrix , and Gaussian distribution of across . Thus directly reflects the typical multiplicative growth or shrinkage of a random perturbation across one layer.
The dynamics of the iterative -map and its agreement with network simulations is shown in Fig. 2B. The correlation dynamics are much slower than the length dynamics because the -map is closer to the unity line (Fig. 2A) than the length map (Fig. 1A). Thus correlations typically take about layers to approach the fixed point, while lengths need only . The fixed point and slope of the -map are shown in Fig. 2CD. For any fixed, finite , as increases three qualitative regions occur. For small , is the only fixed point, and it is stable because . In this strong bias regime, any two input points converge to each other as they propagate through the network. As increases, increases and crosses , destabilizing the fixed point. In this intermediate regime, a new stable fixed point appears, which decreases as increases. Here an equal footing competition between weights and nonlinearities (which de-correlate inputs) and the biases (which correlate them), leads to a finite . At larger , the strong weights overwhelm the biases and maximally de-correlate inputs to make them orthogonal, leading to a stable fixed point at .
Thus the equation yields a phase transition boundary in the plane, separating it into a chaotic (or ordered) phase, in which nearby points separate (or converge). In dynamical systems theory, the logarithm of is related to the well known Lyapunov exponent which is positive (or negative) for chaotic (or ordered) dynamics. However, in a feedforward network, the dynamics is truncated at a finite depth , and hence the dynamics are a form of transient chaos.
The propagation of manifold geometry through deep networks
To illustrate these concepts, it is useful to compute all of them for the circle defined above: , , , , and . As expected, is the inverse of the radius of curvature, which is . Now consider how these quantities change if the circle is scaled up so that . The length and radius scale up by , but the curvature scales down as , and so does not change. Thus linear expansion increases length and decreases curvature, thereby maintaining constant Grassmannian length .
We now show that nonlinear propagation of this same circle through a deep network can behave very differently from linear expansion: in the chaotic regime, length can increase without any decrease in extrinsic curvature! To remove the scaling with in the above quantities, we will work with the renormalized quantities , , and . Thus, can be thought of as a radius of curvature squared per neuron of the osculating circle, while is the squared Euclidean length of the curve per neuron. For the circle, these quantities are and respectively. For simplicity, in the inputs to the first layer of neurons, we begin with a circle with squared radius per neuron , so this radius is already at the fixed point of the length map in (3). In the SM, we derive an iterative formula for the extrinsic curvature and Euclidean metric of this manifold as it propagates through the layers of a deep network:
where is the stretch factor defined in (7) and is defined analogously as
is closely related to the second derivative of the -map in (6) at ; this second derivative is . See SM for a derivation of the evolution equations (8) for the extrinsic geometry of a curve as it propagates through a deep network.
Intriguingly for a sigmoidal neural network, these evolution equations behave very differently in the chaotic () versus ordered () phase. In the chaotic phase, the Euclidean metric grows exponentially with depth due to multiplicative stretching through . This stretching does multiplicatively attenuate any curvature in layer by a factor (see the update equation for in (8)), but new curvature is added in due to a nonzero , which originates from the curvature of the single neuron nonlinearity in (9). Thus, unlike in linear expansion, extrinsic curvature is not lost, but maintained, and ultimately approaches a fixed point . This implies that the global curvature measure grows exponentially with depth. These highly nontrivial predictions of the metric and curvature evolution equations in (8) are quantitatively confirmed in simulations in Figure 4C-E.
Intuitively, this exponential growth of global curvature in the chaotic phase implies that the curve explores many different tangent directions in hidden representation space. This further implies that the coordinate functions of the embedding become highly complex curved basis functions on the input manifold coordinate , allowing a deep network to compute exponentially complex functions over simple low dimensional manifolds (Figure 5A-C, details in SM). In our companion paper , we further develop the relationship between length and expressivity in terms of the number of achievable classification patterns on a set of inputs. Moreover, we explore how training a single layer at different depths from the output affects network performance.
Shallow networks cannot achieve exponential expressivity
Consider a shallow network with hidden layer , one input layer , with , and a linear readout layer. How complex can the hidden representation be as a function of its width , relative to the results above for depth? We prove a general upper bound on (see SM):
Suppose is monotonically non-decreasing with bounded dynamic range , i.e. . Further suppose that is a curve in input space such that no 1D projection of changes sign more than times over the range of . Then for any choice of and the Euclidean length of , satisfies .
For the circle input, and for the nonlinearity, , so in this special case, the normalized length . In contrast, for deep networks in the chaotic regime grows exponentially with depth in space, and so consequently also in space. Therefore the length of curves typically expand exponentially in depth even for random deep networks, but can only expand as the square root of width no matter what shallow network is chosen. Moreover, as we have seen above, it is the exponential growth of that fundamentally drives the exponential growth of with depth. Indeed shallow random networks exhibit minimal growth in expressivity even at large widths (Figure 5D).
Classification boundaries acquire exponential local curvature with depth
We have focused so far on how simple manifolds in input space can acquire both exponential Euclidean and Grassmannian length with depth, thereby exponentially de-correlating and filling up hidden representation space. Another natural question is how the complexity of a decision boundary grows as it is backpropagated to the input layer. Consider a linear classifier acting on the final layer. In this layer, the dimensional decision boundary is the hyperplane . However, in the input layer , the decision boundary is a curved dimensional manifold that arises as the solution set of the nonlinear equation , where is the nonlinear feedforward map from input to output.
At any point on the decision boundary in layer , the gradient is perpendicular to the dimensional tangent plane (see Fig. 4F). The normal vector , along with any unit tangent vector , spans a dimensional subspace whose intersection with yields a geodesic curve in passing through with velocity vector . This geodesic will have extrinsic curvature . Maximizing this curvature over yields the first principal curvature . A sequence of successive maximizations of , while constraining to be perpendicular to all previous solutions, yields the sequence of principal curvatures . These principal curvatures arise as the eigenvalues of a normalized Hessian operator projected onto the tangent plane : , where is the projection operator onto and is the unit normal vector . Intuitively, near , the decision boundary can be approximated as a paraboloid with a quadratic form whose eigenvalues are the principal curvatures (Fig. 4F).
We compute these curvatures numerically as a function of depth in Fig. 4G (see SM for details). We find, remarkably, that a subset of principal curvatures grow exponentially with depth. Here the principal curvatures are signed, with positive (negative) curvature indicating that the associated geodesic curves towards (away from) the normal vector . Thus the decision boundary can become exponentially curved with depth, enabling highly complex classifications. Moreover, this exponentially curved boundary is disentangled and mapped to a flat boundary in the output layer.
Discussion
Moreover, our analysis of a maximum entropy distribution over deep networks constitutes an important null model of deep signal propagation that can be used to assess and understand different behavior in trained networks. For example, the metrics we have adapted from Riemannian geometry, combined with an understanding of their behavior in random networks, may provide a basis for understanding what is special about trained networks. Furthermore, while we have focused on the notion of input-output chaos, the duality between inputs and synaptic weights imply a form of weight chaos, in which deep neural networks rapidly traverse function space as weights change (see SM). Indeed, just as autocorrelation lengths between outputs as a function of inputs shrink exponentially with depth, so too will autocorrelations between outputs as a function of weights.
References
Supplementary Material
Below is a series of appendices giving derivations of results in the main paper, followed by details of results along with more visualizations.
Note: code to programmatically reproduce all plots in the paper in Jupyter notebooks will be released upon publication.
where is a vector of biases, is the pattern of inputs to neurons at layer , and is a single neuron scalar nonlinearity that acts component-wise to transform inputs to activities . The synaptic weights are drawn i.i.d. from a zero mean Gaussian with variance , while the biases are drawn i.i.d. from a zero mean Gaussian with variance . This weight scaling ensures that the input contribution to each individual neuron at layer from activities in layer remains , independent of the layer width .
As a single input point propagates through the network, it’s length in downstream layers can either grow or shrink. To track the propagation of this length, we track the normalized squared length of the input vector at each layer,
This length is the second moment of the empirical distribution of inputs across all neurons in layer for a fixed set of weights. This empirical distribution is expected to be Gaussian for large , since each individual is Gaussian distributed, as a sum of a large number of independent random variables, and each is independent of for because the synaptic weights vectors and biases into each neuron are chosen independently.
While the mean of this Gaussian is , its variance can be computed by considering the variance of the input to a single neuron:
where denotes an average over the distribution of weights and biases into neuron at layer . Here we have used the identity . Now the empirical distribution of inputs across layer is also Gaussian, with mean zero and variance . Therefore we can replace the average over neurons in layer in (12) with an integral over a Gaussian random variable, obtaining
where is the standard Gaussian measure, and the initial condition for the variance map is , where is the length in the initial activity layer. The function in (13) is an iterative variance map that predicts how the length of an input in (11) changes as it propagates through the network. Its derivation relies on the well-known self-averaging assumption in the statistical physics of disordered systems, which, in our context, means that the empirical distribution of inputs across neurons for a fixed network converges for large width, to the distribution of inputs to a single neuron across random networks.
A.2 Derivation of a correlation map for the propagation of two points
Now consider the layer-wise propagation of two inputs and . The geometry of these two inputs as they propagate through the layers is captured by the by matrix of inner products
The joint empirical distribution of and across at large will converge to a 2 dimensional Gaussian distribution with covariance . Propagating this joint distribution forward one layer using ideas similar to the derivation above for input yields
where is the correlation coefficient (CC). Here and are independent standard Gaussian variables, while and are correlated Gaussian variables with covariance matrix . The integration over and can be thought of as the large limit of sums over and .
When both input points are at their fixed point length, , the dynamics of their correlation coefficient can be obtained by simply setting in (15) and dividing by to obtain a recursion relation for :
Direct calculation reveals that as expected. Of particular interest is the slope of this map at . A direct, if tedious calculation shows that
To obtain this result, one has to apply the chain rule and product rule from calculus, as well as employ the identity
which can be obtained via integration by parts. Evaluating the derivative at yields
Appendix B Derivation of evolution equations for Riemannian curvature
Here we derive recursion relations for Riemannian curvature quantitites.
with . At large , the inner-product structure of translation invariant manifolds remains approximately translation invariant as it propagates through the network. Therefore, at large , we can express inner products of derivatives of in terms of derivatives of . For example, the Euclidean metric is given by
Here, each dot is a short hand notation for derivative w.r.t. . Also, the extrinsic curvature
where and , simplifies to
Now if the translation invariant manifold lives on a sphere of radius where is the fixed point radius of the length map, then its radius does not change as it propagates through the system. Then we can also express and in terms of the correlation coefficient function (up to a factor of ). Thus to understand the propagation of local quantities like Euclidean length and curvature, we need to understand the propagation of derivatives of at under the -map in (16). Note that is symmetric and achieves a maximum value of at . Thus the function is symmetric with a minimum at . We consider the propagation of though the -map. But first we consider the propagation of derivatives under function composition in general.
B.2 Behavior of first and second derivatives under function composition
Assume is an even function and , so that its Taylor expansion can be written as . We are interested in determining how the second and fourth derivatives of propagate under composition with another function , so that . We assume . We can use the chain rule and the product rule to derive:
B.3 Evolution equations for curvature and length
We now apply the above iterations with and . Clearly, the symmetric obeys , satisfying the above iterations of second and fourth derivatives. Taking into account these derivative recursions, using the expressions for and in terms of derivatives of at , and carefully accounting for factors of and , we obtain the final evolution equations that have been successfully tested against experiments:
where is the stretch factor defined in (19) and is defined analogously as
is closely related to the second derivative of the correlation coefficient map in (16) at . Indeed this second derivative is .
Appendix C Upper bounds on the complexity of shallow neural representations
Consider a shallow network with hidden layer and one input layer , so that . The network can compute functions through a linear readout of the hidden layer . We are interested in how complex these neural representations can get, with one layer of synaptic weights and nonlinearities, as a function the number of hidden units . In particular, we are interested in how the length and curvature of an input manifold changes as it propagates to become in the hidden layer. We would like to upper bound the maximal achievable length and curvature over all possible choices of and .
Here, we derive such an upper bound on the Euclidean length for a very general class of nonlinearities . We simply assume that (1) is monotonically non-decreasing (so that ) and (2) has with bounded dynamic range R, i.e. The Euclidean length in hidden space is
where the inequality follows from the triangle inequality. Now suppose that for any , never changes sign across . Furthermore, assume that ranges from to . Then
More generally, let denote the maximal number of times that any one neuron has a change in sign of the derivative across . Then applying the above argument to each segment of constant sign yields
Now how many times can change sign? Since , where , and is monotonically increasing, the number of times changes sign equals the number of times the input changes sign. In turn, suppose is the maximal number of times any one dimensional projection of the derivative vector changes sign across . Then the number of times the sign of changes for any cannot exceed because is a linear projection of . Together this implies . We have thus proven:
Appendix D Simulation details
All neural network simulations were implemented in Keras and Theano. For all simulations (except Figure 5C), we used inputs and hidden layers with a width of 1,000 and tanh activations. We found that our results were mostly insensitive to width, but using larger widths decreased the fluctuations in the averaged quantities. Simulation error bars are all standard deviations, with the variance computed across the different inputs, . If not mentioned, the weights in the network are initialized in the chaotic regime with , .
Computing requires the computation of the velocity and acceleration vectors, corresponding to the first and second derivatives of the neural network with respect to . As is always one-dimensional, we can greatly speed up these computations by using forward-mode auto-differentiation, evaluating the Jacobian and Hessian in a feedforward manner. We implemented this using the R-op in Theano.
To identify the curvature of the decision boundary, we first had to identify points that lied along the decision boundary. We randomly initialized data points and then optimized with respect to the input using Adam. This yields a set of inputs where we compute the Jacobian and Hessian of to evaluate principal curvatures.
D.2 Details on Figure 5C-D: evaluating expressivity
To evaluate the set of functions reachable by a network, we first parameterized function space using a Fourier basis up to a particular maximum frequency, on a sampled set of one dimensional inputs parameterized by . We then took the output activations of each neural network and linearly regressed the output activations onto each Fourier basis. For each basis, we computed the angle between the predicted basis vector and the true basis vector. These are the quantities that appear in Figure 5C-D. Given any function with bounded frequency, we can represent it in this Fourier basis, and decompose the error in the prediction of the function into the error in prediction of each Fourier component. Thus error in the predicting the Fourier basis is a reasonable proxy for error in prediction of functions with bounded frequency.
Appendix E Additional visualization of hidden actions
Appendix F A view from the function space perspective
We have shown above that for a fixed set of weights and biases in the chaotic regime, the internal representation at large depth , rapidly de-correlates from itself as the input changes (see e.g. Fig. 3B in the main paper). Here we ask a dual question: for a fixed input manifold, how does a deep network move in a function space over this manifold as the weights in a single layer change? Consider for example, a random one parameter family of deep networks parameterized by . In this family, we assume that the bias vectors in each layer are chosen as i.i.d. random Gaussian vectors with zero mean and variance , independent of . Moreover, we assume the weight matrix has elements that are drawn i.i.d. from zero mean Gaussians with variance , independent of for all layers except . The only dependence on in this family of networks originates in the weights in layer , chosen as
Here both a base matrix and a perturbation matrix have matrix elements that are zero mean i.i.d. Gaussians with variance . Each matrix element of thus also has variance just like all the other layers. In turn, this family of networks induces a family of functions . For simplicity, we restrict these functions to a simple input manifold, the circle,
and the associated correlation coefficient,
Because of our restriction to an input circle at the fixed point radius, for all and in the large width limit. By using logic similar to the derivation of (15), we can derive a recursion relation for the function space correlation :
where . The initial condition for this recursion is , since the family of functions in the first layer of inputs is independent of . Now, the difference in weights at a nonzero reduces the function space correlation to . At this point, the representation in is different for the two networks at parameter values and . Moreover, in the chaotic regime, this difference will amplify due to the similarity between the function space evolution equation in (37) and the evolution equation for the similarity of two points in (15). In essence, just as two points in the input exponentially separate as they propagate through a single network in the chaotic regime, a pair of different functions separate when computed in the final layer. Thus a small perturbation in the weights into layer can yield a very large change in the space of functions from the input manifold to layer . Moreover, as varies from -1 to 1, the function roughly undergoes a random walk in function space whose autocorrelation length decreases exponentially with depth . This weight chaos, or a sensitive dependence of the function computed by a deep network with respect to weight changes far from the final layer, is another manifestation of deep neural expressivity. Our companion paper further explores the expressivity of deep random networks in function space and also finds an exponential growth in expressivity with depth.