Sparse Attentive Backtracking: Temporal CreditAssignment Through Reminding

Nan Rosemary Ke, Anirudh Goyal, Olexa Bilaniuk, Jonathan Binas, Michael C. Mozer, Chris Pal, Yoshua Bengio

Introduction

Humans have a remarkable ability to remember events from the distant past which are associated with the current mental state (Ciaramelli et al., 2008). Most experimental and theoretical analyses of memory have focused on understanding the deliberate route to memory formation and recall. But automatic reminding—when memories pop into one’s head—can have a potent influence on cognition. Reminding is normally triggered by contextual features present at the moment of retrieval which match distinctive features of the memory being recalled (Berntsen et al., 2013; Wharton et al., 1996), and can occur more often following unexpected events (Read & Cesa, 1991). Thus, an individual’s current state of understanding can trigger reminding of a past state. Reminding can provide distracting sources of irrelevant information (Forbus et al., 1995; Novick, 1988), but it can also serve a useful computational role in ongoing cognition by providing information essential to decision making (Benjamin & Ross, 2010).

In this paper, we identify another possible role of reminding: to perform credit assignment across long time spans. Consider the following scenario. As you drive down the highway, you hear an unusual popping sound. You think nothing of it until you stop for gas and realize that one of your tires has deflated, at which point you are suddenly reminded of the pop. The reminding event helps determine the cause of your flat tire, and probably leads to synaptic changes by which a future pop sound while driving would be processed differently. Credit assignment is critical in machine learning. Back-propagation is fundamentally performing credit assignment. Although some progress has been made toward credit-assignment mechanisms that are functionally equivalent to back-propagation (Lee et al., 2014; Scellier & Bengio, 2016; Whittington & Bogacz, 2017), it remains very unclear how the equivalent of back-propagation through time, used to train recurrent neural networks (RNNs), could be implemented by brains. Here we explore the hypothesis that an associative reminding process could play an important role in propagating credit across long time spans, also known as the problem of learning long-term dependencies in RNNs, i.e., of learning to exploit statistical dependencies between events and variables which occur temporally far from each other.

RNNs are used to processes sequences of variable length. They have achieved state-of-the-art results for many machine learning sequence processing tasks. Examples where models based on RNNs shine include speech recognition (Miao et al., 2015; Chan et al., 2016), image captioning (Vinyals et al., 2015; Lu et al., 2017), machine translation (Luong et al., 2015).

It is common practice to train RNNs using gradients computed with back-propagation through time (BPTT), wherein the network states are unrolled in time over the whole trajectory of discrete time steps and gradients are back-propagated through the unrolled graph. The network unfolding procedure of BPTT does not seem biologically plausible because it requires storing and playing back these events much later (at the end of a trajectory of TT time steps) in reverse order to propagate gradients backwards. If a discrete time instant corresponds to a saccade (about 200-300ms,) then a trajectory of 100 days would require replaying back computations through over 42 million time steps. This is not only inconvenient, but more importantly a small error to any one of these events could either vanish or blow up and cause catastrophic outcomes. Also, if this unfolding and back-propagation is done only over shorter sequences, then learning typically will not capture longer-term dependencies linking events across larger temporal spans then the length of the back-propagated trajectory.

What are the alternatives to BPTT? One approach we explore here exploits associative reminding of past events which may be triggered by the current state and added to it, thus making it possible to propagate gradients with respect to the current state into approximate gradients in the state corresponding to the recalled event. The approximation comes from not backpropagating through the unfolded ordinary recurrence across long time spans, but only through this memory retrieval mechanism. Completely different approaches are possible but are not currently close to BPTT in terms of learning performance on large networks, such as methods based on the online estimation of gradients (Ollivier et al., 2015). Assuming that no exact gradient estimation method is possible (which seems likely) it could well be that brains combine multiple estimators.

In machine learning, the most common practical alternative to full BPTT is truncated BPTT (TBPTT) Williams & Peng (1990). In TBPTT, a long sequence is sliced into a number of (possibly overlapping) subsequences, gradients are backpropagated only for a fixed, limited number of time steps into the past, and the parameters are updated after each backpropagation through a subsequence. Unfortunately, this truncation makes capturing dependencies across distant timesteps nigh-impossible, because no error signal reaches further back into the past than TBPTT’s truncation length.

