Symbolic Discovery of Optimization Algorithms

Xiangning Chen, Chen Liang, Da Huang, Esteban Real, Kaiyuan Wang, Yao Liu, Hieu Pham, Xuanyi Dong, Thang Luong, Cho-Jui Hsieh, Yifeng Lu, Quoc V. Le

Introduction

Optimization algorithms, i.e., optimizers, play a fundamental role in training neural networks. There are a large number of handcrafted optimizers, mostly adaptive ones, introduced in recent years (Zhuang et al., 2020; Balles and Hennig, 2018; Liu et al., 2020; Bernstein et al., 2018; Dozat, 2016; Anil et al., 2020). However, Adam (Kingma and Ba, 2014) with decoupled weight decay (Loshchilov and Hutter, 2019), also referred to as AdamW, and Adafactor with factorized second moments (Shazeer and Stern, 2018), are still the de facto standard optimizers for training most deep neural networks, especially the recent state-of-the-art language (Brown et al., 2020; Vaswani et al., 2017; Devlin et al., 2019), vision (Dosovitskiy et al., 2021; Dai et al., 2021; Zhai et al., 2021) and multimodal (Radford et al., 2021; Saharia et al., 2022; Yu et al., 2022) models.

Another direction is to automatically discover such optimization algorithms. The learning to optimize (L2O) approach proposes to discover optimizers by training parameterized models, e.g., neural networks, to output the updates (Andrychowicz et al., 2016; Metz et al., 2019; Li and Malik, 2017; Metz et al., 2022). However, those black-box optimizers, typically trained on a limited number of small tasks, struggle to generalize to state-of-the-art settings where much larger models are trained with significantly more training steps. Another line of methods (Bello et al., 2017; Wang et al., 2022) apply reinforcement learning or Monte Carlo Sampling to discover new optimizers, where the search space is defined by trees composed from predefined operands (e.g., gradient and momentum) and operators (e.g., unary and binary math operations). However, to make the search manageable, they often limit the search space by using fixed operands and restricting the size of the tree, thereby limiting the potential for discovery. For example, they are unable to modify the tracking of momentum or how it contributes to the update, which is an essential component of Lion. Consequently, the algorithms discovered have not yet reached the state-of-the-art. AutoML-Zero (Real et al., 2020) is an ambitious effort that attempts to search every component of a machine learning pipeline while evaluating on toy tasks. This work follows the research direction of automatic discovering optimizers and is in particular inspired by AutoML-Zero, but aims at discovering effective optimization algorithms that can improve the state-of-the-art benchmarks.

In this paper, we present a method to formulate algorithm discovery as program search and apply it to discover optimization algorithms. There are two primary challenges. The first one is to find high-quality algorithms in the infinite and sparse program space. The second one is to further select out the algorithms that can generalize from small proxy tasks to much larger, state-of-the-art tasks. To tackle these challenges, we employ a range of techniques including evolutionary search with warm-start and restart, abstract execution, funnel selection, and program simplification.

Our method discovers a simple and effective optimization algorithm: Lion, short for EvoLved Sign Momentum. This algorithm differs from various adaptive algorithms by only tracking momentum and leveraging the sign operation to calculate updates, leading to lower memory overhead and uniform update magnitudes across all dimensions. Despite its simplicity, Lion demonstrates outstanding performance across a range of models (Transformer, MLP, ResNet, U-Net, and Hybrid) and tasks (image classification, vision-language contrastive learning, diffusion, language modeling, and fine-tuning). Notably, we achieve 88.3% zero-shot and 91.1% fine-tuning accuracy on ImageNet by replacing Adafactor with Lion in BASIC (Pham et al., 2021), surpassing the previous best results by 2% and 0.1%, respectively. Additionally, Lion reduces the pre-training compute on JFT by up to 5x, improves training efficiency on diffusion models by 2.3x and achieves a better FID score, and offers similar or better performance on language modeling with up to 2x compute savings.

We analyze the properties and limitations of Lion. Users should be aware that the uniform update calculated using the sign function usually yields a larger norm compared to those generated by SGD and adaptive methods. Therefore, Lion requires a smaller learning rate lrlr, and a larger decoupled weight decay λ\lambda to maintain the effective weight decay strength. For detailed guidance, please refer to Section 5. Additionally, our experiments show that the gain of Lion increases with the batch size and it is more robust to different hyperparameter choices compared to AdamW. For limitations, the difference between Lion and AdamW is not statistical significant on some large-scale language and image-text datasets. The advantage of Lion is smaller if using strong augmentations or a small batch size (<<64) during training. See Section 6 for details.

Symbolic Discovery of Algorithms

We present an approach that formulates algorithm discovery as program search (Koza, 1994; Brameier et al., 2007; Real et al., 2020). We use a symbolic representation in the form of programs for the following advantages: (1) it aligns with the fact that algorithms must be implemented as programs for execution; (2) symbolic representations like programs are easier to analyze, comprehend and transfer to new tasks compared to parameterized models such as neural networks; (3) program length can be used to estimate the complexity of different programs, making it easier to select the simpler, often more generalizable ones. This work focuses on optimizers for deep neural network training, but the method is generally applicable to other tasks.

We adhere to the following three criteria while designing the program search space: (1) the search space should be flexible enough to enable the discovery of novel algorithms; (2) the programs should be easy to analyze and incorporate into a machine learning workflow; (3) the programs should focus on the high-level algorithmic design rather than low-level implementation details. We define the programs to contain functions operating over n-dimensional arrays, including structures like lists and dictionaries containing such arrays, in an imperative language. They are similar to Python code using NumPy / JAX (Harris et al., 2020; Bradbury et al., 2018) as well as pseudo code of optimization algorithms. The details of the design are outlined below, with an example representation of AdamW in Program LABEL:lst:p1.

