Dynamical Isometry and a Mean Field Theory of RNNs: Gating Enables Signal Propagation in Recurrent Neural Networks

Minmin Chen, Jeffrey Pennington, Samuel S. Schoenholz

Introduction

Recurrent Neural Networks (RNNs) (Rumelhart et al., 1986; Elman, 1990) have found widespread use across a variety of domains from language modeling (Mikolov et al., 2010; Kiros et al., 2015; Jozefowicz et al., 2016) and machine translation (Bahdanau et al., 2014) to speech recognition (Graves et al., 2013) and recommendation systems (Hidasi et al., 2015; Wu et al., 2017). However, RNNs as originally proposed are difficult to train and are rarely used in practice. Instead, variants of RNNs - such as Long Short-Term Memory (LSTM) networks (Hochreiter & Schmidhuber, 1997) and Gated Recurrent Units (GRU) (Chung et al., 2014) - that feature various forms of “gating” perform significantly better than their vanilla counterparts. Often, these models must be paired with techniques such as normalization layers (Ioffe & Szegedy, 2015b; Ba et al., 2016) and gradient clipping (Pascanu et al., 2013) to achieve good performance.

A rigorous explanation for the remarkable success of gated recurrent networks remains illusive (Jozefowicz et al., 2015; Greff et al., 2017). Recent work (Collins et al., 2016) provides empirical evidence that the benefits of gating are mostly rooted in improved trainability rather than increased capacity or expressivity. The problem of disentangling trainability from expressivity is widespread in machine learning since state-of-the-art architectures are nearly always the result of sparse searches in high dimensional spaces of hyperparameters. As a result, we often mistake trainability for expressivity. Seminal early work (Glorot & Bengio, ; Bertschinger et al., ) showed that a major hindrance to trainability was the vanishing and exploding of gradients.

Recently, progress has been made in the feed-forward setting (Schoenholz et al., 2017; Pennington et al., 2017; Yang & Schoenholz, 2017) by developing a theory of both the forward-propagation of signal and the backward-propagation of gradients. This theory is based on studying neural networks whose weights and biases are randomly distributed. This is equivalent to studying the behavior of neural networks after random initialization or, equivalently, to studying the prior over functions induced by a particular choice of hyperparameters (Lee et al., 2017). It was shown that randomly initialized neural networks are trainable if three conditions are satisfied: (1) the size of the output of the network is finite for finite inputs, (2) the output of the network is sensitive to changes in the input, and (3) gradients neither explode nor vanish. Moreover, neural networks achieving dynamical isometry, i.e. having input-output Jacobian matrices that are well-conditioned, were shown to train orders of magnitude faster than networks that do not.

In this work, we combine mean field theory and random matrix theory to extend these results to the recurrent setting. We will be particularly focused on understanding the role that gating plays in trainability. As we will see, there are a number of subtleties that must be addressed for (gated) recurrent networks that were not present in the feed-forward setting. To clarify the discussion, we will therefore contrast vanilla RNNs with a gated RNN cell, that we call the minimalRNN, which is significantly simpler than LSTMs and GRUs but implements a similar form of gating. We expect the framework introduced here to be applicable to more complicated gated architectures.

The first main contribution of this paper is the development of a mean field theory for forward propagation of signal through vanilla RNNs and minimalRNNs. In doing so, we identify a theory of the maximum timescale over which signal can propagate in each case. Next, we produce a random matrix theory for the end-to-end Jacobian of the minimalRNN. As in the feed-forward setting, we establish that the duality between the forward propagation of signal and the backward propagation of gradients persists in the recurrent setting. We then show that our theory is indeed predictive of trainability in recurrent neural networks by comparing the maximum trainable number of steps of RNNs with the timescale predicted by the theory. Overall, we find remarkable alignment between theory and practice. Additionally, we develop a closed-form initialization procedure for both networks and show that on a variety of tasks RNNs initialized to be dynamically isometric are significantly easier to train than those lacking this property.

