DoReMi: Optimizing Data Mixtures Speeds Up Language Model Pretraining

Sang Michael Xie, Hieu Pham, Xuanyi Dong, Nan Du, Hanxiao Liu, Yifeng Lu, Percy Liang, Quoc V. Le, Tengyu Ma, Adams Wei Yu

Introduction

Datasets for training language models (LMs) are typically sampled from a mixture of many domains (Gao et al., 2020, Du et al., 2021, Chowdhery et al., 2022, Brown et al., 2020). For example, The Pile (Gao et al., 2020), a large publicly available dataset, is composed of 24% web data, 9% Wikipedia, 4% GitHub, etc.The domain weights, which are based on token count in this paper, varies by tokenizer; see Appendix C. The composition of the pretraining data greatly affects the effectiveness of an LM (Du et al., 2021, Hoffmann et al., 2022, Xie et al., 2023). However, it is unclear how much of each domain to include to produce a model that performs well for a wide variety of downstream tasks.

Existing works determine domain weights (the sampling probabilities for each domain) by using intuition or a set of downstream tasks. For example, The Pile uses heuristically-chosen domain weights, which could be suboptimal. On the other hand, existing LMs such as PaLM (Chowdhery et al., 2022) and GLaM (Du et al., 2021) tune the domain weights based on a set of downstream tasks, but requires training potentially thousands of LMs on different domain weights and risks overfitting to the particular set of downstream tasks.

Instead of optimizing domain weights based on a set of downstream tasks, our approach aims to find domain weights which lead to models that perform well on all domains by minimizing the worst-case excess loss over domains, following Oren et al. (2019), Mindermann et al. (2022). The excess loss is the loss gap between the model being evaluated and a pretrained reference model.

This motivates our algorithm, Domain Reweighting with Minimax Optimization (DoReMi), which leverages distributionally robust optimization (DRO) to tune the domain weights without knowledge of downstream tasks (Figure 1). First, DoReMi trains a small reference model (e.g., 280M parameters) in a standard way. Second, DoReMi trains a small distributionally robust language model (DRO-LM) (Oren et al., 2019), which minimizes the worst-case excess loss (relative to the reference’s model’s loss) across all domains. Notably, rather than using the robust LM, we take the domain weights produced by DRO training. Finally, we train a large (8B) LM on a new dataset defined by these domain weights.

Our approach adapts the DRO-LM framework (Oren et al., 2019) to optimize domain weights instead of producing a robust model. To do this, DoReMi uses the online learning-based optimizer from Group DRO (Sagawa et al., 2020, Nemirovski et al., 2009), which dynamically updates domain weights according to the loss on each domain for rescaling the training objective, instead of sub-selecting examples from a minibatch as in Oren et al. (2019), Mindermann et al. (2022). Finally, DoReMi takes the averaged domain weights over DRO training steps.

In Section 3, we run DoReMi on 280M proxy and reference models to optimize domain weights on The Pile (Gao et al., 2020) and the GLaM dataset (Du et al., 2021) (used in PaLM (Chowdhery et al., 2022)). The DoReMi domain weights are used to train an 8B parameter LM (over 30x larger). On The Pile, DoReMi reduces perplexity on all domains over baseline domain weights, even when it downweights a domain. DoReMi improves average downstream accuracy over a baseline model trained on The Pile’s default domain weights by 6.5% points on generative few-shot tasks and achieves the baseline downstream accuracy 2.6x faster (Figure 2). In Section 4, we find that DoReMi consistently improves LM training when varying the sizes of the proxy model and the main model trained with optimized domain weights. On the GLaM dataset where domain weights tuned on downstream tasks are available, DoReMi even performs comparably to tuning domain weights on downstream task performance.A public re-implementation of DoReMi and optimized domain weights for The Pile can be found at https://github.com/sangmichaelxie/doremi.

Domain Reweighting with Minimax Optimization (DoReMi)

In this section we define DoReMi, an algorithm for using a small proxy model to optimize the domain weights of a language modeling dataset, which then improves the training of a large model.