Input / output signature The program defines a train function, which encodes the optimization algorithm being searched for, where the main inputs are the model weight (w), the gradient (g) and the learning rate schedule value (lr) at the current training step. The main output is the update to the weight. The program also incorporates extra variables initialized as zeros to collect historical information during training. For example, AdamW requires two extra variables to estimate first and second moments. Note that those variables can be used arbitrarily, we use the name m and v in Program LABEL:lst:p1 just for better readability. This simplified code snippet in Program LABEL:loop uses the same signature as AdamW to ensure that the discovered algorithms have smaller or equal memory footprints. As opposed to previous optimizer search attempts (Bello et al., 2017; Wang et al., 2022), our method allows discovering better ways of updating the extra variables.

Building blocks The train function consists of a sequence of assignment statements, with no restrictions on the number of statements or local variables. Each statement calls a function using constants or existing variables as inputs, and the resulting value is stored in a new or existing variable. For the program, we select 45 common math functions, most of which corresponds to a function in NumPy or an operation in linear algebra. Some functions are introduced to make the program more compact, such as the linear interpolation function interp(x, y, a), which is made equivalent to (1 - a) * x + a * y. Preliminary experiments have investigated the inclusion of more advanced features such as conditional and loop statements, and defining and calling new functions, but these do not yield improved results, so we leave them out. A detailed description of the functions are summarized in Appendix H. When necessary, the types and shapes of the function arguments are automatically cast, e.g., in the case of adding a dictionary of arrays to a scalar.

Mutations and redundant statements The design of mutations utilized in evolutionary search is tightly intertwined with the representation of the program. We include three types of mutations: (1) inserting a new statement at a random location with randomly chosen functions and arguments, (2) deleting a random chosen statement, and (3) modifying a random statement by randomly altering one of its function arguments, which may be either variables or constants. To mutate an argument, we replace it with an existing variable or a newly generated constant obtained by sampling from a normal distribution XN(01)X\sim\mathcal{N}(0\,\,1). Additionally, we can mutate an existing constant by multiplying it by a random factor 2a2^{a}, where aN(01)a\sim\mathcal{N}(0\,\,1). These constants serve as tunable hyperparameters in the optimization algorithm, such as the peak learning rate and weight decay in AdamW. Note that we allow a program to include redundant statements during search, i.e., statements that do not impact the final program outputs. This is necessary as mutations are limited to only affecting a single statement. Redundant statements therefore serve as intermediate steps towards future substantial modifications in the program.

Infinite and sparse search space Given the limitless number of statements and local variables, as well as the presence of mutable constants, the program search space is infinite. Even if we ignore the constants and bound the program length and number of variables, the number of potential programs is still intractably large. A rough estimate of the number of possible programs is np=nflnvnaln_{p}=n_{f}^{l}n_{v}^{n_{a}*l}, where nfn_{f} is the number of possible functions, nvn_{v} is the number of local variables, nan_{a} is the average number of arguments per statement, and ll is the program length. More importantly, the challenge comes from the sparsity of high-performing programs in the search space. To illustrate this point, we conduct a random search that evaluates over 2M programs on a low-cost proxy task. The best program among them is still significantly inferior to AdamW.

2 Efficient Search Techniques

We employ the following techniques to address the challenges posed by the infinite and sparse space.

Evolution with warm-start and restart We apply regularized evolution as it is simple, scalable, and has shown success on many AutoML search tasks (Real et al., 2020, 2019; Ying et al., 2019; So et al., 2019; Holland, 1992). It keeps a population of PP algorithms that are gradually improved through cycles. Each cycle picks T ⁣ ⁣< ⁣PT\!\!<\!P algorithms at random and the best performer is chosen as the parent, i.e., tournament selection (Goldberg and Deb, 1991). This parent is then copied and mutated to produce a child algorithm, which is added to the population, while the oldest algorithm is removed. Normally, evolutionary search starts with random candidates, but we warm-start the initial population as AdamW to accelerate the search. By default, we use a tournament size of two and a population size of 1K. To further improve the search efficiency, we apply two types of restart: (1) restarting from the initial program, which can lead to different local optima due to the randomness in evolution and encourage exploration. This can be done by running multiple searches in parallel. (2) restarting from the best algorithm found thus far to further optimize it, encouraging exploitation. Figure 2 (Left) displays the mean and standard error of five evolutionary search experiments. We run hyperparameter tuning based on AdamW by only allowing mutations of constants in the evolution, and run random search by sampling random programs, both with 4x more compute. Our search significantly outperforms the best results achieved by both baselines, demonstrated as the two dashed lines in the figure. The high variance in the search fitness necessitates running multiple repeats through restarting from the initial program. When the search fitness plateaus after \sim300K progress, restarting from the best program found thus far further improves the fitness shown by the orange curve.

Pruning through abstract execution We propose to prune the redundancies in the program space from three sources: programs with syntax or type / shape errors, functionally equivalent programs, and redundant statements in the programs. Before a program is actually executed, we perform an abstract execution step that (1) infers variable types and shapes to detect programs with errors, and keeps mutating the parent program until a valid child program is generated; (2) produces a hash that uniquely identifies how the outputs are computed from the inputs, allowing us to cache and look up semantically duplicate programs (Gillard et al., 2023); (3) identifies redundant statements that can be ignored during actual execution and analysis. For instance, Program LABEL:lst:p2 is obtained after removing all redundant statements in Program LABEL:lst:raw. Abstract execution has negligible cost compared to the actual execution, with each input and function replaced by customized values, e.g., hash. See Appendix I for details of abstract execution. Preliminary experiments have shown that the search process can become overwhelmed with invalid programs and cannot make progress without filtering out invalid programs. As seen in Figure 2 (Right), the percentage of redundant statements and cache hit rate both increase as the search proceeds. Based on five search runs, each covering 300K programs, there are 69.8±1.9%69.8\pm 1.9\% redundant statements towards the end, implying that redundant statements removal makes the program \sim3x shorter on average, thus easier to analyze. The cache hit rate is 89.1±0.6%89.1\pm 0.6\%, indicating that using the hash table as cache brings \sim10x reduction on the search cost.

