Generating Sequences by Learning to Self-Correct
Sean Welleck, Ximing Lu, Peter West, Faeze Brahman, Tianxiao Shen, Daniel Khashabi, Yejin Choi
Introduction
The standard practice for natural language generation tasks is inherently single-pass: applying a decoding procedure to either a few-shot prompted language model or one tuned for a given task, then considering the generation as “finished” (e.g. Radford et al. (2019); Brown et al. (2020); Chen et al. (2021)). Powerful generation models often meet most of the task requirements, yet miss a few (e.g., omitting a subset of keywords), or generate incorrect hypotheses that nevertheless provide useful structure (e.g., a correct problem solving strategy with a missing step). However, after generating even a slightly sub-optimal sequence, the single-pass paradigm requires models to “start from scratch”, effectively discarding work already done. A more natural, intuitive approach is leveraging the generation as a useful starting point to refine into a higher quality output.
To formalize this intuition, we introduce Self-Correction for Sequence Generation. Figure 1 demonstrates its central principle: a generation model is re-framed as a base generator, which produces a reasonable initial hypothesis but does not need to solve the task in one pass, and a second module–the corrector–trained to make up the difference between the hypothesis and an optimal solution. Neither the generator nor the corrector must solve the full task in one pass, and the corrector can be applied multiple times to iteratively improve the output (§3.6). We propose a simple, general procedure for training the corrector (Figure 2) by pairing generator outputs with carefully selected targets. The result is a system which self-corrects, producing outputs through multiple generation passes and breaking the task into steps that can be solved by dedicated and efficient sub-systems.
We find that Self-Correction is broadly applicable. Training a corrector model improves the base generator on 3 diverse tasks: mathematical program synthesis (§3.1), lexically constrained generation (§3.2), and toxicity reduction (§3.3). The trained corrector model can even be applied to a larger generator with similar performance to training a new corrector (§3.4), showing that the sub-task of correction is transferable, even to stronger generators. Finally, we explore the prospect of introducing a third module to the Self-Correction system (§3.5)–explicitly using natural language feedback to guide corrections–with promising results. Self-Correction offers an exciting opportunity to build on existing generation models and the sequences they generate, with efficient, effective, and transferable corrector networks.
Self-correcting sequence generators
A typical autoregressive text generator (e.g. GPT-3 (Brown et al., 2020)) maps an input prompt to a distribution over outputs using a single parameterized module (e.g. a large transformer), . We explore an alternative that decomposes into two modules, a base generator, and a corrector,
where the generator provides an initial hypothesis that is refined by the corrector. In practice, the corrector can be applied multiple times, . Since a model of this form can both generate and correct its generations, we call it a Self-Corrector.
Self-correctors have several unique properties compared to typical generators. First, a self-corrector decouples generation and correction, allowing us to freely parameterize each module – for instance, by prompting a single language model or using two different language models. In this paper, we develop a framework to train a separate corrector model (§2.1). We find that the resulting self-corrector improves upon the generator alone (§3), even when the corrector is much smaller (§3.4).
Second, since the generator and the corrector are separated, we can keep the generator as a general-purpose language model and train the corrector with different objectives for different task requirements. In §2.1, we propose a training algorithm for the corrector that is dedicated to improving generations, where the improvement can be in any aspect, measured by scalar values.
Third, the corrector can receive explicit feedback about intermediate generations to guide subsequent generations. Formally, , where is the feedback. The feedback can be of many forms, e.g. a sentence, a compiler trace, etc. In contrast, a typical generator that generates in a single pass does not leverage feedback on its own generation. In this paper, we show that the corrector can learn to exploit explicit natural language feedback to achieve better performance (§3.5). Next, we describe our training framework of the corrector.
Our goal is to have the generator generate an initial hypothesis, then improve the hypothesis with the corrector (Eq. 1). We train the corrector to improve the quality of a hypothesis, while staying as close as possible to the original hypothesis. Here, quality is measured with a scalar value function which we assume is accessible at training time (e.g. a classifier).
Since direct supervision on how to improve hypotheses is not available, we design a new algorithm to train the corrector, which we refer to as self-corrective learning. The algorithm collects a pool of generations, groups them and selects pairs of generation that increase in value and are nearby, then updates the corrector on these pairs. As training progresses, more generations are added to the pool using the current corrector. Algorithm 1 summarizes self-corrective learning, detailed below.
The algorithm initializes a datapool of (input, output, value, feedback) examples by using the generator to generate multiple outputs for each input. Formally,
where denotes outputs generated with decoding algorithm (e.g. temperature sampling). When available, examples from another source (e.g. a dataset) can also be added.
Pairing. Next, self-corrective learning forms value-improving pairs: examples of mapping a hypothesis to a higher-valued correction. We use the datapool to form a set of (input, hypothesis, correction) pairs. A pair is formed when an output has a higher value than another We also store the value and feedback for and along with , which we omit to reduce clutter.:
Learning. Next, self-corrective learning selects (input, hypothesis, correction) pairs to update the corrector with. We sample a pair proportional to its improvement in value as well as the proximity between the hypothesis and the correction :
Exploration. During exploration, self-corrective learning adds new generations to the datapool by generating from the current corrector:
and updating the datapool . The hypotheses to correct can come from any source, e.g. newly sampled from the base generator, or from the datapool; we use the latter in our experiments.
Inference. We use the trained corrector along with a generator to generate a trajectory , and consider the final output. Since marginalizing over the intermediate generations in Eq. 1 is intractable, we approximate each summation with a single sequence generated with a decoding algorithm . That is, we decode from the generator, then repeatedly from the corrector:
Generation: ;
Correction: , .
The stopping time is either fixed, or when a target value is obtained (if is available).
Experiments
We evaluate self-correction on a diversity of tasks: mathematical program synthesis, in which generations are strictly correct or incorrect, and generators typically have low performance; lexically-constrained generation, which allows for partial credit, and generators usually give partially-correct solutions (e.g. matching 3 out of 5 constraints); and toxicity control, where ‘correctness’ is more loosely defined, and the output space is much more open-ended. Our experiments are organized to study three settings:
Using self-correctors to improve upon generators (§3.1,3.2,3.3).
Correcting generators that are much larger than the corrector (§3.4).
Leveraging explicit feedback during training and inference (§3.5).
Next, we describe the self-correction setup and baselines for each task, along with their results. Code will be publicly available upon acceptance.
First, we consider mathematical program synthesis (Austin et al., 2021; Mishra et al., 2022). Given a natural language problem specification , the task is to generate a program that upon execution returns the correct answer to . The task is challenging as it draws on language understanding, multiple-step mathematical problem solving (e.g. identifying a solution strategy, decomposing a problem), and leveraging symbolic tools (e.g. built-in operations, variables). Furthermore, the task demands a high level of precision, e.g. a single misplaced operation makes the program incorrect.
Experimental setup. As the corrector we use GPT-Neo 1.3B (Black et al., 2021), an open-source autoregressive language model. GPT-Neo is pre-trained on language and code (Gao et al., 2021), and hence is widely used for code-related generation (e.g. Chen et al. (2021); Ni et al. (2022); Mishra et al. (2022)). We consider two settings for the initial generator: (1) a separate fine-tuned instance of GPT-Neo 1.3B, and (2) few-shot prompted GPT-3 (Brown et al., 2020). For GPT-3, we evaluate the davinci and text-davinci-002 engines, representative of large (Estimated size of davinci (https://blog.eleuther.ai/gpt3-model-sizes). Further details not available.) generators that are state-of-the-art in related tasks (Wei et al., 2022). See the Appendix for additional details.
Self-correction setup. As the value function we use correctness, which is 1 when the program executes and outputs the ground-truth answer and 0 otherwise. Our main experiments do not use explicit feedback, i.e. . At inference time, we study two settings for the corrector: (1) applying corrections and selecting the final generation, (2) an oracle setting that only corrects a draft if the draft is incorrect. We use greedy decoding for the generator and corrector, and .
Datasets. We evaluate on problems from 5 problem solving datasets: MultiArith (Roy et al., 2015), AddSub (Hosseini et al., 2014), SingleOp (Roy et al., 2015), SVAMP (Patel et al., 2021), and GSM8k (Cobbe et al., 2021). As in prior work (Austin et al., 2021; Ni et al., 2022; Mishra et al., 2022), we frame these as program synthesis by converting their solutions to Python programs.We use data from the Lila benchmark (https://github.com/allenai/Lila). We separate our experiments into three increasingly difficult settings:
MultiArith, using problems from the MultiArith arithmetic word problem dataset.
Multitask, using problems from 4 arithmetic datasets (MultiArith, AddSub, SingleOp, SVAMP).
GSM, using problems from the challenging GSM8k dataset.
For the MultiArith and Multitask settings, we make train/valid/test splits using 60/20/20% of the respective datasets. Similar to Ni et al. (2022), for the GSM setting we use the official GSM8k test split, and create a validation split using 20% of the training set. Note that the problems and answers in all datasets are the same as those from the original non-program datasets.
Baselines. We compare self-correct with its baseline generator (GPT-Neo 1.3B) in all three settings. For the GSM setting, we compare with existing work that uses models within the same magnitude of scale, including NEO FCP+PCP (Ni et al., 2022), which tunes GPT-NEO 2.7B with additional self-sampled programs, and their fine-tuned GPT-NEO 2.7B baseline. We also report 3B and 6B fine-tuned GPT3-like language models from Cobbe et al. (2021), which were trained on the non-program version of GSM8k. We evaluate larger models later in (§3.4).
Results. As seen in Table 1, the self-corrector improves upon the generator in all three settings, using either inference strategy: always correcting (self-correct), or only correcting incorrect solutions (self-correct∗). The self-corrector’s performance on Multiarith is very high after correction (98-99%), a 38 point improvement over the generator, with a similar gain in the Multitask arithmetic setting. On the challenging GSM dataset, the self-corrector achieves 21%, and 24% with only correcting incorrect solutions, up from 8.57% for the generator. Notably, this is higher than previous work based on the larger 2.7B GPT-Neo, or larger models tuned on the language version of GSM. The results show that self-corrective learning can improve task performance via training a corrector. Qualitatively, the self-corrector can correct values in a correctly structured solution, fix the order of operations within a multistep solution, adjust unit conversions, and make larger multipart revisions (see Figures 3,8,8). Notably, these are learned automatically through self-corrective learning.
2 Lexically Constrained Generation
Next, we consider lexically constrained generation. Given a set of constraint words , the task is to generate a sentence that includes all the given constraints. Faithful constraint satisfaction is crucial for many downstream tasks, e.g., those that require converting information to text (McKeown, 1985).
Datasets and Metrics. We experiment on CommonGen (Lin et al., 2020) and E2E (Novikova et al., 2017). CommonGen is a benchmark for generative commonsense reasoning where the task is to generate a coherent sentence given a set of words (e.g., dog, catch). E2E involves converting structured inputs into natural language. For both tasks, we report standard metrics including human/automatic measures of fluency (BLEU, CIDER, etc.) as well as constraint coverage. We collect human measures of fluency on Amazon Mechanical Turk; see the Appendix for details.
Setup. We parameterize the base generator with GPT-2 Radford et al. (2019) (large-size for CommonGen and medium-size for E2E). We fine-tuned the generator for each task. As the value function for self-corrective learning we use coverage, i.e. the percentage of constraints that are present in the output. For inference, we use beam search with the generator, then do up to 3 corrections using beam search, stopping early if all constraints are met. See the Appendix for additional details.
Results. Table 2 shows the evaluation results. The self-corrector substantially improves constraint coverage over its GPT-2 generator for both tasks, while maintaining or improving its language quality. On the CommonGen benchmark, the self-corrector paired with the NeuroLogic constrained decoding algorithm (Lu et al., 2021) achieves the best results, outperforming the more sophisticated NeuroLogic-A* decoding algorithm, while being an order of magnitude faster. Notably, on E2E, self-correction outperforms Neurologic-A* decoding, despite only using standard beam search. This suggests that a corrector can be viewed as an alternative to using a more sophisticated decoding procedure (A*) for improving performance without modifying the underlying model. See Figure 9 for qualitative examples.
3 Toxicity Reduction
Next, we consider the task of toxicity reduction (Gehman et al., 2020; Liu et al., 2021). Given a prompt , the task is to generate a fluent continuation while avoiding offensive content. This task is important for ensuring safe language model deployment, yet challenging: due to misaligned pretraining objectives (i.e. modeling internet text vs. non-toxic text), language models are susceptible to generating toxic completions, even when prompted with seemingly innocuous text (Gehman et al., 2020). Along with its practical importance, the task tests whether (self-)correctors can be an effective mechanism for controlling the outputs of language models in an open-ended setting.
Datasets and Metrics. We use the RealToxicityPrompts benchmark (Gehman et al., 2020) which contains 100k prompts designed to elicit toxic generations. Following the experimental setup of Liu et al. (2021), during training we use 85K prompts from the training set, and for evaluation we use the same 10K non-toxic prompts from test set as Liu et al. (2021). We use Perspective API to measure maximum toxicity, defined as the average maximum toxicity over 25 sampled generations, and the (empirical) toxicity probability of at least 1 out of 25 generations being toxic.
Baselines. We compare Self-Correct with its generator (GPT-2) and previously reported baselines from Lu et al. (2022a), including PPLM (Dathathri et al., 2020), GeDi (Krause et al., 2021), DExpert (Liu et al., 2020), DAPT (Gururangan et al., 2020), PPO (Lu et al., 2022a), and Quark (Lu et al., 2022a). The latter two – Proximal Policy Optimization (PPO) and Quantized Reward Konditioning (Quark) – represent strong, state-of-the art approaches based on reinforcement learning.
Setup. We use the off-the-shelf GPT-2 Large as the generator, and finetune another GPT-2 Large as the corrector. During inference, we use nucleus sampling with to generate 25 samples for all baselines. As the value function, we use the Perspective API score, , which measures the toxicity of the completed sequence. We do up to three corrections with the corrector model.
Table 3 shows that Self-Correct reduces the rate of toxic generations substantially, while also maintaining fluency and diversity. Self-Correct outperforms all baselines. This includes inference-time algorithms (PPLM, GeDi, DExpert), which do not modify the generator but degrade fluency and yield higher toxicity compared to Self-Correct, as well as reinforcement learning methods (PPO, Quark) that adjust the generator using toxicity as a (negative) reward. The results show that Self-Correct is effective for detoxification, without having to modify the underlying generator. We study implications of this latter property further in the next section.
4 Changing Modules – Correcting GPT-3
Next, we show that a self-corrector can improve the outputs of a generator that is much larger than the corrector. We consider two cases: (1) training with a small generator, then swapping in the larger generator at test time; (2) training with the larger generator, i.e. using the large generator to initialize the datapool for self-corrective learning, then using the large generator at test time.
Toxicity. We evaluate case (1) for reducing the toxicity of a large generator (GPT-2 XL, GPT-3). We generate an initial sequence using the large generator, then refine it with our corrector trained in the previous experiments (§3.3). Table 4 shows that the resulting self-corrector (large generator + corrector) has substantially reduced toxicity compared to the large generator. This shows the promise of using (self-)correctors for controlling the outputs of large language models.
Math program synthesis. Table 4 shows results for math. Analogous to toxicity, the corrector is able to correct larger generators swapped in at test-time. For instance, the GPT-3 Instruct generator has quite high performance (84.90 Multitask, 36.80 GSM), which improves to 90.90 and 45.00, respectively, by adding in a corrector. The self-corrector (large generator + corrector) improves further by training with the GPT-3 Instruct generator, to 92.75 and 45.92, respectively.
5 Leveraging Explicit Feedback
Next, we demonstrate Self-Correct’s capacity to incorporate explicit natural language feedback. This amounts to defining a feedback function , then using the same self-corrective learning and inference algorithms (§2.1) as in our preceding experiments (in those experiments, returned ). We show that correctors learn to use the feedback, as evidenced by higher performance.
Toxicity. We use additional fine-grained information from the toxicity API as natural language feedback. Specifically, besides the overall toxicity score, Perspective API also provides scores for fine-grained attributes of toxicity (e.g. identity attack, profanity, flirtation, etc.). At training time, we compare the attribute scores from a hypothesis and its selected correction, and use the attribute with the largest decrease as natural language feedback (e.g. ”decrease toxicity in profanity”). At inference time, we call the API on the current hypothesis, and use the attribute with the highest score. Here we use the API at inference time, which is not required in our previous experiments.
Lexical constraints. In training time, we generate natural language feedback for every example pair by elaborating the extra lexical constraints satisfied by but not . e.g. “adding constraint word: read”. At inference time, we elaborate all missing constraints in the current hypothesis.
Math program synthesis. Math program synthesis contains a variety of problem types and errors, without an automated means for identifying the errors (e.g. an API). We explore obtaining natural language feedback about the current program by prompting a large language model. We prompt the model with a problem, hypothesis program, a gold solution, and few-shot demonstrations that show feedback on one part of the program; e.g. In the initial guess, 3 should be subtracted. When the program is correct, the feedback is Correct. At inference time, we also use feedback from the language model. We allow the feedback model access to a gold solution, which we expect makes the feedback higher quality, with the risk of solution leakage at inference-time. Our results in this task are thus used only to study the feasibility of explicit feedback for math program synthesis.
Setup. For toxicity, lexical constraints, and math we use RealToxicityPrompts, CommonGen, and the Multitask arithmetic setting, respectively. We follow the setup of each task’s previous experiments (§3.3,§3.2,§3.1), except for math we use 5 correction iterations (previously 1). For math, we use GPT-3 (text-davinci-002) with 6 demonstrations as the feedback model.
Results. Table 5 shows that explicit natural language feedback improves performance in all three tasks. For toxicity, this means that providing fine-grained attributes (e.g. identity attack, profanity, etc.) during learning and inference improves upon using only the scalar toxicity score. Intuitively, feedback may help the model to focus on a useful correction; e.g., see Figure 5.
6 Additional Ablations and Analysis
Effect of multiple corrections. Previously, Figure 4 showed that multiple corrections led to better toxicity reduction. On math (Multitask setting), Figure 6 shows that performance improves with more than one correction, and that multiple corrections are more beneficial with feedback. Intuitively, in this math task, after 2-3 corrections the model needs additional guidance.
Effect of pairing and proportional sampling. Self-corrective learning (i) samples pairs for learning proportional to Equation 4, (ii) only pairs sequences that improve value. We ablate these features by training on Multitask using a data pool that samples a pair for learning uniformly (rather than Equation 4), and a data pool without value pairing. Table 7 shows that both improve performance.
Effect of exploration. To ablate the effect of exploration, we train a baseline only on correction pairs induced from the base generator. Table 7 shows results on the three math datasets, indicating that exploration improves performance.
Related Work
Self-correction relates to recent works on editing text, including modeling Wikipedia edits (Reid & Neubig, 2022; Faltings et al., 2021; Schick et al., 2022), which relies on supervised edits, unsupervised methods (Miao et al., 2019; Liu et al., 2020) that perturb sequences with simple operations (e.g. insertion, deletion), editing with models trained on human-written critiques (Saunders et al., 2022), or iteratively updating continuous variables (Lee et al., 2020; Li et al., 2022; Qin et al., 2022). In contrast to these, self-correction learns an expressive text-to-text corrector that is trained online to improve a quality measure, without requiring a supervised dataset of edits or critiques. Separately, denoising ground-truth sequences is a common pretraining objective (Devlin et al., 2019; Lewis et al., 2020; Raffel et al., 2020), while self-correction ‘denoises’ generations to improve a scalar quality measure. Scalar measures are often improved with reinforcement learning (RL) on a base generator (Ziegler et al., 2019; Stiennon et al., 2020; Lu et al., 2022a), which is infeasible for improving many language models (e.g. those accessed through an API), and uses only scalar feedback. Moreover, self-correction learns the difference between a generation and solution, and is complementary to RL-tuned generators, which can be used within a self-corrector. Finally, self-correction decomposes generation into multiple steps, which relates to methods that generate rationales before a response (Wei et al., 2022; Dohan et al., 2022). Self-correction also produces intermediate steps, but each step is of the same form as the output, allowing for re-using previous generations.
Conclusion
We introduced self-correctors, a class of models that decompose generation into initial generation and correction steps. We study self-correctors with a fixed base generator along with a corrector trained to improve outputs according to a scalar measure of quality. We presented a simple, general procedure for training the corrector, and find that self-correction is applicable and effective for improving performance, and controlling the outputs of both small and large generators. Moreover, we found that self-correction along with our learning framework provides a promising mechanism for using natural language feedback to improve generation. These findings, along with exploring alternative self-correctors, open up many avenues that we leave for future work.
Acknowledgments
This work was funded in part by the DARPA MCS program through NIWC Pacific (N66001-19-2-4031), and the Allen Institute for AI.
References
Appendix A Additional Experimental Details
We fine-tune a separate instance of GPT-Neo 1.3B as an initial generator, using the Huggingface library with default hyperparameters, except for evaluation steps, which we set to a small number to ensure a strong checkpoint is selected for each dataset. We use the fine-tuned initial generator as initialization for the corrector, and tune the corrector on sequences where is a problem, and form a residual pair, and are special tokens. The loss is on tokens after [START].
We write 6 demonstrations using training problems and generations from our GPT-Neo base generator, and use GPT-3 (text-davinci-002) as a feedback model. We use the same training procedure and hyperparameters, except that the sequences now include feedback, where is a problem, and form a residual pair, and is feedback. We include loss on tokens after [FEEDBACK].
A.2 Lexically-constrained Generation
Hyper-parameters. Table 9 and Table 9 show hyperparameters for CommonGen and E2E.
Human Evaluation. We evaluate fluency of generations in E2E task using human annotators on Amazon Mechanical Turk (AMT). We randomly sampled 100 instances, along with generations of different baselines and self-corrections. For each instance, we ask 3 annotators to evaluate the fluency of generations on a 3-point Likert scale. We aggregate annotations from 3 annotators using majority vote. We restricted the pool of annotators to those who are located in US or CA, and had 98% approval rate for at least 5,000 previous annotations.