Symbolic Chain-of-Thought Distillation: Small Models Can Also "Think" Step-by-Step

Liunian Harold Li, Jack Hessel, Youngjae Yu, Xiang Ren, Kai-Wei Chang, Yejin Choi

Introduction

Empirical scaling laws suggest that the accuracy of Large Language Models (LLMs) on benchmark tasks can be improved by increasing model size and pre-training data volume Hoffmann et al. (2022). Beyond these training-time improvements, however, an inference-time strategy dubbed “chain-of-thought" (CoT) prompting,Sometimes called “self-rationalization” or “prompting with explanations.” We will use these terms interchangeably in this paper. i.e., eliciting verbalizations of predictive processes via key-phrases like “Let’s think step-by-step" Kojima et al. (2022), can similarly improve performance, e.g., Suzgun et al. (2022) demonstrate additional performance gains on a hard subset of the BigBench tasks BIG-bench collaboration (2022) using chain-of-thought.

However, chain-of-thought prompting has only been shown to be beneficial for models of sufficient scale (e.g., with more than 60B parameters Wei et al. (2022b)). In this work, we study whether small language models can be “taught" the capacity for chain-of-thought reasoning by larger language models. We adopt a simple strategy, which we call Symbolic Chain-of-thought Distillation (SCoTD): first, we sample chain-of-thought rationales from large language model given (unlabeled) input instances from a dataset; then, we train a smaller language model to predict the sampled rationale and sampled label. This process follows the “symbolic knowledge distillation” paradigm as in West et al. (2022), wherein corpora are sampled from a larger language model to serve as training data for a smaller one.

We find that through SCoTD, smaller language models learn to self-rationalize and perform significantly better on 3 commonsense QA tasks compared to learning without rationalizations. This result holds for both supervised and few-shot settings, and across student models of varying scales (125M–1.3B parameters). Performance gains are especially pronounced when applying distilled chain-of-thought models to difficult scenarios like: contrast sets Gardner et al. (2020) (§3.4; SCoTD significantly outperforms supervised learning on labels) and fully held-out tasks (§3.5; few-shot SCoTD significantly outperforms in-context learning).

Key to the success of this process is sampling a relatively large number of rationales per example from the teacher model (e.g., 30 rationales/example) (Figure 2). This is different from many prior practices that train with one rationale per example Camburu et al. (2018); Li et al. (2022a). In ablation studies, we investigate several competing hypotheses for what are the most important factors within the corpus: we filter the corpus to CoTs that are assigned high probability by GPT-3 vs. filtering to CoTs that are diverse vs. filtering to CoTs that explain more open-ended input instances. While diversity and high probability are reasonable filters that on average perform well, the “null hypothesis” of random downsampling performs well, suggesting that the sheer volume of the rationales is also a key contributing factor.

We will release code and the corpus of sampled chain-of-thoughts at https://github.com/allenai/cot_distillation.

Symbolic Chain-of-Thought Distillation

Our primary goal is to improve the accuracy of a (relatively small) student language model S\mathcal{S} on a target classificationFuture work would be well suited to consider if chain-of-thought prompting can be useful for generative tasks. task DTest={(xi,yi)}\mathcal{D}_{\texttt{Test}}=\{(x_{i},y_{i})\}.In practice, we primarily consider CommonsenseQA Talmor et al. (2019), OpenBookQA Mihaylov et al. (2018), and QuaRel Tafjord et al. (2019) as D\mathcal{D}. We assume access to 1) (an unlabeled) training set DTrain={(xi)}\mathcal{D}_{\texttt{Train}}=\{(x_{i})\}; and 2) a large teacher language model T\mathcal{T} (e.g., GPT-3 Brown et al. (2020)), capable of generating chain-of-thoughts in a few-shot fashion.

Our first step is to curate a set of labeled chain-of-thoughts to serve as few-shot P\mathcal{P}rompts for T\mathcal{T}. For each target task, we sample a small number (e.g., 10) of examples xix_{i} from DTrain\mathcal{D}_{\texttt{Train}}, provide a gold classification label yiy_{i}, and manually author a chain-of-thought ziz_{i} for each to form the prompt set P={(xi,yi,zi)}\mathcal{P}=\{(x_{i},y_{i},z_{i})\}In addition to authoring our own, we reuse chain-of-thought prompts from prior work (Wei et al., 2022b; Wang et al., 2022b) when available..

Experiments

