FiDO: Fusion-in-Decoder optimized for stronger performance and faster inference
Michiel de Jong, Yury Zemlyanskiy, Joshua Ainslie, Nicholas FitzGerald, Sumit Sanghai, Fei Sha, William Cohen
Introduction
A large body of work has demonstrated that language model performance on downstream tasks can be improved by augmenting the model with relevant retrieved text (Guu et al., 2020; Lewis et al., 2020; Izacard and Grave, 2021; Izacard et al., 2022). In particular, the Fusion-in-Decoder (FiD) architecture (Izacard and Grave, 2021) stands out for strong performance, even outperforming much larger models on many knowledge-intensive tasks (Izacard et al., 2022). However, FiD uses a standard T5 encoder-decoder architecture Raffel et al. (2020) which was not designed for use as a retrieval-augmented model. In this work we propose FiDO, a modified FiD architecture optimized for the retrieval-augmented setting.
The FiD decoder is responsible for a difficult task, assimilating information from many passages and reasoning over the information to generate an output. However, because the encoder and decoder are similar size and the encoder is applied to a large number of retrieved passages, FiD devotes an order of magnitude more Floating Point Operations (FLOPs) to the encoder than the decoder. In spite of this, the majority of inference time is actually spent in the decoder, as has been observed in prior work (Hofstätter et al., 2022). This surprising result is shown in Figure 1. Our analysis finds that for typical inference settings the FiD decoder is memory-bandwidth bound (Williams et al., 2009) due to using multi-head cross-attention (Vaswani et al., 2017) over a large input sequence.
Based on this analysis, we propose two sets of architectural changes. We first propose to reduce the cost of cross-attention over retrieved passages by removing most cross-attention layers from the decoder. This reduces cost and yields much smaller losses in performance than FiD-Light (Hofstätter et al., 2022), the best previously-proposed approach for optimizing FiD. We also replace multi-head attention with multi-query attention (Shazeer, 2019). With these modifications the memory-bandwidth bottleneck is eliminated: decoder inference is now orders of magnitude faster and most inference time is spent in the encoder, consistent with the balance of FLOPs between components.
Finally, we propose to partially rebalance compute towards the decoder by massively scaling decoder size, using a smaller encoder to extract information from retrieved passages and a larger decoder to assimilate the information and reason about the desired output. We refer to the resulting series of models as FiDO (Fusion in Decoder Optimized) and show that FiDO strongly outperforms standard FiD models on the question-answering datasets Natural Questions (Kwiatkowski et al., 2019), TriviaQA (Joshi et al., 2017) and WebQuestions (Berant et al., 2013) for a wide range of inference budgets and settings. Figure 2 summarizes some of these results.
Analysis
FiDFiDO Figure 3: MAIN RESULT. FiDO achieves much higher performance for any given inference budget. Exact match on Natural Questions (NaturalQ), TriviaQA and WebQuestions (WebQ) test sets as a function of inference budget (log scale). Compares FiD Small, Base and Large models with FiDO Small-Large, Base-XL, Large-XXL and XL-XXL models. Retrieval-augmented models generally read many context tokens relative to the number of question or answer tokens, such that processing retrieved text consumes the bulk of FLOPs. However, past work has shown that most inference time for Fusion-in-Decoder (FiD) is spent in the decoder (Hofstätter et al., 2022). Our own experiments support this (Figure 1). This section investigates FiD’s computational structure and decoder inference speed, and finds the slower decoder speed to be the result of memory bandwidth constraints, exacerbated by attention over retrieved documents.
The backbone of the Fusion-in-Decoder model (Izacard and Grave, 2021) is a T5 encoder-decoder architecture. The model is provided a question or other input, as well as a number of relevant retrieved text passages. The question is prepended to each retrieved passage, and then the encoder is applied to each passage separately. The resulting representations are concatenated. Finally, the decoder cross-attends to the large number of concatenated representations and assimilates the information from the different passages to generate an answer, hence Fusion-in-Decoder.
2 FLOPs of FiD model
Model speed is determined by the number of FLOPs and the speed at which computations are performed, typically measured in floating point operations per second (FLOP/s). Operations in a Transformer can be roughly divided into MLP layers, attention projection layers, and attention operations. For simplicity, we count only multiplication operations.
Let be the dimension of the model, the total number of tokens across all passages, the number of tokens in a single retrieved passage, the number of tokens in the target, the number of layers, and assume the MLP dimension is . The number of FLOPs used in an encoder layer is approximately
Since the size of each retrieved passage , computation of the attention score is negligible and we can approximate total FLOPs in the encoder as
Decoder layers additionally have cross-attention layers, leading to FLOPs of
The output length , so the only non-negligible term for decoder FLOPs originates from the cross-attention key and value projections, which cost the same FLOPs as encoder key and value projections. We see that the decoder consumes roughly the FLOPs of the encoder.
Figure 1 shows that actual measured training time closely mirrors this FLOPs approximation. However, the decoder is much more expensive for inference. We argue below this is because the decoder is memory bandwidth constrained during inference, specifically the cross-attention layers.
3 Effective computational throughput
In order to perform computations, accelerators must transmit data between global memory and registers, which can be a limiting factor. The actual FLOP/s achieved can be usefully modeled with the roofline model (Williams et al., 2009; Ofenbeck et al., 2014; Mohan, 2018) as the lesser of peak FLOP/s the device is capable of and how fast required data can be transferred.
The data constraint is given by the product of device memory bandwidth – how fast data can be transferred – and operational intensity – how many operations are performed per unit of data. The latter is determined by an algorithm’s degree of data reuse, the number of operations that can be performed before new data needs to be fetched.
High operational intensity is necessary for good performance on modern GPU/TPU hardware, for which peak FLOP/s are usually two orders of magnitude times larger than memory bandwidth (Google, 2022; NVIDIA, 2022). If operational intensity is too low, the accelerator will spend the majority of its time waiting for data to be transferred to registers. Usually, that happens when the model performs minor computations with large tensors repeatedly, for example in normalization layers or during incremental decoding.
4 Operational intensity of FiD inference
Shazeer (2019) shows that the speed of incremental Transformer decoding is memory-bandwidth bound due to low operational intensity. Here we follow their analysis and derive the asymptotic inverse of operational intensity – the ratio of memory operations to the compute performed during each incremental decoding step – for FiD. Let be the batch size, the number of attention heads and assume that attention heads have dimension .
For each token the linear projections perform operations, and load memory, where corresponds to activations and to the weight matrices. During training, sequence length effectively multiplies batch size as weights need to be loaded only once for the entire sequence, but for inference each token is processed incrementally. The inverse operational intensity is then
Therefore, obtaining high operational intensity of MLP layer () during inference requires a large batch size.
Operational intensity of attention layers.
Memory bandwidth is a more severe bottleneck for attention inference, particularly cross-attention. At each decoding step the model applies projections for a single token, and has to load all cached key and value projections from encoder tokens and prior decoder tokens into memory. This leads to very low operational intensity.
Specifically, query/key/value/output projections for a single position take operations. As discussed earlier, we can ignore the attention computation itself. The model needs to load projection matrices ( memory) and past keys and values ( memory). Therefore, the inverse operational intensities for self-attention layers, and cross-attention layers are
Because the source input length is extremely long for FiD, the cross-attention operational intensity is very low, which bottlenecks inference.
Method
We have shown that the encoder accounts for the bulk of FiD FLOPs and training cost, while FiD spends the majority of inference time in the decoder due to low operational intensity of cross-attention layers. Next we propose several ways to alleviate the decoder bottleneck. This allows us to efficiently allocate more compute to the decoder by scaling decoder size without significantly increasing the inference speed. We denote Fusion-in-Decoder with the proposed optimizations as FiDO (Fusion-in-Decoder Optimized).
The decoder cross-attention layer is the primary bottleneck for inference due to its low operational intensity. FiD-Light (Hofstätter et al., 2022) improves the operational intensity by reducing the effective input length by a factor of . We instead propose to remove cross-attention from some decoder layers entirely, keeping cross-attention only in one out of every decoder layers. We call this layer-sparse cross-attention (LSA). Section 5 provides evidence that LSA achieves similar speedups without FiD-Light’s drop in quality. For FiDO we use LSA with sparsity , which means that a Large decoder has cross-attention only at layers 6, 12, 18 and 24. In principle LSA and FiD-Light can be combined, but we find that after applying LSA and multi-query attention the remaining cross-attention makes up a small proportion of decoder inference cost and further speedups from reducing cross-attention are modest (Figure 4).
Removing cross-attention layers also reduces FiD’s FLOPs and memory usage. Cross-attention layers make up approximately of total FiD FLOPs (see Eqn 2) and applying LSA-6 leads to a 12% reduction in FLOPs. Table 2 shows the reduction in FLOPs is reflected by an increase in training speed. Moreover, cross-attention keys and values make up a substantial proportion of memory usage during inference, and LSA-6 enables a much larger batch size (Table 1).
2 Multi-query attention
Shazeer (2019) proposes to increase the operational intensity of decoder attention layers by applying multi-query attention, in which keys and values share a single head each and only queries have multiple heads. With a single head, keys and values use a factor less memory and are much faster to load. With multi-query attention, keys and values occupy memory, so that the inverse operational intensity of cross-attention becomes
which has the problematic term reduced by factor of . Multi-query attention further reduces inference cost (Figure 2) and memory (Table 1) on top of layer-sparse cross-attention, though not training speed (Table 2).
3 Asymmetric Decoder
Section 5.4 showed that the FiD encoder consumes an order of magnitude more FLOPs than the decoder because the encoder and decoder are the same size but the encoder is applied to many more tokens. After applying layer-sparse cross-attention and multi-query attention, the decoder also takes up much less time for inference. Such an allocation may not be optimal, as the FiD decoder is responsible for a more challenging task than the standard T5 encoder: it has to assimilate and reason over information from many passages.
We propose to partially redress this imbalance through massively scaling the decoder up, by as much as 15x. Because the decoder is applied to fewer tokens, and because increased decoder dimension improves operational efficiency, such scaling only modestly increases inference cost. For example, Figure 2 shows that replacing the Base-sized decoder with an XL-sized decoder increases the total inference time per sample by only 21%. Fine-tuning costs also increase only modestly (Table 2). However, pre-training costs increase more (though still much less than the scaling factor of the decoder), as T5 pre-training uses a much smaller ratio of input length to output length. After reducing the decoder cross-attention memory costs scaling the decoder only mildly increases activation memory, so that FiDO can still fit much larger batch sizes than vanilla FiD (Table 1). For the FiDO method we use decoders that are typically two T5 sizes larger than the encoder: Small-Large, Base-XL, Large-XXL and XL-XXL (as XXL is the largest T5 model).
Related Work
There exists a large body of retrieval-augmented approaches. Some particularly well known models are REALM (Guu et al., 2020), RAG (Lewis et al., 2020), RETRO (Borgeaud et al., 2022) and Fusion-in-Decoder (Izacard and Grave, 2021). FiD in particular has achieved state-of-the-art performance on a wide variety of tasks (Izacard and Grave, 2021; Izacard et al., 2022; Yu et al., 2022b) and in this work we focus on improving the performance-efficiency trade-offs for FiD. RETRO is another closely related retrieval-augmented model, as it uses a small encoder for retrieved context and a larger primary decoder like FiDO does. Unlike RETRO, FiDO’s efficiency improvements allow it to tractably attend to many retrieved passages with a much larger decoder.
Efficient Transformers
Our work builds heavily on existing insights into neural network and particularly Transformer speed. Previous work has found that data movement is often a constraining factor for computations on modern devices (Williams et al., 2009; Dao et al., 2022; Shazeer, 2019). Shazeer (2019) shows that autoregressive Transformers are particularly bandwidth bound during inference, and proposes multi-query attention as a partial solution. We find that this is exacerbated by the FiD setting, and adopt multi-query attention for FiDO to ameliorate the problem. Pope et al. (2022) also investigates multi-query attention, primarily in the context of efficient inference and parallelization for very large language models, whereas we focus on performance/cost trade-offs for the retrieval-augmented setting.
Another way to alleviate memory bandwidth constraints is to quantize model parameters and possibly activations (Dettmers et al., 2022; Zeng et al., 2022). Quantizing models reduces data that needs to be sent to device registers, and also reduces overall memory usage which allows for larger, more efficient batch sizes. Finally, it is possible to distill (Hinton et al., 2015; Gou et al., 2021) models into a smaller student model, which is cheaper for inference. However, knowledge distillation requires labeling a very large number of samples with the larger model, so reducing the inference costs of larger models is highly valuable.
Efficient retrieval-augmented models
FiDO lies in a body of work that attempts to improve the efficiency of retrieval-augmented or long-input models. One direction focuses on reducing the cost of the attention mechanism. LongT5 (Guo et al., 2022) routes long-range attention through a small number of global tokens. FiD-Light (Hofstätter et al., 2022), the most closely related work to FiDO, employs a similar mechanism for FiD, as the decoder attends to only the first proportion of representations of each retrieved passage. We opt to introduce sparsity in attention layers as in ReadTwice (Zemlyanskiy et al., 2021) instead of attention patterns. FiDO applies cross-attention from the decoder to the encoder in one out of every K layers, which achieves a similar speedup to FiD-Light but with only minor performance penalty. FiDO also incorporates multi-query attention leading to a further order of magnitude reduction in decoder inference cost, and takes advantage of this to massively scale the decoder.
A different and complementary direction is to reduce the cost of reading retrieved passages. KG-FiD (Yu et al., 2022a) reranks retrieved passages and reads only the top passages, while Varshney et al. (2022) reads more retrieved passages only if it is not confident in its answer. Another approach is to pre-compute and store encoder representations in a memory and directly retrieve representations from memory, rather than re-encoding retrieved text (de Jong et al., 2022; Wu et al., 2022; Li et al., 2022). For standard FiD, the decoder actually makes up the bulk of the inference cost. FiDO reduces the cost of the decoder such that encoding retrieved passages becomes the bottleneck, increasing the benefit of the above approaches.
Experiments
All models are based on the T5.1.1 architecture (Raffel et al., 2020), pre-trained from scratch on C4 (Dodge et al., 2021) using JAX (Bradbury et al., 2018), FLAX (Heek et al., 2020), and T5X (Roberts et al., 2022). We employ the standard T5 training recipe except for a modified Adafactor (Shazeer and Stern, 2018) optimizer. Appendix A describes training in greater detail.
Downstream evaluation
We evaluate FiDO on open-domain question-answering datasets Natural Questions (Kwiatkowski et al., 2019), TriviaQA (Joshi et al., 2017) and WebQuestions (Berant et al., 2013). We report results on the open-domain QA splits from Lee et al. (2019). For all datasets, each sample is paired with a set of 100-word Wikipedia passages ranked by DPR (Karpukhin et al., 2020) score. The question is prepended to each retrieved passage, and then truncated to 256 tokens. The experiments in the paper use 40 retrieved passages to balance performance and speed, but our results hold across a wide range of retrieved passages.
Inference setup
For our main results we choose a setting that we believe is most representative for common use of retrieval-augmented models. We perform inference on a single TPUv4 and report inference time per sample (TPS) as measured by xprof (Google, 2020). We use a batch size of 64 (or the largest batch size that fits, if smaller) for the main experiments. Figure 1 and 2 use batch size 24 to ensure a like-for-like comparison, as it is the largest batch size that fits for vanilla FiD. All experiments use 40 passages of 256 tokens and output size of 32 tokens. Predictions are generated with greedy decoding as we found beam search did not meaningfully improve performance for considered tasks. Analysis in Section 5.4 investigates how trade-offs change with input and output length, low batch size and different sampling methods.
2 Main results
Figure 3 shows performance as a function of inference time for FiD and FiDO. FiDO strongly outperforms FiD at any inference budget and achieves the same performance with order of magnitude faster speed. The following section investigates how each component of FiDO contributes to its performance. Table 5 compares FiDO to published results.
3 Components
First, Table 3 shows that layer-sparse cross-attention significantly reduces inference cost with modest performance degradation. Separately, Table 4 compares the inference speed and performance impact of layer-sparse cross-attention with the token-sparse cross-attention from FiD-Light. Reducing cross-attention layers and inducing encoder output sparsity by the same factor lead to similar speedups, but layer-sparse cross-attention achieves the inference speedup with much lower performance penalty.
Note that we find a much larger performance degradation from compressing the encoder output in our setting compared to the experiments in Hofstätter et al. (2022). Some exploratory experiments suggest that multi-task training fine-tuning on large amounts of data as done in FiD-Light may ameliorate the performance penalty from compressing encoder output; however even with such training Hofstätter et al. (2022) still report significant peformance degradation, in contrast to LSA.
Layer-sparsity over a factor of 6 incurs greater performance penalties. However, as shown in Table 4, with LSA-6 cross-attention already makes up a small proportion of total decoder inference cost.
Multi-query attention
Table 3 shows that multi-query attention achieves a large cost reduction on top of layer-sparse cross-attention with minimal performance degradation, consistent with our analysis and findings from Shazeer (2019).
Decoder scale
We can see in Table 3 that increasing the size of the decoder leads to a significant improvement in performance at the cost of a modest increase in inference time. Figure 5 provides a visual comparison of the performance-inference profile for FiDO with and without asymmetric decoders and shows that asymmetric large decoders achieve a better trade-off.
4 Other analysis
FiDFiD + LSAFiD + LSA + MQFiDO Figure 6: Time per sample (TPS) as a function of retrieved passages (left) or the number of generated tokens (right) for Base FiD variants and FiDO-Base-XL. Our main results use a middle-of-the-road setting for FiD applications with a medium number of retrievals and a relatively short output, reflecting common knowledge-intensive tasks. However, it is interesting to ask how FiDO components affect speed for other settings. Figure 6 shows time per sample as a function of retrieved passages and length of the target output for each step from FiD to FiDO.
We first note that layer-sparse cross-attention and multi-query attention are critical across all settings. For standard output length, the asymmetric decoder is cheap for any reasonable number of retrieved passages, becoming negligible as a fraction of total inference time as the number of retrievals increases. As output length increases, the cost of the disproportionately large decoder rises, although it only becomes a substantial proportion of inference time for output length of 256-512 and above. For tasks with long outputs, such as summarization, one may want to reduce the level of decoder asymmetry (e.g. Base-Large rather than Base-XL).
Low batch size setting
For our primary investigation we focus on medium batch sizes (24+). There are two reasons one might care about smaller batch sizes: either because larger batches do not fit in memory or because they lead to excessive latency. The first constraint is not binding for FiDO: due to FiDO’s memory efficiency we are able to fit larger batches even for the XL-XXL model, and if necessary model size can be further extended with quantization (Zeng et al., 2022) and parallelism (Pope et al., 2022).
For real-time serving latency can be a constraint, but in those settings it is common practice to use much smaller models which are distilled from larger teacher models (Gou et al., 2021). The student models can utilize a higher batch size, while the teacher models do not have latency constraints, so FiDO also applies to this use case.
For rare cases where a lower batch size is required layer-sparse and multi-query attention are still important, but cannot fully eliminate the decoder as a bottleneck for inference (Table 6). The term in Equation 5 dominates, reflecting the fact that the model has to repeatedly load model parameters without spreading the cost over many samples.
Instead of scaling the decoder, it would be more cost-effective to apply more expensive sampling methods, because sampling methods increase the effective batch size. For example, beam search with large beams is nearly free at lower batch sizes.
Sampling
We do not apply beam search for our main experiments as decoder inference time is proportional to beam width for medium batch sizes and beam search does not improve performance on the considered set of tasks. Instead, we find that scaling decoder size provides a more cost-efficient way to add decoder capacity. Table 7 compares the performance vs time trade-offs from beam search and scaling the decoder for Natural Questions, and shows that scaling the decoder is significantly more effective. Beam search may be more important for other tasks, such as tasks with longer outputs.
Conclusion
We perform analysis of the performance-inference speed tradeoff for FiD, showing that the encoder uses more FLOPs but most time is spent in the decoder due to memory bandwidth constraints. We propose FiDO, an extension of FiD which removes most cross-attention layers and employs multi-query attention to vastly reduce the cost of the decoder. The resulting model spends most time in the encoder, consistent with compute analysis, which FiDO takes advantage of by strongly increasing the size of the decoder. We show that FiDO achieves much stronger performance for the same inference budget relative to existing FiD models.
Acknowlegements
We thank Livio Baldini Soares, Kenton Lee, Pat Verga, Iftekhar Naim and others at Google Research for insightful advice and discussion. Michiel de Jong is partially supported by NSF Awards IIS-1513966/ 1632803/1833137, CCF-1139148, DARPA Awards#: FA8750-18-2-0117, FA8750-19-1-0504, DARPA-D3M - Award UCB-00009528, Google Research Awards, gifts from Facebook and Netflix, and ARO# W911NF-12-1-0241 and W911NF-15-1-0484.
Limitations
One of the advantages of the Fusion-in-Decoder approach is that it uses the off-the-shelf T5 architecture with publicly available checkpoints. The proposed FiDO modifications strongly improve performance and inference speed for retrieval-augmented question-answering, but require pre-training from scratch. It is in general preferable to have a small number of checkpoints that can be fine-tuned for any application. For example, it may not be feasible to train different giant language models for use in the retrieval-augmented setting. Instead, the architectures for such large models may need to be a compromise for different use cases.
Ethics
In general the ethics concerns for this paper are similar to those for the large body of work studying retrieval-augmented language models. One distinction worth pointing out is that this work proposes a model with faster inference, which makes retrieval-augmented models more feasible to apply in practical settings and serve to users and inherently carries higher risk.
References
Appendix A Training
All experiments are built on the T5.1.1 architecture with the training recipe from T5 (Raffel et al., 2020). The first exception is the optimizer; we find that the second moment factoring and mixing schedule from Adafactor (Shazeer and Stern, 2018) can lead to instability, especially with unbalanced encoder and decoder sizes. Instead, we disable factoring and second moment mixing, leading to an optimizer that is a hybrid between Adafactor and Adam (Kingma and Ba, 2015).
The second difference to the training recipe arises from the observation that FiDO XL-XXL is unstable for the standard training regimen. We solve the instability by restarting from a recent healthy checkpoint with a 10x decreased learning rate, which happened once.
During fine-tuning, we load not only model weights but also second moment estimates, which we find leads to better fine-tuning in general and particularly for asymmetric models. We finetune with learning rate 0.001 and batch size 64 for all datasets. For evaluation on test sets we select the checkpoint with the best validation performance.