Unbiased Online Recurrent Optimization
Corentin Tallec, Yann Ollivier
Related work
A widespread approach to online learning of recurrent neural networks is Truncated Backpropagation Through Time (truncated BPTT) [Jae02], which mimics Backpropagation Through Time, but zeroes gradient flows after a fixed number of timesteps. This truncation makes gradient estimates biased; consequently, truncated BPTT does not provide any convergence guarantee. Learning is biased towards short-time dependencies. Arguably, truncated BPTT might still learn some dependencies beyond its truncation range, by a mechanism similar to Echo State Networks [Jae02]. However, truncated BPTT’s gradient estimate has a marked bias towards short-term rather than long-term dependencies, as shown in the first experiment of Section 4. . Storage of some past inputs and states is required.
Online, exact gradient computation methods have long been known (Real Time Recurrent Learning (RTRL) [WZ89, Pea95]), but their computational cost discards them for reasonably-sized networks.
NoBackTrack (NBT) [OTC15] also provides unbiased gradient estimates for recurrent neural networks. However, contrary to UORO, NBT cannot be applied in a blackbox fashion, making it extremely tedious to implement for complex architectures.
Other previous attempts to introduce generic online learning algorithms with a reasonable computational cost all result in biased gradient estimates. Echo State Networks (ESNs) [Jae02, JLPS07] simply set to the gradients of recurrent parameters. Others, e.g., [MNM02, Ste04], introduce approaches resembling ESNs, but keep a partial estimate of the recurrent gradients. The original Long Short Term Memory algorithm [HS97] (LSTM now refers to a particular architecture) cuts gradient flows going out of gating units to make gradient computation tractable. Decoupled Neural Interfaces [JCO+16] bootstrap truncated gradient estimates using synthetic gradients generated by feedforward neural networks. The algorithm in [MMW02] provides zeroth-order estimates of recurrent gradients via diffusion networks; it could arguably be turned online by running randomized alternative trajectories. Generally these approaches lack a strong theoretical backing, except arguably ESNs.
Background
UORO is a learning algorithm for recurrent computational graphs. Formally, the aim is to optimize , a parameter controlling the evolution of a dynamical system
Unbiased Online Recurrent Optimization
Unbiased Online Recurrent Optimization is built on top of a forward computation of the gradients, rather than backpropagation. Forward gradient computation for neural networks (RTRL) is described in [WZ89] and we review it in Section 3.1. The derivation of UORO follows in Section 3.2. Implementation details are given in Section 3.3. UORO’s derivation is strongly connected to [OTC15] but differs in one critical aspect: the sparsity hypothesis made in the latter is relieved, resulting in reduced implementation complexity without any model restriction.
Forward computation of the gradient for a recurrent model (RTRL) is directly obtained by applying the chain rule to both the loss function and the state equation (1), as follows.
Here, the term represents the effect on the state at time of a change of parameter during the whole past trajectory. This term can be computed inductively from time to . Intuitively, looking at the update equation (1), there are two contributions to :
The direct effect of a change of on the computation of , given .
The past effect of on via the whole past trajectory.
With this in mind, differentiating (1) with respect to yields
A huge disadvantage of RTRL is that is of size . For instance, with a fully connected standard recurrent network with units, scales as . This makes RTRL impractical for reasonably sized networks.
2 Rank-one trick: from RTRL to UORO
However, in general this is no longer rank-one.
Let be a real matrix that decomposes as
Let be a vector of independent random signs, and a vector of positive numbers. Consider the rank-one matrix
The rank-one trick can be applied for any . The choice of influences the variance of the approximation; choosing
This results in a rank-two, unbiased estimate of by substituting (10) into (6)
minimizes variance of the second reduction.
We are left to demonstrate that these update rules are scalably implementable.
3 Implementation
Implementing UORO requires maintaining the rank-one approximation and the corresponding gradient loss estimate.
4 Memory-T𝑇T UORO and rank-k𝑘k UORO
The unbiased gradient estimates of UORO come at the price of noise injection via . This requires smaller learning rates. To reduce noise, UORO can be also used on top of truncated BPTT so that recent gradients are computed exactly.
The resulting algorithm is referred to as memory- UORO. Its scaling in is similar to -truncated BPTT, both in terms of memory and computation. In the experiments below, memory- UORO reduced variance early on, but did not significantly impact later performance.
Experiments illustrating truncation bias
The set of experiments below aims at displaying specific cases where the biases from truncated BPTT are likely to prevent convergence of learning. On this test set, UORO’s unbiasedness provides steady convergence, highlighting the importance of unbiased estimates for general recurrent learning.
The first test case exemplifies learning of a scalar parameter which has a positive influence in the short term, but a negative one in the long run. Short-sightedness of truncated algorithms results in abrupt failure, with the parameter exploding in the wrong direction, even with truncation lengths exceeding the temporal dependency range by a factor of or so.
Learning is performed online with vanilla SGD, using gradient estimates either from UORO or -truncated BPTT with various . Learning rates are of the form for suitable values of .
As shown in Fig. 1(a), UORO solves the problem while -truncated BPTT fails to converge for any learning rate, even for truncations largely above . Failure is caused by ill balancing of time dependencies: the influence of on the loss is estimated with the wrong sign due to truncation. For units, with minus signs, truncated BPTT requires a truncation to converge.
Next-character prediction.
The next experiment is character-level synthetic text prediction: the goal is to train a recurrent model to predict the -th character of a text given the first online, with a single pass on the data sequence.
A single layer of units, either GRU or LSTM, is used to output a probability vector for the next character. The cross entropy criterion is used to compute the loss.
Optimization was performed using Adam with the default setting and , and a decreasing learning rate , with the number of characters processed. As convergence of UORO requires smaller learning rates than truncated BPTT, this favors UORO. Indeed UORO can fail to converge with non-decreasing learning rates, due to its stochastic nature.
The distant brackets dataset is generated by repeatedly outputting a left bracket, generating random characters from an alphabet of size , outputting a right bracket, generating random characters from the same alphabet, repeating the same first characters between brackets and finally outputting a line break. A sample is shown in Fig. 2(b).
UORO is compared to -truncated BPTT. Truncation is deliberately shorter than the inherent time range of the data, to illustrate how bias can penalize learning if the inherent time range is unknown a priori. The results are given in Fig. 1(b) (with learning rates using and ). UORO beats -truncated BPTT in the long run, and succeeds in reaching near optimal behaviour both with GRUs and LSTMs. Truncated BPTT remains stuck near a memoryless optimum with LSTMs; with GRUs it keeps learning, but at a slow rate. Still, truncated BPTT displays faster early convergence.
The dataset tests memory and counting [GS01]; it is generated by repeatedly picking a random number between and , outputting a string of ’s, a line break, ’s, and a line break (see Fig. 2(c)). The difficulty lies in matching the number of ’s and ’s.
Plots for a few setups are given in Fig. 3. The learning rates used and .
Numerical results at the end of training are given in Table 1. For reference, the true entropy rate is bits per character, while the entropy rate of a model that does not understand that the numbers of ’s and ’s coincide would be double, bpc.
Here, in every setup, UORO reliably converges and reaches near optimal performance. Increasing UORO’s range does not significantly improve results: providing an unbiased estimate is enough to provide reliable convergence in this case. Meanwhile, truncated BPTT performs inconsistently. Notably, with GRUs, it either converges to a poor local optimum corresponding to no understanding of the temporal structure, or exhibits gradient reascent in the long run. Remarkably, with LSTMs rather than GRUs, -truncated BPTT reliably reaches optimal behavior on this problem even with biased gradient estimates.
Conclusion
We introduced UORO, an algorithm for training recurrent neural networks in a streaming, memoryless fashion. UORO is easy to implement, and requires as little computation time as truncated BPTT, at the cost of noise injection. Importantly, contrary to most other approaches, UORO scalably provides unbiasedness of gradient estimates. Unbiasedness is of paramount importance in the current theory of stochastic gradient descent.
Furthermore, UORO is experimentally shown to benefit from its unbiasedness, converging even in cases where truncated BPTT fails to reliably achieve good results or diverges pathologically.