Replacing softmax with ReLU in Vision Transformers

Mitchell Wortsman, Jaehoon Lee, Justin Gilmer, Simon Kornblith

Introduction

The transformer architecture is ubiquitous in modern machine learning. Attention, a central component of the transformer , includes a softmax which produces a probability distribution over tokens. Softmax is costly due to an exponent calculation and a sum over sequence length which makes parallelization challenging .

In this report we explore point-wise alternatives to the softmax operation which do not necessarily output a probability distribution. As a highlight, we observe that attention with ReLU divided by sequence length can approach or match traditional softmax attention in terms of scaling behavior as a function of compute for vision transformers. This result presents new opportunities for parallelization, as ReLU-attention can be parallelized over the sequence length dimension with fewer gather operations than traditional attention.

Related work

Previous research has explored substituting softmax with ReLU or squared ReLU . However, these approaches do not divide by sequence length, which we experimentally find is important to reach accuracy comparable to softmax. In addition, previous research has replaced softmax while still requiring normalization over the sequence length axis to ensure the attention weights sum to one. This retains the downside of requiring a gather. After writing an initial version of this note, it was brought to our attention that the variant of ReLU-atttention we study was also explored with a theoretical motivation .

Method

Attention. Attention transforms dd-dimensional queries, keys, and values {qi,ki,vi}i=1L\{q_{i},k_{i},v_{i}\}_{i=1}^{L} with a two step procedure. First, attention weights αij\alpha_{ij} are produced via

where ϕ\phi is typically softmax\mathsf{softmax}. Next, the attention weights are used to compute outputs oi=j=1Lαijvjo_{i}=\sum_{j=1}^{L}\alpha_{ij}v_{j}. This report explores point-wise alternatives to ϕ\phi.

ReLU-attention. We observe that ϕ=L1relu\phi=L^{-1}\mathsf{relu} is a promising alternative to ϕ=softmax\phi=\mathsf{softmax} in Equation 1. We refer to attention with ϕ=L1relu\phi=L^{-1}\mathsf{relu} as ReLU-attention.

Scaled point-wise attention. More generally, our experiments will explore ϕ=Lαh\phi=L^{-\alpha}h for α\alpha\in and h{relu,relu2,gelu,softplus,identity,relu6,sigmoid}h\in\{\mathsf{relu},\mathsf{relu}^{2},\mathsf{gelu},\mathsf{softplus},\mathsf{identity},\mathsf{relu6},\mathsf{sigmoid}\} .

Sequence length scaling. We observe that scaling by a term involving sequence length LL is beneficial for high accuracy. This scaling is absent from prior work which removes softmax . While the central justification for sequence length scaling is empirical, we provide brief analytical motivation.

Transformers are currently designed with softmax attention for which j=1Lαij=1\sum_{j=1}^{L}\alpha_{ij}=1. This implies that \mathdsEj[αij]=L1\mathds{E}_{j}[\alpha_{ij}]=L^{-1}. While it is unlikely that this is a necessary condition, ϕ=L1relu\phi=L^{-1}\mathsf{relu} does ensure that \mathdsEj[αij]\mathds{E}_{j}[\alpha_{ij}] is O(L1)O(L^{-1}) at initialization. Preserving this condition may alleviate the need to change other hyperparameters when replacing softmax.

At initialization the elements of qq and kk are O(1)O(1) and so qi,kjd\frac{\langle q_{i},k_{j}\rangle}{\sqrt{d}} will also be O(1)O(1). Activation functions such as ReLU preserve O(1),O(1),With the exception of squared ReLU. and so a factor L1L^{-1} is necessary for \mathdsEj[αij]\mathds{E}_{j}[\alpha_{ij}] to be O(L1)O(L^{-1}).

Experiments