Neurophysiological findings support the existence of remembering memories and their involvement in credit assignment and learning in biological systems. In particular, hippocampal recordings in rats indicate that brief sequences of prior experience are replayed both in the awake resting state and during sleep, both of which conditions are linked to memory consolidation and learning (Foster & Wilson, 2006; Davidson et al., 2009; Gupta et al., 2010; Ambrose et al., 2016). Thus, the mental look back into the past seems to occur exactly when credit assignment is to be performed. Thus, it is plausible that hippocampal replay could be a way of doing temporal credit assignment (and possibly BPTT) on a short time scale, but here we argue for a solution which could handle credit assignment over much longer durations.

2 Novel Credit Assignment Mechanism: Sparse Attentive Backtracking

Inspired by the ability of brains to selectively reactivate memories of the past based on the current context, we propose here a novel solution called Sparse Attentive Backtracking (SAB) that incorporates a differentiable, sparse (hard) attention mechanism to select from past states. Inspired by the cognitive analogy of reminding, SAB is designed to retrieve one or very few past states. This may also be advantageous in focusing the credit assignment, although this hypothesis remains to be tested. SAB meshes well with TBPTT, yet allows gradient to propagate over distances far in excess of the TBPTT truncation length. We experimentally answer affirmatively the following questions:

Can Sparse Attentive Backtracking (sab) capture long-term dependencies? sab captures long-term dependencies. See results for 7 tasks supporting this in §4.

Generalization and transfer ability of sab? See the strong transfer results in §4.

How does sab perform compared to the Transformers (Vaswani et al., 2017)? sab outperforms the Transformers (comparison in §4).

Is sparsity important for sab and does it learn to retrieve meaningful memories? See the results on the Importance of Sparsity and Table 3 in §4.

Related Machine Learning Work

Neural architectures such as Residual Networks (He et al., 2016) and Dense Networks (Huang et al., 2016) allow information to skip over convolutional processing blocks of an underlying convolutional network architecture. This construction provably mitigates the vanishing gradient problem by allowing the gradient at any given layer to be bounded. Densely-connected convolutional networks alleviate the vanishing gradient problem by allowing a direct path from any layer in the network to the output layer. In contrast, in this work we propose and explore what one might regard as a form of dynamic skip connection, modulated by an attention mechanism corresponding to a reminding process, which matches the current state with an older state which is retrieved from memory.

The transformer network

The Transformer network (Vaswani et al., 2017) takes sequence processing using attention to its logical extreme – using attention only, not relying on RNNs at all. The attention mechanism is a softmax not over the sequence itself but over the outputs of the previous self-attention layer. In order to attend to multiple parts of the layer outputs simultaneously, the Transformer uses 8 small attention “heads” per layer (instead of a single large head) and combines the attention heads’ outputs by concatenation. No attempt is made to make the attention weights sparse, and the authors do not test their models on sequences of length greater than the intermediate representations of the Transformer model. With brains clearly involving a recurrent computation, this approach would seem to miss an important characteristic of biological credit assignment through time. Another implausible aspect of the Transformer architecture is the simultaneous access to (and linear combination of) all past memories (as opposed to a handful with SAB.)

Sparse Attentive Backtracking

Mindful that humans use a very sparse subset of past experiences in credit assignment, and are capable of direct random access to past experiences and their relevance to the present, we present here sab: the principle of learned, dynamic, sparse access to, and replay of, relevant past states for credit assignment in neural network models, such as RNNs.

In the limit of maximum sparsity (no access to the past), SAB degenerates to the use of a regular static neural network. In the limit of minimum sparsity (full access to the past), SAB degenerates to the use of a full self-attention mechanism. For the purposes of this paper, we explore the gap between these with a specific variety of augmented LSTM models; but SAB does not refer to any particular architecture, and the augmented LSTM described herein is used purely as a vehicle to explore and validate our hypotheses in §1.

Broadly, an SAB neural network is required to do two things:

During the forward pass, manage a memory unit and select at most a sparse subset of past memories at every timestep. We will call this sparse retrieval.

During the backward pass, propagate gradient only to that sparse subset of memory and its local surroundings. We will call this sparse replay.

