Conditional Adapters: Parameter-efficient Transfer Learning with Fast Inference
Tao Lei, Junwen Bai, Siddhartha Brahma, Joshua Ainslie, Kenton Lee, Yanqi Zhou, Nan Du, Vincent Y. Zhao, Yuexin Wu, Bo Li, Yu Zhang, Ming-Wei Chang
Introduction
Large pretrained models have achieved groundbreaking results but the main impediment to deploy them has been the cost of adaptation and inference. Due to the ever growing size of the pretrained models, for example, finetuning has become increasingly expensive as it requires a separate copy of the full model and updates to all parameters for every downstream task. Parameter-efficient transfer learning such as Adapter (Houlsby et al., 2019) and Prompt Tuning (Lester et al., 2021) have been proposed to address this issue. These methods only update a small subset of parameters for each downstream task, allowing the model to retain knowledge and avoid catastrophic forgetting (Vu et al., 2022). Noticeably, these methods can match the accuracy of a fully finetuned model, while achieving better accuracy on out-of-domain data distributions (Lester et al., 2021; Awadalla et al., 2022).
Unfortunately, standard parameter-efficient transfer learning methods only bring parameter efficiency, not inference efficiency. For example, while only a few small projection matrices are added into the pretrained model in the Adapter approach, all the model inputs (such as tokens) still use all parameters during inference. Therefore, the inference speed is the same (or slightly lower) with respect to the full finetuning method. Moreover, prior studies have shown that these parameter-efficient learning methods are most effective when the size of the pretrained model is large (Lester et al., 2021), making many advantages of these methods difficult to realize in practice.
In this paper, we propose Conditional Adapter (CoDA), a parameter-efficient transfer learning method that offers both parameter and inference efficiency. CoDA is a generalization of the adapter approach, built with the following intuition – we can treat the pretrained model as a universal source of knowledge but only query against it for necessary inputs. Figure 2 compares CoDA with finetuning and standard adapter approaches. Similar to standard adapter approaches, our model adds and updates a small adapter in each layer, while fixing the pretrained Transformer blocks for downstream adaptation. Unlike previous approaches, however, CoDA assumes that many of input token representations (of each layer) are not important for the prediction task and therefore do not require heavy computation. In such cases, the pretrained Transformer block can be skipped. Given that many tokens are not processed by the Transformer block, CoDA runs significantly faster than previous methods.
While conditional activation has clear speed benefits, CoDA must learn to select important tokens for heavy computation in order to maintain model accuracy. To this end, we introduce a soft top- operation for computing the token selection decision. This soft top- operation, which can be seen as a generalization of softmax and a relaxation of hard top-, utilizes entropy-regularized optimization techniques similar to computational optimal transport (Cuturi, 2013). As a result, its output can be computed using fast and differentiable iterations, allowing token selection to be directly optimized for model performance.
We apply CoDA on encoder-heavy tasks and evaluate its effectiveness on three different domains – natural language processing, computer vision and speech processing. Overall, CoDA achieves 2 to 8 times inference speed-up over standard adapter approach with moderate to no accuracy loss. Table 2 showcases our results by selecting one of the best performing tasks in each domain. We also conduct comprehensive ablation studies to analyze the effectiveness, efficiency and scalability of CoDA. For example, we found that with just a little to no router pretraining, existing dense pretrained models such as T5 (Raffel et al., 2020) can be efficiently converted into CoDA models to gain both parameter efficiency and speed advantages.
Related Work
Due to the ever-growing number of parameters in the pretrained Transformer models, various methods have been proposed for transfer learning with minimal parameter updates. Prompt tuning (Lester et al., 2021) and prefix tuning (Li and Liang, 2021) introduce new virtual token embeddings that can be finetuned as model parameters. Adapter approaches (Houlsby et al., 2019; He et al., 2021) add a small number of new, learnable parameters to each layer while keeping the pretrained parameters fixed. Another popular method, Low-Rank Adaptation (LoRA; Hu et al., 2021), injects learnable low-rank decomposition matrices into pretrained model parameters. In addition to requiring less storage cost, parameter-efficient methods have been shown to be more sample-efficient and achieve better out-of-domain generalization than standard finetuning. CoDA is an adapter approach but can be easily combined with other parameter-efficient methods such as LoRA to accelerate their inference.
Conditional computation
The development of sparsely and conditionally activated models has been a very active research area. For example, Mixture-of-Experts (MoE) models (Shazeer et al., 2017) and many recent advances (Du et al., 2022; Fedus et al., 2021) have been proposed to scale up the size of language models without increasing the computation cost. Many recent works have explored better token routing methods for MoE models, for example using random hashing (Roller et al., 2021), balanced assignment (Lewis et al., 2021) and expert-choosing router (Zhou et al., 2022). CoDA applies conditional computation to both attention and feed-forward blocks of the model, whereas MoE models only focus on sparse activation in the feed-forward blocks.
Similar to our approach, various recent methods have achieved computation efficiency by skipping computation on a subset of input tokens. However, the selection mechanism can be very different, such as using pooling (Nawrot et al., 2022), token merging (Bolya et al., 2023), token pruning (Rao et al., 2021; Yin et al., 2022), learned sigmoid gates (Bapna et al., 2020) and early exiting (Schuster et al., 2022). While most of the token merging and pruning methods have been proposed for vision tasks, we show that CoDA is applicable to multiple domains including text, vision and speech. In addition, token merging and our token selection method are built with different inductive biases and intuition. Token merging leverages redundancies in visual tokens, while token selection assumes a spike of token relevance. That is, only a few tokens are necessary for the prediction task. Another major difference is that CoDA dynamically routes and updates token representations in each layer, whereas if a token is pruned (or merged), it will never be re-used by subsequent layers. We believe our token routing mechanism is more suited for text and speech applications, such as question answering, where different tokens might play important roles in different layers, or given different input queries.
Finally, CoDA is closely related to a concurrent work, CoLT5 (Ainslie et al., 2023), which also utilizes conditional activation (token selection) for inference efficiency. The focus of CoLT5 and CoDA are very different. CoLT5 specifically tailors its model architecture for long text (e.g. over 16k tokens), for example, by combining local attention with routed attention. The CoLT5 models are pre-trained from scratch and all parameters are finetuned for downstream tasks. In comparison, CoDA is directly initialized and adapted from an already pretrained dense model, and we optimize its performance on parameter-efficient transfer learning. The strengths of CoDA and CoLT5 can be combined for long text applications.
Efficient Transformer models
Many efficient Transformer variants have been proposed to accelerate model computation. Examples include creating fast attention variants (Wang et al., 2020a; Beltagy et al., 2020; Guo et al., 2022; Hua et al., 2022), searching network architectures (Press et al., 2019; So et al., 2021; Su et al., 2021) and utilizing non-attention neural modules for efficiency (Gulati et al., 2020; Lei, 2021). CoDA utilizes conditional computation as an orthogonal approach for efficiency.
Model compression
Apart from building efficient model architectures, model compression methods such as pruning (Han et al., 2016; Zhu and Gupta, 2017; Wang et al., 2020b; Xia et al., 2022) and distillation (Hinton et al., 2015; Kim and Rush, 2016; Turc et al., 2019; Lin et al., 2020) can be adopted to speed up model inference. Compared to these methods, CoDA retains all model parameters of the pretrained large model, and therefore avoids retraining a new model from scratch or knowledge forgetting caused by parameter removal. In addition, CoDA can be seen as a dynamic version of layer pruning because it can activate different Transformer layers for each token, and can be further combined with distillation to reduce the loss of accuracy caused by conditional computation.
Method
Throughout this and the experiment section, we build CoDA on top of parallel adapters (He et al., 2021). However, note that our method can be generalized to other types of adapters such as sequential adapters (Houlsby et al., 2019) and LoRA (Hu et al., 2021). We present additional experimental results using LoRA in Appendix B.3. Figure 3 illustrates our architecture and shows how CoDA computes its output by selecting only a small subset of input tokens to query against the pretrained model. When parallel adapters are used, CoDA introduces a small number of learnable parameters in the parallel branches, while the vast majority of model parameters (associated with the pretrained Transformer layers) remain fixed. In addition, CoDA only sends tokens for heavy processing. We define as the reduction factor, a constant (such as 4) to control the computation saving.
Given layer input , we first apply layer normalization, namely . The normalized input will be processed by the adapter branch and the conditional Transformer branch. Their outputs are then added and combined as the final output of the layer.
Let denote the transformation function of the adapter branch. The output is defined as
Conditional branch
The computation of the conditional branch takes three steps. First, each CoDA layer defines a router function to select tokens for the conditional branch. The router function in each layer returns two outputs
After the routing decision is made, the input representations of the selected tokens can be collected using a matrix multiplication,
where rows in are selected to construct the -by- matrix . Similar to a standard Transformer layer, the conditional branch applies attention and feed forward transformations to the selected input:
We consider two attention variants which differ in how they compute key-value vectors. One variant applies a -to- attention using as both the query vectors and key-value vectors. The other variant applies a -to-all attention using the entire input vectors as the attention keys and values. The -to-all variant runs slower but obtains higher quality close to the full model. We compare the performance of the two variants in Section 5.
The attention and feed-forward output and are combined and projected back to the same shape of the original input
Finally merges with the adapter output and the original input of the current layer to produce the output of the layer:
is an element-wise multiplication that scales the rows of using weight . This operation can be seen as a gating operation, where the hidden state of the -th token is weighted by the token selection score assigned by the router. This enables gradient propagation from to the router parameters, such that the token selection can be jointly optimized with other model components during training.
Learned router
Here is an indicator function which returns a binary mask indicating the top- values in . The one-hot matrix defined in (2) can be created according to . In short, the highest values of will be selected by the router.
Function must remain differentiable with respect to its input ( in this case) such that we can update the router parameters during training. One possible choice for is the sigmoid activation function which normalizes the values in independently. However, this does not explicitly model the constraint that we need to select tokens from available tokens. Consider a simple case where , a natural choice for would be the softmax function. Since softmax provides global normalization over the input scores, a gradient update to increase one of the scores would also decrease the other scores, a desirable effect for learning top-1 selection.
We hypothesize that a soft top- operator that generalizes softmax should be used for general . This is indeed possible by formalizing soft top- as the following optimization problem:
Here is a generalized entropy function (applied to any positive vector instead of a distribution), and is a small coefficient.
This optimization problem is closely related to the softmax and top- operation. Specifically, when , it becomes a linear program which returns as the solution. In addition, when , it can be shown that its solution is . Broadly speaking, (10) will return a soft top- mask and the smoothness is controlled by (and hence must be positive to act as a temperature).
In practice, we use iterations and the function returns using and from the last iteration. The function remain differentiable with respect to using these iterative updates, so we can train the router jointly with other model parameters. We provide additional discussion and the derivation of the updates in Appendix §C.
2 Training
CoDA can be directly initialized from an existing Transformer model. Given a pretrained model such as T5 (Raffel et al., 2020), the Transformer layers are directly re-used and copied in the conditional branches of CoDA, and only the adapter and router parameters are randomly initialized. Because pretraining a large dense model can be expensive, our method reduces the overall training cost.
The routers and neural network components in CoDA must co-operate and be optimized for accurate model predictions. When the available finetuning data is limited, a random initialization for the router (and adapter) parameters can be sub-optimal. We demonstrate that CoDA can be further pretrained using the same pretraining objective as the dense model, in order to enhance downstream performance. Importantly, CoDA requires significantly fewer training steps during pretraining, since most of its parameters are taken from an already pretrained model. We show that the cost of CoDA pretraining can be 10-30x lower than the pretraining of its original dense model. We present this analysis in Section 5.
Finally, we train CoDA on downstream tasks by only updating the adapter, router and layer normalization parameters. The size of the adapters is small (e.g. 5M parameters), and each router and layer normalization block only introduces parameters, where is the model dimension. As a result, CoDA remains parameter-efficient similar to previous adapter and prompt-tuning methods.
Experimental setup
CoDA is evaluated on three domains including natural language processing (NLP), computer vision and speech processing, and on a range of applications such as classification, question answering, summarization and speech recognition. The experiments are organized as follows: We first demonstrate the effectivenss of CoDA conduct analyses on its design choices using the publicly available T5 models (§5). In our final results (§6), we pretrain Transformer models from scratch and extend our evaluation to vision and speech domains.
We use the C4 corpus (Raffel et al., 2020) for pretraining text models. For speech models, we use the LibriLight corpus (Kahn et al., 2020) for pretraining. Our vision Transformer models use the same data and training procedure in Pix2Struct (Lee et al., 2022). Our finetuning datasets for text models include the MNLI (Williams et al., 2018), RTE (Dagan et al., 2005; Haim et al., 2006; Giampiccolo et al., 2007; Bentivogli et al., 2009), BoolQ (Clark et al., 2019), SQuAD (Rajpurkar et al., 2016) and XSum (Narayan et al., 2018) datasets. The speech models are evaluated on the speech recognition task using the LibriSpeech dataset (Panayotov et al., 2015). Finally, we use the OCR-VQA (Mishra et al., 2019), DocVQA (Mathew et al., 2021), and Screen2Words (Wang et al., 2021) datasets for vision models.
Understanding and Analyzing CoDA
We present several analyses to validate the design choices of CoDA in this section. We initialize CoDA using the version 1.1 release of T5 checkpointshttps://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511, and perform CoDA pretraining using the same setting as the T5 models. During pretraining, we set routing capacity to given input sequence length . We do not tune the value of for pretraining, but will report the results of using different values in finetuning. We perform 100K gradient steps, which is 10% of the total number of steps used to train the T5 dense models. The overall computational cost is over 20x less than the full training of dense models, since CoDA only applies heavy computation on less than half of the tokens.
For simplicity, we evaluate on classification tasks for various ablation studies of CoDA. Specifically, we report results on the MNLI, RTE and BoolQ datasets, and test three different model sizes including the Base, Large and XL size of T5. We will extend our evaluation to generation tasks such as question answering in the full result section (§6).
Can CoDA be fast and accurate?
Table 1 presents the finetuning results of CoDA. As a comparison, we also report the results of Parallel Adapter, which is similar to CoDA except that it applies the expensive Transformer layers to all input tokens. This constitutes an upper-bound, and is a strong baseline that has been reported as the best among a range of adapter and prompt tuning methods (He et al., 2021). As shown in Table 1, CoDA can achieve 3-5x computation reduction () in the Transformer layers at a cost of less than 1.0 point drop on average accuracy. As expected, our -to-all attention variant achieves consistently better accuracy than the -to- variant, since it can access the full attention context. On the other hand, the -to- attention variant runs faster in practice, which can be beneficial for tasks with very long inputs. We select the -to-all version in the final result section (§6).
How many pretraining steps are needed?
Figure 4 plots the finetuning accuracy by varying the number of pretraining steps for CoDA. Because CoDA can be initialized using pretrained dense models, it requires as few as 20K steps to obtain competitive finetuning results. Of course, using more pretraining steps can improve the downstream accuracy. The fact that CoDA can be quickly updated without repeating the expensive pretraining will be very beneficial in real-world applications.
Does learned routing matter?
We analyze the impact of learned routing in Table 2 by comparing our soft top- router with other router implementations. We implement a variant that replaces soft top- with the sigmoid activation function, so the selection weight of each token activates on its own (without considering the capacity constraint). As shown in the table, this variant achieves worse accuracy on almost all tasks and model sizes tested, getting 2.0 point worse on average. We also implement a “no-learning” baseline that simply selects the first tokens, which is equivalent to truncating the input sequence.We always include the question text for BoolQ, to achieve higher accuracy. This baseline performs much worse, resulting in more than 10 point decrease in accuracy for small (and equivalently large ). This analysis confirms the importance of learning a good routing in order to retain strong model performance.
Full Results
In this section, we apply our best training recipe to all tasks and application domains. We first pretrain dense Transformer models, followed by the CoDA training procedure in §3.2. Our speech models are pretrained using a masked language modeling (MLM) objective similar to BERT (Devlin et al., 2019), and random quantized output label space (Chiu et al., 2022). Our vision and text models use an encoder-decoder architecture similar to T5 but incorporate a few changes. Following PaLM (Chowdhery et al., 2022), we use multi-query attention (Shazeer, 2019) that shares the same key and value projection for multiple query heads. We only use 6 decoder layers and increase the feed forward hidden size (to compensate for the decrease in the number of layers). These modifications have a neutral effect on model quality, but speed up auto-regressive decoding significantly. We will show CoDA is compatible with these changes and can further speed up inference by a considerably large factor. We provide more details of our experimental setup in Appendix A.
NLP results
In addition to the classification datasets used in Section 5, we also evaluate our final models on the SQuAD, ReCord and XSum datasets which require generating an answer or a summary given the input. Table 6 contains the finetuning results of XL models. Compared to the parallel adapter baseline that uses full computation, CoDA achieves 3x and 5x computation reduction with only 1.0 and 1.7 point loss in average score.
Figure 6 and 7 highlight the scaling trend of CoDA. CoDA runs much faster with slightly worse quality than the parallel adapter baseline. This is expected because the baseline processes all tokens in every layer, whereas CoDA only selects of tokens for heavy processing. Importantly, this quality gap reduces as the model size increases (as shown in Figure 6), making CoDA a computationally efficient choice for large models. Indeed, CoDA can trade off quality for speed by varying the number of selected tokens. Figure 7 (left) demonstrates that CoDA achieves much stronger speed-quality trade-off compared to dense models without conditional computation. The black line indicates the results of Parallel Adapter when the model size grows from Small to XL, and each blue line represents the speed-quality trade-off of CoDA using . Moreover, Figure 7 (middle) shows that larger CoDA models exhibit higher inference speed-ups. These observations are consistent on other tasks. We provide additional results in Appendix §B.
Speech recognition results
We further validate the performance of CoDA in the speech domain. Our model uses a Transformer encoder and a 2-layer LSTM Transducer (Graves, 2012). Similar to NLP setups, we test the performance of the speech model on 3 scales – Base, Large and XL (see Appendix A for details). Table 9 demonstrates that with sizable reduction ratios (), the change on word error rate (WER) is consistently minimal on the test-clean and test-other sets of LibriSpeech across different model sizes (and on other sets in §B.2). Moreover, our results are comparable to the top-performing models, such as w2v-BERT (Chung et al., 2021) and BEST-RQ (Chiu et al., 2022), that are fully finetuned by updating all parameters. Figure 7 (right) highlight again that applying conditional computation leads to better speed-quality trade-off compared to dense models.
Vision results
We extend our experiments to visual tasks that involves natural language within the image, such as documents and user interfaces. Our experiments are based on Pix2Struct (Lee et al., 2022), where an image-encoder-text-decoder is pretrained by learning to predict simplified HTML from webpage screenshots. Table 9 shows the results on three tasks that were also evaluated in the original Pix2Struct paper. In OCRVQA and Screen2Words, we observe relatively small drops in performance when reducing the number of routed tokens (i.e. patches). When the capacity is 1/16th of the original sequence length, leading to around 13 speedup, we only lose about 1 point. We speculate that this is due to the high-level sparsity in the inputs for these two tasks. For DocVQA, where there is comparatively more textual information, we observe a steeper performance-speed trade-off but still achieve a 8 speedup with a 4-point drop.
To provide a more intuitive understanding why CoDA works, we visualize the router behavior for the OCR-VQA model in Figure 10. We show which patches the routers prefers the most (warmest colors) and least (coolest colors), for several layers. The first, immediately obvious, observation is that router avoids low-frequency patches, i.e. patches likely to be “whitespace”, since they can be adequately handled by the cheap adapter layers. The second, more subtle, observation is that the router progressively converges on a small number of key patches that we hypothesize serve as representations for larger regions. The visualization confirms that CoDA is able to select meaningful and representative patches that are useful for the prediction task.
Conclusion and Limitation
We present CoDA, a parameter-efficient adapter method that enables fast inference. CoDA relies on conditional computation to selectively activate model computation on important input units, providing a novel way to balance model expressivity and efficiency.
In this work, we focus on encoder-heavy applications such as summarization, speech recognition and visual question answering, by applying our method to the encoder. One limitation of CoDA is that the current routing mechanism (i.e. token selection in a given sequence) is not directly applicable to decoder-only models for auto-regressive token generation. Enabling fast token generation using conditional activation in decoder layers is an interesting direction we plan to explore in future work.
Acknowledgements
We would like to thank Rama Pasumarthi, Hongkun Yu, Kelvin Guu, Zhuyun Dai, Timothy Dozat, Raphael Hoffmann, Tao Wang, Tal Schuster, Ziwei Ji, Frederick Liu and Slav Petrov for helpful advice and discussion.
References
Appendix A Experimental details
For our text and vision experiments, we implement our models using JAX [Bradbury et al., 2018]. Specifically, our training and model modules are built on top of the T5X, Flax and Flaxformer framework [Roberts et al., 2022, Heek et al., 2020]. Following the T5 v1.1 implementation and PaLM [Chowdhery et al., 2022], our Transformer models use the GLU variant [Shazeer, 2020] as the feed forward network and multi-query-attention [Shazeer, 2019] as the attention block. These modifications are shown to enhance modeling capacity and speed up decoding respectively.
For the speech experiments, we use TensorFlow [Abadi et al., 2015] and the Lingvo framework [Shen et al., 2019]. The state-of-the-art Transformer variant for speech recognition is the Conformer architecture [Gulati et al., 2020] which additionally uses depth-wise convolution in each layer. Since the convolution operation is applied to consecutive inputs and does not immediately support routing, we use the standard Transformer architecture [Vaswani et al., 2017] instead. Swish activation is used in the feed forward blocks, following Gulati et al. . We provide the model configuration details in Table 3.
Model training
We use the same data and procedure described in T5 [Raffel et al., 2020], BEST-RQ [Chiu et al., 2022] and Pix2struct [Lee et al., 2022] for pre-training the respective text, speech and vision models. We use the same training hyper-parameters, such as batch size, input sequence length, the number of pre-training steps and the choice of optimizer and learning rate scheduling. All models have been pre-trained using 128 or 256 TPUv3/TPUv4 chips.
We run CODA pre-training for text and vision models, using an additional 100K steps and 200K steps respectively. For text models, the input sequence length is and we set the number of selected tokens . For vision models, the input sequence contains image patches and we set . CODA pre-training is not used for our speech models because there are sufficient fine-tuning data. Following standard practice in speech, we use the 1K hour data from the LibriSpeech dataset [Panayotov et al., 2015] and another 30K hour data generated using the noisy student self-training method [Xie et al., 2020, Zhang et al., 2022].
Table 4 lists the hyper-parameters used for fine-tuning, including the sequence length, learning rate, batch size and the number of fine-tuning steps used. For NLP datasets, we set the maximum input length and decoding length to the 98th percentile of lengths in the training set. For vision datasets, we set the input length following the suggested values in Pix2struct. We also find that annealing the number of routed tokens can achieve better finetuning results. Specifically, we decrease linearly from the sequence length down to the target value using the first to of the finetuning steps.
Appendix B Additional results
Table 5 contains the complete fine-tuning results on the 6 language datasets. As discussed in §6, the gap between CoDA and its counterpart without conditional computation is large at Base size. As the model size increases, CoDA retains almost the same level of quality given 3x computation reduction (). The reduction leads to decoding speed-ups, as shown in Figure 11. More importantly, we see that larger model benefits more from CoDA, achieving a speed-up factor close to the reduction factor . These results highlight the potential of CoDA for large-scale models, which we plan to investigate in future work.
B.2 Speech
Table 6 extends Table 9 by including WER results on dev-clean and dev-other splits. From the table, one can observe that XL with CoDA () are consistently better than the Large parallel adapter model on each split, and the Large model with CoDA () are also consistently better than the Base PA on each split. Given the inference speeds for CoDA models shown in Table 12, larger CoDA models are generally faster and better than smaller dense ones (even with PA) with regard to either time cost or computation GFLOPs. Therefore, it is likely for CoDA to help scale up ASR models with decent computation resources and time cost.
B.3 Combining CoDA and LoRA
CoDA can be easily combined with other types of adapter methods. To see this, we additionally implemented a variant that combines with Low-Rank Adapter [LoRA; Hu et al., 2021], which is another parameter-efficient transfer learning method that recently became the most popular choice for LLMs. We incorporate the latest development suggested in the QLoRA paper [Dettmers et al., 2023], which adds low-rank adapters to every linear projection matrix in the Transformer layers. This is found to obtain better fine-tuning performance than the original implementation. Our CoDA variant with LoRA simply removes the parallel adapter branches and instead adds low-rank adapters to the projection matrices of the pretrained layers.
Table 7 shows the finetuning results. The new LoRA baseline achieves stronger accuracy than the Parallel Adapter baseline (84.0 v.s. 82.9 on average), highlighting the effectiveness of recent development on LoRA. In addition, our CoDA variant using LoRA still achieves very close accuracy compared to its dense counterpart (84.0 v.s. 84.0 or 83.7 on average). We believe the additional results strengthen our claims – that CoDA enables a strong trade-off between accuracy and efficiency using conditional activation, and this technique can be combined with other developments in PETL.
Appendix C Soft top-k𝑘k algorithm
We present the derivation of iterative updates (11) for solving the soft top- problem (10) in Section 3. The soft top- operation is defined as a maximization problem (10). For the derivation, we rewrite it as an equivalent minimization problem:
Note the term will be a constant , but we include it in the minimization object to make our derivation simpler later.
The objective function (12) is strongly convex and the solution space of is a convex set. As a result, strong duality holds so we can instead solve the dual problem. The dual problem exchanges the and operators in (13):
The optimal solution must have the Karush-Kuhn-Tucker (KKT) conditions hold [Kuhn and Tucker, 2014], namely
Substituting using the above equation in (14), the dual problem now has a simple form:
We can solve this problem using coordinate descent [Wright, 2015] by successively maximizing the function with either or fixed. That is, we find the optimal that maximizes the dual objective given a fixed , and vice versa. This leads to the iterative updates (11) described in Section 3.
In short, we obtain the iterative updates of the soft top- problem (10) by solving its dual problem and by performing coordinate decent of the dual variables and . The iterative updates are in fact the coordinate decent steps.
C.2 The ϵitalic-ϵ\epsilon-scaling trick
The iterations of and will converge but the number of iterations needed can be very large for small . In practice, we only perform a small number of iterations and return the corresponding , which may be close but not the exact solution to (12). In order to improve the convergence given a small number of iterations, we apply an empirical trick called the -scaling heuristic [Schmitzer, 2019]. Let denote the value of at the -th iteration. We initialize to a larger value and gradually reduce to the target . Specifically, we set at the -th iteration, using a scaling constant . We use throughout our experiments, and for text and vision models and and for speech models. Using a larger number of iterations leads to better convergence but we found sufficient for our experiments.
C.3 Overhead of soft top-k𝑘k iterations
The soft top- iterations are performed for every routed Transformer layer. Although this seems computationally expensive, the actual overhead is very small compared to the overall decoding latency. The complexity only scales linearly with the number of layers and the sequence length, and does not depend on the model dimension . Table 8 showcases the latency numbers on the BoolQ and XSum datasets, when performing batched decoding using a single TPUv4 chip. We observe that the cost of iterations is less than 2% of the total decoding latency. Moreover, the relative cost decreases dramatically as the model size increases, since it does not depend on the model dimension.
C.4 Additional discussion
This iterative algorithm is closely related to the Sinkhorn algorithm of Optimal Transport (OT). Specifically, the Sinkhorn algorithm solves the entropy-regularized version of Optimal Transport [Cuturi, 2013]. However, our work concerns an different optimization instance. While OT solves a transportation problem where the solution space is defined with the marginal constraints over the rows and columns of a transportation matrix, our optimization problem is constrained with a total budget () and upper bounds (). This leads to different iterative updates.
Concurrent to our work, Tai et al. have used a similar linear program (LP) formulation for soft top- operation, and have applied the operator for learning sparse neural networks (i.e. model pruning). Compared to our formulation (12), they first reduce the LP to an equivalent instance of optimal transport problem, before introducing the entropy term. As a result, the derived updates are different. In addition, Tai et al. have introduced an initialization for the dual variables to improve the convergence of their algorithm, whereas we use scaling instead. Their implementation can be explored for CoDA as well.
Besides formulating soft top- using entropy-regularized optimization, there are other possible variants for trainable sparsity. One example is sparsemax [Martins and Astudillo, 2016] that can learn sparse multi-label probabilities. We believe that the sparsemax formulation can generalize from the top-1 to top- case, but it is beyond the scope of this work. We use the current soft top- implementation because it is a natural extension of softmax (see discussions in §3), and because it can be solved using simple iterative updates.
Appendix D Author Contributions
All authors have contributed to running experiments and discussing research ideas. Tao leads the project, developed the conditional architecture, designed the experiments and analyses. Kenton, Yu and Ming-Wei proposed the idea of applying conditional computation for large model adaptation. Joshua demonstrated the conditional architecture is applicable to attention, and implemented the initial version of conditional attention block. Tao, Yanqi, Nan, Vincent, Yuexin, Ming-Wei and Yu conducted the NLP experiments including model pre-training, fine-tuning and various ablation analyses. Siddhartha conducted the majority of the vision experiments. Kenton conducted the vision analysis and advised on the vision experiments. Junwen conducted the majority of the speech experiments. Bo and Yu assisted in trouble-shooting the speech models, ran the model pre-training and provided guidance on the speech experiments. Finally, Tao, Ming-Wei, Junwen and Kenton made the primary contributions to the writing of the paper.