The Emergence of Spectral Universality in Deep Networks

Jeffrey Pennington, Samuel S. Schoenholz, Surya Ganguli

INTRODUCTION

A well-conditioned initialization is essential for successfully training neural networks. Seminal initial work focused on random weight initializations ensuring that the second moment of the spectrum of singular values of the network Jacobian from input to output remained one, thereby preventing exponential explosion or vanishing of gradients . However, recent work has shown that even among different random initializations sharing this property, those whose entire spectrum tightly concentrates around one can often yield faster learning by orders of magnitude. For example, deep linear networks with orthogonal initializations, for which the entire spectrum is exactly one, can achieve depth-independent learning speeds, while the corresponding Gaussian initializations cannot .

Recently, it was shown that a similarly well-conditioned Jacobian could be constructed for deep non-linear networks using a combination of orthogonal weights and tanh\tanh nonlinearities. The result of this improved conditioning was an orders-of-magnitude speedup in learning for tanh\tanh networks. However, the same study also proved that a well-conditioned Jacobian could not be achieved with Rectified Linear units (ReLUs). Together these results explained why, historically, in some cases orthogonal weight initialization had been found to improve training efficiency only slightly .

These empirical results connecting the conditioning of the Jacobian to a dramatic speedup in learning raise an important theoretical question. Namely, how does the entire shape of this spectrum depend on a network’s nonlinearity, weight and bias distribution, and depth? Here we provide a detailed analytic answer by using powerful tools from free probability theory. Our answer provides theoretical guidance on how to choose these different network ingredients so as to achieve tight concentration of deep Jacobian spectra even at very large depths. Along the way, we find several surprises, and we summarize our results in the discussion.

PRELIMINARIES

Here Dl\mathbf{D}^{l} is a diagonal matrix with entries Dijl=ϕ(hil)δijD^{l}_{ij}=\phi^{\prime}(h^{l}_{i})\,\delta_{ij}, where δij\delta_{ij} is the Kronecker delta function. The input-output Jacobian J\mathbf{J} is closely related to the backpropagation operator mapping output errors to weight matrices at a given layer, in the sense that if the former is well-conditioned, then the latter tends to be well-conditioned for all weight layers. We are therefore interested in understanding the entire singular value spectrum of J\mathbf{J} for deep networks with randomly initialized weights and biases.

In particular, we will take the biases bil\mathbf{b}^{l}_{i} to be drawn i.i.d. from a zero-mean Gaussian with standard deviation σb\sigma_{b}. For the weights, we will consider two random matrix ensembles: (1) random Gaussian weights in which each WijlW^{l}_{ij} is drawn i.i.d from a Gaussian with variance σw2/N\sigma_{w}^{2}/N, and (2) random orthogonal weights, drawn from a uniform distribution over scaled orthogonal matrices obeying (Wl)TWl=σw2I(\mathbf{W}^{l})^{T}\mathbf{W}^{l}=\sigma_{w}^{2}\,\mathbf{I}.

2 Review of Signal Propagation

The random matrices Dl\mathbf{D}^{l} in (2) depend on the empirical distribution of pre-activations hilh^{l}_{i} for i=1,,Ni=1,\dots,N entering the nonlinearity ϕ\phi in (1). The propagation of this empirical distribution through different layers ll was studied in . In those works, it was shown that in the large NN limit this empirical distribution converges to a Gaussian with zero mean and variance qlq^{l}, where qlq^{l} obeys a recursion relation induced by the dynamics in (1):

with initial condition q1=σw2Ni=1N(xi0)2+σb2q^{1}=\frac{\sigma_{w}^{2}}{N}\sum_{i=1}^{N}(x^{0}_{i})^{2}+\sigma_{b}^{2}, and Dh=dh2πexp(h22)\mathcal{D}h=\frac{dh}{\sqrt{2\pi}}\,\exp{(-\frac{h^{2}}{2})} denoting the standard normal measure. This recursion has a fixed point obeying,

If the input x0\mathbf{x}^{0} is chosen so that q1=qq^{1}=q^{*}, then the dynamics start at the fixed point and the distribution of Dl\mathbf{D}^{l} is independent of ll. Moreover, even if q1qq^{1}\neq q^{*}, a few layers is often sufficient to approximately converge to the fixed point (see ). As such, when LL is large, it is often a good approximation to assume that ql=qq^{l}=q^{*} for all depths ll when computing the spectrum of J\mathbf{J}.

Another important quantity governing signal propagation through deep networks is