Corroborating the experimental findings of Collins et al. (2016), we show that both signal propagation and dynamical isometry in vanilla RNNs is far more precarious than in the case of the minimalRNN. Indeed the vanilla RNN achieves dynamical isometry only if the network is initialized with orthogonal weights at the boundary between order-and-chaos, a one-dimensional line in parameter space. Owing to its gating mechanism, the minimalRNN on the other hand enjoys a robust multi-dimensional subspace of good initializations which all enable dynamical isometry. Based on these insights, we conjecture that more complex gated recurrent neural networks also benefit from the similar effects.

Related Work

Identity and Orthogonal initialization schemes have been identified as a promising approach to improve trainability of deep neural networks (Le et al., 2015; Mishkin & Matas, 2015). Additionally, Arjovsky et al. (2016); Hyland & Rätsch (2017); Xie et al. (2017) advocate going beyond initialization to constrain the transition matrix to be orthogonal throughout the entire learning process either through re-parametrisation or by constraining the optimization to the Stiefel manifold (Wisdom et al., 2016). However, as was pointed out in Vorontsov et al. (2017), strictly enforcing orthogonality during training may hinder training speed and generalization performance. While these contributions are similar to our own, in the sense that they attempt to construct networks that feature dynamical isometry, it is worth noting that orthogonal weight matrices do not guarantee dynamical isometry. This is due to the nonlinear nature of deep neural networks as shown in Pennington et al. (2017). In this paper we continue this trend and show that orthogonality has little impact on the conditioning of the Jacobian (and so trainability) in gated RNNs.

The notion of “edge of chaos” initialization has been explored previously especially in the case of recurrent neural networks. Bertschinger et al. ; Glorot & Bengio propose edge-of-chaos initialization schemes that they show leads to improved performance. Additionally, architectural innovations such as batch normalization (Ioffe & Szegedy, 2015a), orthogonal matrix initialization (Saxe et al., 2013), random walk initialization (Sussillo & Abbott, 2014), composition kernels (Daniely et al., 2016), or residual network architectures (He et al., 2015) all share a common goal of stabilizing gradients and improving training dynamics.

There is a long history of applying mean field-like approaches to understand the behavior of neural networks. Indeed several pieces of seminal work used statistical physics (Derrida & Pomeau, ; Sompolinsky et al., 1988) and Gaussian Processes (Neal, 2012) to show that neural networks exhibit remarkable regularity as the width of the network gets large. Mean field theory also has long been used to study Boltzmann machines (Ackley et al., ) and sigmoid belief networks (Saul et al., 1996). More recently, there has been a revitalization of mean field theory to explore questions of trainability and expressivity in fully-connected networks and residual networks (Poole et al., 2016; Schoenholz et al., 2017; Yang & Schoenholz, 2017; Schoenholz et al., 2017; Karakida et al., 2018; Hayou et al., 2018; Hanin & Rolnick, 2018; Yang & Schoenholz, 2018). Our approach will closely follow these later contributions and extend many of their techniques to the case of recurrent networks with gating. Beyond mean field theory, there have been several attempts in understanding signal propagation in RNNs, e.g., using Gers̆gorin circle theorem (Zilly et al., 2016) or time invariance (Tallec & Ollivier, 2018).

Theory and Critical Initialization

We begin by developing a mean field theory for vanilla RNNs and discuss the notion of dynamical isometry. Afterwards, we move on to a simple gated architecture to explain the role of gating in facilitating signal propagation in RNNs.

Vanilla RNNs are described by the recurrence relation,

Next, we apply mean-field theory to vanilla RNNs following a similar strategy introduced in (Poole et al., 2016; Schoenholz et al., 2017). At the level of mean-field theory, vanilla RNNs will prove to be intimately related to feed-forward networks and so this discussion proceeds analogously. For a more detailed discussion, see these earlier studies.

To make progress, we proceed by developing the theory of signal propagation for RNNs with untied weights. This allows for several simplifications, including the application of the CLT to conclude that eiate^{t}_{ia} are jointly Gaussian distributed,

With this approximation in mind, we will now quantify how the pre-activation hidden states {e1t}\{\bm{e}^{t}_{1}\} and {e2t}\{\bm{e}^{t}_{2}\} evolve by deriving the recurrence relation of the covariance matrix qt\bm{q}^{t} from the recurrence on et\bm{e}^{t} in eq. (1). Using identical arguments to Poole et al. (2016) one can show that,