Proxy tasks and search cost To reduce search cost, we create low-cost proxies by decreasing the model size, number of training examples, and steps from the target tasks. Evaluation on the proxies can be completed on one TPU V2 chip within 20min. We use the accuracy or perplexity on the validation set as the fitness. Each search experiment utilizes 100 TPU V2 chips and runs for \sim72h. There are a total of 200-300K programs generated during each search experiment. However, the number of programs that are actually evaluated is around 20-30K, thanks to the use of the cache through abstract execution. To incorporate restart, we start five repeats of search experiments, followed by another round of search initializing from the best algorithm found thus far. This results in a total cost of \sim3K TPU V2 days. See Appendix F for the details of proxy tasks.

3 Generalization: Program Selection and Simplification

The search experiments can discover promising programs on proxy tasks. We use performance on meta-validation tasks that are larger than the proxy tasks by increasing the model size and training steps, to select the programs that generalize beyond proxy tasks then further simplify them. The phenomenon of meta-overfitting occurs when the search fitness keeps growing, but the meta-validation metric declines, indicating that the discovered algorithms have overfit the proxy tasks. Two examples are shown in Figure 3 (Left), where the blue curve represents early meta-overfitting and the orange curve represents later meta-overfitting.

Large generalization gap The discovered algorithms face a significant challenge due to the substantial gap between the proxy tasks during search and the target tasks. While proxy tasks can typically be completed within 20min on one TPU V2 chip, target tasks can be >104>10^{4}x larger and require days of training on 512 TPU V4 chips. Furthermore, we expect the optimizer to perform well on different architectures, datasets and even different domains, so the discovered algorithms need to show strong out-of-distribution generalization. The sparse search space and inherent noise in the evolution process further compound this challenge, leading to inconsistent generalization properties between different runs. Our observation suggests that evolutionary search experiments that meta-overfit later tend to uncover optimization algorithms that generalize better. See more details in Figure 3 (Right).

Funnel selection To mitigate the generalization gap, we collect promising programs based on search fitness and add an extra selection step using a series of meta-validation tasks to select those generalize better. To save compute, we apply a funnel selection process that gradually increases the scale of the meta-validation tasks. For example, starting with proxy task A, we create a 10x larger task B by increasing the model size and the training steps. Only algorithms that surpass the baseline on task B will be evaluated on task C, which is 100x larger. This approach allows us to gradually filter out algorithms that show poor generalization performance, ultimately leading to the selection of algorithms that generalize well to larger tasks.

Simplification Simpler programs are easier to understand and our intuition is that they are more likely to generalize, so we simplify the programs with the following steps. Firstly, we remove redundant statements that do not contribute to the final output as identified through abstract execution. Secondly, we remove statements that are non-redundant but produce minimal differences when removed. This step can also be achieved through evolution by disabling the insertion of new statements in the mutation process. Finally, we rearrange the statements manually, assign clear and descriptive names to variables, and convert the program into its simpler, mathematically equivalent form.

Derivation and Analysis of Lion

We arrive at the optimizer Lion due to its simplicity, memory efficiency, and strong performance in search and meta-validation. Note that the search also discovers other existing or novel algorithms shown in Appendix D, e.g., some with better regularization and some resembling AdaBelief (Zhuang et al., 2020) and AdaGrad (Duchi et al., 2011).

The search and funnel selection process lead to Program LABEL:lst:p2, which is obtained by automatically removing redundant statements from the raw Program LABEL:lst:raw (in the Appendix). We further simplify it to get the final algorithm (Lion) in Program LABEL:lst:p0. Several unnecessary elements are removed from Program LABEL:lst:p2 during the simplification process. The cosh function is removed since m would be reassigned in the next iteration (line 3). The statements using arcsin and clip are also removed as we observe no quality drop without them. The three red statements translate to a single sign function. Although both m and v are utilized in Program LABEL:lst:p2, v only changes how the momentum is updated (two interp functions with constants \sim0.9 and \sim1.1 is equivalent to one with \sim0.99) and does not need to be separately tracked. Note that the bias correction is no longer needed, as it does not change the direction. Algorithm 2 shows the pseudocode.

2 Analysis

Sign update and regularization The Lion algorithm produces update with uniform magnitude across all dimensions by taking the sign operation, which is in principle different from various adaptive optimizers. Intuitively, the sign operation adds noise to the updates, which acts as a form of regularization and helps with generalization (Neelakantan et al., 2017; Foret et al., 2021; Chen et al., 2022). An evidence is shown in Figure 11 (Right) in the Appendix, where the ViT-B/16 trained by Lion on ImageNet has a higher training error compared to AdamW but a 2% higher accuracy on the validation set (as shown in Table 2). Additionally, the results in Appendix G demonstrate that Lion leads to the convergence in smoother regions, which usually results in better generalization.

Momentum tracking The default EMA factor used to track the momentum in Lion is 0.99 (β2\beta_{2}), compared to the commonly used 0.9 in AdamW and momentum SGD. The current gradient and momentum are interpolated with a factor of 0.9 (β1\beta_{1}) before the sign operation is applied. This choice of EMA factor and interpolation allows Lion to balance between remembering a \sim10x longer history of the gradient in momentum and putting more weight on the current gradient in the update. The necessity of both β1\beta_{1} and β2\beta_{2} is further discussed in Section 4.6.

Hyperparameter and batch size choices Lion is simpler and has fewer hyperparameters compared to AdamW and Adafactor as it does not require ϵ\epsilon and factorization-related ones. The update is an element-wise binary ±1\pm 1 if we omit the weight decay term, with larger norm than those produced by other optimizers like SGD and adaptive algorithms. As a result, Lion needs a smaller learning rate and in turn a larger decoupled weight decay to achieve a similar effective weight decay strength (lr * λ\lambda). Detailed information on tuning Lion can be found in Section 5. Additionally, the advantage of Lion over AdamW enlarges as the batch size increases, which fits the common practice of scaling up model training through data parallelism (Section 4.6).