where ϕ\phi^{\prime} is the derivative of ϕ\phi. Here χ\chi is second moment of the distribution of squared singular values of the matrix DW\mathbf{DW}, when the pre-activations are at their fixed point distribution with variance qq^{*}. As shown in , χ(σw,σb)\chi(\sigma_{w},\sigma_{b}) separates the (σw,σb)(\sigma_{w},\sigma_{b}) plane into two regions: (a) when χ>1\chi>1, forward signal propagation expands and folds space in a chaotic manner and back-propagated gradients exponentially explode; and (b) when χ<1\chi<1, forward signal propagation contracts space in an ordered manner and back-propagated gradients exponentially vanish. Thus the constraint χ(σw,σb)=1\chi(\sigma_{w},\sigma_{b})=1 determines a critical line in the (σw,σb)(\sigma_{w},\sigma_{b}) plane separating the ordered and chaotic regimes. Moreover, the second moment of the distribution of squared singular values of J\mathbf{J} was shown simply to be χL\chi^{L} in . Fig. 1 shows an example of an order-chaos transition for the tanh nonlinearity.

3 Review of Free Probability

The previous section revealed that the mean squared singular value of J\mathbf{J} is χL\chi^{L}. Indeed when χ1\chi\ll 1 or χ1\chi\gg 1 the vanishing or explosion of gradients, respectively, dominates the learning dynamics and provide a compelling case for choosing an initialization that is critical with χ=1\chi=1. We would like to investigate the question of whether or not all cases where χ=1\chi=1 are the same and, in particular, to obtain more detailed information about entire the singular value distribution of J\mathbf{J} when χ=1\chi=1. Since (2) consists of a product of random matrices, free probability becomes relevant as a powerful tool to compute the spectrum of J\mathbf{J}, as we now review. See for a pedagogical introduction, and for prior work applying free probability to deep learning.

In general, given a random matrix X\mathbf{X}, its limiting spectral density is defined as

where X\langle\cdot\rangle_{X} denotes an average w.r.t to the distribution over the random matrix X\mathbf{X}.

The Stieltjes transform of ρX\rho_{X} is defined as,

GXG_{X} is related to the moment generating function MXM_{X},

where mkm_{k} is the kkth moment of the distribution ρX\rho_{X},

In turn, we denote the functional inverse of MXM_{X} by MX1M_{X}^{-1}, which by definition satisfies MX(MX1(z))=MX1(MX(z))=zM_{X}(M_{X}^{-1}(z))=M_{X}^{-1}(M_{X}(z))=z. Finally, the S-transform is defined as,

The utility of the S-transform arises from its behavior under multiplication. Specifically, if A\mathbf{A} and B\mathbf{B} are two freely independent random matrices, then the S-transform of the product random matrix ensemble AB\mathbf{A}\mathbf{B} is simply the product of their S-transforms,

MASTER EQUATION FOR SPECTRAL DENSITY

We can now write down an implicit expression of the spectral density of JJT\mathbf{J}\mathbf{J}^{T}, which is also the distribution of the square of the singular values of J\mathbf{J}. In particular, in the supplementary material (SM) Sec. 1, we combine (12) with the facts that the S-transform depends only on traces of moments through (9), and that these traces are invariant under cyclic permutations, to derive a simple expression for the S-transform of JJT\mathbf{J}\mathbf{J}^{T},

Here the lack of dependence on the layer index ll on the RHS is valid if the input x0\mathbf{x}^{0} is such that q1=qq^{1}=q^{*}.

Thus, given expressions for the S-transforms associated with the nonlinearity, SD2S_{D^{2}}, and the weights, SWTWLS^{L}_{W^{T}W}, one can compute the S-transform of the input-output Jacobian SJJTS_{JJ^{T}} at any network depth LL through (13). Then from SJJTS_{JJ^{T}}, one can invert the sequence (7), (9), and (11) to obtain ρJJT(λ)\rho_{JJ^{T}}(\lambda).

2 An Efficient Master Equation

The previous section provides a naive method for computing the spectrum ρJJT(λ)\rho_{JJ^{T}}(\lambda), through a complex sequence of calculations. One must start from ρWTW(λ)\rho_{W^{T}W}(\lambda) and ρD2(λ)\rho_{D^{2}}(\lambda), compute their respective Stieltjes transforms, moment generating functions, inverse moment generating functions, and S-transforms, take the product in (13), and then invert this sequence of steps to finally arrive at ρJJT(λ)\rho_{JJ^{T}}(\lambda). Here we provide a much simpler “master” equation for extracting information about ρJJT(λ)\rho_{JJ^{T}}(\lambda) and its moments directly from knowledge of the moment generating function of the nonlinearity, MD2(z)M_{D}^{2}(z), and the S-transform of the weights, SWTW(z)S_{W^{T}W}(z). As we shall see, these latter two functions are the simplest functions to work with for arbitrary nonlinearities.