is a Gaussian measure with covariance matrix q\bm{q}. By symmetry, our normalization allows us to define q11t=q22t=qtq_{11}^{t}=q_{22}^{t}=q^{t} to be the magnitude of the pre-activation hidden state and ct=q12t/qtc^{t}=q_{12}^{t}/q^{t} to be the cosine similarity between the hidden states. We will be particularly concerned with understanding the dynamics of the cosine similarity, ctc^{t}.

In feed-forward networks, the inputs dictate the initial value of the cosine similarity, c0c^{0} and then the evolution of ctc^{t} is determined solely by the network architecture. By contrast in recurrent networks, inputs perturb ctc^{t} at each timestep. Analyzing the dynamics of ctc^{t} for arbitrary Σt\bm{\Sigma}^{t} is therefore challenging, however significant insight can be gained by studying the off-diagonal entries of eq. (2) for Σt=Σ\bm{\Sigma}^{t}=\bm{\Sigma} independent of time. In the case of time-independent Σt\bm{\Sigma}^{t}, as tt\to\infty both qtqq^{t}\to q^{\ast} and ctcc^{t}\to c^{\ast} where qq^{\ast} and cc^{\ast} are fixed points of the variance of the pre-activation hidden state and the cosine-similarity between pre-activation hidden states respectively. As was discussed previously (Poole et al., 2016; Schoenholz et al., 2017), the dynamics of qtq^{t} are generally uninteresting provided qq^{\ast} is finite. We therefore choose to normalize the hidden state such that q0=qq^{0}=q^{\ast} which implies that qt=qq^{t}=q^{\ast} independent of time.

In this setting it was shown in Schoenholz et al. (2017) that in the vicinity of a fixed point, the off-diagonal term in eq. (2) can be expanded to lowest order in ϵt=cct\epsilon^{t}=c^{\ast}-c^{t} to give the linearized dynamics, ϵt=χcϵt1\epsilon^{t}=\chi_{c^{\ast}}\epsilon^{t-1} where

These dynamics have the solution ϵt=χctt0ϵt0\epsilon^{t}=\chi_{c^{*}}^{t-t_{0}}\epsilon^{t_{0}} where t0t_{0} is the time when ctc^{t} is close enough to cc^{*} for the linear approximation to be valid. If χc<1\chi_{c^{*}}<1 it follows that ctc^{t} approaches cc^{\ast} exponentially quickly over a timescale τ=1/logχc\tau=-1/\log\chi_{c^{\ast}} and cc^{\ast} is called a stable fixed point. When ctc^{t} gets too close to cc^{\ast} to be distinguished from it to within numerical precision, information about the initial inputs has been lost. Thus, τ\tau sets the maximum timescale over which we expect the RNN to be able to remember information. If χc>1\chi_{c^{\ast}}>1 then ctc^{t} gets exponentially farther from cc^{\ast} over time and cc^{\ast} is an unstable fixed point. In this case, for the activation function considered here, another fixed point that is stable will emerge. Note that χc\chi_{c^{\ast}} is independent of Σ\bm{\Sigma} and so the dynamics of ctc^{t} near cc^{\ast} do not depend on Σ\bm{\Sigma}.

In vanilla fully-connected networks c=1c^{\ast}=1 is always a fixed point of ctc^{t}, but it is not always stable. Indeed, it was shown that these networks exhibit a phase transition where c=1c^{\ast}=1 goes from being a stable fixed point to an unstable one as a function of the network’s hyperparameters. This is known as the order-to-chaos transition and it occurs exactly when χ1=1\chi_{1}=1. Since τ=1/log(χ1)\tau=-1/\log(\chi_{1}), signal can propagate infinitely far at the boundary between order and chaos. Comparing the diagonal and off-diagonal entries of eq. (2), we see that in recurrent networks, c=1c^{\ast}=1 is a fixed point only when Σ12=1\Sigma_{12}=1, and in this case the discussion is identical to the feed-forward setting. When Σ12<1\Sigma_{12}<1, it is easy to see that c<1c^{\ast}<1 since if ct=1c^{t}=1 at some time tt then ct+1=1σv2R(1Σ12)/q<1c^{t+1}=1-\sigma_{v}^{2}R(1-\Sigma_{12})/q^{\ast}<1. We see that in recurrent networks noise from the inputs destroys the ordered phase and there is no ordered-to-chaos critical point. As a result we should expect the maximum timescale over which memory may be stored in vanilla RNNs to be fundamentally limited by noise from the inputs.