Memory and runtime benefits Lion only saves the momentum thus has smaller memory footprint than popular adaptive optimizers like AdamW, which is beneficial when training large models and / or using a large batch size. As an example, AdamW needs at least 16 TPU V4 chips to train a ViT-B/16 with image resolution 224 and batch size 4,096, while Lion only needs 8 (both with bfloat16 momentum). Another practical benefit is that Lion has faster runtime (steps / sec) in our experiments due to its simplicity, usually 2-15% speedup compared to AdamW and Adafactor depending on the task, codebase, and hardware.

Relation to existing optimizers The sign operation has been explored in previous optimizers (Riedmiller and Braun, 1993; Bernstein et al., 2018). The closest to ours is the handcrafted optimizer signSGD (Bernstein et al., 2018) (and its momentum variant) that also utilizes the sign operation to calculate the update but has a different momentum update rule from Lion. Their focus is to mitigate communication costs between agents in distributed training, and they observe inferior performance when training ConvNets on image classification tasks. On the other hand, NAdam (Dozat, 2016) combines the updated first moment and the gradient to compute the update, but Lion decouples the momentum tracking and how it is applied to the update through β2\beta_{2}. A comparison of Lion with related optimizers can be found in Section 4.5.

Evaluation of Lion

In this section, we present evaluations of Lion, on various benchmarks. We mainly compare it to AdamW (or Adafactor when memory is a bottleneck) as it is exceedingly popular and the de facto standard optimizer on a majority of learning tasks. The result of momentum SGD is only included for ResNet since it performs worse than AdamW elsewhere. We also benchmark other popular optimizers in Section 4.5, including handcrafted and automatically discovered ones. We make sure that every optimizer is well-tuned for each task (see Section 5 for tuning details). By default, the learning rate schedule is cosine decay with 10K steps warmup, and the momentum is saved as bfloat16 to reduce the memory footprint.

Train from scratch on ImageNet Following previous works (Dosovitskiy et al., 2021; He et al., 2016), we train ResNet-50 for 90 epochs with a batch size of 1,024, and other models for 300 epochs with a batch size of 4,096. As shown in Table 2, Lion significantly outperforms AdamW on various architectures. Empirically, the improvement is more substantial on models with larger capacity, with accuracy increases of 1.96% and 0.58% for ViT-B/16 and ViT-S/16, respectively. The performance gaps also tend to enlarger with fewer inductive biases. When strong augmentations are applied, the gain of Lion over AdamW shrinks, but it still outperforms AdamW by 0.42% on CoAtNet-3, despite the strong regularization during training (Dai et al., 2021).

Pre-train on ImageNet-21K We pre-train ViT-B/16 and ViT-L/16 on ImageNet-21K for 90 epochs with a batch size of 4,096. Table 2 shows that Lion still surpasses AdamW even when the training set is enlarged for 10x. The gaps on larger models are consistently bigger, with +0.52% vs. +0.33% (ImageNet), +0.57% vs. +0.23% (ReaL), and +0.74% vs. +0.25% (V2) for ViT-L/16 and ViT-B/16, respectively.

Pre-train on JFT To push the limit, we conduct extensive experiments on JFT. We follow the settings of Dosovitskiy et al. (2021) and Zhai et al. (2021) for both pre-training and fine-tuning. Figure 1 (Left) and 4 present the accuracy of three ViT models (ViT-B/16, ViT-L/16, and ViT-H/14) under different pre-training budgets on JFT-300M. Lion enables the ViT-L/16 to match the performance of ViT-H/14 trained by AdamW on ImageNet and ImageNet V2 but with 3x less pre-training cost. On ImageNet ReaL, the compute saving further becomes 5x. Another evidence is that even when a ViT-L/16 is trained by AdamW for 4M steps by Zhai et al. (2021), its performance still lags behind the same model trained by Lion for 1M steps.

Table 3 shows the fine-tuning results, with higher resolution and Polyak averaging. Our ViT-L/16 matches the previous ViT-H/14 results trained by AdamW, while being 2x smaller. The advantage is larger on more challenging benchmarks, such as +1.33% (V2), +6.08% (A), +5.54% (R) for ViT-L/16. After we scale up the pre-training dataset to JFT-3B, the ViT-g/14 trained by Lion outperforms the previous ViT-G/14 results (Zhai et al., 2021), with 1.8x fewer parameters. Our ViT-G/14 further achieves a 90.71% accuracy on ImageNet.

2 Vision-Language Contrastive Learning

This section focuses on the vision-language contrastive training (Radford et al., 2021). We compare Lion with AdamW (Adafactor) on zero-shot image classification and image-text retrieval benchmarks. Instead of learning all the parameters from scratch, we initialize the image encoder with a strong pre-trained model as it is suggested to be more efficient (Zhai et al., 2022).

Locked-image text Tuning (LiT) We perform a comparison between Lion and AdamW on LiT (Zhai et al., 2022) by training the text encoder (Zhai et al., 2022) in a contrastive manner using the same frozen pre-trained ViT. All models are trained for 1B image-text pairs with a batch size of 16,384. Table 4 shows the zero-shot image classification results on three model scales, with the name specifies the size, e.g., LiT-B/16-B denotes a ViT-B/16 and a base size Transformer as the text encoder. Our method, Lion, demonstrates consistent improvement over AdamW with gains of +1.10%, +1.13%, and +0.66% on zero-shot ImageNet accuracy for LiT-B/32-B, LiT-B/16-B, and LiT-g/14288-L, respectively. Figure 5 (Left) depicts an example zero-shot learning curve of LiT-B/16-B. Similar results are obtained on the other two datasets. The zero-shot image-text retrieval results on MSCOCO (Lin et al., 2014) and Flickr30K (Plummer et al., 2015) can be found in Figure 9 (in the Appendix). The evaluation metric is Recall@K, calculated based on if the ground truth label of the query appears in the top-K retrieved examples. Lion outperforms AdamW on both datasets, with a larger gain in Recall@1 than Recall@10 on Flicker30K, implying more accurate retrieval results: +1.70% vs. +0.60% for image \rightarrow text and +2.14% vs. +0.20% for text \rightarrow image.