Suppose that we have kk domains (e.g., Wikipedia, GitHub), where for each domain ii, we have a set of examples DiD_{i}. Domain weights αΔk\alpha\in\Delta^{k} specify a probability distribution over the kk domains, and consequently a distribution over the training data: Pα=i=1kαiunif(Di)P_{\alpha}=\sum_{i=1}^{k}\alpha_{i}\cdot\textup{unif}(D_{i}) where unif(D)=1DxDδx\textup{unif}(D)=\frac{1}{|D|}\sum_{x\in D}\delta_{x} is the uniform distribution over the examples in DD and δx(x)\delta_{x}(x^{\prime}) is 1 if x=xx^{\prime}=x and 0 otherwise.

DoReMi.

The inputs of DoReMi are the data D1,,DkD_{1},\dots,D_{k}, reference domain weights αref\alpha_{\text{ref}} (e.g., uniform or based on raw token count of each domain), and training hyperparameters for the large, full-size model (number of training steps TT and batch size bb). DoReMi returns optimized domain weights αˉ\bar{\alpha} and ultimately, a large model trained on PαˉP_{\bar{\alpha}}.

Step 1: Obtain a small reference model.

We first train a model prefp_{\text{ref}} on some reference domain weights αref\alpha_{\text{ref}} (e.g., based on raw token count as a default) for TT steps, batch size bb. This model serves as the reference model for step 2 and captures a baseline level of difficulty of each example/domain. The reference model can be a relatively small model (280M parameters in our experiments).

Step 2: Train proxy model with Group DRO to obtain domain weights.

To obtain domain weights, we train a small proxy model pθp_{\theta} in the distributionally robust language modeling (DRO-LM) (Oren et al., 2019) framework with the Group DRO optimizer (Sagawa et al., 2020), where θ\theta are the weights of the proxy model. This framework trains a robust model by optimizing the worst-case loss over domains, which is equivalent to the following minimax objective:

Step 3: Train large model with new domain weights.

The tuned domain weights αˉ\bar{\alpha} define a new training distribution PαˉP_{\bar{\alpha}}. We resample the data from this new distribution to train a main model (larger than the reference/proxy models), using a standard training procedure.

Details for Step 2.

Iterated DoReMi.

We extend DoReMi by running it for multiple rounds, setting the reference domain weights αref\alpha_{\text{ref}} for the next round to be αˉ\bar{\alpha} from the previous round. We call this iterated DoReMi. The entire iterated process still only uses small models for tuning domain weights. We stop iterating when the domain weights converge, which we define as when maximum change in any domain weight αˉαref\|\bar{\alpha}-\alpha_{\text{ref}}\|_{\infty} is less than 1e-3. Empirically, this takes only 3 rounds on the GLaM dataset (Section 3.2).

DoReMi Improves LM Training Efficiency and Performance

In this section, we use DoReMi domain weights optimized with a 280M-parameter proxy model to train a 8B-parameter main model (30x larger). We consider two datasets, The Pile (Gao et al., 2020) and the GLaM dataset (Du et al., 2021). On The Pile, DoReMi reduces perplexity significantly on every domain, improves average downstream accuracy on generative one-shot tasks by 6.5%, and achieves the baseline accuracy 2.6x faster. On the GLaM dataset where domain weights tuned on downstream datasets are available, DoReMi finds domain weights with comparable performance to downstream-tuned domain weights.

The Pile (Gao et al., 2020) is a 800GB text dataset with 22 domains (Table 1). The default domain weights were determined heuristically. We use the default domain weights from The Pile dataset to train the baseline and as the reference domain weights αref\alpha_{\text{ref}} in DoReMi (see Appendix C).

GLaM dataset.

The GLaM dataset (Du et al., 2021) (also used in training PaLM (Chowdhery et al., 2022)) includes text from 8 domains (Table 2). For comparison, the GLaM domain weights (downstream-tuned) were tuned according to the downstream performance of models trained on each domain and the size of each domain (Du et al., 2021). We consider this an oracle comparison, since these domain weights are tuned on downstream tasks that are in our evaluation set. We use uniform domain weights both for training the baseline and the reference domain weights αref\alpha_{\text{ref}} for DoReMi.

Training setup.