We evaluate primarily on 3 target tasks: 1) CommonsenseQA (CSQA) Talmor et al. (2019), a 5-way multi-choice dataset; 2) OpenBookQA Mihaylov et al. (2018), and 3) QuaRel Tafjord et al. (2019). While any model capable of few-shot chain-of-thought could be substituted, we use the code-davinci-002 version of GPT-3 Wang et al. (2022a) reports better CoT performance from this version compared to other GPT-3 models. Brown et al. (2020) as our teacher model T\mathcal{T}. We use OPT Zhang et al. (2022) as our student model S\mathcal{S}. Our standard student model is OPT-1.3B (though we explore a range of student model sizes in §3.3).

We sample from GPT-3 with a temperature of T=1.0T=1.0. For each training example, we sample N=30N=30 rationales. OPT is fine-tuned with a batch size of 32 and a learning rate of 2×1052\times 10^{-5}. We use HuggingFace transformers Wolf et al. (2019), Pytorch Paszke et al. (2019), and Acceleratehttps://github.com/huggingface/accelerate for the implementation. Main experiments can be reproduced on one GPU with 48GB of memory.

We first consider both a few-shot learning setting and a supervised setting. For the few-shot setting, the only labeled examples available to our teacher/student models are contained in the prompt set P\mathcal{P} (but we use the unlabeled examples and teacher-generated chain-of-thoughts/labels for training).In this setting, teacher samples can contain incorrect labels, thus preserving the few-shot nature of the task. We also consider the supervised setting, where we assume access to labels in DTrain\mathcal{D}_{\texttt{Train}}. Supervised SCoTD involves simply discarding the samples within C\mathcal{C} that do not have the correct label prior to fine-tuning the student: for CommonsenseQA, OpenBookQA, and QuaRel, this results in discarding 40.4%40.4\%, 45.0%45.0\%, 34.2%34.2\% of chain-of-thoughts. For the few-shot setting, we decode with the self-consistency approach; for the supervised setting, we decode with greedy decoding (introduced in § 2; see an discussion in § 3.2).

We compare SCoTD to 2 baselines: 1) Label-Only, the student is fine-tuned on just the label (in the few-shot setting, the label comes from the teacher and could be wrong; in the supervised setting, we use the gold label), instead of also with CoT; 2) Greedy-CoT, we decode a single-CoT per example (instead of N=30N=30 samples) from T\mathcal{T} for each training example instead of sampling. For additional reference, Table 2 (a) reports the performance of the student (and teacher) in a variety of few-shot settings prior to applying any distillation: No CoT = few shot prompting with labeled instances from P\mathcal{P} but no ziz_{i}, Greedy and Self-Consistency are prompting with CoT but with different decoding strategies (§ 2).

Table 2 (b) gives the performance of the student model after distillation in the supervised and few-shot settings. In all cases, distillation significantly improves the student model, and in all-but-one case, learning with CoT outperforms the label-only distillation baseline. While the student model initially fails to perform CoT through prompting (Table 2 (a)) it learns to do so through distillation.

In our default setting, to serve as our distillation corpus C\mathcal{C}, we sample N=30N=30 rationales from the teacher T\mathcal{T} for each (unlabelled) training instance. Figure 2 shows the performance of the student model when it is trained on corpora with fewer sampled CoT per instance: results suggest that learning with multiple sampled (albeit nosier) rationales/chain-of-thoughts per example is more beneficial than learning with one (most likely) rationale. Will more rationales bring more performance improvement? We sampled more rationales from GPT-3 to train the student model; however, this does not bring more performance gains. When N=50N=50, the performance is similar to N=30N=30: the model achieves 67.0 in accuracy on OpenBookQA (v.s. 67.0), 67.2 on CommonsenseQA (v.s. 67.0), 84.9 on QuaRel (v.s. 83.8).

1.1 Human Evaluations

While SCoTD improves task accuracy significantly, we additionally conduct human evaluations to assess the generated chain-of-thoughts themselves (see Table 1 for samples). We sample instances from the CommonsenseQA, OpenBookQA, and QuaRel validation sets (300 instances per dataset), and conduct head-to-head human evaluationsWe remove the final prediction from each chain-of-thought, and ask crowdworkers which is more coherent, fluent, and (importantly) likely to lead to a correct answer. We use Amazon Mechanical Turk and pay a minimum of $15/hr, see Appendix A for more details, including a screenshot of the HIT. to assess:

Test: OPT-1.3B versus OPT-1.3B + SCoTD. Result: Yes. We assess this hypothesis on two subsets of instances: 1) a pure random sample (N=900); and 2) a set of instances for which both models eventually predicted the correct label (N=654). The second setting focuses more closely on the chain-of-thoughts themselves rather than the predictive accuracy of the model. SCoTD is superior in both settings: for the random sample setting, SCoTD won in 59% of cases (pp<.001), whereas in the correctness controlled setting, SCoTD won in 61% of cases (pp<.001). Results hold with p<.05p<.05 for each QA dataset individually.

Test: OPT-1.3B + SCoTD versus text-davinci-002. While the task accuracy of the teacher is still higher in most cases, the student-generated CoT are comparable.See §Acknowledgment for more discussion about the disparity between CoT-quality and task accuracy. We again evaluate on: 1) a pure random sample (N=900); and 2) a correctness-controlled setting (N=659). The 100x smaller SCoTD’s generations are competitive in both cases; we can’t reject the null hypothesis of the crowd having equal preferences (OPT-1.3B + SCoTD wins in 47% and 51% of cases respectively, p>.01p>.01). Results hold for each dataset individually, as well.

2 Self-Consistency for the Student

Wang et al. (2022b) find that, for chain-of-thought prompted models, taking a majority vote over a large set of sample of predicted labels (resulting from a diverse range of CoTs) can improve performance. Our results regarding the effectiveness of sampling N=30N=30 rationales from the teacher during SCoTD are similar-in-spirit: i.e., we also show performance gains from sampling multiple rationalization chains per instance.