BASIC Pham et al. (2021) propose to scale up batch size, dataset, and model size simultaneously, achieving drastic improvements over CLIP. It uses a sophisticated CoAtNet (Dai et al., 2021) pre-trained on JFT-5B as the image encoder. Furthermore, the contrastive training is performed on 6.6B image-text pairs with a larger 65,536 batch size. To push the limit, we only experiment on the largest BASIC-L, and use Lion on both image encoder pre-training and contrastive learning stages. As illustrated in Table 1, we achieve a significant 2.6% gain over the baseline, striking a 88.3% accuracy on zero-shot ImageNet classification. Note that this result is 2.0% higher than the previous best result (Yu et al., 2022). The performance gain is consistent on five other robustness benchmarks. After fine-tuning the image encoder (CoAtNet-7) in BASIC-L obtained by Lion, we further achieve a 91.1% top-1 accuracy on ImageNet, which is 0.1% better than the previous SOTA.

3 Diffusion Model

Recently, diffusion models achieve a huge success on image generation (Ho et al., 2020; Song et al., 2021; Dhariwal and Nichol, 2021; Ho and Salimans, 2022; Saharia et al., 2022). Given its enormous potential, we test the performance of Lion on unconditional image synthesis and multimodal text-to-image generation.

Image synthesis on ImageNet We utilize the improved U-Net architecture introduced in Dhariwal and Nichol (2021) and perform 64×6464\times 64, 128×128128\times 128, and 256×256256\times 256 image generation on ImageNet. The batch size is set as 2,048 and the learning rate remains constant throughout training. For decoding, we apply DDPM (Ho et al., 2020) for 1K sampling steps without classifier-free guidance.The evaluation metric is the standard FID score. Illustrated by Figure 1 (Right) and 5 (Middle and Right), Lion enables both better quality and faster convergence on the FID score. Note that the gap between Lion and AdamW tends to increase with the image resolution, where the generation task becomes more challenging. When generating 256×256256\times 256 images, Lion achieves the final performance of AdamW at 440K steps, reducing 2.3x iterations. The final FID scores are 4.1 (Lion) vs. 4.7 (AdamW), and for reference, the FID of ADM (Dhariwal and Nichol, 2021) is 10.94.

Text-to-image generation We follow the Imagen (Saharia et al., 2022) setup to train a base 64×6464\times 64 text-to-image model and a 64×64256×25664\times 64\rightarrow 256\times 256 super-resolution model. All models are trained on a high-quality internal image-text dataset with a batch size of 2,048 and a constant learning rate. Due to computational constraints, our base U-Net has a width of 192 compared to 512 in the original 2B model, while the 600M super-resolution model is identical to the original Imagen setup. Along with the training, 2K images are sampled from the MSCOCO (Lin et al., 2014) validation set for real-time evaluation. We use the CLIP score to measure image-text alignment and the zero-shot FID-30K to measure image fidelity. Classifier-free guidance (Ho and Salimans, 2022) with a weight of 5.0 is applied as it has been shown to improve image-text alignment. Figure 7 depicts the learning curve. While there is no clear improvement on the base 64×6464\times 64 model, Lion outperforms AdamW on the text-conditional super-resolution model. It achieves a higher CLIP score and has a less noisy FID metric compared to AdamW.

4 Language Modeling and Fine-tuning

This section focuses on language modeling and fine-tuning. On language-only tasks, we find that tuning β1\beta_{1} and β2\beta_{2} can improve the quality for both AdamW and Lion. See Section 5 for tuning details.

Autoregressive language modeling We first experiment on two smaller-scale academic datasets Wiki-40B (Guo et al., 2020) and PG-19 (Rae et al., 2020) following Hua et al. (2022). The employed Transformer spans three scales: small (110M), medium (336M), and large (731M). The architecture details can be found in Appendix E. All models are trained with 2182^{18} tokens per batch for 125K steps, with a learning rate schedule of 10K steps warmup followed by linear decay. The context length is set to 512 for Wiki-40B and 1,024 for PG-19. Figure 7 illustrates the token-level perplexity for Wiki-40B and word-level perplexity for PG-19. Lion consistently achieves lower validation perplexity than AdamW. It achieves 1.6x and 1.5x speedup when training the medium size model on Wiki-40B and PG-19, respectively. When the model is increased to the large size, the speedup on PG-19 further increases to 2x.

Scaling up the scale of language models and pre-training datasets has revolutionized the field of NLP. So we further perform larger-scale experiments. Our pre-training dataset, similar to that used in GLaM (Du et al., 2022), consists of 1.6 trillion tokens spanning a wide range of natural language use cases. Following GPT-3 (Brown et al., 2020), we train three models, ranging from 1.1B to 7.5B parameters, for 300B tokens with a batch size of 3M tokens and a context length of 1K. We evaluate them on three natural language generative (NLG) and 21 natural language understanding (NLU) tasks (see Appendix C for task details). On this massive dataset, we observe no perplexity difference throughout training. Nevertheless, Lion outperforms Adafactor on the average in-context learning ability, as shown in Table 5. Our 7.5B baseline model, trained for 300B tokens, outperforms the 8B PaLM, trained for 780B tokens, demonstrating the strength of our setup. Lion outperforms Adafactor on both NLG and NLU tasks, particularly on the NLG tasks, with an exact match improvement of +1.0, +0.9, and +0.6 for the 1.1B, 2.1B, and 7.5B models, respectively.

Masked language modeling We also perform BERT training on the C4 dataset (Raffel et al., 2020). It requires the language models to reconstruct randomly masked out tokens in the input sequence. We use the same architectures and training setups as the smaller-scale autoregressive experiments. Lion performs slightly better than AdamW regarding the validation perplexity: 4.18 vs. 4.25 (small), 3.42 vs. 3.54 (medium), and 3.18 vs. 3.25 (large). See Figure 11 (Left) in the Appendix for the learning curves.