We train Transformer (Vaswani et al., 2017) decoder-only LMs with the standard next-token language modeling loss. We conduct a controlled comparison by equalizing the amount of compute, measured by the number of tokens processed during training. For The Pile, we train each model for 200k steps; for the GLaM dataset, we train each model for 300k steps. All models use a batch size of 512 and maximum token length of 1024. The proxy and reference models have 280M parameters. All models are trained from scratch (other hyperparameters are in Appendix C).

Evaluation.

We use held-out validation data to measure the perplexity on each domain. For downstream evaluation, we use the generative one-shot tasks from the GPT-3 paper (Brown et al., 2020): TriviaQA (Joshi et al., 2017), NaturalQuestions (Kwiatkowski et al., 2019), WebQuestions (Berant et al., 2013), SQuADv2 (Rajpurkar et al., 2018), and LAMBADA (Paperno et al., 2016). We use the standard exact-match accuracy metric for the these datasets. The performance on these datasets (particularly TriviaQA) has been shown to correlate well with model scale even at the 100M–1B range (Brown et al., 2020).

Compute used for optimizing domain weights.

We train two 280M models (the reference and proxy models) to optimize the domain weights. This is 8% of the FLOPs required to train the main 8B model. All FLOPs come from standard forward and backward passes.

Notation for model sizes in DoReMi.

We denote the size of the reference/proxy models (which are always the same size in our experiments) and the size of the main model trained with DoReMi domain weights as “DoReMi (size of reference/proxy\rightarrowsize of main model)”: for example, DoReMi (280M\rightarrow8B). When we are discussing the optimized domain weights independently of the main model, we only include one number (e.g., DoReMi (280M)) which refers to the reference/proxy model size.

2 DoReMi improves perplexity and downstream accuracy

We show that DoReMi significantly improves both the perplexity and downstream accuracy of 8B models trained on The Pile and the GLaM dataset over their respective baseline domain weights.

Figure 3 (left) shows the average downstream performance for baseline and DoReMi (280M\rightarrow8B) models on The Pile. DoReMi improves the downstream accuracy by 6.5% points and achieves the baseline accuracy within 75k steps — 2.6x faster than the baseline (200k steps). Thus, DoReMi can dramatically speed up training and improve downstream performance.

DoReMi can reduce perplexity across all domains without a tradeoff.

Figure 4 shows the per-domain log-perplexity of the 8B models on The Pile. DoReMi significantly reduces the perplexity over the baseline across all domains, despite allocating lower weight to some domains. How can this occur? One hypothesis is that the domains with the lowest and highest entropy can be downweighted without impacting the perplexity much. The lowest entropy domains statistically require few samples to learn. The highest entropy domains have token distributions that are close to common uniform priors — for example, models at random initialization tend to output a uniform next token distribution. Thus, we need less samples to fit these domains. Positive transfer from allocating more samples to medium entropy domains can then improve perplexity on all domains. In Appendix D, we provide a simple example where reweighting domains can improve perplexity on all domains and DoReMi finds such domain weights in simulations.

Iterated DoReMi achieves performance of downstream-tuned weights on the GLaM dataset.

We employ iterated DoReMi on the GLaM dataset over 3 rounds. We find that the second and third round domain weights are almost identical (Table 2). Figure 3 (right) shows one-shot results for the first two rounds of iterated DoReMi. After the first round, the DoReMi main model has comparable downstream accuracy to the baseline (uniform domain weights). After the second round, the DoReMi main model achieves comparable downstream accuracy to oracle domain weights tuned on downstream tasks in our evaluation set. Overall, domain reweighting has a smaller effect on GLaM, possibly because there are only 8 domains compared to 22 in The Pile.

Inspecting the DoReMi domain weights.

Tables 1 and 2 present the DoReMi domain weights for The Pile and the GLaM dataset. When running DoReMi on a 280M proxy model (DoReMi (280M)), most weight is put on the diverse Pile-CC web text domain. Note that Wikipedia is downweighted in comparison to the baseline, but DoReMi still improves the downstream accuracy on tasks derived from Wikipedia (e.g., TriviaQA, Appendix Table 5). Domain weights for a 1B proxy model (Appendix 8) shows a different trend, where OpenWebText is the mostly upweighted instead of Pile-CC. This suggests that there may be multiple possible local minima in the domain weight space. On the GLaM dataset, the DoReMi weights have the same general pattern as the downstream-tuned domain weights. DoReMi is able to recover a similar set of domain weights by starting from uniform initial reference domain weights, without any use of downstream data.