Just as humans make a selective use of all past memories to inform their decisions in the present, so must an SAB model learn to remember and dynamically select only a few memories that could be potentially useful in the present. There are several alternative implementations of this concept. An important class of them are attention mechanisms, especially self-attention over a model’s own past states. Closely linked to the question of dynamic access to memory is the structure of the memory itself; for instance, in the Differentiable Neural Computer (DNC) (Graves et al., 2016), the memory is a fixed-size tensor accessed with explicit read and write operations, while in Bahdanau et al. (2014), the memory is implicitly a list of past hidden states that continuously grows.

For the purposes of this paper, we choose a simple approach similar to Bahdanau et al. (2014). Many other options are possible, and the question of memory representation in humans (faithful to actual brains) and machines (with good computationsl properties) remains open. Here, to test the principle of SAB without having to answer that question, we use an approach already shown to work well in machine learning. We augment a unidirectional LSTM with the memory of every kattk_{att}’th hidden state from the past, with a modified hard self-attention mechanism limited to selecting at most ktopk_{top} memories at every timestep. Future work should investigate more realistic mechanisms for storing memories, e.g., based on saliency, novelty, etc. But this simple scheme allows us to test the hypothesis that neural network models can still perform well even when compelled at every timestep to access their past sparsely. If they cannot, then it would be meaningless to further encumber them with a bounded-size memory.

We now describe the sparse retrieval mechanism that we have settled on. It determines which memories will be selected on the forward pass of the RNN, and therefore also which memories will receive gradient on the backward pass during training.

At time tt, the underlying LSTM receives a vector of hidden states h(t1)\bm{h}^{(t-1)}, a vector of cell states c(t1)\bm{c}^{(t-1)}, and an input x(t)\bm{x}^{(t)}, and computes new cell states c(t)\bm{c}^{(t)} and a provisional hidden state vector h^(t)\bm{\hat{h}}^{(t)} that also serves as a provisional output. We next use an attention mechanism that is similar to Bahdanau et al. (2014), but modified to produce sparse attention decisions. First, the provisional hidden state vector h^(t)\bm{\hat{h}}^{(t)} is concatenated to each memory vector m(i)\bm{m}^{(i)} in the memory M\mathcal{M}. Then, an MLP with one hidden layer maps each such concatenated vector to a scalar, non-sparse, raw attention weight ai(t)\smash{a^{(t)}_{i}} representing the salience of the memory ii at the current time tt. The MLP is parametrized with weight matrices W1\bm{W}_{1}, W2\bm{W}_{2} and W3\bm{W}_{3}.

A summary vector s(t)\bm{s}^{(t)} is then computed using a simple sum of the selected memories, weighted by their respective sparsified attention weight. Given that this sum is very sparse, the summary operation is very fast. This summary is then added into the provisional hidden state h^(t)\bm{\hat{h}}^{(t)} computed previously to obtain final state h(t)\bm{h}^{(t)}.

Lastly, to compute the SAB-augmented LSTM cell’s output y(t)\bm{y}^{(t)} at tt, we concatenate h(t)\bm{h}^{(t)} and summary vector s(t)\bm{s}^{(t)}, then apply an affine output transform parametrized with learned weights matrices V1\bm{V}_{1} and V2\bm{V}_{2} and bias vector b\bm{b}.

The forward pass into a hidden state h(t)\bm{h}^{(t)} has two paths contributing to it. One path is the regular sequential forward path in an RNN; the other path is through the dynamic but sparse skip connections in the attention mechanism that connect the present states to potentially very distant past experiences.

2 Sparse replay

Humans are trivially capable of assigning credit or blame to events even a long time after the fact, and do not need to replay all events from the present to the credited event sequentially and in reverse to do so. But that is effectively what RNNs trained with full BPTT require, and this does not seem biologically plausible when considering events which are far from each other in time. Even less plausible is TBPTT because it ignores time dependencies beyond the truncation length ktrunck_{trunc}.