The end-to-end Jacobian of a vanilla RNN with untied weights is in fact formally identical to the input-output Jacobian of a feedforward network, and thus the results from (Pennington et al., 2017) regarding conditions for dynamical isometry apply directly. In particular, dynamical isometry is achieved with orthogonal state-to-state transition matrices W\bm{W}, tanh\tanh non-linearities, and small values of qq^{\ast}. Perhaps surprisingly, these conclusions continue to be valid if the assumption of untied weights is relaxed. To understand why this is the case, consider the example of a linear network. For untied weights, the end-to-end Jacobian is given by J^=t=1TWt\hat{\bm{J}}=\prod_{t=1}^{T}\bm{W}_{t}, while for tied weights the Jacobian is given by J=WT{\bm{J}}={\bm{W}}^{T}. It turns out that as NN\to\infty there is sufficient self-averaging to overcome the dependencies induced by weight tying and the asymptotic singular value distributions of J^\hat{\bm{J}} and J\bm{J} are actually identical (Haagerup & Larsen, 2000).

2 MinimalRNN

We note that Rt\bm{R}^{t} is fixed by the input, but it remains for us to work out Qt\bm{Q}^{t}. We find that (see SI section B),

Here we assume that the expectation factorizes so that ht1\bm{h}^{t-1} and ut\bm{u}^{t} are approximately independent. We believe this approximation becomes exact in the NN\to\infty limit.

We choose to normalize the data in a similar manner to the vanilla case so that R11t=R22t=RR_{11}^{t}=R_{22}^{t}=R independent of time. An immediate consequence of this normalization is that Q11t=Q22t=QtQ_{11}^{t}=Q_{22}^{t}=Q^{t} and q11t=q22t=qtq_{11}^{t}=q_{22}^{t}=q^{t}. We then write Ct=Q12t/QtC^{t}=Q_{12}^{t}/Q^{t} and ct=q12t/qtc^{t}=q_{12}^{t}/q^{t} as the cosine similarities between the hidden states and the pre-activations respectively. With this normalization, we can work out the mean-field recurrence relation characterizing the covariance matrix for the minimalRNN. This analysis can be done by deriving the recurrence relation for either Qt\bm{Q}^{t} or qt\bm{q}^{t}. We will choose to study the dynamics of qt\bm{q}^{t}, however the two are trivially related by eq. (6). In SI section C, we analyze the dynamics of the diagonal term in the recurrence relation and prove that there is always a fixed point at some qq^{\ast}. In SI section D, we compute the depth scale over which qtq^{t} approaches qq^{\ast}. However, as in the case of the vanillaRNN, the dynamics of qq^{\ast} are generally uninteresting.

We now turn our attention to the dynamics of the cosine similarity between the pre-activations, ctc^{t}. As in the case of vanilla RNNs, we note that qtq^{t} approaches qq^{\ast} quickly relative to the dynamics of ctc^{t}. We therefore choose to normalize the hidden state of the RNN so that Q0=QQ^{0}=Q^{\ast} in which case both Qt=QQ^{t}=Q^{\ast} and qt=qq^{t}=q^{\ast} independent of time. From eq. (6) and (7) it follows that the cosine similarity of the pre-activation evolves as,

where we have defined ρt=RΣ12t/q\rho^{t}=R\Sigma_{12}^{t}/q^{\ast}. As in the case of the vanilla RNN, we can study the behavior of ctc^{t} in the vicinity of a fixed point, cc^{*}. By expanding eq. (3.2.1) to lowest order in ϵt=cct\epsilon^{t}=c^{*}-c^{t} we arrive at a linearized recurrence relation that has an exponential solution ϵt+1=χcϵt\epsilon^{t+1}=\chi_{c^{*}}\epsilon^{t} where here,

The discussion above in the vanilla case carries over directly to the minimalRNN with the appropriate replacement of χc\chi_{c^{*}}. Unlike in the case of the vanilla RNN, here we see that χc\chi_{c^{*}} itself depends on Σ12\Sigma_{12}.

