Unbiased Online Recurrent Optimization

Corentin Tallec, Yann Ollivier

Related work

A widespread approach to online learning of recurrent neural networks is Truncated Backpropagation Through Time (truncated BPTT) [Jae02], which mimics Backpropagation Through Time, but zeroes gradient flows after a fixed number of timesteps. This truncation makes gradient estimates biased; consequently, truncated BPTT does not provide any convergence guarantee. Learning is biased towards short-time dependencies. Arguably, truncated BPTT might still learn some dependencies beyond its truncation range, by a mechanism similar to Echo State Networks [Jae02]. However, truncated BPTT’s gradient estimate has a marked bias towards short-term rather than long-term dependencies, as shown in the first experiment of Section 4. . Storage of some past inputs and states is required.

Online, exact gradient computation methods have long been known (Real Time Recurrent Learning (RTRL) [WZ89, Pea95]), but their computational cost discards them for reasonably-sized networks.

NoBackTrack (NBT) [OTC15] also provides unbiased gradient estimates for recurrent neural networks. However, contrary to UORO, NBT cannot be applied in a blackbox fashion, making it extremely tedious to implement for complex architectures.

Other previous attempts to introduce generic online learning algorithms with a reasonable computational cost all result in biased gradient estimates. Echo State Networks (ESNs) [Jae02, JLPS07] simply set to the gradients of recurrent parameters. Others, e.g., [MNM02, Ste04], introduce approaches resembling ESNs, but keep a partial estimate of the recurrent gradients. The original Long Short Term Memory algorithm [HS97] (LSTM now refers to a particular architecture) cuts gradient flows going out of gating units to make gradient computation tractable. Decoupled Neural Interfaces [JCO+16] bootstrap truncated gradient estimates using synthetic gradients generated by feedforward neural networks. The algorithm in [MMW02] provides zeroth-order estimates of recurrent gradients via diffusion networks; it could arguably be turned online by running randomized alternative trajectories. Generally these approaches lack a strong theoretical backing, except arguably ESNs.

Background

UORO is a learning algorithm for recurrent computational graphs. Formally, the aim is to optimize θ\theta, a parameter controlling the evolution of a dynamical system

Unbiased Online Recurrent Optimization

Unbiased Online Recurrent Optimization is built on top of a forward computation of the gradients, rather than backpropagation. Forward gradient computation for neural networks (RTRL) is described in [WZ89] and we review it in Section 3.1. The derivation of UORO follows in Section 3.2. Implementation details are given in Section 3.3. UORO’s derivation is strongly connected to [OTC15] but differs in one critical aspect: the sparsity hypothesis made in the latter is relieved, resulting in reduced implementation complexity without any model restriction.

Forward computation of the gradient for a recurrent model (RTRL) is directly obtained by applying the chain rule to both the loss function and the state equation (1), as follows.

Here, the term st/θ\partial s_{t}/\partial\theta represents the effect on the state at time tt of a change of parameter during the whole past trajectory. This term can be computed inductively from time tt to t+1t+1. Intuitively, looking at the update equation (1), there are two contributions to st+1/θ\partial s_{t+1}/\partial\theta:

The direct effect of a change of θ\theta on the computation of st+1s_{t+1}, given sts_{t}.

The past effect of θ\theta on sts_{t} via the whole past trajectory.

With this in mind, differentiating (1) with respect to θ\theta yields

A huge disadvantage of RTRL is that st/θ\partial s_{t}/\partial\theta is of size dim(state)×dim(params)\text{dim}(\text{state})\times\text{dim}(\text{params}). For instance, with a fully connected standard recurrent network with nn units, st/θ\partial s_{t}/\partial\theta scales as n3n^{3}. This makes RTRL impractical for reasonably sized networks.

2 Rank-one trick: from RTRL to UORO

However, in general this is no longer rank-one.

Let AA be a real matrix that decomposes as

Let ν\nu be a vector of kk independent random signs, and ρ\rho a vector of kk positive numbers. Consider the rank-one matrix

The rank-one trick can be applied for any ρ\rho. The choice of ρ\rho influences the variance of the approximation; choosing

This results in a rank-two, unbiased estimate of st+1/θ\partial s_{t+1}/\partial\theta by substituting (10) into (6)

minimizes variance of the second reduction.

We are left to demonstrate that these update rules are scalably implementable.

3 Implementation

Implementing UORO requires maintaining the rank-one approximation and the corresponding gradient loss estimate.

4 Memory-T𝑇T UORO and rank-k𝑘k UORO

The unbiased gradient estimates of UORO come at the price of noise injection via ν\nu. This requires smaller learning rates. To reduce noise, UORO can be also used on top of truncated BPTT so that recent gradients are computed exactly.