SAB networks’ twin paths during the forward pass (sequential connection and sparse skip connections) allow gradient to flow not just from h(t)\bm{h}^{(t)} to h(t1)\bm{h}^{(t-1)}, but also to the at-most ktopk_{top} memories m(i)\bm{m}^{(i)} retrieved by the attention mechanism (and no others.) Learning to deliver gradient directly (and sparsely) where it is needed (and nowhere else) (1) avoids competition for the limited information-carrying capacity of the sequential path, (2) is a simple form of credit assignment, (3) and imposes a trade-off that is absent in previous, dense self-attentive mechanisms: opening a connection to an interesting or useful timestep must be made at the price of excluding others. This competition for a limited budget of ktopk_{top} connections results in interesting timesteps being given frequent attention and strong gradient flow, while uninteresting timesteps are ignored and starve.

If we not only allow gradient to flow directly to a past timestep, but on to a few local timesteps around it as well, we have mental updates: a type of local credit assignment around a memory. There are various ways of enabling this. In our SAB-augmented LSTM, we choose to perform TBPTT locally before the selected timesteps (ktrunck_{trunc} timesteps before a selected one.)

Experimental Setup and Results

For all tasks, We compare sab to two baseline models for all tasks. The first is an LSTM trained both using full BPTT and TBPTT with various truncation length. The second is an LSTM augmented with full self-attention trained using full BPTT. For pixel-by-pixle Cifar10 classification task, we also compare to the Transformer (Vaswani et al., 2017) architecture.

Copying and Adding problems (Q1)

The copy and adding problems defined in Hochreiter & Schmidhuber (1997) are synthetic tasks specifically designed to evaluate a model’s performance on long-term dependencies by testing its ability to remember a sub-sequence for a large number of timesteps. The performance of sab almost matches the performance of LSTMs augmented with self-attention trained using full BPTT. Note that our copy and adding LSTM baselines are more competitive compared to ones reported in the existing literature (Arjovsky et al., 2016). These findings support our hypothesis that at any given time step, only a few past events need to be recalled for the correct prediction of output of the current timestep.

Table 2 reports the cross-entropy (CE) of the model predictions on unseen sequences in the adding task. LSTM with full self-attention trained using BPTT obtains the lowest CE loss, followed by LSTM trained using BPTT. LSTM trained with truncated BPTT performs significantly worse. When T=200T=200, sab’s performance is comparable to the best baseline models. With longer sequences (T=400T=400), sab outperforms TBPTT, but is outperformed by pure BPTT. For more details regarding the setup, refer to supplementary material.

Character level Penn TreeBank (PTB) (Q1)

Details about our experimental setup can be found in the supplementary material. We evaluate the performance of our model using the bits-per-character (BPC) metric. As shown in Table 2, SAB’s performance is significantly better than TBPTT and almost matches BPTT, which is roughly what one expects from an approximate-gradient method like SAB.

Text8 (Q1)

Details about our experimental setup can be found in supplementary material. Note that we did not carry out any additional hyperparameter search for our model. Table 2 reports the BPC of the model’s predictions on the test sets. sab outperforms LSTM trained using TBPTT. SAB also outperforms LSTM and self-attention trained with TBPTT. For more details, refer to supplementary material.

Permuted pixel-by-pixel MNIST (Q1)

This task is a sequential version of the MNIST classification dataset. The task involves predicting the label of the image after being given its pixels as a sequence permuted in a fixed, random order. Our experiment setup can be found in the supplementary material. Table 5 shows that sab performs well compared to BPTT.

CIFAR10 classification (Q1,Q3)

We test our model’s performance on pixel-by-pixel CIFAR10 (no permutation). This task involves predicting the label of the image after being given it as a sequence of pixels. This task is relatively difficult compared to other tasks, as sequences are substantially longer (length 1024.) Our method outperforms Transformers and LSTMs trained with BPTT (Table 5).

Learning long-term dependencies (Q1)

Table 1 reports both accuracy and cross-entropy (CE) of the models’ predictions on unseen sequences for the copy memory task. The best-performing baseline model is the LSTM with full self-attention trained using BPTT, followed by vanilla LSTMs trained using BPTT. Far behind are LSTMs trained using truncated BPTT. Table 1 demonstrates that sab is able to learn the task almost perfectly for all copy lengths TT. Further, sab outperforms all LSTM baselines and matches the performance of LSTMs with full self-attention trained using BPTT on the copy memory task. This becomes particularly noticeable as the sequence length increases.

Transfer Learning (Q2)