To derive the master equation, we insert (11), for X=D2\mathbf{X}=\mathbf{D}^{2}, into (13), and perform some algebraic manipulations (see SM Sec. 3 for details) to obtain implicit functional equations for MJJT(z)M_{JJ^{T}}(z) and G(z)G(z),

In principle, a solution to eq. (15) allows us to compute the entire spectrum of JJT\bm{J}\bm{J}^{T}. In practice, when an exact solution in terms of elementary functions is lacking, it is still possible to extract robust numerical solutions, as we describe in the next subsection.

3 Numerical Extraction of Spectra

Here we describe how to solve (15) numerically. The difficulty is that (15) implicitly defines G(z)G(z) through an equation of the form F(G,z)=0\mathcal{F}(G,z)=0. Notice that, for any given zz, this equation may have multiple roots in GG. The correct branch can be chosen by requiring that zz\to\infty, G(z)1/zG(z)\sim 1/z . Therefore, one point on the correct branch can be found by taking z|z| large, and finding the solution to F(G,z)=0\mathcal{F}(G,z)=0 that is closest to G=1/zG=1/z. Recall that to obtain the density ρJJT(λ)\rho_{JJ^{T}}(\lambda) through the inversion formula ((8)), we need to extract the behavior of G(z)G(z) near the real axis at a point z=λ+iϵz=\lambda+i\epsilon where ρJJT(λ)\rho_{JJ^{T}}(\lambda) has support. So, practically speaking, for each λ\lambda we can walk along the imaginary direction obeying Re(z)=λ\text{Re}(z)=\lambda from large imaginary values to small, and repeatedly solve F(G,z)=0\mathcal{F}(G,z)=0, always choosing the root that is closest to the previous root.

In the following sections, we demonstrate through many examples a precise numerical match between the outcome of Algorithm 1 and direct simulations of various random neural networks, thereby justifying not only (15), but also the efficacy our algorithm.

4 Moments of Deep Spectra

In addition to numerically extracting the spectrum of JJT\mathbf{J}\mathbf{J}^{T}, we can also calculate its moments mkm_{k} encoded in the function

These moments in turn can be computed in terms of the series expansions of SWTWS_{W^{T}W} and MD2M_{D^{2}}, which we define as

where the moments μk\mu_{k} of D2\mathbf{D}^{2} are given by,

Substituting these expansions into (14), we obtain equations for the unknown moments mkm_{k} in terms of the known moments μk\mu_{k} and sks_{k}. We can solve for the low-order moments by expanding (14) in powers of z1z^{-1}. By equating the coefficients of z1z^{-1} and z2z^{-2}, we find equations for m1m_{1} and m2m_{2} whose solution yields (see SM Sec. 3),

Note the combination σw2μ1\sigma_{w}^{2}\mu_{1} is none other than χ\chi defined in (5), and so (21) recovers the result that the mean squared singular value m1m_{1} of J\mathbf{J} either exponentially explodes or vanishes unless χ(σw,σb)=1\chi(\sigma_{w},\sigma_{b})=1 on a critical boundary between order and chaos. However, even on this critical boundary where the mean m1m_{1} of the spectrum of JJT\mathbf{J}\mathbf{J}^{T} is one for any depth LL, the variance

grows linearly with depth LL for generic values of μ1\mu_{1}, μ2\mu_{2} and s1s_{1}. Thus J\mathbf{J} can be highly ill-conditioned at large depths LL for generic choices of nonlinearities and weights, even when σw\sigma_{w} and σb\sigma_{b} are tuned to criticality.

SPECIAL CASES OF DEEP SPECTRA

Exploiting the master equation (14) requires information about MD2(z)M_{D^{2}}(z), and SWWT(z)S_{WW^{T}}(z). We first provide this information and then use it to look at special cases of deep networks.

First, for any nonlinearity ϕ(h)\phi(h), we have, through (7) and (9),

The integral over the Gaussian measure Dh\mathcal{D}h reflects a sum over all the activations hilh^{l}_{i} in a layer ll, since in the large NN limit the empirical distribution of activations converges to a Gaussian with standard deviation q\sqrt{q^{*}}. Moreover, an activation hilh^{l}_{i} feels a squared slope ϕ(hil)2\phi^{\prime}(h^{l}_{i})^{2}, which appears as an eigenvalue of the diagonal matrix (Dl)2(\mathbf{D}^{l})^{2}. Thus MD2(z)M_{D^{2}}(z) naturally involves an integral over a function of ϕ()2\phi^{\prime}(\cdot)^{2} against a Gaussian.

Table 1 provides the moment generating function and moments of D2\mathbf{D}^{2} for several nonlinearities. Detailed derivations of the results in Table 1, which follow from performing the integral in (23), can be found in the SM Sec. 3. In the Erf case, Φ\Phi is a special function known as the Lerch transcendent, which can be defined by its moments μk\mu_{k}.