Ablations and Analysis Across Scales

Previously in Section 3, we showed that DoReMi finds domain weights using 280M models that can improve training of 8B models. In this section, we conduct an analysis of DoReMi where we vary the scale of the proxy model in relation to the main model and ablate the components of the excess loss objective.

We consider using proxy and main models of the same size to analyze DoReMi’s behavior in a simple setting, without the need for the domain weights to generalize across scales. Note that this is just for scientific purposes since this does not save compute in practice. In particular, we run DoReMi (X\rightarrowX) where X is 280M, 510M, 760M, or 1B on The Pile. Figure 5 shows that DoReMi consistently improves downstream accuracy over the baseline by 2% and achieves the baseline accuracy 4x faster on average across scales, and this improvement does not shrink with larger model size. DoReMi improves the worst-case perplexity on all scales and improves 18 of 22 individual domain perplexities on average across scales (Appendix Table 6). These experiments give a rough picture of how much is lost when using a smaller proxy model; our DoReMi (280M\rightarrow8B) model achieves the baseline accuracy 2.6x faster, while matching the proxy and main model sizes results in a 4x average speedup.

Proxy model underperforms main model, especially at larger sizes.

Recall that DoReMi uses Group DRO to train a proxy model, which reweights the objective with the domain weights. In contrast, the main model is trained by resampling on the domain weights from DoReMi. When the proxy model and the main model are the same size, which one is the better model? Table 6(b) shows that the proxy model typically underperforms the main model in this case. The gap between the proxy and main model increases with scale, as the 1B proxy model not only underperforms the 1B main model but also the 1B baseline model, while the 280M proxy model achieves better perplexity than the 280M baseline model on 19/22 domains. Despite the relatively poor quality of the 1B proxy model, the domain weights still allow the 1B main model to achieve the baseline performance over 2x faster. This suggests that DoReMi can succeed even if the proxy model is not trained well. However, we hypothesize that the mismatch between the proxy and main model training (loss reweighting vs. resampling) explains their performance difference and therefore a resampling-based Group DRO optimizer may improve DoReMi for larger proxy models.

Effect of proxy model scale on larger main model’s performance.

We consider 70M, 150M, 280M, and 1B scales for the DoReMi proxy model while fixing the main model size at 8B (DoReMi (X\rightarrow8B)). From 70M to 280M, increasing the proxy model size improves downstream accuracy at 8B (Figure 6 left). We hypothesize that this trend does not continue for the 1B proxy model because the Group DRO optimizer is worse at larger scales (Table 6(b)). While DoReMi (280M\rightarrow8B) results in the most improvement at 8B, DoReMi (150M\rightarrow8B) and DoReMi (1B\rightarrow8B) still achieve the baseline accuracy almost 2x faster. This suggests that DoReMi is robust to the proxy model scale. In practice, we suggest choosing a relatively small proxy model size (280M) to save compute.

Choosing the easiest or hardest domains do not suffice.

Related Work

Most closely related is the GLaM dataset (Du et al., 2021) (also used for training PaLM (Chowdhery et al., 2022)), which has domain weights that are tuned using downstream data. Optimizing domain weights for downstream tasks can be expensive and could require search/zero-order optimization (Snoek et al., 2012), RL (Zoph and Le, 2016), or heuristic assumptions on how positive/negative transfer between domains work. Example-level filtering also brings benefits for LM training. The C4 dataset (Raffel et al., 2019) shows gains over CommonCrawl via heuristic data cleaning methods. Du et al. (2021), Xie et al. (2023) show that filtering the data at an example level for high-quality text that look like Wikipedia and books can significantly improve downstream performance for LMs. In contrast to these works, DoReMi sets domain weights automatically with only two small LM training runs and does not make assumptions about the type of data to prefer (Wikipedia-like, etc.).

General data selection methods.