Again c=1c^{*}=1 is a fixed point of the dynamics only when Σ12=1\Sigma_{12}=1. In this case, the minimalRNN experiences an order-to-chaos phase transition when χ1=1\chi_{1}=1 at which point the maximum timescale over which signal can propagate goes to infinity. Similar to the vanilla RNN, when Σ12<1\Sigma_{12}<1, we expect that the phase transition will be destroyed and the maximum duration of signal propagation will be severely limited. However, in a significant departure from the vanilla case, when μb\mu_{b}\to\infty we notice that σ(z+μb)1\sigma(z+\mu_{b})\to 1, and σ(z+μb)0\sigma^{\prime}(z+\mu_{b})\to 0 for all zz. Considering eq. (9) we notice that in this regime χc1\chi_{c^{*}}\to 1 independent of Σ12\Sigma_{12}. In other words, gating allows for arbitrarily long term signal propagation in recurrent neural networks independent of Σ12\Sigma_{12}.

We explore agreement between our theory and MC simulations of the minimalRNN in fig. 1. In this set of experiments, we consider inputs such that Σ12t=0\Sigma_{12}^{t}=0 for t<10t<10 and Σ12t=1\Sigma_{12}^{t}=1 for t10t\geq 10. Fig. 1 (a,c,d) show excellent quantitative agreement between our theory and MC simulations. In fig. 1 (a,b) we compare the MC simulations of the minimalRNN with and without weight tying. While we observe that for many choices of hyperparameters the untied weight approximation is quite good (particularly when c1c^{\ast}\approx 1), deeper into the chaotic phase the quantitative agreement between breaks down. Nonetheless, we observe that the untied approximation describes the qualitative behavior of the real minimalRNN overall. In fig. 1 (e) we plot the timescale for signal propagation for Σ12=1,0.99,\Sigma_{12}=1,0.99, and for the minimalRNN with identical choices of hyperparameters. We see that while τ\tau\to\infty as μb\mu_{b} gets large independent of Σ12\Sigma_{12}, a critical point at μb=0\mu_{b}=0 is only observed when Σ12=1\Sigma_{12}=1.

In the previous subsection, we derived a quantity χ1\chi_{1} that defines the boundary between the ordered and the chaotic phases of forward propagation. Here we show that it also defines the boundary between exploding and vanishing gradients. To see this, consider the Jacobian of the state-to-state transition operator,

where Dx\bm{D}_{\bm{x}} denotes a diagonal matrix with x\bm{x} along its diagonal. We can compute the expected norm-squared of back-propagated error signals, which measures the growth or shrinkage of gradients. It is equal to the mean-squared singular value of the Jacobian (Poole et al., 2016; Schoenholz et al., 2017) or the first moment of JtJtT\bm{J}_{t}\bm{J}_{t}^{T},

As argued in (Pennington et al., 2017, 2018), controlling the variance of back-propagated gradients is necessary but not sufficient to guarantee trainability, especially for very deep networks. Beyond the first moment, the entire distribution of eigenvalues of JJT{\bm{J}}{\bm{J}}^{T} (or of singular values of J{\bm{J}}) is relevant. Indeed, it was found in (Pennington et al., 2017, 2018) that enabling dynamical isometry, namely the condition that all singular values of J{\bm{J}} are close to unity, can drastically improve training speed for very deep feed-forward networks.

Following (Pennington et al., 2017, 2018), we use tools from free probability theory to compute the variance σJJT2\sigma^{2}_{{\bm{J}}{\bm{J}}^{T}} of the limiting spectral density of JJT{\bm{J}}{\bm{J}}^{T}; however, unlike previous work, in our case the relevant matrices are not symmetric and therefore we must invoke tools from non-Hermitian free probability, see (Cakmak, 2012) for a review. As in previous section, we make the simplifying assumption that the weights are untied, relying on the same motivations given in section 3.1. Using these tools, an un-illuminating calculation reveals that,

and s1s_{1} is the first term in the Taylor expansion of the S-transform of the eigenvalue distribution of WWTWW^{T} (Pennington et al., 2018). For example, for Gaussian matrices, s1=1s_{1}=-1 and for orthogonal matrices s1=0s_{1}=0.