2 Transforms of Weights

The S-transforms of the weights can be obtain through the sequence of equations (7), (9), and (11), starting with ρWTW(λ)=δ(λ1)\rho_{W^{T}W}(\lambda)=\delta(\lambda-1) for an orthogonal random matrix W\mathbf{W}, and ρWTW(λ)=(2π)14λforλ\rho_{W^{T}W}(\lambda)=(2\pi)^{-1}\sqrt{4-\lambda}\quad\text{for}\,\lambda\in, for a Gaussian random matrix W\mathbf{W} with variance 1N\frac{1}{N} (see SM Sec. 5). Furthermore, by scaling WσwW\mathbf{W}\rightarrow\sigma_{w}\mathbf{W}, the S-transform scales as SWTWσw2SWTWS_{W^{T}W}\rightarrow\sigma_{w}^{-2}S_{W^{T}W}, yielding the S-transforms and first moments in Table 2.

3 Exact Properties of Deep Spectra

Now for different randomly initialized deep networks, we insert the appropriate expressions in Tables 1 and 2 into our master equations (14) and (15) to obtain information about the spectrum of JJT\mathbf{J}\mathbf{J}^{T}, including its entire shape, through Algorithm 1, and its variance σJJT2\sigma_{JJ^{T}}^{2} through (21) and (22). We always work at criticality, so that in (5), χ=σw2μ1=1\chi=\sigma_{w}^{2}\mu_{1}=1. The resulting condition for σw2\sigma_{w}^{2} at criticality and the value of σJJT2\sigma_{JJ^{T}}^{2} are shown in Table 1 for different nonlinearities, both for orthogonal (s1=0s_{1}=0) and Gaussian (s1=1s_{1}=-1) weights.

For linear networks, the fixed point equation (4) reduces to q=σw2q+σb2q^{*}=\sigma_{w}^{2}q^{*}+\sigma_{b}^{2}, and (σw,σb)=(1,0)(\sigma_{w},\sigma_{b})=(1,0) is the only critical point. Moreover, linear Gaussian networks behave very differently from orthogonal ones. The latter are well conditioned, with σJJT2=0\sigma^{2}_{JJ^{T}}=0 because the product of orthogonal matrices is orthogonal and so ρJJT(λ)=δ(λ1)\rho_{JJ^{T}}(\lambda)=\delta(\lambda-1) for all LL. However, σJJT2=L\sigma^{2}_{JJ^{T}}=L for Gaussian weights. This radically different behavior of the spectrum of JJT\mathbf{JJ}^{T} is shown in Fig. 2A.

3.2 ReLU Networks

For ReLU networks, the fixed point equation (4) reduces to q=12σw2q+σb2q^{*}=\frac{1}{2}\sigma_{w}^{2}q^{*}+\sigma_{b}^{2}, and (σw,σb)=(2,0)(\sigma_{w},\sigma_{b})=(\sqrt{2},0) is the only critical point. Unlike the linear case, σJJT2\sigma_{JJ^{T}}^{2} becomes LL for orthogonal and 2L2L for Gaussian weights. In essence, the ReLU nonlinearity destroys the qualitative scaling advantage that linear networks possess for orthogonal weights versus Gaussian. The qualitative similarity of spectra for ReLU Orthogonal and linear Gaussian is shown in Fig. 2AB.

3.3 Hard Tanh and Erf Networks

For Hard Tanh and Erf Networks, the criticality condition σw2=μ11\sigma_{w}^{2}={\mu_{1}^{-1}} does not determine a unique value of σw2\sigma_{w}^{2} because μ1\mu_{1}, the mean squared slope ϕ(h)2\phi^{\prime}(h)^{2}, now depends on the variance qq^{*} of the distribution of pre-activations hh. Since qq^{*} itself is a function of σw\sigma_{w} and σb\sigma_{b} through (4), these networks enjoy an entire critical curve in the (σw,σb)(\sigma_{w},\sigma_{b}) plane, similar to that shown in Fig. 1. As qq^{*} decreases monotonically towards zero, the corresponding point on this curve approaches the point (σw,σb)=(1,0)(\sigma_{w},\sigma_{b})=(1,0).

Moreover, Table 1 shows that σJJT2=L(F(q)1s1)\sigma_{JJ^{T}}^{2}=L(\mathcal{F}(q^{*})-1-s_{1}) with limq0F(q)=1\lim_{q^{*}\rightarrow 0}\mathcal{F}(q^{*})=1. This implies that for Gaussian weights (s1=1s_{1}=-1), no matter how small one makes σw\sigma_{w}, σJJT2L\sigma^{2}_{JJ^{T}}\propto L. However, for orthogonal weights (s1=0s_{1}=0), for any fixed LL, one can reduce σw\sigma_{w} and therefore qq^{*}, so as to make σJJT2\sigma_{JJ^{T}}^{2} arbitrarily small. Thus Hard Tanh and Erf nonlinearities rescue the scaling advantage that orthogonal weights possess over Gaussian, which was present in linear networks, but destroyed in ReLU networks. Examples of the well-conditioned nature of orthogonal Hard Tanh and Erf networks compared to orthogonal ReLu networks are shown in Fig. 2.