Moore-Lewis selection (Moore and Lewis, 2010, Axelrod, 2017, Feng et al., 2022) selects examples with high cross-entropy difference (similar to excess log-perplexity) between language models trained on target and raw data. In contrast, DoReMi reweights the data without a target distribution. Coleman et al. (2020) select examples based on the uncertainty of a small proxy model for active learning, while DoReMi uses DRO on the excess loss with respect to a reference model, and focuses on data mixture reweighting. Mindermann et al. (2022) select examples in an online fashion by taking the top kk examples in a minibatch according to excess loss. DoReMi optimizes the data mixture before training, allowing the larger main model to train in a standard way. Many other works on data selection are in vision (Sorscher et al., 2022, Kaushal et al., 2019, Killamsetty et al., 2021b, a, c, Wang et al., 2020, Wei et al., 2015, Paul et al., 2021, Mirzasoleiman et al., 2020, Sener and Savarese, 2018) and mainly focus on example-level subset selection with metrics such as gradient matching. Overall, these methods do not address data selection for pretraining, where the downstream data distribution may be very different from the pretraining distribution. DoReMi aims to address the pretraining/downstream distribution shift with a robust optimization approach. To the best of our knowledge, we are the first to show that reweighting the data according to losses of a small proxy LM can improve the training efficiency of much larger LM.

Distributionally robust optimization.

Within DRO methods for deep learning (Ben-Tal et al., 2013, Sinha et al., 2018, Oren et al., 2019, Sagawa et al., 2020), we target a restricted form of shift called group shifts (Duchi et al., 2019, Oren et al., 2019, Sagawa et al., 2020), where the test distribution can be an unknown mixture of groups (domains). We follow DRO-LM (Oren et al., 2019), which employs DRO for LMs in the group shift setting. DRO-LM also uses a baselined loss, but with a simple bigram reference model. DoReMi uses a reference model of the same size and architecture as the proxy model to ensure that the losses are on a similar scale. During optimization, DRO-LM takes a worst-case subset of each minibatch to update the model on, while we use the Group DRO optimizer (Sagawa et al., 2020) which doesn’t require online subselection. If we equalize the number of examples in each minibatch used for gradient updates, online subselelction is more expensive than Group DRO since it requires running forward passes on a larger minibatch (e.g., double the minibatch size) before selecting a subset to update the model with. In comparison, the Group DRO optimizer updates the model on all examples in a weighted fashion. Overall, in contrast to these DRO methods which aim to produce robust models, we use DRO to optimize the data for training larger models more efficiently.

Data-centric AI.

Large-scale datasets and benchmarks have driven much of the recent progress in AI, including vision, NLP, and multimodal models (Deng et al., 2009, Russakovsky et al., 2015, Wang et al., 2019, Rajpurkar et al., 2016, Raffel et al., 2019, Gao et al., 2020, Schuhmann et al., 2022, Gadre et al., 2023). However, most datasets are still painstakingly created with human-generated data, manual work, and heuristics (Deng et al., 2009, Raffel et al., 2019, Gao et al., 2020, Schuhmann et al., 2022, Gadre et al., 2023). DoReMi is a principled data-centric method that aims to improve language model training efficiency. We hope that DoReMi can provide a starting point for a general data-centric framework for language modeling via robust optimization.

Discussion and Limitations

In Section 2, we run DoReMi for the number of training steps that will be used to train the final model, which could be unnecessarily expensive. A future direction for saving compute would be to stop running DoReMi at an early step and extrapolate the domain weights for the desired number of steps, since we found that most of the variation in the domain weights during a DoReMi run seems to occur in the beginning of training (Appendix Figure 8).

Choice of reference model.

The choice of reference model can affect the domain weights found by DoReMi. For example, iterated DoReMi (Section 3) improves performance by using a reference model trained on the tuned domain weights from a previous round of DoReMi. Further directions include varying the reference model size and using specialized reference models to optimize domain weights for a specific application area.

What is a domain?

We define a domain by data provenance in our experiments, but this only enables coarse-grained control. Using fine-grained domains could improve the gains from DoReMi. For example, DoReMi is more effective on The Pile (22 domains) than the GLaM dataset (8 domains). Open directions include automatically finding fine-grained domains (e.g., via clustering as in DRO-LM (Oren et al., 2019)) and reweighting the data at an example level. When domains are very fine-grained, it will be important to control the pessimism of DRO (e.g., DRO can put all the weight on a small set of worst-case examples).