Some remarks are in order about eq. (12). First, we note the duality between the forward and backward signal propagation (eq. (9) and eq. (13)). For critical initializations, χ1=1\chi_{1}=1, so σJJT2\sigma^{2}_{{\bm{J}}{\bm{J}}^{T}} does not grow exponentially, but it still grows linearly with TT. This situation is entirely analogous to the feed-forward analysis of (Pennington et al., 2017, 2018). In the case of the vanilla RNN, the coefficient of the linear term is proportion to qq^{*}, and can only be reduced by taking the weight and bias variances (σw2,σb2)(1,0)(\sigma_{w}^{2},\sigma_{b}^{2})\to(1,0). A crucial difference in the minimalRNN is that the coefficient of the linear term can be made arbitrarily small by simply adjusting the bias mean μb\mu_{b} to be positive, which will send μ20\mu_{2}\to 0 and μ11\mu_{1}\to 1 independent of Σ\Sigma. Therefore the conditions for dynamical isometry decouple from the weight and bias variances, implying that trainability can occur for a higher-dimensional, more robust, slice of parameter space. Moreover, the value of s1s_{1} has no effect on the capacity of the minimalRNN to achieve dynamical isometry. We believe these are fundamental reasons why gated cells such as the minimalRNN perform well in practice.

Algorithm 1 describes the procedure to find σw2,σv2\sigma_{w}^{2},\sigma_{v}^{2} and σb2\sigma_{b}^{2} to achieve χ1\chi_{1} condition for minimalRNN. Given σw2,σv2,σb2\sigma_{w}^{2},\sigma_{v}^{2},\sigma_{b}^{2}, we then construct the weight matrices and biases accordingly. QQ^{\ast} is used to initialize the h0h^{0} to avoid transient phase.

Experiments

Having established a theory for the behavior of random vanilla RNNs and minimalRNNs, we now discuss the connection between our theory and trainability in practice. We begin by corroborating the claim that the maximum timescale over which memory can be stored in a RNN is controlled by the timescale τ\tau identified in the previous section. We will then investigate the role of dynamical isometry in speeding up learning.

Dataset. To verify the results of our theoretical calculation, we consider a task that is reflective of the theory above. To that end, we constructed a sequence dataset for training RNNs from MNIST (LeCun et al., 1998). Each of the 28×2828\times 28 digit image is flattened into a vector of 784784 pixels and sent as the first input to a RNN. We then send TT random inputs xtN(0,σx2)\bm{x}^{t}\sim\mathcal{N}(0,\sigma_{x}^{2}), 0<t<T0<t<T into the RNN varying TT between 10 and 1000 steps. As the only salient information about the digit is in the first layer, the network will need to propagate information through TT layers to accurately identify the MNIST digit. The random inputs are drawn independently for each example and so this is a regime where Σt=0\Sigma^{t}=0 for all t>0t>0.

We then performed a series of experiments on this task to make connection with our theory. In each case we experimented with both tied and untied weights. The result are shown in fig. 2. In the case of untied weights, we observe strong quantitative agreement between our theoretical prediction for τ\tau and the maximum depth TT where the network is still trainable. When the weights of the network are tied, we observe quantitative deviations between our thoery and experiments, but the overall qualitative picture remains.

We train vanilla RNNs for 10310^{3} steps (around 10 epochs) varying σw[0.5,1.5]\sigma_{w}\in[0.5,1.5] while fixing σv=0.025\sigma_{v}=0.025. The results of this experiment are shown in fig. 2 (a-b). We train minimalRNNs for 10210^{2} steps (around 1 epoch) fixing σv=1.39\sigma_{v}=1.39. We perform three different experiments here: 1) varying μb\mu_{b}\in with σw=6.88\sigma_{w}=6.88 shown in fig. 2 (c-d), 2) varying σw[0.5,10]\sigma_{w}\in[0.5,10] with μb=4\mu_{b}=4 shown in fig. 2 (e-f), 3) varying σw[0.5,10]\sigma_{w}\in[0.5,10] with μb=6\mu_{b}=6 shown in fig. 2 (g-h). Comparing fig. 2(a,b) with fig. 2(c,d, g,h), the minimalRNN with large depth TT is trainable over a much wider range of hyperparameters than the vanillaRNN despite the fact that the network was trained for an order of magnitude less time.