UNIVERSALITY IN DEEP SPECTRA

Table 1 shows that for orthogonal Erf and Hard Tanh networks (but not ReLU networks), since σJJT2=L(F(q)1)\sigma_{JJ^{T}}^{2}=L(\mathcal{F}(q^{*})-1) with limq0F(q)=1\lim_{q^{*}\rightarrow 0}\mathcal{F}(q^{*})=1, one can always choose qq^{*} to vary inversely with LL so as to achieve a desired LL-independent constant variance σJJT2σ02\sigma^{2}_{JJ^{T}}\equiv\sigma^{2}_{0}. To achieve this scaling, q(L)q^{*}(L) should satisfy the equation F(q(L))=1+σ02L\mathcal{F}(q^{*}(L))=1+\frac{\sigma^{2}_{0}}{L}, which implies σw1\sigma_{w}\to 1 and q0q^{*}\rightarrow 0 as LL\to\infty.

Remarkably, in this double scaling limit, not only does the variance of the spectrum of JJT\mathbf{JJ}^{T} remain constant at the fixed value σ02\sigma_{0}^{2}, but the entire shape of the distribution converges to a universal limiting distribution as LL\rightarrow\infty. There is more than one possible limiting distribution, but its form depends on ϕ\phi only through the distribution of ϕ(h)2\phi^{\prime}(h)^{2} as q0q^{*}\to 0 via the expression for MD2(z)M_{D^{2}}(z) in (23). Therefore, many qualitatively different activation functions may in fact be members of the same universality class. We identify two universality classes that correspond to many common activation functions: the Bernoulli universality class and the smooth universality class, named based on the distribution of ϕ(h)2\phi^{\prime}(h)^{2} as q0q^{*}\to 0.

The Bernoulli universality class contains many piecewise linear activation functions, such as Hard Tanh (Fig. 3C) and a version of ReLU shifted so as to be linear at the origin, which for concreteness we define as ϕ(x)=[x+12]+12\phi(x)=[x+\frac{1}{2}]_{+}-\frac{1}{2} (Fig. 3E). While these functions look quite different, their derivatives are both Bernoulli-distributed (Fig. 3DF) and the limiting spectra of their corresponding Jacobians are the same (Fig. 4AB).

The smooth universality class contains many smooth activation functions, such as Erf (Fig. 3G) and a smoothed version of ReLU that we take to be the sigmoid-weighted linear unit (SiLU) (Fig. 3I). In this case, not only do the activation functions themselves look different, but so too do their derivatives (Fig. 3HJ). Nevertheless, in the double scaling limit, the limiting spectra of their corresponding Jacobians are the same (Fig. 4CD). The rate of convergence to the limiting distribution is different, because the moments μk\mu_{k} differ substantially for non-zero qq^{*}.

Unlike the smoothed and shifted versions of ReLU, the vanilla ReLU activation (Fig. 3AB) behaves entirely differently and has no limiting distribution because the μk\mu_{k} are independent of qq^{*} and therefore it is impossible to attain an LL-independent constant variance σJJT2σ02\sigma^{2}_{JJ^{T}}\equiv\sigma^{2}_{0} in this case.

To understand the mechanism behind the emergence of spectral universality, we now examine orthogonal networks whose activation functions have squared derivatives obeying a Bernoulli distribution and show that they all share a universal limiting distribution as LL\to\infty. To this end, we suppose that,

for some function p(q)p(q^{*}) that measures the probability of the nonlinearity having slope one as a function of qq^{*}. We will assume that p(q)1p(q^{*})\to 1 as q0q^{*}\to 0. The relevant ratio of moments and the weight variance σw2\sigma_{w}^{2} are given as,

Notice that a solution q(L)q^{*}(L) to (22) will exist for large LL since we are assuming p(q)1p(q^{*})\to 1 as q0q^{*}\to 0. Substituting this solution in (24) and (25) gives for large LL,

Using these expressions and (11), we find that the S-transform obeys,

Using (9) and (11) to solve for G(z)G(z) gives,

where WW denotes the principal branch of the Lambert-W function and solves the transcendental equation,

The spectral density can be extracted from (30) easily using (8). The results are shown in black lines in Fig. 4AB. Both Hard Tanh and Shifted ReLU have Bernoulli-distributed ϕ(h)2\phi^{\prime}(h)^{2} and, despite being qualitatively different activation functions, have the same limiting spectral distributions. It is evident that the empirical spectral densities converge to this universal limiting distribution as the depth increases.