The resulting algorithm is referred to as memory-TT UORO. Its scaling in TT is similar to TT-truncated BPTT, both in terms of memory and computation. In the experiments below, memory-TT UORO reduced variance early on, but did not significantly impact later performance.

Experiments illustrating truncation bias

The set of experiments below aims at displaying specific cases where the biases from truncated BPTT are likely to prevent convergence of learning. On this test set, UORO’s unbiasedness provides steady convergence, highlighting the importance of unbiased estimates for general recurrent learning.

The first test case exemplifies learning of a scalar parameter θ\theta which has a positive influence in the short term, but a negative one in the long run. Short-sightedness of truncated algorithms results in abrupt failure, with the parameter exploding in the wrong direction, even with truncation lengths exceeding the temporal dependency range by a factor of 1010 or so.

Learning is performed online with vanilla SGD, using gradient estimates either from UORO or TT-truncated BPTT with various TT. Learning rates are of the form ηt=η1+t\eta_{t}=\frac{\eta}{1+\sqrt{t}} for suitable values of η\eta.

As shown in Fig. 1(a), UORO solves the problem while TT-truncated BPTT fails to converge for any learning rate, even for truncations TT largely above nn. Failure is caused by ill balancing of time dependencies: the influence of θ\theta on the loss is estimated with the wrong sign due to truncation. For n=23n=23 units, with 1313 minus signs, truncated BPTT requires a truncation T200T\geq 200 to converge.

Next-character prediction.

The next experiment is character-level synthetic text prediction: the goal is to train a recurrent model to predict the t+1t+1-th character of a text given the first tt online, with a single pass on the data sequence.

A single layer of 6464 units, either GRU or LSTM, is used to output a probability vector for the next character. The cross entropy criterion is used to compute the loss.

Optimization was performed using Adam with the default setting β1=0.9\beta_{1}=0.9 and β2=0.999\beta_{2}=0.999, and a decreasing learning rate ηt=γ1+αt\eta_{t}=\frac{\gamma}{1+\alpha\sqrt{t}}, with tt the number of characters processed. As convergence of UORO requires smaller learning rates than truncated BPTT, this favors UORO. Indeed UORO can fail to converge with non-decreasing learning rates, due to its stochastic nature.

The distant brackets dataset is generated by repeatedly outputting a left bracket, generating ss random characters from an alphabet of size aa, outputting a right bracket, generating kk random characters from the same alphabet, repeating the same first ss characters between brackets and finally outputting a line break. A sample is shown in Fig. 2(b).

UORO is compared to 44-truncated BPTT. Truncation is deliberately shorter than the inherent time range of the data, to illustrate how bias can penalize learning if the inherent time range is unknown a priori. The results are given in Fig. 1(b) (with learning rates using α=0.015\alpha=0.015 and γ=103\gamma=10^{-3}). UORO beats 44-truncated BPTT in the long run, and succeeds in reaching near optimal behaviour both with GRUs and LSTMs. Truncated BPTT remains stuck near a memoryless optimum with LSTMs; with GRUs it keeps learning, but at a slow rate. Still, truncated BPTT displays faster early convergence.

The anbn(k,l)a^{n}b^{n}(k,l) dataset tests memory and counting [GS01]; it is generated by repeatedly picking a random number nn between kk and ll, outputting a string of nn aa’s, a line break, nn bb’s, and a line break (see Fig. 2(c)). The difficulty lies in matching the number of aa’s and bb’s.

Plots for a few setups are given in Fig. 3. The learning rates used α=0.03\alpha=0.03 and γ=103\gamma=10^{-3}.

Numerical results at the end of training are given in Table 1. For reference, the true entropy rate is 0.140.14 bits per character, while the entropy rate of a model that does not understand that the numbers of aa’s and bb’s coincide would be double, 0.280.28 bpc.

Here, in every setup, UORO reliably converges and reaches near optimal performance. Increasing UORO’s range does not significantly improve results: providing an unbiased estimate is enough to provide reliable convergence in this case. Meanwhile, truncated BPTT performs inconsistently. Notably, with GRUs, it either converges to a poor local optimum corresponding to no understanding of the temporal structure, or exhibits gradient reascent in the long run. Remarkably, with LSTMs rather than GRUs, 1616-truncated BPTT reliably reaches optimal behavior on this problem even with biased gradient estimates.

Conclusion

We introduced UORO, an algorithm for training recurrent neural networks in a streaming, memoryless fashion. UORO is easy to implement, and requires as little computation time as truncated BPTT, at the cost of noise injection. Importantly, contrary to most other approaches, UORO scalably provides unbiasedness of gradient estimates. Unbiasedness is of paramount importance in the current theory of stochastic gradient descent.

Furthermore, UORO is experimentally shown to benefit from its unbiasedness, converging even in cases where truncated BPTT fails to reliably achieve good results or diverges pathologically.

References