Transferability of domain weights across scales.

We optimized the domain weights with a small proxy model (280M) and directly used these domain weights to improve training at a larger scale (8B). Understanding why the domain weights can be transferred across scales and the limits of how far these domain weights transfer are important questions to answer in future work.

Broader impacts.

Large language models are We hope to improve training efficiency and reduce the environmental impact of training large LMs (Strubell et al., 2019, Lacoste et al., 2019, Patterson et al., 2021, Ligozat et al., 2021). In particular, by reducing the training time by 2x, we can halve the cost and energy consumption of training large language models. Since such efficiency improvements may be used to develop even larger models, there may be no absolute improvement in energy consumption. Ultimately, we hope to improve the training efficiency and cost of developing future language models relative to existing methods.

Large LMs have also been well-documented to have risks and biases (Abid et al., 2021, Nadeem et al., 2020, Bommasani et al., 2021, Blodgett and OConnor, 2017, Gehman et al., 2020). For example, GPT-3 tends to have an anti-Muslim bias, where Muslims are frequently related to violence or terrorism in analogy and completion tasks (Abid et al., 2021). As large language models are increasingly relied upon in applications, the magnitude of the risks increases (Bommasani et al., 2022). Distributionally robust optimization (DRO), which is used in DoReMi to optimize the data mixture, can have a favorable impact on fairness (Hashimoto et al., 2018). While the standard approach of minimizing the average loss can lead to disparate performance on minority subgroups that do not contribute heavily to the loss (Amodei et al., 2016), DRO promotes good performance on all groups via a worst-case loss. In this way, DRO-style data-centric methods such as DoReMi can improve the representation disparity between majority and minority subgroups in a dataset.

Conclusion

We introduced DoReMi, an algorithm reweighting data domains for training language models. DoReMi is able to run on small models and transfer the benefits to 30x larger models, resulting in a 2.6x speedup in training on the Pile just by changing the sampling probabilities on domains. We hope to instigate more research on data-centric approaches for improving language model training efficiency.

Acknowledgments

We thank Xiangning Chen, Andrew Dai, Zoubin Ghahramani, Balaji Lakshminarayanan, Paul Michel, Yonghui Wu, Steven Zheng, Chen Zhu and the broader Google Bard team members for insightful discussions and pointers.

References

Appendix A Results Across Scales on the GLaM dataset

Figure 7 presents results across different scales (280M, 510M, 760M, 1B) on the GLaM dataset, where the proxy/reference models are the same size as the main model trained with DoReMi domain weights. Across all scales, DoReMi is comparable or better than both the baseline (uniform) domain weights and downstream-tuned domain weights. Interestingly, for iterated DoReMi at the 280M scale, the second round weights achieve slightly worse downstream accuracy than the round 1 weights when used to train 280M models, but transfer better to training 8B models.

Appendix B Detailed Results for The Pile

Table 4 shows per-domain perplexities for 8B models trained on the Pile. The reference/proxy models in this case are 70M, 150M, 280M, and 1B. DoReMi improves the perplexity on each domain compared to the baseline domain weights.

Per-task accuracies for 8B models.

Table 5 shows the accuracies on one-shot generative tasks for various reference/proxy model sizes from 70M to 1B. All DoReMi models improve downstream performance significantly over the baseline.

Summary of perplexity results across scales.

Table 6 shows a summary of per-domain perplexities for DoReMi across 4 scales (280M, 510M, 760M, 1B). Here, the reference/proxy models are the same size as the main model trained with DoReMi domain weights. On average, DoReMi improves perplexity on 18.25 out of 22 domains from The Pile. The worst-case perplexity is always reduced (or comparable in the 510M case) with respect to the baseline domain weights.

Perplexity results for ablations.

Table 7 shows the perplexities for ablations on the DRO objective. We change the DRO objective and use these to tune domain weights on 280M reference/proxy models. These tuned domain weights are then used to train a main 280M model. Hardest refers to optimizing the domain-level log-perplexity without baselining with a reference model. Easiest refers to optimizing for the domains with lowest log-perplexity under the reference model. Both ablations do not improve perplexity on any domain over the baseline. Optimizing for the “hardest” domain does not actually result in improving worst-case perplexity, supporting the results of Oren et al. (2019), which also employs DRO for language modeling with a baselined loss.