Fine-tuning We fine-tune Base (220M), Large (770M), and the largest 11B T5 (Raffel et al., 2020) models on the GLUE benchmark (Wang et al., 2019a). Every model is fine-tuned for 500K steps with a batch size of 128 and a constant learning rate. Table 6 shows the results on the GLUE dev set. For MRPC and QQP, we report the F1 / Accuracy scores, for STS-B, we report the Pearson / Spearman correlation, and for the other datasets, we report their default metric. On average, Lion beats AdamW across all three model scales. It achieves 10, 12, and 10 wins out of 12 scores for T5 Base, Large, and 11B models, respectively.

5 Comparison with Other Popular Optimizers

We also employ four popular handcrafted optimizers: RAdam (Liu et al., 2020), NAdam (Dozat, 2016), AdaBelief (Zhuang et al., 2020), AMSGrad (Reddi et al., 2018) and two optimizers discovered by AutoML: PowerSign (Bello et al., 2017) and AddSign (Bello et al., 2017) to train ViT-S/16 and ViT-B/16 on ImageNet (with RandAug and Mixup). We thoroughly tune the peak learning rate lrlr and decoupled weight decay λ\lambda (Loshchilov and Hutter, 2019) of every optimizer, while other hyperparameters are set as the default values in Optax.https://github.com/deepmind/optax As shown in Table 7, Lion is still the best performing one. We notice that there is no clear winner amongst the baselines. AMSGrad performs the best on ViT-S/16 but the worst on ViT-B/16. The inferior performance of PowerSign and AddSign compared to other optimizers is consistent with previous observations that automatically discovered optimizers have difficulty generalizing to real-world learning tasks. Figure 10 (in the Appendix) further shows that the learning curves of the five adaptive optimizers are pretty similar, whereas Lion has a unique one that learns faster.

6 Ablations

Momentum tracking To ablate the effects of both β1\beta_{1} and β2\beta_{2}, we compare to a simple update rule: m = interp(g, m, β\beta); update = sign(m). Two optimizers, Ablation0.9 and Ablation0.99, are created with β\beta values of 0.9 and 0.99 respectively. Illustrated by Table 7, the two ablated optimization algorithms perform worse than all five compared baselines, let alone our Lion. Further ablation studies on the language modeling task (as depicted in Figure 12 in the Appendix) yield similar conclusions. Those results validate the effectiveness and necessity of using two linear interpolation functions, letting Lion to remember longer gradient history meanwhile assign a higher weight to the current gradient.

Effect of batch size Some may question whether Lion requires a large batch size to accurately determine the direction due to the added noise from the sign operation. To address this concern, we train a ViT-B/16 model on ImageNet using various batch sizes while maintaining the total training epoch as 300, and incorporating RandAug and Mixup techniques. As shown in Figure 8 (Left), the optimal batch size for AdamW is 256, while for Lion is 4,096. This indicates that Lion indeed prefers a larger batch size, but its performance remains robust even with a small 64 batch size. Furthermore, when the batch size enlarges to 32K, leading to only 11K training steps, Lion achieves a significant 2.5% accuracy gain over AdamW (77.9% vs. 75.4%), demonstrating its effectiveness in the large batch training setting.

Hyperparameter Tuning

To ensure a fair comparison, we tune the peak learning rate lrlr and decoupled weight decay λ\lambda for both AdamW (Adafactor) and our Lion using a logarithmic scale. The default values for β1\beta_{1} and β2\beta_{2} in AdamW are set as 0.9 and 0.999, respectively, with an ϵ\epsilon of 1e81e-8, while in Lion, the default values for β1\beta_{1} and β2\beta_{2} are discovered through the program search process and set as 0.9 and 0.99, respectively. We only tune those hyperparameters in Section 4.4, where β1=0.9\beta_{1}=0.9, β2=0.99\beta_{2}=0.99 in AdamW, and β1=0.95\beta_{1}=0.95, β2=0.98\beta_{2}=0.98 in Lion. In our experience, reducing β2\beta_{2} results in shorter memorization of historical information and enhanced training stability. Additionally, the ϵ\epsilon in AdamW is set as 1e61e-6 instead of the default 1e81e-8 as it improves stability in our experiments, similar to the observations in RoBERTa (Liu et al., 2019b).

The update generated by Lion is an element-wise binary ±1\pm 1, as a result of the sign operation, therefore it has a larger norm than those generated by other optimizers. Based on our experience, a suitable learning rate for Lion is typically 3-10x smaller than that for AdamW. Note that the initial value, peak value, and end value of the learning rate should be changed simultaneously with the same ratio compared to AdamW. We do not modify other training settings such as the learning rate schedule, gradient and update clipping. Since the effective weight decay is lr * λ\lambda: update += w * λ\lambda; update *= lr, the value of λ\lambda used for Lion is 3-10x larger than that for AdamW in order to maintain a similar strength. For instance,

lr=1e4lr=1e-4, λ=10.0\lambda=10.0 in Lion and lr=1e3lr=1e-3, λ=1.0\lambda=1.0 in AdamW when training ViT-B/16 on ImageNet with strong augmentations,

lr=3e5lr=3e-5, λ=0.1\lambda=0.1 in Lion and lr=3e4lr=3e-4, λ=0.01\lambda=0.01 in AdamW for diffusion models,

lr=1e4lr=1e-4, λ=0.01\lambda=0.01 in Lion and lr=1e3lr=1e-3, λ=0.001\lambda=0.001 in Adafactor for the 7.5B language modeling.

Please see Table 12 (in the Appendix) for all hyperparameters.

Apart from the peak performance, the sensitivity to hyperparameters and the difficulty in tuning them are also critical for the adoption of an optimizer in practice. In Figure 8 (Middle and Right), we alter both lrlr and λ\lambda when training ViT-B/16 from scratch on ImageNet. Suggested by the heatmaps, Lion is more robust to different hyperparameter choices compared to AdamW.