Next we build some additional understanding of the spectral density implied by (30). Because the spectral density is proportional to the imaginary part of G(z)G(z), we expect the locations of the spectral edges to be related to branch points of G(z)G(z), or more generally to poles in its derivative. Using the relation,

we can inspect the derivative of G(z)G(z). It may be expressed as,

By inspection, we find that G(z)G^{\prime}(z) has double poles at,

which are locations where the spectral density diverges, i.e. there are delta function peaks at λ0\lambda_{0} and λ2\lambda_{2}. Note that there is only a pole at λ2\lambda_{2} if σ01\sigma_{0}\leq 1. There is also a single pole at,

which defines the right spectral edge, i.e. the maximum value of the bulk of the density.

The above observations regarding λ0\lambda_{0}, λ1\lambda_{1}, and λ2\lambda_{2} are evident in Fig. 4AB. Noting that in the figure, σ0=1/2\sigma_{0}=1/2, we predict that the bulk of the density to have its right edge located at s=λ1=e/20.82s=\sqrt{\lambda_{1}}=\sqrt{e}/2\approx 0.82 and that there should be a delta function peak at s=λ2=e1/81.13s=\sqrt{\lambda_{2}}=e^{1/8}\approx 1.13, both of which are reflected in the figure.

A similar analysis can be carried out for activation functions for which the distribution of ϕ(h)2\phi^{\prime}(h)^{2} is smooth and concentrates around one as q0q^{*}\to 0. The analysis for Erf is presented in the SM. We find that,

and that G(z)G(z) can be expressed in terms of a generalized Lambert-W function . The locations of the spectral edges are given by s±=e14σ±21+12σ2s_{\pm}=e^{-\frac{1}{4}\sigma_{\pm}^{2}}\sqrt{1+\frac{1}{2}\sigma_{\mp}^{2}}, where,

For σ0=1/2\sigma_{0}=1/2, these results give s0.57s_{-}\approx 0.57 and s+=1.56s_{+}=1.56, which is in excellent agreement with the behavior observed in Fig. 4CD. Overall, Fig. 4 provides strong evidence supporting our predictions that orthogonal Hard Tanh and shifted ReLU networks have the Bernoulli limit distribution, while orthogonal Erf and smoothed Relu networks have the smooth limit distribution.

Finally, we derived these universal limits assuming orthogonal weights. In the SM we show that orthogonality is in fact necessary for the existence of a stable limiting distribution for the spectrum of JJT\mathbf{JJ}^{T}. No other random matrix ensemble can yield a stable distribution for any choice of nonlinearity with ϕ(0)=1\phi^{\prime}(0)=1. Essentially, any spread in the singular values of W\mathbf{W} grows in an unbounded way with depth and cannot be nonlinearly damped.

DISCUSSION

In summary, motivated by a lack of theoretical clarity on when and why different weight initializations and nonlinearities combine to yield well-conditioned spectra that speed up deep learning, we developed a calculational framework based on free probability to provide, with unprecedented detail, analytic information about the entire Jacobian spectrum of deep networks with arbitrary nonlinearities. Our results provide a principled framework for the initialization of weights and the choice of nonlinearities in order to produce well-conditioned Jacobians and fast learning. Intriguingly, we find novel universality classes of deep spectra that remain well-conditioned as the depth goes to infinity, as well as theoretical conditions for their existence. Our results lend additional support to the surprising conclusions revealed in , namely that using either Gaussian initializations or ReLU nonlinearities precludes the possibility of obtaining stable spectral distributions for very deep networks. Beyond the sigmoidal units advocated in , our results suggest that a wide variety of nonlinearities, including shifted and smoothed variants of ReLU, can achieve dynamical isometry, provided the weights are orthogonal. Interesting future work could involve the discovery of new universality classes of well-conditioned deep spectra for more diverse nonlinearities than considered here.

References

Review of free probability

For what follows, we define the key objects of free probability. Given a random matrix X\mathbf{X}, its limiting spectral density is defined as

where X\langle\cdot\rangle_{X} denotes an average w.r.t to the distribution over the random matrix X\mathbf{X}. For large NN, the empirical histogram of eigenvalues of a single realization of X\mathbf{X} converges to ρX\rho_{X}. In turn, the Stieltjes transform of ρX\rho_{X} is defined as,

GXG_{X} is related to the moment generating function MXM_{X},

where the mkm_{k} is the kk’th moment of the distribution ρX\rho_{X},