Trajectory of domain weights.

Figure 8 shows the exponential moving average (smoothing parameter 0.99) of domain weights during a run of DoReMi. In both cases, there are domains with very high weight initially and decrease in weight very quickly (within 50k steps). Since we compute the final domain weights by integrating these curves over steps and normalizing, this suggests that if we have a smaller compute budget, these domains could become more important — this highlights the dependence of the mixture weights on the compute budget. At the same time, the domain weights tend to quickly stabilize after 50k steps, suggesting that the optimal domain weights should be similar for larger compute budgets. We may also be able to take advantage of this stability after 50k steps to run DoReMi for a smaller number of steps and extrapolate the domain weights to save compute.

Comparison of domain weights for 280M and 1B.

Table 8 presents the DoReMi domain weights for The Pile at 280M and 1B proxy models. Different proxy model sizes can result in different domain weights, which suggests that there may be multiple local minima in domain weight space. With a 280M proxy model, most of the weight is put on the Pile-CC web text domain, while DoReMi with a 1B proxy model puts most of the weight on OpenWebText2. The overall pattern of the domain weights for the rest of the domains are similar.

Appendix C Training Details

For all datasets, we preprocessed the data by chunking into length 1024 examples with respect to a SentencePiece tokenizer with 256k vocabulary size. The examples are separated by domain to facilitate hierarchical sampling (first sample a domain according to some domain weights, then sample an example from that domain at random). To reduce the amount of padding tokens, we made an effort to pack examples (possibly from different domains) together into the same sequence. When doing such a packing, we compute the domain perplexities on a per-token level in DoReMi.

Baseline domain weights for The Pile.

The baseline domain weights for The Pile were computed from The Pile dataset and the number of epochs for each domain given in Gao et al. (2020). After chunking into length 1024 examples, we counted the number of examples in each domain and multiplied by the number of epochs that domain specified in Gao et al. (2020). We then normalized these counts to obtain the baseline domain weights.

Training setup.

For all training runs (including DRO runs), we train with a batch size of 512, initial learning rate of 1e-3, weight decay of 1e-2, and gradient clipping to norm 1. We decay the learning rate exponentially until it reaches a minimum of 1e-4 at the end of training, with a linear warmup of 6% of the total training steps. We train for 200k steps on The Pile and 300k steps on the GLaM dataset. Models under 1B parameters were trained with TPUv3 accelerators, while 1B and 8B models were trained with TPUv4.

Model architectures.

Table 9 shows the architecture hyperparameters for the model sizes used in the paper. All the models we use are vanilla Transformer decoder-only models with a 256k vocab size.

Appendix D Simple Example Where Data Reweighting Has No Tradeoff

Motivated by the findings in Section 3.2, we present a simple language modeling example where reweighting the training data from different domains improves perplexity on all domains. The example shows that DoReMi downweights domains that are extremely high or low entropy.

Suppose the ground-truth distribution of text pp^{*} is a mixture over kk domains, where each domain z{1,,k}z\in\{1,\dots,k\} is defined by a different unigram distribution p(xz)p^{*}(x\mid z) over mm tokens. Given a budget of nn training samples, the goal is choose domain weights p(z)p(z) (kk scalars that add to 1) to sample training data with such that we learn the parameters of the unigram distributions p(z)p^{*}(\cdot\mid z) well for all zz from 11 to kk. Notably, we do not aim to estimate the ground truth mixture proportions across domains.

Data.

Given some domain weights p(z)p(z), we sample training data hierarchically: first we determine the number of samples nzn_{z} per domain zz by drawing from a multinomial distribution over kk possibilities with probabilities defined by p(z)p(z) and nn total trials. Then, for each domain zz, we sample nzn_{z} tokens from p(z)p^{*}(\cdot\mid z), forming a vector of tokens XzX_{z} with length nzn_{z}.

Model.

where sz=xλz(x)s_{z}=\sum_{x}\lambda_{z}(x) is the sum of pseudocounts.