Experimental setup. Our experiments use ImageNet-21k and ImageNet-1k training configurations from the BigVision codebase without modifying hyperparameters.For ImageNet1k we use the base config https://github.com/google-research/big_vision/blob/main/big_vision/configs/vit_i1k.py. For ImageNet21k we use the base config https://github.com/google-research/big_vision/blob/main/big_vision/configs/vit_i21k.py. In our experiments on ImageNet-21k we train for 30 epochs, and in our experiments on ImageNet-1k we train for 300 epochs. As a result, both training runs use a roughly similar number of steps of around 9e5. We use ViTs with qk-layernorm as this was previously observed to be necessary to prevent instability when scaling model size. However, we ablate that this is not an important component at the scales we test. We use i21k and i1k to mean ImageNet-21k and ImageNet-1k, respectively, and report ImageNet-1k accuracy for ImageNet-21k models by taking the top class among those that are in ImageNet-1k, without fine-tuning. When evaluating transfer performance on downstream tasks we use a 10-shot linear probe averaged over three seeds. The downstream tasks are Caltech Birds , Caltech-101 , Stanford Cars , CIFAR-100 , DTD , ColHsit , Pets , and UC Merced .

Main experiment. Figure 1 illustrates that ReLU-attention matches the scaling trends for softmax attention for ImageNet-21k training. On the xx-axis we display the total core hours required for the experiment. As an advantage, ReLU-attention enables parallelization over the sequence length dimension with fewer gather operations than softmax attention.

Effect of sequence length scaling. Figure 2 examines the effect of sequence length scaling for various point-wise alternatives to softmax. Concretely, we replace softmax with LαhL^{-\alpha}h for α\alpha\in and h{relu,relu2,gelu,softplus,identity}h\in\{\mathsf{relu},\mathsf{relu}^{2},\mathsf{gelu},\mathsf{softplus},\mathsf{identity}\}. On the xx-axis we display α\alpha. The yy-axis displays accuracy for the S/32, S/16, and S/8 vision transformer models . The best results are typically achieved when α\alpha is close to 1. Since there is not clear best non-linearity, we use ReLU in our main experiment as it is faster.

Effect of qk-layernorm. Our main experiments use qk-layernorm in which queries and keys are passed through LayerNorm before computing attention weights. We use qk-layernorm by default as it was found to be necessary to prevent instability when scaling up model size . Figure 3 shows the effect of removing qk-layernorm. The results indicate that qk-layernorm does not have a large effect for these models, but this may change at scale.

Effect of adding a gate. Previous work removing softmax adds a gated unit and does not scale by sequence length . Concretely, in the gated attention unit an extra projection produces output which is combined through elementwise-multiplication before the out projection. In Figure 4 we investigate whether the presence of a gate removes the need for sequence length scaling. Overall we observe that the best accuracy is still achieved with sequence length scaling, with or without the gate. Note that gating increases the core hours required for the experiment by roughly 9.3% for the S/8 model with ReLU.

Conclusion

This report leaves many open questions. In particular, we are unsure why the factor L1L^{-1} improves performance or if this term could be learned. Moreover, it is likely that there is a better activation function that we do not explore.

We thank Lucas Beyer, Mostafa Dehghani, and David Fleet for their helpful comments and suggestions.

We thank the members of the Google DeepMind PAGI team for their support of this effort, Jascha Sohl-dickstein, Noah Fiedel, Aaron Parisi, Abhishek Kumar, Alex Alemi, Alex Rizkowsky, Avi Singh, Azade Nova, Ben Adlam, Bernd Bohnet, Daniel Freeman, Gamaleldin Elsayed, Gaurav Mishra, Hanie Sedghi, Isabelle Simpson, Izzeddin Gur, JD Co-Reyes, James Harrison, Jeffrey Pennington, Jiri Hron, Kathleen Kenealy, Kelvin Xu, Kevin Swersky, Kshiteej Mahajan, Laura Culp, Lechao Xiao, Max Bileschi, Merrie Morris, Roman Novak, Rosanne Liu, Sharad Vikram, Tris Warkentin, Yundi Qian.

References