We examine the generalization ability of sab compared to full BPTT trained LSTM and LSTM with full self-attention. The experiment is set up as follows: For the copy task of length T=100T=100, we train sab, LSTM trained with BPTT, LSTM and full self-attention to convergence. We then take the trained model and evaluate them on the copy task for an array of larger TT values. The results are shown in Table 5. Although all 3 models have similar performance on T=100T=100, it is clear that performance for all 3 models drops as TT grows. However, sab still manages to complete the task at T=5000T=5000, whereas by T=2000T=2000 both vanilla LSTM and LSTM with full self-attention do no better than random guessing (1/8=12.5%1/8=12.5\%).

Importance of Sparisity and Mental Updates (Q4)

We study the necessity of sparsity and mental updates by running an ablation study on the copying problem. The ablation study focuses on two variants. The first model attends to all events in the past while performing a truncated update. This can be seen either as a dense version of sab or an LSTM with full self-attention trained using TBPTT. Empirically, we find that such models are both more difficult to train and do not reach the same performance as sab. The second ablation experiment tests the necessity of mental updates, without which the model would only attend to the past time steps without passing gradients through them to preceding time steps. We observe a degradation of model performance when blocking gradients to past events. This effect is most evident when attending to only one timestep in the past (ktop=1\smash{k_{top}=1}).

We evaluate sab on language modeling, with the Penn TreeBank (PTB) (Marcus et al., 1993) and Text8 Mahoney (2011) datasets. For models trained using truncated BPTT, the performance drops as ktrunck_{\textrm{trunc}} shrinks. We found that on PTB, sab with ktrunc=20k_{\textrm{trunc}}=20, ktop=10k_{\textrm{top}}=10 performs almost as well as full BPTT. For the larger Text8 dataset, sab with ktrunc=10k_{\textrm{trunc}}=10 and ktop=5k_{\textrm{top}}=5 outperforms LSTM trained using BPTT.

Comparison to Transformer (Q3)

We test how SAB compares to the Transformer model (Vaswani et al., 2017), based a self-attention mechanism. On pMNIST, the Transformer model outperforms our best model, as shown in Table 5. On CIFAR10, however, our proposed model performs much better.

Conclusions

By considering how brains could perform long-term temporal credit assignment, we developed an alternative to the traditional method of training recurrent neural networks by unfolding of the computational graph and BPTT. We explored the hypothesis that a reminding process which uses the current state to evoke a relevant state arbitrarily far back in the past could be used to effectively teleport credit backwards in time to the computations performed to obtain the past state. To test this idea, we developed a novel temporal architecture and credit assignment mechanism called SAB for Sparse Attentive Backtracking, which aims to combine the strengths of full backpropagation through time and truncated backpropagation through time. It does so by backpropagating gradients only through paths for which the current state and a past state are associated. This allows the RNN to learn long-term dependencies, as with full backpropagation through time, while still allowing it to only backtrack for a few steps, as with truncated backpropagation through time, thus making it possible to update weights as frequently as needed rather than having to wait for the end of very long sequences.

Cognitive processes in reminding serve not only as the inspiration for SAB, but suggest two interesting directions of future research. First, we assumed a simple content-independent rule for selecting microstates for inclusion in the macrostate, whereas humans show a systematic dependence on content: salient, extreme, unusual, and unexpected experiences are more likely to be stored and subsequently remembered. These landmarks of memory should be useful for connecting past to current context, just as an individual learns to map out a city via distinctive geographic landmarks. Second, SAB determines the relevance of past microstates to the current state through a generic, flexible mapping, whereas humans perform similarity-based retrieval. We conjecture that a version of SAB with a strong inductive bias in the mechanism to select past states may further improve its performance.

Acknowledgement

The authors would like to thank Hugo Larochelle, Walter Senn, Alex Lamb, Remi Le Priol, Matthieu Courbariaux, Gaetan Marceau Caron, Sandeep Subramanian for the useful discussions, as well as NSERC, CIFAR, Google, Samsung, SNSF, Nuance, IBM, Canada Research Chairs, National Science Foundation awards EHR-1631428 and SES-1461535 for funding. We would also like to thank Compute Canada and NVIDIA for computing resources. The authors would also like to thank Alex Lamb for code review. The authors would also like to express debt of gratitude towards those who contributed to Theano over the years (now that it is being sunset), for making it such a great tool.