For a domain zz, we can write the parameter error of this estimator as a function of the “difficulty” HzH_{z} of predicting the next token and the “quality” of the prior Δz\Delta_{z}, defined below.

For domain index zz with nzn_{z} samples, the parameter error is

Putting it all together, the parameter error can be written as

No-tradeoff example.

Suppose there are 3 domains z{1,2,3}z\in\{1,2,3\} and m=3m=3 vocabulary tokens x{1,2,3}x\in\{1,2,3\}. We use a symmetric Dirichlet prior (preferring a uniform token distribution) where λz(x)=1/3\lambda_{z}(x)=1/3 for all tokens xx and domains zz. Here, sz=xλz(x)=1s_{z}=\sum_{x}\lambda_{z}(x)=1. In this setting, we show that there is a set of domain weights that has strictly lower parameter error than the baseline where we sample the same number of tokens from each domain: nzn_{z} are equal for all domains zz.

Suppose the ground truth paramaters for the unigram distributions are

where row zz contains the parameters for domain zz. For example, token 1 has probability 1 under domain 1’s unigram distribution.

For domain z=1z=1 (non-noisy domain), we have H1=0H_{1}=0 so the parameter error (according to Lemma 1) is

which is strictly decreasing in the number of samples n1n_{1}.

For domain z=3z=3 (noisy domain), we have Δ3=0\Delta_{3}=0 so the parameter error is

by Lemma 1. This error is minimized to zero at n3=0n_{3}=0 (no samples). This means that we can allocate samples elsewhere while still reducing error.

For z=2z=2 (intermediate entropy domain), we have Δ2=0.207\Delta_{2}=0.207 and H2=0.46H_{2}=0.46. The derivative of the parameter error with respect to the number of samples n2n_{2} is

This inequality holds in this case since 2Δ2H2<1\frac{2\Delta_{2}}{H_{2}}<1 and s2=1s_{2}=1. Therefore the parameter error is decreasing in the number of samples n2n_{2}.

Thus, any domain weights that reallocate the examples from domain 3 to domains 1 and 2 reduces the parameter error for all domains.

What kind of domains are downweighted?

Intuitively, we can downweight the very noisy (high entropy/difficulty) domain 3 because the initialization perfectly matches the ground truth. This allows us to reallocate samples to the other domains 1 and 2. Between these, domain 1 requires less additional samples since the parameter error decreases very quickly with the number of samples n1n_{1} (the difficulty H1H_{1} is zero). Thus, the easiest domains should also receive relatively less weight. In practice, positive transfer between domains (which is not captured here) can also contribute to scenarios where reweighting results in no tradeoff across domains.

Simulation with DoReMi.

We consider running DoReMi on the above no-tradeoff instance of the simple example with the ground truth unigram distributions in Equation 14. Note that DoReMi’s domain reweighting step (Step 2, Algorithm 1) involves a loop over TT iterative model updates, while the estimator from Equation 2 is computed in closed form. To adapt the estimator for DoReMi, we consider an iterative version where the average is computed in an online fashion. We run DoReMi for T=500T=500 steps using minibatch size 1 over the n=500n=500 training examples with domain weight update rate η=0.5\eta=0.5. For the model update at step tt on an example xx from domain zz, we increase the pseudo-count θ^z(x)\hat{\theta}_{z}(x) by the current domain weight αt\alpha_{t} corresponding to domain zz. Instead of using the examples in the minibatch (which is only size 1 and doesn’t represent all domains), we compute the per-domain excess log-perplexities in Algorithm 1 using a fixed, independent evaluation set of 30 examples.

We compare DoReMi against a model trained with baseline domain weights, which are uniform over the 3 domains. All models are trained on n=500n=500 training examples. We evaluate the log-perplexity of a model on each domain in closed form using the ground truth unigram distribution parameters.

On this simple example, DoReMi returns domain weights [0.39,0.61,0.0][0.39,0.61,0.0] after rounding to 2 decimal places. These weights correspond to our intuitions — the first domain (non-noisy) is increased by a small amount, the third domain (noisy) is decreased to 0 weight, and most of the weight is allocated to the second domain. We use these domain weights to generate a new dataset of 500 examples. The model trained with this new dataset improves over the baseline model in perplexity on all domains.