A natural question is, does the student model S\mathcal{S} exhibit the same phenomenon, i.e., can we sample multiple chain-of-thoughts from it and take a majority vote? We find that the student model can benefit from “self-consistency,” but not in all cases. In Table 3, we report the performance with/without self-consistency (majority vote among 30 sampled reasoning paths with a temperature of 0.70.7). When training with filtered CoTs (Table 3 (a) bottom rows) or training with few CoTs per example (Table 3 (b), when #CoTs/Example is small), the student model does not benefit from self-consistency. Only when we train with multiple rationales per example without filtering (the few-shot setting), self-consistency is beneficial on CSQA and OpenBookQA. Overall, the results show that student models benefit from being shown a diverse/noisy set of rationales, and that self-consistency can be effectively applied after distillation.

3 SCoTD across Model and Dataset Sizes

We also verify the effectiveness of SCoTD across model and dataset sizes; in these experiments, we consider the supervised setting.

Figure 3 shows the effect of varying the size of DTrain\mathcal{D}_{\texttt{Train}} (for simplicity, we show only performance on CSQA as an example). Learning with CoTs is beneficial under all data scales. Interestingly, SCoTD, trained with access to only 40%40\% of the labelled data, can surpass the direct supervised label-only model with 100%100\% of the labelled corpus; this result aligns with the argument in Zaidan et al. (2007) – providing more explanations from the teacher model could be more beneficial than providing more labels.

Figure 4 presents results when varying the size of the student model from 125M to 1.3B parameters for CSQA. For all model three model sizes, SCoTD outperforms the standard supervised fine-tuning baseline (Label Only). Sampling multiple rationales per input instance is an effective strategy for all model sizes.

4 SCoTD on Challenging Contrast Sets

Can learning with explanations help generalization, as hypothesized by (Zaidan et al., 2007)? As a preliminary study, we show that SCoTD enables better generalization to contrast sets. Contrast sets (Gardner et al., 2020) are proposed to evaluate a model’s robustness to perturbations around the decision boundary, by asking annotators to modify the original test instances in small but meaningful ways that (typically) change the gold label.

We experiment on the IMDB Maas et al. (2011) sentiment analysis task in the supervised setting; we consider the corresponding contrast set of IMDB proposed by Gardner et al. (2020). We train two models on the training set of IMDB: Label-Only and SCoTD. For efficiency, we sub-sample 100K100K examples from the training set of IMDB and truncate input sequences to 700 tokens. As shown in Figure 5, while both models with/without SCoTD achieve high performance on the original IMDB test set (96.1% v.s. 95.5%, with the Label-Only model performing slightly better), the model with SCoTD achieves significantly higher performance on the contrast set: 92.0% vs. 81.6%. This result supports the hypothesis of (Zaidan et al., 2007); that explanations can support more robust generalization.

5 SCoTD on Unseen, Out-of-domain Tasks

Large language models can perform few-shot, in-context learning with chain-of-thought prompting, i.e., generating reasonable chain-of-thoughts on unseen tasks with a few demonstrations Suzgun et al. (2022). We conduct a preliminary experiment, inspired by Min et al. (2021)’s MetaICL, to test whether student models trained with SCoTD acquire the same ability. We train a supervised SCoTD model on ANLI, CommonsenseQA, and OpenBookQA, and evaluate it on SST-2 Socher et al. (2013), a sentiment analysis task.

The SCoTD model achieves a few-shot accuracy of 79.6%79.6\% on the validation set (an example prediction is shown in Figure 6).For reference, GPT-3 text-curie-001 (\sim6.7B parameters) achieves 74.5%74.5\% with the same prompt. Compared to a baseline model that learns with no CoT(i.e., a re-implementation of MetaICL trained on 3 source tasks); the baseline fails to recognize the input/output format of the new task and predicts answers out of the desired label set. It achieves (an effective) 0%0\% accuracy on SST-2. This suggests the potential of including CoTs during instruction/in-context tuning (Wei et al., 2022a; Min et al., 2021).

What Factors are Important for Distillation?

An important factor underlying the performance gains highlighted in §3 was the number of chain-of-thoughts we sampled from the teacher model per-instance (more samples = better; Figure 2). Here we ask: is data volume the key contributing factor to the performance improvement? Or, are specific aspects of chain-of-thought samples key for the performance improvements?

We design several filters to identify potentially important examples/CoTs among the correct rationales. We apply designed filters (to be introduced) to C\mathcal{C^{\prime}}, the corpus sampled from the teacher (with wrong CoTs dropped), that operationalize different hypotheses about what factors are important to distill. We control for dataset size when filtering, i.e., all filtered corpora have the same number of training CoTs. We downsample with a budget of 5 CoT per instance on averageIn rare cases, we may end up with less as there are less than 5 correct CoTs for the instance.. Then, we train the same student model on each of the filtered corpora, and compare on downstream tasks. If a student model trained on filtered corpus A tends to outperform the student model trained on filtered corpus B, then we argue that the property that produced corpus A is more important. The hypotheses we consider are:

As a null hypothesis, we randomly sub-sample 5 CoT per instance; this filter operationalizes the assumption that an arbitrary set of samples is sufficient.

For each instance, we compute S-BERT Reimers and Gurevych (2019) embeddingsWe use paraphrase-MiniLM-L6-v2. of each of the chain-of-thoughts, and cluster the resulting embeddings using hierarchical clustering into k=5k=5 clusters. Then, we randomly sample a single instance from each cluster: the resulting sample covers all clusters, and thus represents a diverse+representative sample.

For each instance, we keep the 5 CoT samples with the highest per-token log-likelihood according to the teacher model.

Some instances in each dataset lead to a broader range of chain-of-thought samples than others. For example, on CommonsenseQA, the question “What form of alcohol is made from grapes?" leads to a narrower range of rationalizations vs. “Why might someone purposefully be going into trance?" We hypothesize that open-ended instances could benefit from relatively more sampled rationales. We sort instances into quintiles based on the unique bi-grams in their corresponding 30 CoTs; for high-ranking instances (more unique CoT bi-grams, like the “trance" example above), we keep more rationales and for low-ranking instances, we keep less rationales. We keep 1,3,5,7,91,3,5,7,9 rationales for instances of different bins (thus controlling for the total number of CoT).

Figure 7 reports the accuracy of the student model when fine-tuned on the different subsampled corpora for the three tasks we consider. Overall, random subsampling is a strong baseline, but, we see some evidence that diversity among the rationales is important. None of the models trained on the sub-sampled data could approach the model trained on the full 30x/instance CoT set. This suggests that the sheer volume of the CoTs is a key driving force for the performance improvement.

Related Work

As an extension of few-shot prompting Brown et al. (2020), chain-of-thought has proven more generally applicable than algorithmic/structured reasoning for which intermediate step generation was initially studied, e.g., by Roy and Roth (2015); Ling et al. (2017); Chiang and Chen (2019); Nye et al. (2021). Recent studies seek to improve and analyze CoTs from different perspectives: Wang et al. (2022b) improves the original CoTs through marginalizing over diverse reasoning paths while Wang et al. (2022a) marginalize over diverse prompts; Zelikman et al. (2022); Huang et al. (2022) improves CoT through a bootstrap manner of training on self-generated CoTs; Li et al. (2022b) introduce voting classifiers to filter sampled CoTs before final prediction; Golovneva et al. (2022) introduce some automatic metrics for automatic assessment of chain-of-thoughts. This study instead focuses on enabling CoT for smaller models via distillation.

Hase and Bansal (2022) discuss how explanations can serve as inputs Talmor et al. (2020), targets Hendricks et al. (2016); Fidler et al. (2017); Camburu et al. (2018); Zhou et al. (2020); Narang et al. (2020); Kayser et al. (2021); Wiegreffe et al. (2022), and priors Zhang et al. (2016); Srivastava et al. (2018) for machine learning models. Chain-of-thought extends earlier efforts which treat explanations as intermediate structures, generated at inference time Rajani et al. (2019). Most related to our work is Li et al. (2022a), who do also learn with GPT-3 generated explanations; we show multiple samples improve significantly over their single-sample method, and also use chain-of-thought prompting at inference time vs. predicting explanations+labels via independent multitasking.

Recent work, inspired by Knowledge Distillation Hinton et al. (2015), has considered symbolic knowledge distillation, West et al. (2022), i.e., instead of distilling from soft representations like logits, large language model serve as training data generators Xiong et al. (2019); Petroni et al. (2019); Schick and Schütze (2021); West et al. (2022); Liu et al. (2022); Meng et al. (2022); Bhagavatula et al. (2022); this paper continues this line of work.

There are several contemporaneous papers: Huang et al. (2022), Magister et al. (2022), and Ho et al. (2022) all show that smaller models can benefit from large models’ chains of thought. We contributes beyond these by: 1) showing that sampling a large number of chain-of-thoughts is paramount; 2) exploring transfer performance to challenge sets/unseen tasks; and 3) analysis that address what factors are important in the teacher corpus.