Limitations

Limitations of search Despite the efforts to make the search space less restrictive, it remains inspired by the popular first-order optimization algorithms, leading to a bias towards similar algorithms. It also lacks the functions required to construct advanced second-order algorithms (Anil et al., 2020; Gupta et al., 2018; Martens and Grosse, 2015). The search cost is still quite large and the algorithm simplification requires manual intervention. Further reducing the bias in the search space to discover more novel algorithms and improving the search efficiency are important future directions. The current program structure is quite simplistic, as we do not find a good usage of more advanced program constructs such as conditional, loop statements, and defining new functions. Exploring how to incorporate these elements has the potential to unlock new possibilities.

Limitations of Lion While we endeavour to evaluate Lion on as many tasks as possible, the assessment is limited to the chosen tasks. On vision tasks, the discrepancies between Lion, AdamW, and momentum SGD are pretty small on ResNets, likely due to the fact that ConvNets are easier to optimize compared to Transformers. The performance gain brought by Lion decreases when strong augmentations are utilized. There are also several tasks where Lion performs similarly to AdamW, including: (1) the Imagen text-to-image base model, (2) the perplexity of autoregressive language model trained on the large-scale internal dataset, which is arguably a more reliable metric the in-context learning benchmarks, and (3) masked language modeling on C4. These tasks have a common characteristic in that the datasets are massive and of high quality, which results in a reduced difference between optimizers. Another potential limitation is the batch size. Though people often scale up the batch size to enable more parallelism, it is likely that Lion performs no better than AdamW if the batch size is small (<<64). Additional, Lion still requires momentum tracking in bfloat16, which can be expensive for training giant models. One potential solution is to factorize the momentum to save memory.

Related Work

Our work lies in the area of AutoML and meta-learning that includes learning to learn (Andrychowicz et al., 2016; Ravi and Larochelle, 2017; Wichrowska et al., 2017; Bello et al., 2017; Xiong et al., 2022; Metz et al., 2019, 2022), neural architecture search (Real et al., 2019; Zoph and Le, 2017; Pham et al., 2018; Liu et al., 2019a; Chen and Hsieh, 2020; Wang et al., 2021b; So et al., 2019; Chen et al., 2021; Yang et al., 2022; Wang et al., 2021a) and hyperparameter optimization (Li et al., 2017; Jamieson and Talwalkar, 2016; Hutter et al., 2011; Dong et al., 2021), etc. There is also a long history of using evolutionary methods to search for programs, i.e., genetic programming (Koza, 1994; Brameier et al., 2007; Holland, 1992). Our approach builds upon a symbolic search space similar to AutoML-Zero (Real et al., 2020; Peng et al., 2020). However, instead of discovering programs with fixed dimensional matrices, vector, and scalars for toy tasks, our goal is to develop programs that operate on n-dimensional arrays and can generalize to state-of-the-art tasks. Other related works include numerous handcrafted optimizers (Kingma and Ba, 2014; Bernstein et al., 2018; Duchi et al., 2011; Shazeer and Stern, 2018; Zhuang et al., 2020; Dozat, 2016; Anil et al., 2020; Liu et al., 2020; Reddi et al., 2018; Gupta et al., 2018; Riedmiller and Braun, 1993; Ma and Yarats, 2019), which we discuss in Section 3.2.

Conclusion

This paper proposes to discover optimization algorithms via program search. We propose techniques to address the challenges in searching an infinite and sparse search space, and large generalization gap between the proxy and target tasks. Our method discovers a simple and effective optimizer, Lion, that is memory-efficient and achieves strong generalization across architectures, datasets and tasks.

Acknowledgements

We would like to thank (in alphabetical order) Angel Yu, Boqing Gong, Chen Cheng, Chitwan Saharia, Daiyi Peng, David So, Hanxiao Liu, Hanzhao Lin, Jeff Lund, Jiahui Yu, Jingru Xu, Julian Grady, Junyang Shen, Kevin Regan, Li Sheng, Liu Yang, Martin Wicke, Mingxing Tan, Mohammad Norouzi, Qiqi Yan, Rakesh Shivanna, Rohan Anil, Ruiqi Gao, Steve Li, Vlad Feinberg, Wenbo Zhang, William Chan, Xiao Wang, Xiaohua Zhai, Yaguang Li, Yang Li, Zhuoshu Li, Zihang Dai, Zirui Wang for helpful discussions, and the Google Brain team at large for providing a supportive research environment.

References

Appendix A Pseudocode for AdamW and Lion

Appendix B Image Classification Tasks

Our evaluation covers various benchmarks: ImageNet, ImageNet ReaL (Beyer et al., 2020), ImageNet V2 (Recht et al., 2019), ImageNet A (Hendrycks et al., 2021b), ImageNet R (Hendrycks et al., 2021a), ImageNet Sketch (Wang et al., 2019b), ObjectNet (Barbu et al., 2019), CIFAR-100 (Krizhevsky, 2009), and Oxford-IIIT Pet (Parkhi et al., 2012).

Appendix C NLP Tasks

This section shows all the natural language generation (NLG) and natural language understanding (NLU) tasks where we evaluate the large-scale language models in Section 4.4. Those tasks include Open-Domain Question Answering, Cloze and Completion Tasks, Winograd-Style Tasks, Common Sense Reasoning, In-Context Reading Comprehension, SuperGLUE, and Natural Language Inference.

NLG: TriviaQA (Joshi et al., 2017), Natural Questions (Kwiatkowski et al., 2019), Web Questions (Berant et al., 2013).

NLU: HellaSwag (Zellers et al., 2019), StoryCloze (Mostafazadeh et al., 2016), Winograd (Levesque et al., 2012), Winogrande (Sakaguchi et al., 2020), RACE (Lai et al., 2017), PIQA (Bisk et al., 2020), ARC (Clark et al., 2018), OpenbookQA (Mihaylov et al., 2018), BoolQ (Clark et al., 2019), Copa (Gordon et al., 2012), RTE (Dagan et al., 2006), WiC (Pilehvar and Camacho-Collados, 2019), Multirc (Khashabi et al., 2018), WSC (Levesque et al., 2012), ReCoRD (Zhang et al., 2018), CB (de Marneffe et al., 2019), Adversarial NLI (Nie et al., 2020).