2 Critical initialization

Dataset. To study the impact of critical initialization on training speed, we constructed a more realistic sequence dataset from MNIST. We unroll the pixels into a sequence of TT inputs, each containing 784/T784/T pixels. We tested T=196T=196 and T=784T=784 to vary the difficulty of the tasks.

Note that we are more interested in the training speed of these networks under different initialization conditions than the test accuracy. We compare the convergence speed of vanilla RNN and minimalRNN under four initialization conditions: 1) critical initialization with orthogonal weights (solid blue); 2) critical initialization with Gaussian distributed weights (sold red); 3) off-critical initialization with orthogonal weights (dotted green); 4) off-critical initialization with Gaussian distributed weights (dotted black).

We fix σb2\sigma_{b}^{2} to zero in all settings. Under critical initialization, σw2\sigma_{w}^{2} and σv2\sigma_{v}^{2} are carefully chosen to achieve χ1=1\chi_{1}=1 as defined in eqn.(4) for vanilla RNN and eqn.(13) (detailed in algorithm 1) for minimalRNN respectively. When testing networks off criticality, we employ a common initialization procedure in which, σw2=1.0\sigma_{w}^{2}=1.0 and σv2=1.0\sigma_{v}^{2}=1.0.

Figure 3 summarizes our findings: there is a clear difference in training speed between models trained with critical initialization compared with models initialized far from criticality. We observe two orders of magnitude difference in training speed between a critical and off-critical initialization for vanilla RNNs. While a critically initialized model reaches a test accuracy of 90%90\% after 750 optimization steps, the off-critical nework takes over 16,000 updates. A similar trend was observed for the minimalRNN. This difference is even more pronounced in the case of the longer sequence with T=784T=784. Both vanilla RNNs and minimalRNNs initialized off-criticality failed at task. The well-conditioned minimalRNN trains a factor of three faster than the vanilla RNN. As predicted above, the difference in training speed between orthogonal and Gaussian initialization schemes is significant for vanilla RNNs but is insignificant for the minimalRNN. This is corroborated in fig. 3 (b,d) where the distribution of the weights has no impact on the training speed.

Language modeling

We compare the minimalRNN against more complex gated RNNs such as LSTM and GRU on the Penn Tree-Bank corpus (Marcus et al., 1993). Language modeling is a difficult task, and competitive performance is often achieved by more complicated RNN cells. We show that the minimalRNN achieves competitive performance despite its simplicity.

We follow the precise setup of (Mikolov et al., 2010; Zaremba et al., 2014), and train RNNs of two sizes: a small configuration with 5M parameters and a medium-sized configuration with 20M parameters The hidden layer size of these networks are adjusted accordingly to reach the target model size.. We report the perplexity on the validation and test sets. We focus our comparison on single layer RNNs, however we also report perplexities for multi-layer RNNs from the literature for reference. We follow the learning schedule of Zaremba et al. (2014) and (Jozefowicz et al., 2015). We review additional hyperparameter ranges in section F of the supplementary material.

Table 1 summarizes our results. We find that single layer RNNs perform on par with their multi-layer counterparts. Despite being a significantly simpler model, the minimalRNN performs comparably to GRUs. Given the closed-form critical initialization developed here that significantly boosts convergence speed, the minimalRNN might be a favorable alternative to GRUs. There is a gap in perplexity between the performance of LSTMs and minimalRNNs. We hypothesize that this is due to the removal of an independent gate on the input. The same strategy is employed in GRUs and may cause a conflict between keeping longer-range memory and updating new information as was originally pointed out by Hochreiter & Schmidhuber (1997).

Discussion

We have developed a theory of signal propagation for random vanilla RNNs and a simple gated RNNs. We demonstrate rigorously that the theory predicts trainability of these networks and gating mechanisms allow for a significantly larger trainable region. We are planning to extend the theory to more complicated RNN cells as well as RNNs with multiple layers.

Acknowledgements

We thank Jascha Sohl-Dickstein and Greg Yang for helpful discussions and Ashish Bora for many contributions to early stages of this project.

References

Supplemental material

Appendix B Diagonal Recurrence Relation