Conclusion

We demonstrate the effectiveness of Symbolic Chain-of-thought Distillation (SCoTD): a method that enables smaller language models to effectively use chain-of-thought-style reasoning. We demonstrate the method’s effectiveness across several downstream tasks, different student model sizes, different levels of supervision, and in difficult settings (challenge sets, unseen tasks). Our ablations shed light on what factors are particularly important to distill in these chain-of-thoughts.

Our concrete recommendations are: 1) sampling multiple and diverse CoTs for each input instance, and 2) performing self-consistency when the teacher CoTs are noisy. Several promising avenues for future work include:

Exploring SCoTD for generation tasks in addition to classification tasks;

Scaling up the number of source tasks in § 3.5 to generalize to more tasks;

Using the down-sampling setup introduced in §4 to explore additional hypotheses about what other factors may be of importance in CoTs.

Limitations

Several limitations of our study include:

only English-language chain-of-thoughts/tasks considered;

reliance on GPT-3, which is a closed-source product with an unknown training set (which could itself include some explanations); and

focusing only on a single type of student model, OPT.

More broadly, learning from and with explanations carries some specific risks related to automation bias. While a model might rationalize its predictions using a seemingly coherent string of natural language steps, even if it eventually gets the prediction correct, there’s no guarantee that the eventually predicted output actually results from a process represented by the rationalization. A user might assign excessive confidence to that system based on the chain-of-thought. We observed many cases where the chain of thought seemed promising only to result in models ultimately making incorrect predictions in the final few tokens. Caution should be taken when displaying chain-of-thoughts to users.

Acknowledgment

We thank anonymous reviewers for their comments. This work is supported in part by the DARPA MCS program, NCSOFT NLP Center and a Sloan research fellowship.

References

Appendix A Crowdworking details

A screenshot of the interface we use to collect the pairwise human judgments from §3.1.1 is given in Figure 8. We conduct a post-hoc analysis using a javascript timer to ensure that annotators were paid at least $15/hr: crowdworkers who didn’t meet this hourly rate during annotation were awarded bonuses post-hoc to ensure they were paid that rate. We select crowdworkers with IP addresses in US,CA,NZ,AU,GB.

Crowdworking studies of standard NLP corpora (involving no personal disclosures) are not required by our IRB to be reviewed by them. While the authors of this work are not lawyers and this is not legal advice, this opinion is based on United States federal regulation 45 CFR 46, under which this study qualifies as exempt. We do not release crowdworker IDs, so annotations cannot be back-traced to individual workers.