In turn, we denote the functional inverse of MXM_{X} by MX1M_{X}^{-1}, which by definition satisfies MX(MX1(z))=MX1(MX(z))=zM_{X}(M_{X}^{-1}(z))=M_{X}^{-1}(M_{X}(z))=z. Finally, the S-transform is defined in terms of the functional inverse MX1M_{X}^{-1} as,

The utility of the S-transform arises from its behavior under multiplication. Specifically, if A\mathbf{A} and B\mathbf{B} are two freely independent random matrices, then the S-transform of the product random matrix ensemble AB\mathbf{A}\mathbf{B} is simply the product of their S-transforms,

Free probability and deep networks

We will now use eqn. (S7) to write down an implicit definition of the spectral density of JJT\mathbf{J}\mathbf{J}^{T}, which is also the distribution of the square of the singular values of J\mathbf{J}. Here J\mathbf{J} is the input-output Jacobian of a deep network defined in the main paper. First notice that, by eqn. (9), M(z)M(z) and thus S(z)S(z) depend only on the moments of the spectral density. The moments, in turn, can be defined in terms of traces (as in eqn. (S5)), which are invariant to cyclic permutations, i.e.,

where the last equality follows if each term in the Jacobian product identically distributed. Given the expression for SJJTS_{JJ^{T}}, a simple procedure recovers the density of singular values of J\mathbf{J}:

Use eqn. (S6) to obtain the moment generating function MJJT(z)M_{JJ^{T}}(z)

Use eqn. (9) to obtain the Stieltjes transform GJJT(z)G_{JJ^{T}}(z)

Use eqn. (S3) to obtain the spectral density ρJJT(λ)\rho_{JJ^{T}}(\lambda)

Use the relation λ=σ2\lambda=\sigma^{2} to obtain the density of singular values of JJ.

So in order to compute the distribution of singular values of of JJ, all that remains is to compute the S-transforms of WTWW^{T}W and of D2D^{2}. We will attack this problem for specific activation functions and matrix ensembles in the following sections.

Derivation of master equations for the spectrum of the Jacobian

To derive the master equation, we first insert (S6), for X=D2\mathbf{X}=\mathbf{D^{2}}, into (S10) to obtain

Then we find MJJT1=(1+z)(zSJTJ)1M_{JJ^{T}}^{-1}=(1+z)(zS_{J^{T}J})^{-1} by inverting (S6), which combined with the above equation yields

Applying MD2M_{D^{2}} to both sides gives,

Finally, evaluating this equation at z=MJJTz=M_{JJ^{T}} gives our sought after master equation:

This is an implicit functional equation for MJJT(z)M_{JJ^{T}}(z), an unknown quantity, in terms of the known functions MD2(z)M_{D^{2}}(z) and SWTW(z)S_{W^{T}W}(z). Furthermore, by substituting (S4), MJJT=zGJJT1M_{JJ^{T}}=zG_{JJ^{T}}-1, into (S11), we also obtain an implicit functional equation for the Stieltjes transform GG of ρJJT(λ)\rho_{JJ^{T}}(\lambda),

Derivation of Moments of deep spectra

The moments mkm_{k} of the spectrum of JJT\mathbf{J}\mathbf{J}^{T} are encoded in the moment generating function

These moments in turn can be computed in terms of the series expansions of SWTWS_{W^{T}W} and MD2M_{D^{2}}, which we define as

where the moments μk\mu_{k} of D2\mathbf{D}^{2} are given by,

We can substitute these moment expansions into (S11) to obtain equations for the unknown moments mkm_{k} of the spectrum of JJT\mathbf{J}\mathbf{J}^{T}, in terms of the known moments μk\mu_{k} and sks_{k}. We can solve for the low order moments by expanding (S11) in powers of z1z^{-1}. By equating the coefficients of z1z^{-1} and z2z^{-2}, we obtain the following equations for m1m_{1} and m2m_{2},

Transforms of Nonlinearities

Here we compute the moment generating functions MD2(z)M_{D^{2}}(z) for various choices of the nonlinearity ϕ\phi, some of which are displayed in Table 1 of the main paper.

3 ϕ​(x)=htanh⁡(x)italic-ϕ𝑥htanh𝑥\phi(x)=\operatorname{htanh}(x)

5 ϕ​(x)=erf⁡(π2​x)italic-ϕ𝑥erf𝜋2𝑥\phi(x)=\operatorname{erf}(\frac{\sqrt{\pi}}{2}x)

where Φ\Phi is the special function known as the Lerch transcendent.

6 ϕ​(x)=2π​arctan⁡(π2​x)italic-ϕ𝑥2𝜋𝜋2𝑥\phi(x)=\frac{2}{\pi}\arctan(\frac{\pi}{2}x)

Transforms of Weights

First consider the case of an orthogonal random matrix satisfying WTW=I\mathbf{W^{T}W}=\mathbf{I}. Then