Appendix D Other Discovered Programs

By varying the task setting, different types of algorithms can be discovered. For example, if we reduce the amount of data in the proxy task, we are more likely to discover algorithms with better regularization (Program LABEL:lst:reg), and if we reduce the search progress, we are likely to find simple variants of AdamW (Program LABEL:lst:adagrad and LABEL:lst:belief). Future work can explore this potential to discover optimizers specialized for different tasks.

Appendix E Architecture Details for Language Modeling

Table 10 shows the Transformer architecture details for language modeling (Section 4.4). The dimension of the feed-forward layer is 4×dmodel4\times d_{model}. We use vocabulary size 32K for small-scale and 256K for large-scale models.

Appendix F Details of Proxy Tasks

For vision tasks, we train a ViT with three layers, 96 hidden units and three heads, on 10% ImageNet for 30k steps with batch size 64. The image size is 64×6464\times 64 and the patch size is 16. For language tasks, we train a Transformer with two layers, 128 hidden units and two heads on LM1B (Chelba et al., 2013) for 20K steps with batch size 64, sequence length 32 and vocabulary size 3K. The evaluation time may vary for different programs, but typically a evaluation can be done on one TPU V2 chip within 20min. The validation accuracy or perplexity is used as the fitness.

Appendix G Analysis of Loss Landscape

Appendix H Available Functions

We include 43 available functions that can be used in the program during search. Note that the input of the functions can be one n-dimensional array, dictionaries or lists of arrays, similar to the pytrees in JAX.

Basic math functions from NumPy / JAX This includes unary functions like abs, cos, sin, tan, arcsin, arccos, arctan, exp, log, sinh, cosh, tanh, arcsinh, arccosh, arctanh, sign, exp2, exp10, expm1, log10, log2, log1p, square, sqrt, cube, cbrt, sign, reciprocal and binary functions like +, -, *, /, power, maximum, minimum with the same semantic as the corresponding function in NumPy / JAX.

Linear algebra functions commonly used in first-order optimization algorithms This includes: (1) unary function norm that computes the norm of each arrays in the input; (2) unary function global_norm that computes the global norm by treating all the numbers in the input as one vector; (3) binary function dot that treats the two inputs as two vectors and computes their dot product; (4) binary function cosine_sim that treats the two inputs as two vectors and computes their cosine similarity; (5) binary clip_by_global_norm (clip) that clips the global norm of the first input to the value of the second input that is required to be a scalar; (6) ternary function interpolate (interp) that uses the third argument a, required to be a scalar, to compute a linear interpolation of the first two arguments x and y with (1 - a) * x + a * y.

Functions producing commonly used constants This includes get_pi, get_e, get_eps that generates π\pi, ee and ϵ=108\epsilon=10^{-8} respectively.

Appendix I Abstract Execution

We propose to prune the large search space with abstract execution. Our approach is motivated by the fact that a large number of programs are invalid, functionally equivalent, or contain redundant statements that waste compute during evaluation. To address this, we introduce an abstract execution step that checks the type and shape of each variable, and computes a hash for each unique computation from inputs to outputs to detect redundant statements. The abstract execution can be seen as a static analysis of the program, achieved by replacing functions and inputs with customized values. We outline the specifics of the customized values and abstract execution procedure for three use cases below. The cost of the abstract execution is usually negligible compared to the actual execution of the program.

Detecting errors with type / shape inference To detect programs containing errors, we infer the type and shape of each variable in the program through the following steps: (1) replace each input with an abstract object that only contains type and shape information, and replace each statement with a type and shape inference function; (2) iterate through all statements. Instead of executing the original statement, we validate a function call by checking the function signature and type and shape information of its arguments. If valid, we compute the type and shape information of the output and assign it to the new variable; (3) verify the validity of the derived type and shape of the output. This process essentially performs a static analysis of the program, exposing errors caused by type and shape mismatch. Note that there are still run-time errors, such as division by zero, that cannot be detected in this manner. Without such filtering of invalid programs, the search would be overwhelmed with invalid programs, making it difficult to achieve meaningful progress.

Deduplicating with functional hash Among the valid programs that execute without errors, there are still lots of duplicates due to functionally equivalent programs that have different surface forms but the same underlying functionality. To address this issue, we calculate a functional hash value for every unique computation from the inputs to the outputs as follows: (1) a unique hash value is assigned to each input and function; (2) iterate through all statements, calculating the hash value of the outputs by combining the hash values of the functions and arguments; (3) compute the hash value of program by combining the hash values of all outputs. We then build a hash table that maps each unique functional hash value to the fitness of the corresponding program. When a new program is generated, we first look up its hash value and only perform evaluation if it is not found or if we want to evaluate it multiple times to reduce measurement noise. In our experiments, this technique reduces the search cost by \sim10x, as depicted in Figure 2 (Right).

Identifying redundant statements by tracking dependencies In program evolution, redundant statements are included to enable combining multiple mutations to make larger program changes. However, these redundant statements increase the evaluation cost and make program analysis more challenging. To identify redundant statements, we need to determine the set of statements that the outputs depend on, which can be computed in a recursive manner using the following steps: (1) replace the value of each input with an empty set, as they do not depend on any statement; (2) iterate through each statement. Note that each statement is an assignment that calls a function and assigns the result to a variable, which in turn depends on the current statement and all the depending statements of the function arguments. Therefore we replace the value of the variable with its dependency, i.e., a set of all depending statements; (3) compute the union of all statements that each output depends on, which contains all non-redundant statements. By filtering out redundant statements, we obtain a simplified version of the program that is cheaper to execute and easier to analyze. In our experiments, this reduces the program length by \sim3x on average, as shown in Figure 2 (Right).