Here ii denotes the (pre)-activation and aa denotes an input to the network.Thus, uitu^{t}_{i} acts as a gate on the tt’th step. We take WijN(0,σw2/N)W_{ij}\sim\mathcal{N}(0,\sigma_{w}^{2}/N), VijN(0,σv2/M)V_{ij}\sim\mathcal{N}(0,\sigma_{v}^{2}/M) and biN(μb,σb2)b_{i}\sim\mathcal{N}(\mu_{b},\sigma_{b}^{2}).

By the CTL we can make a mean field assumption that vi;atN(μb,qabt)v^{t}_{i;a}\sim\mathcal{N}(\mu_{b},q^{t}_{ab}) where,

where we have assumed that the expectation factorizes so that hi;at1h^{t-1}_{i;a} and ui;atu^{t}_{i;a} are approximately independent.

We choose to normalize the data so that Raat=Rbbt=RR_{aa}^{t}=R_{bb}^{t}=R independent of time. An immediate consequence of this normalization is that Qaat=Qbbt=QtQ_{aa}^{t}=Q_{bb}^{t}=Q^{t} and qaat=qbbt=qtq_{aa}^{t}=q_{bb}^{t}=q^{t}. We then write Rabt=RΣtR_{ab}^{t}=R\Sigma^{t}, Qabt=QtCtQ_{ab}^{t}=Q^{t}C^{t} and qabt=qtctq_{ab}^{t}=q^{t}c^{t} where Σt\Sigma^{t}, CtC^{t}, and ctc^{t} are cosine similarities between the inputs, the hidden states, and the va,btv^{t}_{a,b} respectively. With this normalization, we can work out the mean-field recurrence relation characterizing the covariance matrix for the minimalRNN.

We begin by considering the diagonal recurrence relations. We find that the dynamics are described by the equation,

As expected, the first and second integrands determine how much of the update of the random network is controlled by the norm of the hidden state and how much is determined by the norm of the input. Since σ(z)=1σ(z)\sigma(z)=1-\sigma(-z) it follows that when μb=0\mu_{b}=0 the first and second term will be equal and so,

In general, μb\mu_{b} will therefore control the degree to which the hidden state of the random minimalRNN is updated based on the previous hidden state or based on the inputs with μb=0\mu_{b}=0 implying parity between the two. This is reflected in eq. (23).

In the event that the norm of the inputs is time-independent, Rt=RR^{t}=R for all tt, then the minimalRNN will have a fixed point provided there exists a QQ^{*} that satisfies a transcendental equation, namely that

It is easy to see that such a solution always exists. When QQ^{*}\to\infty the first term of F(Q)\mathcal{F}(Q^{*}) approaches 11 while the magnitude of the second increases without bound and so F(Q)<0\mathcal{F}(Q^{*})<0. Conversely, when Q0Q^{*}\to 0 the first term is positive while Q/R0Q^{*}/R\to 0 and so F(Q)>0\mathcal{F}(Q^{*})>0. The existence of a QQ^{*} satisfying the transcendental equation then follows directly from the intermediate value theorem.

We can now investigate the dynamics of the norm of the hidden state in the vicinity of QQ^{*}. To do this suppose that Qt=Q+ϵtQ^{t}=Q^{*}+\epsilon^{t} with ϵ1\epsilon\ll 1. Our goal is then to expand eq.(21) about QQ^{*}. First, we note that,

Letting ζ(z)=qz+μb\zeta(z)=\sqrt{q^{*}}z+\mu_{b} this implies that,

Appendix E Off-Diagonal Recurrence Relation

We now turn our attention to the off-diagonal term. From eq. (7) it follows that,

By expanding eq (37) as ct=c+ϵtc^{t}=c^{*}+\epsilon^{t} we find ϵt+1=χcϵt\epsilon^{t+1}=\chi_{c^{*}}\epsilon^{t} where,

We note that when c=1c^{*}=1 it follows that χc=χ1\chi_{c^{*}}=\chi_{1}.

Appendix F Additional Hyperparameter Ranges

We tune the learning hyper-parameters in the following ranges for all the models:

learning rate: {0.1, 0.2, 0.3, 0.5, 1, 2}

Dynamical Isometry and a Mean Field Theory of RNNs: Gating Enables Signal Propagation in Recurrent Neural Networks — p7