The case of a random Gaussian random matrix W\mathbf{W} with zero mean, variance 1N\frac{1}{N} entries is more complex, but well known:

Furthermore, by scaling WσwW\mathbf{W}\rightarrow\sigma_{w}\mathbf{W}, the S-transform scales as SWTWσw2SWTWS_{W^{T}W}\rightarrow\sigma_{w}^{-2}S_{W^{T}W}, yielding the S-transforms in Table 1.

Universality class of orthogonal Hard Tanh networks

We consider hard tanh with orthogonal weights. The moment generating function is,

if we wish to scale qq^{*} with depth LL so as to achieve a depth independent constant variance σJJT2=σ02\sigma_{JJ^{T}}^{2}=\sigma_{0}^{2} as LL\rightarrow\infty. This expression for qq^{*} gives,

where WW is the standard Lambert-W function, or product log. The derivative of this function has double poles at,

which are locations where the spectral density diverges. There is also a single pole at,

which is the maximum value of the bulk of the density.

Universality class of orthogonal erferf\operatorname{erf} networks

Consider ϕ(x)=π2erf(x2),\phi(x)=\sqrt{\frac{\pi}{2}}\operatorname{erf}(\frac{x}{\sqrt{2}}), which has been scaled so that ϕ(0)=1\phi^{\prime}(0)=1 and ϕ(0)=1\phi^{\prime\prime\prime}(0)=-1. The μk\mu_{k} are given by,

If we wish to scale qq^{*} with depth LL so as to achieve a depth independent constant variance σJJT2=σ02\sigma_{JJ^{T}}^{2}=\sigma_{0}^{2} as LL\rightarrow\infty, then we can choose

Since we also assume the network is critical, we also have that,

To illustrate universality, we next consider an arbitrary activation function, and assume that it has a Taylor expansion around 0. This allows us to expand the μk\mu_{k}. First we write,

We will need ϕ10\phi_{1}\neq 0. First we will assume that ϕ20\phi_{2}\neq 0. Using this expansion we can write,

where we have used the fact that the network is critical so that we have μ1=g2\mu_{1}=g^{-2}. Using the Lagrange inversion theorem to expand MD21M_{D^{2}}^{-1}, we find that

Next we will assume that ϕ2=0\phi_{2}=0 and ϕ30\phi_{3}\neq 0We suspect these additional assumptions are unnecessary and that the results which follow are valid so long as there exists a kk for which ϕk0\phi_{k}\neq 0. It would be interesting to prove this.. Using the above expansion we can write,

where we have used the fact that the network is critical so that we have μ1=σw2\mu_{1}=\sigma_{w}^{-2}. Using the Lagrange inversion theorem to expand MD21M_{D^{2}}^{-1}, we find that

establishing a universal limiting S-transform (subject to our assumptions). From this result we can extract the Stieltjes transform and thus the spectral density. The result establishes a universal double scaling limiting spectral distribution. Next we observe that the Stieltjes transform can be expressed in terms of a generalization of the Lambert -WW function called the r-Lambert function, Wr(z)W_{r}(z), which is defined by

In terms of this function, the Stieltjes transform is,

We can extract the maximum and minumum eigenvalue by finding the branch points of this function. It suffices to look for poles in the derivative of the numerator of G(z)G(z). Using r=σ02zeσ02r=-\sigma_{0}^{2}ze^{\sigma_{0}^{2}}, eqn. (S52) and its total derivative with respect to zz yields the following equation defining the locations of these poles,

where WW is the standard Lambert W function. Next we substite this relation into eqn. (S52); zeros in zz then define the location of the branch points. Some straightforward algebra yields the maximum and minimum eigenvalue,

Orthogonal weights are required for stable, universal limiting distributions

We work at criticality so χ=σw2μ1=1\chi=\sigma_{w}^{2}\mu_{1}=1. This implies that

Observe that Jensen’s inequality requires that μ2μ12\mu_{2}\geq\mu_{1}^{2}. If we require that σJJT2\sigma_{JJ^{T}}^{2} approach a constant as LL\to\infty, we must have that,

we can relate σw\sigma_{w} and s1s_{1} to m1\mathfrak{m}_{1} and m2\mathfrak{m}_{2}. Specifically, evaluating the relation,

Expanding this equation to second order gives,

Positivity of variance gives s10s_{1}\leq 0, which, together with eqn. (S58) implies,

Altogether we see that the variance of the distribution of eigenvalues of WWTWW^{T} must be zero. Since its mean is equal to σw2\sigma_{w}^{2}, we see that the only valid distribution for the eigenvalues of WWTWW^{T} is a delta function peaked at σw2\sigma_{w}^{2}, i.e. the distribution corresponding to the singular values of an orthogonal matrix scaled by σw\sigma_{w}.