References

Supplementary material

We follow the setup of the copying memory problem from Hochreiter & Schmidhuber (1997). Specifically, the network is given a sequence of T+20T+20 inputs consisting of: a) 10 (randomly generated) digits (digits 1 to 8) followed by; b) TT blank inputs followed by; c) a special end-of-sequence character followed by; d) 10 additional blank inputs. After the end-of-sequence character the network must output a copy of the initial 10 digits.

The adding task

The adding task requires the model to sum two specific entries in a sequence of TT (input) entries (Hochreiter & Schmidhuber, 1997). In the spirit of the copying task, larger values of TT will require the model to keep track of longer-term dependencies. The exact setup is as follows. Each example in the task consists of two input vectors of length TT. The first is a vector of uniformly generated values between and 11. The second vector encodes a binary mask which indicates the two entries in the first input to be added (the mask vector consists of T2T-2 zeros and 22 ones). The mask is randomly generated with the constraint that masked-in entries must be from different halves of the first input vector.

Hyperparameters

The hyperparameters for both baselines and sab are kept the same. All models has 128 hidden units and uses the Adam (Kingma & Ba, 2014) optimizer with a learning rate of 1e31e-3. The first model in the ablation study (dense version of sab) was more difficult to train, therefore we explored different learning rate ranging from 1e31e-3 to 1e51e-5, we report the best performing model.

2 Char Level PennTree Bank

We follow the setup in Cooijmans et al. (2016) and all of our models use 1000 hidden units for and a learning rate of 0.002. We used non-overlapping sequences of 100 in the batches of 32 as in Cooijmans et al. (2016). All models trained for upto 100 epochs with early stopping on the validation set. We evaluate the performance of our model using the bits-per-character (BPC) metric.

3 Char Level Text8

We follow the setup of Mikolov et al. (2012); use the first 90M characters for training, the next 5M for validation and the final 5M characters for testing. We train on non-overlapping sequences of length 180. Due to computational constraints, all baselines use 1000 hidden units. We trained all models using a batch size of 64. We trained sab for a maximum of 30 epochs.

4 Permuted Pixel-by-pixel MNIST

All models use an LSTM with 128 hidden units. The prediction is produced by passing the final hidden state of the network into a softmax. We used a learning rate of 0.001. We trained our model for about 100 epochs, and did early stopping based on the validation set.

5 Comparison to LSTM + Self Attention(with truncation)

While SAB is trained with truncated BPTT (and the vanilla LSTM+self-attention is not), Here we argue, that training the vanilla LSTM and self attention with truncation works less well on a more challenging Text8 language modelling dataset.

Computational Complexity of SAB

If the memory was allowed to grow unbounded in size, then the computational complexity would scale linearly with the length of history. However, humans have a bounded memory. In a computer science context with unbounded memory, the time complexity of the forward pass of both training and inference in sab is O(t2n2)O(t^{2}n^{2}), with tt the number of timesteps and nn the size of the hidden state. The space complexity of the forward pass of training is unchanged at O(tn)O(tn), but the space complexity of inference in sab is now O(tn)O(tn) rather than O(n)O(n). However, the time cost of the backward pass of training cost is very difficult to formulate. Hidden states depend on a sparse subset of past microstates, but each of those past microstates may itself depend on several other, even earlier microstates. The web of active connections is, therefore, akin to a directed acyclic graph, and it is quite possible in the worst case for a backpropagation starting at the last hidden state to touch all past microstates several times. However, if the number of microstates truly relevant to a task is low, the attention mechanism will repeatedly focus on them to the exclusion of all others, and pathological runtimes will not be encountered.

Our method approximates the true gradient but in a sense it’s no different than the kind of approximation made with truncated gradient, except that instead of truncating to the last ktrunck_{\textit{trunc}} time steps, we truncate to one skip-step in the past, which can be arbitrarily far in the past. This provides a way of combating exploding and vanishing gradient problems by learning long-term dependencies. To verify the fact, we ran our model on all the datasets (Text8, Pixel-By-Pixel MNIST, char level PTB) with and without gradient clipping. We empirically found, that we need to use gradient clipping only for text8 dataset, for all the other datasets we observed little or no difference with gradient clipping.