Diffusion-LM Improves Controllable Text Generation

Xiang Lisa Li, John Thickstun, Ishaan Gulrajani, Percy Liang, Tatsunori B. Hashimoto

Introduction

Large autoregressive language models (LMs) are capable of generating high quality text , but in order to reliably deploy these LMs in real world applications, the text generation process needs to be controllable: we need to generate text that satisfies desired requirements (e.g. topic, syntactic structure). A natural approach for controlling a LM would be to fine-tune the LM using supervised data of the form (control, text) . However, updating the LM parameters for each control task can be expensive and does not allow for compositions of multiple controls (e.g. generate text that is both positive sentiment and non-toxic). This motivates light-weight and modular plug-and-play approaches that keep the LM frozen and steer the generation process using an external classifier that measures how well the generated text satisfies the control. But steering a frozen autoregressive LM has been shown to be difficult, and existing successes have been limited to simple, attribute-level controls (e.g., sentiment or topic) .

In order to tackle more complex controls, we propose Diffusion-LM, a new language model based on continuous diffusions. Diffusion-LM starts with a sequence of Gaussian noise vectors and incrementally denoises them into vectors corresponding to words, as shown in Figure 1. These gradual denoising steps produce a hierarchy of continuous latent representations. We find that this hierarchical and continuous latent variable enables simple, gradient-based methods to perform complex control tasks such as constraining the parse tree of a generated sequence.

Continuous diffusion models have been extremely successful in vision and audio domains , but they have not been applied to text because of the inherently discrete nature of text (§ 3). Adapting this class of models to text requires several modifications to standard diffusions: we add an embedding step and a rounding step to the standard diffusion process, design a training objective to learn the embedding, and propose techniques to improve rounding (§ 4). We control Diffusion-LM using a gradient-based method, as shown in Figure 1. This method enables us to steer the text generation process towards outputs that satisfy target structural and semantic controls. It iteratively performs gradient updates on the continuous latent variables of Diffusion-LM to balance fluency and control satisfaction (§ 5.1).

To demonstrate control of Diffusion-LM, we consider six control targets ranging from fine-grained attributes (e.g., semantic content) to complex structures (e.g., parse trees). Our method almost doubles the success rate of previous plug-and-play methods and matches or outperforms the fine-tuning oracle on all these classifier-guided control tasks (§ 7.1). In addition to these individual control tasks, we show that we can successfully compose multiple classifier-guided controls to generate sentences with both desired semantic content and syntactic structure (§ 7.2). Finally, we consider span-anchored controls, such as length control and infilling. Diffusion-LM allows us to perform these control tasks without a classifier, and our Diffusion-LM significantly outperforms prior plug-and-play methods and is on-par with an autoregressive LM trained from scratch for the infilling task (§ 7.3).

Related Work

Diffusion models have demonstrated great success in continuous data domains , producing images and audio that have state-of-the-art sample quality. To handle discrete data, past works have studied text diffusion models on discrete state spaces, which defines a corruption process on discrete data (e.g., each token has some probability to be corrupted to an absorbing or random token) . In this paper, we focus on continuous diffusion models for text and to the best of our knowledge, our work is the first to explore this setting. In contrast to discrete diffusion LMs, our continuous diffusion LMs induce continuous latent representations, which enables efficient gradient-based methods for controllable generation.

Autoregressive and Non-autoregressive LMs.

Most large pre-trained LMs are left-to-right autoregressive (e.g., GPT-3 , PaLM ). The fixed generation order limits the models’ flexibility in many controllable generation settings, especially those that impose controls globally on both left and right contexts. One example is infilling, which imposes lexical control on the right contexts; another example is syntactic structure control, which controls global properties involving both left and right contexts. Since autoregressive LMs cannot directly condition on right contexts, prior works have developed specialized training and decoding techniques for these tasks . For example, Qin et al. proposed a decoding method that relaxes the discrete LM outputs to continuous variables and backpropagates gradient information from the right context. Diffusion-LM can condition on arbitrary classifiers that look at complex, global properties of the sentence. There are other non-autoregressive LMs that have been developed for machine translation and speech-to-text tasks . However these methods are specialized for speech and translation settings, where the entropy over valid outputs is low, and it has been shown that these approaches fail for language modeling .

Plug-and-Play Controllable Generation.

Plug-and-play controllable generation aims to keep the LM frozen and steer its output using potential functions (e.g., classifiers). Given a probabilistic potential function that measures how well the generated text satisfies the desired control, the generated text should be optimized for both control satisfaction (measured by the potential function) and fluency (measured by LM probabilities) . There are several plug-and-play approaches based on autoregressive LMs: FUDGE reweights the LM prediction at each token with an estimate of control satisfaction for the partial sequence; GeDi and DExperts reweight the LM prediction at each token with a smaller LM finetuned/trained for the control task.

The closest work to ours is PPLM , which runs gradient ascent on an autoregressive LM’s hidden activations to steer the next token to satisfy the control and maintain fluency. Because PPLM is based on autoregressive LMs, it can only generate left-to-right. This prevents PPLM from repairing and recovering errors made in previous generation steps. Despite their success on attribute (e.g., topic) controls, we will show these plug-and-play methods for autoregressive LMs fail on more complex control tasks such as controlling syntactic structure and semantic content in § 7.1. We demonstrate that Diffusion-LM is capable of plug-and-play controllable generation by applying classifier-guided gradient updates to the continuous sequence of latent variables induced by the Diffusion-LM.

Problem Statement and Background

We first define controllable generation (§ 3.1) and then review continuous diffusion models (§ 3.3).

Text generation is the task of sampling w\mathbf{w} from a trained language model plm(w)p_{\text{lm}}(\mathbf{w}), where w=[w1wn]\mathbf{w}=[w_{1}\cdots w_{n}] is a sequence of discrete words and plm(w)p_{\text{lm}}(\mathbf{w}) is a probability distribution over sequences of words. Controllable text generation is the task of sampling w\mathbf{w} from a conditional distribution p(wc)p(\mathbf{w}\mid\mathbf{c}), where c\mathbf{c} denotes a control variable. For syntactic control, c\mathbf{c} can be a target syntax tree (Figure 1), while for sentiment control, c\mathbf{c} could be a desired sentiment label. The goal of controllable generation is to generate w\mathbf{w} that satisfies the control target c\mathbf{c}.

Consider the plug-and-play controllable generation setting: we are given a language model plm(w)p_{\text{lm}}(\mathbf{w}) trained from a large amount of unlabeled text data, and for each control task, we are given a classifier p(cw)p(\mathbf{c}\mid\mathbf{w}) trained from smaller amount of labeled text data (e.g., for syntactic control, the classifier is a probabilistic parser). The goal is to utilize these two models to approximately sample from the posterior p(wc)p(\mathbf{w}\mid\mathbf{c}) via Bayes rule p(wc)plm(w)p(cw)p(\mathbf{w}\mid\mathbf{c})\propto p_{\text{lm}}(\mathbf{w})\cdot p(\mathbf{c}\mid\mathbf{w}). Here, plm(w)p_{\text{lm}}(\mathbf{w}) encourages w\mathbf{w} to be fluent, and the p(cw)p(\mathbf{c}\mid\mathbf{w}) encourages w\mathbf{w} to fulfill the control.

2 Autoregressive Language Models

The canonical approach to language modeling factors plmp_{\text{lm}} in an autoregressive left-to-right mannar, plm(w)=plm(w1)i=2nplm(xix<i)p_{\text{lm}}(\mathbf{w})=p_{\text{lm}}(w_{1})\prod_{i=2}^{n}p_{\text{lm}}(x_{i}\mid x_{<i}). In this case, text generation is reduced to the task of repeatedly predicting the next token conditioned on the partial sequence generated so far. The next token prediction plm(xix<i)p_{\text{lm}}(x_{i}\mid x_{<i}) is often parametrized by Transformer architecture .

3 Diffusion Models for Continuous Domains

To train the diffusion model, we define a forward process that constructs the intermediate latent variables x1:T\mathbf{x}_{1:T}. The forward process incrementally adds Gaussian noise to data x0\mathbf{x}_{0} until, at diffusion step TT, samples xT\mathbf{x}_{T} are approximately Gaussian. Each transition xt1xt\mathbf{x}_{t-1}\rightarrow\mathbf{x}_{t} is parametrized by q(xtxt1)=N(xt;1βtxt1,βtI)q(\mathbf{x}_{t}\mid\mathbf{x}_{t-1})=\mathcal{N}(\mathbf{x}_{t};\sqrt{1-\beta_{t}}\mathbf{x}_{t-1},\beta_{t}\mathbf{I}), where the hyperparameter βt\beta_{t} is the amount of noise added at diffusion step tt. This parametrization of the forward process qq contains no trainable parameters and allows us to define a training objective that involves generating noisy data according to a pre-defined forward process qq and training a model to reverse the process and reconstruct the data.

However, this objective can be unstable and require many optimization tricks to stabilize . To circumvent this issue, Ho et al. devised a simple surrogate objective that expands and reweights each KL-divergence term in Lvlb\mathcal{L}_{\text{vlb}} to obtain a mean-squared error loss (derivation in Appendix E) which we will refer to as

where μ^(xt,x0)\hat{\mu}(\mathbf{x}_{t},\mathbf{x}_{0}) is the mean of the posterior q(xt1x0,xt)q(\mathbf{x}_{t-1}|\mathbf{x}_{0},\mathbf{x}_{t}) which is a closed from Gaussian, and μθ(xt,t)\mu_{\theta}(\mathbf{x}_{t},t) is the predicted mean of pθ(xt1xt)p_{\theta}(\mathbf{x}_{t-1}\mid\mathbf{x}_{t}) computed by a neural network. While Lsimple\mathcal{L}_{\text{simple}} is no longer a valid lower bound, prior work has found that it empirically made training more stable and improved sample qualityOur definition of Lsimple\mathcal{L}_{\text{simple}} here uses a different parametrization from Ho et al. . We define our squared loss in terms of μθ(xt,t)\mu_{\theta}(\mathbf{x}_{t},t) while they express it in terms of ϵθ(xt,t)\epsilon_{\theta}(\mathbf{x}_{t},t).. We will make use of similar simplifications in Diffusion-LM to stabilize training and improve sample quality (§ 4.1).

Diffusion-LM: Continuous Diffusion Language Modeling

Constructing Diffusion-LM requires several modifications to the standard diffusion model. First, we must define an embedding function that maps discrete text into a continuous space. To address this, we propose an end-to-end training objective for learning embeddings (§ 4.1). Second, we require a rounding method to map vectors in embedding space back to words. To address this, we propose training and decoding time methods to facilitate rounding (§ 4.2).

We propose a modification of the diffusion model training objective (Equation 1) that jointly learns the diffusion model’s parameters and word embeddings. In preliminary experiments, we explored random Gaussian embeddings, as well as pre-trained word embeddings . We found that these fixed embeddings are suboptimal for Diffusion-LM compared to end-to-end trainingWhile trainable embeddings perform best on control and generation tasks, we found that fixed embeddings onto the vocabulary simplex were helpful when optimizing for held-out perplexity. We leave discussion of this approach and perplexity results to Appendix F as the focus of this work is generation quality and not perplexity..

As shown in Figure 2, our approach adds a Markov transition from discrete words w\mathbf{w} to x0\mathbf{x}_{0} in the forward process, parametrized by qϕ(x0w)=N(\textscEmb(w),σ0I)q_{\phi}(\mathbf{x}_{0}|\mathbf{w})=\mathcal{N}(\textsc{Emb}(\mathbf{w}),\sigma_{0}I). In the reverse process, we add a trainable rounding step, parametrized by pθ(wx0)=i=1npθ(wixi)p_{\theta}(\mathbf{w}\mid\mathbf{x}_{0})=\prod_{i=1}^{n}p_{\theta}(w_{i}\mid x_{i}), where pθ(wixi)p_{\theta}(w_{i}\mid x_{i}) is a softmax distribution. The training objectives introduced in § 3 now becomes

We derive Lsimplee2e(w)\mathcal{L}^{\text{e2e}}_{\text{simple}}(\mathbf{w}) from Lvlbe2e(w)\mathcal{L}^{\text{e2e}}_{\text{vlb}}(\mathbf{w}) following the simplification in § 3.3 and our derivation details are in Appendix E. Since we are training the embedding function, qϕq_{\phi} now contains trainable parameters and we use the reparametrization trick to backpropagate through this sampling step. Empirically, we find the learned embeddings cluster meaningfully: words with the same part-of-speech tags (syntactic role) tend to be clustered, as shown in Figure 3.

2 Reducing Rounding Errors

The learned embeddings define a mapping from discrete text to the continuous x0\mathbf{x}_{0}. We now describe the inverse process of rounding a predicted x0\mathbf{x}_{0} back to discrete text. Rounding is achieved by choosing the most probable word for each position, according to argmax pθ(wx0)=i=1npθ(wixi)p_{\theta}(\mathbf{w}\mid\mathbf{x}_{0})=\prod_{i=1}^{n}p_{\theta}(w_{i}\mid x_{i}). Ideally, this argmax-rounding would be sufficient to map back to discrete text, as the denoising steps should ensure that x0\mathbf{x}_{0} lies exactly on the embedding of some word. However, empirically, the model fails to generate x0\mathbf{x}_{0} that commits to a single word.

We described how re-parametrization can be helpful for model training, but we also found that the same intuition could be used at decoding time in a technique that we call the clamping trick. In the standard generation approach for a x0\mathbf{x}_{0}-parametrized model, the model denoises xt\mathbf{x}_{t} to xt1\mathbf{x}_{t-1} by first computing an estimate of x0\mathbf{x}_{0} via fθ(xt,t)f_{\theta}(\mathbf{x}_{t},t) and then sampling xt1\mathbf{x}_{t-1} conditioned on this estimate: xt1=αˉfθ(xt,t)+1αˉϵ\mathbf{x}_{t-1}=\sqrt{\bar{\alpha}}f_{\theta}(\mathbf{x}_{t},t)+\sqrt{1-\bar{\alpha}}\epsilon, where αˉt=s=0t(1βs)\bar{\alpha}_{t}=\prod_{s=0}^{t}(1-\beta_{s}) and ϵN(0,I)\epsilon\sim\mathcal{N}(0,I) This follows from the marginal distribution q(xtx0)q(\mathbf{x}_{t}\mid\mathbf{x}_{0}), which is a closed form Gaussian since all the Markov transitions are Gaussian.. In the clamping trick, the model additionally maps the predicted vector fθ(xt,t)f_{\theta}(\mathbf{x}_{t},t) to its nearest word embedding sequence. Now, the sampling step becomes xt1=αˉClamp(fθ(xt,t))+1αˉϵ\mathbf{x}_{t-1}=\sqrt{\bar{\alpha}}\cdot\operatorname{Clamp}(f_{\theta}(\mathbf{x}_{t},t))+\sqrt{1-\bar{\alpha}}\epsilon. The clamping trick forces the predicted vector to commit to a word for intermediate diffusion steps, making the vector predictions more precise and reducing rounding errors.Intuitively, applying the clamping trick to early diffusion steps with tt near TT may be sub-optimal, because the model hasn’t figured out what words to commit to. Empirically, applying clamping trick for all diffusion steps doesn’t hurt the performance much. But to follow this intuition, one could also set the starting step of the clamping trick as a hyperparameter.

Decoding and Controllable Generation with Diffusion-LM

Having described the Diffusion-LM, we now consider the problem of controllable text generation (§ 5.1) and decoding (§ 5.2).

We now describe a procedure that enables plug-and-play control on Diffusion-LM. Our approach to control is inspired by the Bayesian formulation in § 3.1, but instead of performing control directly on the discrete text, we perform control on the sequence of continuous latent variables x0:T\mathbf{x}_{0:T} defined by Diffusion-LM, and apply the rounding step to convert these latents into text.

Controlling x0:T\mathbf{x}_{0:T} is equivalent to decoding from the posterior p(x0:Tc)=t=1Tp(xt1xt,c)p(\mathbf{x}_{0:T}|\mathbf{c})=\prod_{t=1}^{T}p(\mathbf{x}_{t-1}\mid\mathbf{x}_{t},\mathbf{c}), and we decompose this joint inference problem to a sequence of control problems at each diffusion step: p(xt1xt,c)p(xt1xt)p(cxt1,xt)p(\mathbf{x}_{t-1}\mid\mathbf{x}_{t},\mathbf{c})\propto p(\mathbf{x}_{t-1}\mid\mathbf{x}_{t})\cdot p(\mathbf{c}\mid\mathbf{x}_{t-1},\mathbf{x}_{t}). We further simplify p(cxt1,xt)=p(cxt1)p(\mathbf{c}\mid\mathbf{x}_{t-1},\mathbf{x}_{t})=p(\mathbf{c}\mid\mathbf{x}_{t-1}) via conditional independence assumptions from prior work on controlling diffusions . Consequently, for the tt-th step, we run gradient update on xt1\mathbf{x}_{t-1}:

where both logp(xt1xt)\log p(\mathbf{x}_{t-1}\mid\mathbf{x}_{t}) and logp(cxt1)\log p(\mathbf{c}\mid\mathbf{x}_{t-1}) are differentiable: the first term is parametrized by Diffusion-LM, and the second term is parametrized by a neural network classifier.

Similar to work in the image setting , we train the classifier on the diffusion latent variables and run gradient updates on the latent space xt1\mathbf{x}_{t-1} to steer it towards fulfilling the control. These image diffusion works take one gradient step towards xt1logp(cxt1)\nabla_{\mathbf{x}_{t-1}}\log p(\mathbf{c}\mid\mathbf{x}_{t-1}) per diffusion steps. To improve performance on text and speed up decoding, we introduce two key modifications: fluency regularization and multiple gradient steps.

To generate fluent text, we run gradient updates on a control objective with fluency regularization: λlogp(xt1xt)+logp(cxt1)\lambda\log p(\mathbf{x}_{t-1}\mid\mathbf{x}_{t})+\log p(\mathbf{c}\mid\mathbf{x}_{t-1}), where λ\lambda is a hyperparameter that trades off fluency (the first term) and control (the second term). While existing controllable generation methods for diffusions do not include the λlogp(xt1xt)\lambda\log p(\mathbf{x}_{t-1}\mid\mathbf{x}_{t}) term in the objective, we found this term to be instrumental for generating fluent text. The resulting controllable generation process can be viewed as a stochastic decoding method that balances maximizing and sampling p(xt1xt,c)p(\mathbf{x}_{t-1}\mid\mathbf{x}_{t},\mathbf{c}), much like popular text generation techniques such as nucleus sampling or sampling with low temperature. In order to improve the control quality, we take multiple gradient steps for each diffusion step: we run 33 steps of the Adagrad We tried ablations that replaced Adagrad with SGD, but we found Adagrad to be substantially less sensitive to hyperparameter tuning. update for each diffusion steps. To mitigate for the increased computation cost, we downsample the diffusion steps from 2000 to 200, which speeds up our controllable generation algorithm without hurting sample quality much.

2 Minimum Bayes Risk Decoding

Many conditional text generation tasks require a single high-quality output sequence, such as machine translation or sentence infilling. In these settings, we apply Minimum Bayes Risk (MBR) decoding to aggregate a set of samples S\mathcal{S} drawn from the Diffusion-LM , and select the sample that achieves the minimum expected risk under a loss function L\mathcal{L} (e.g., negative BLEU score): w^=argminwSwS1SL(w,w)\hat{\mathbf{w}}=\operatorname{argmin}_{\mathbf{w}\in S}\sum_{\mathbf{w}^{\prime}\in S}\frac{1}{|S|}\mathcal{L}(\mathbf{w},\mathbf{w}^{\prime}). We found that MBR decoding often returned high quality outputs, since a low quality sample would be dissimilar from the remaining samples and penalized by the loss function.

Experimental Setup

With the above improvements on training (§ 4) and decoding (§ 5), we train Diffusion-LM for two language modeling tasks. We then apply the controllable generation method to 55 classifier-guided control tasks, and apply MBR decoding to a classifier-free control task (i.e. infilling).

We train Diffusion-LM on two datasets: E2E and ROCStories . The E2E dataset consists of 50K restaurant reviews labeled by 8 fields including food type, price, and customer rating. The ROCStories dataset consists of 98K five-sentence stories, capturing a rich set of causal and temporal commonsense relations between daily events. This dataset is more challenging to model than E2E, because the stories contain a larger vocabulary of 11K words and more diverse semantic content.

Our Diffusion-LM is based on Transformer architecture with 8080M parameters, with a sequence length n=64n=64, diffusion steps T=2000T=2000 and a square-root noise schedule (see Appendix A for details). We treat the embedding dimension as a hyperparameter, setting d=16d=16 for E2E and d=128d=128 for ROCStories. See Appendix B for hyperparameter details. At decoding time, we downsample to 200 diffusion steps for E2E and maintain 2000 steps for ROCStories. Decoding Diffusion-LM for 200 steps is still 7x slower than decoding autoregressive LMs. For controllable generation, our method based on Diffusion-LM is 1.5x slower than FUDGE but 60x faster than PPLM.

2 Control tasks

We consider 66 control tasks shown in Table 1: the first 4 tasks rely on a classifier, and the last 2 tasks are classifier freeLength is classifier-free for our Diffusion-LM based methods, but other methods still require a classifier.. For each control task (e.g. semantic content), we sample 200200 control targets c\mathbf{c} (e.g., rating=5 star) from the validation splits, and we generate 5050 samples for each control target. To evaluate the fluency of the generated text, we follow the prior works and feed the generated text to a teacher LM (i.e., a carefully fine-tuned GPT-2 model) and report the perplexity of generated text under the teacher LM. We call this metric lm-score (denoted as lm): a lower lm-score indicates better sample quality. Prior works use GPT as the teacher LM whereas we use a fine-tuned GPT-2 model because our base autoregressive LM and Diffusion-LM both generate UNK tokens, which does not exist in pretrained vocabularies of GPT. We define success metrics for each control task as follows:

Semantic Content. Given a field (e.g., rating) and value (e.g., 5 star), generate a sentence that covers field=value, and report the success rate by exact match of ‘value’.

Parts-of-speech. Given a sequence of parts-of-speech (POS) tags (e.g., Pronoun Verb Determiner Noun), generate a sequence of words of the same length whose POS tags (under an oracle POS tagger) match the target (e.g., I ate an apple). We quantify success via word-level exact match.

Syntax Tree. Given a target syntactic parse tree (see Figure 1), generate text whose syntactic parse matches the given parse. To evaluate the success, we parse the generated text by an off-the-shelf parser , and report F1 scores.

Syntax Spans. Given a target (span, syntactic category) pair, generate text whose parse tree over span [i,j][i,j] matches the target syntactic category (e.g. prepositional phrase).We quantify success via the fraction of spans that match exactly.

Length. Given a target length 10,,4010,\dots,40, generate a sequence with a length within ±2\pm 2 of the target. In the case of Diffusion-LM, we treat this as a classifier-free control task.

Infilling. Given a left context (O1O_{1}) and a right context (O2O_{2}) from the aNLG dataset , and the goal is to generate a sentence that logically connects O1O_{1} and O2O_{2}. For evaluation, we report both automatic and human evaluation from the Genie leaderboard .

3 Classifier-Guided Control Baselines

For the first 5 control tasks, we compare our method with PPLM, FUDGE, and a fine-tuning oracle. Both PPLM and FUDGE are plug-and-play controllable generation approaches based on an autoregressive LM, which we train from scratch using the GPT-2 small architecture .

PPLM. This method runs gradient ascent on the LM activations to increase the classifier probabilities and language model probabilities, and has been successful on simple attribute control. We apply PPLM to control semantic content, but not the remaining 4 tasks which require positional information, as PPLM’s classifier lacks positional information.

FUDGE. For each control task, FUDGE requires a future discriminator that takes in a prefix sequence and predicts whether the complete sequence would satisfy the constraint. At decoding time, FUDGE reweights the LM prediction by the discriminator scores.

FT. For each control task, we fine-tune GPT-2 on (control, text) pair, yielding an oracle conditional language model that’s not plug-and-play. We report both the sampling (with temperature 1.0) and beam search (with beam size 4) outputs of the fine-tuned models, denoted as FT-sample and FT-search.

4 Infilling Baselines

We compare to 3 specialized baseline methods developed in past work for the infilling task.

DELOREAN . This method continuously relaxes the output space of a left-to-right autoregressive LM, and iteratively performs gradient updates on the continuous space to enforce fluent connection to the right contexts. This yields a continuous vector which is rounded back to text.

COLD. COLD specifies an energy-based model that includes fluency (from left-to-right and right-to-left LM) and coherence constraints (from lexical overlap). It samples continuous vectors from this energy-based model and round them to text.

AR-infilling. We train an autoregressive LM from scratch to do sentence infilling task . Similar to training Diffusion-LM, we train on the ROCStories dataset, but pre-process it by reordering sentences from (O1,Omiddle,O2)(O_{1},O_{\text{middle}},O_{2}) to (O1,O2,Omiddle)(O_{1},O_{2},O_{\text{middle}}). At evaluation time, we feed in O1,O2O_{1},O_{2}, and the model generates the middle sentence.

Main Results

We train Diffusion-LMs on the E2E and ROCStories datasets. In terms of negative log-likelihood (NLL, lower is better), we find that the variational upper bound of Diffusion-LM NLL Exact log-likelihoods are intractable for Diffusion-LM, so we report the lower bound Lvlbe2e\mathcal{L}^{\text{e2e}}_{\text{vlb}}. underperforms the equivalent autoregressive Transformer model (2.28 vs. 1.77 for E2E, 3.88 vs 3.05 for ROCStories) although scaling up model and dataset size partially bridges the gap (3.88 \xrightarrow{} 3.10 on ROCStories). Our best log-likelihoods required several modifications from § 4; we explain these and give detailed log-likelihood results in Appendix F. Despite worse likelihoods, controllable generation based on our Diffusion-LM results in significantly better outputs than systems based on autoregressive LMs, as we will show in § 7.1,§ 7.2, and § 7.3

As shown in Table 2, Diffusion-LM achieves high success and fluency across all classifier-guided control tasks. It significantly outperforms the PPLM and FUDGE baselines across all 5 tasks. Surprisingly, our method outperforms the fine-tuning oracle on controlling syntactic parse trees and spans, while achieving similar performance on the remaining 3 tasks.

Controlling syntactic parse trees and spans are challenging tasks for fine-tuning, because conditioning on the parse tree requires reasoning about the nested structure of the parse tree, and conditioning on spans requires lookahead planning to ensure the right constituent appears at the target position.

We observe that PPLM fails in semantic content controls and conjecture that this is because PPLM is designed to control coarse-grained attributes, and may not be useful for more targeted tasks such as enforcing that a restaurant review contains a reference to Starbucks.

FUDGE performs well on semantic content control but does not perform well on the remaining four tasks. Controlling a structured output (Parts-of-speech and Syntax Tree) is hard for FUDGE because making one mistake anywhere in the prefix makes the discriminator assign low probabilities to all continuations. In other control tasks that require planning (Length and Syntax Spans), the future discriminator is difficult to train, as it must implicitly perform lookahead planning.

The non-autoregressive nature of our Diffusion-LM allows it to easily solve all the tasks that require precise future planning (Syntax Spans and Length). We believe that it works well for complex controls that involve global structures (Parts-of-speech, Syntax Tree) because the coarse-to-fine representations allow the classifier to exert control on the entire sequence (near t=Tt=T) as well as on individual tokens (near t=0t=0).

Table 3 shows samples of Syntax Tree control. Our method and fine-tuning both provide fluent sentences that mostly satisfy controls, whereas FUDGE deviates from the constraints after the first few words. One key difference between our method and fine-tuning is that Diffusion-LM is able to correct for a failed span and have suffix spans match the target. In the first example, the generated span (“Family friendly Indian food”) is wrong because it contains 1 more word than the target. Fortunately, this error doesn’t propagate to later spans, since Diffusion-LM adjusts by dropping the conjunction. Analogously, in the second example, the FT model generates a failed span (“The Mill”) that contains 1 fewer word. However, the FT model fails to adjust in the suffix, leading to many misaligned errors in the suffix.

2 Composition of Controls

One unique capability of plug-and-play controllable generation is its modularity. Given classifiers for multiple independent tasks, gradient guided control makes it simple to generate from the intersection of multiple controls by taking gradients on the sum of the classifier log-probabilities.

We evaluate this setting on the combination of Semantic Content + Syntax Tree control and Semantic Content + Parts-of-speech control. As shown in Table 4, our Diffusion-LM achieves a high success rate for both of the two components, whereas FUDGE gives up on the more global syntactic control. This is expected because FUDGE fails to control syntax on its own.

Fine-tuned models are good at POS and semantic content control individually but do not compose these two controls well by product of experts (PoE), leading to a large drop in success rates for both constraints.

3 Infilling Results

As shown in Table 5, our diffusion LM significantly outperforms continuous relaxation based methods for infilling (COLD and DELOREAN). Moreover, our method achieves comparable performance to fine-tuning a specialized model for this task. Our method has slightly better automatic evaluation scores and the human evaluation found no statistically significant improvement for either method. These results suggest that Diffusion LM can solve many types of controllable generation tasks that depend on generation order or lexical constraints (such as infilling) without specialized training.

4 Ablation Studies

We verify the importance of our proposed design choices in § 4 through two ablation studies. We measure the sample quality of Diffusion-LM using the lm-score on 500 samples § 6.2.

Learned v.s. Random Embeddings (§ 4.1). Learned embeddings outperform random embeddings on the ROCStories, which is a harder language modeling task. The same trend holds for the E2E dataset but with a smaller margin.

Objective Parametrization (§ 4.2). We propose to let the diffusion model predict x0\mathbf{x}_{0} directly. Here, we compare this with standard parametrization in image generation which parametrizes by the noise term ϵ\epsilon. Figure 4 (right) shows that parametrizing by x0\mathbf{x}_{0} consistently attains good performance across dimensions, whereas parametrizing by ϵ\epsilon works fine for small dimensions, but quickly collapses for larger dimensions.

Conclusion and Limitations

We proposed Diffusion-LM, a novel and controllable language model based on continuous diffusions, which enables new forms of complex fine-grained control tasks. We demonstrate Diffusion-LM’s success in 6 fine-grained control tasks: our method almost doubles the control success rate of prior methods and is competitive with baseline fine-tuning methods that require additional training.

We find the complex controls enabled by Diffusion-LM to be compelling, and we are excited by how Diffusion-LM is a substantial departure from the current paradigm of discrete autoregressive generation. As with any new technologies, there are drawbacks to the Diffusion-LMs that we constructed: (1) it has higher perplexity; (2) decoding is substantially slower; and (3) training converges more slowly. We believe that with more follow-up work and optimization, many of these issues can be addressed, and this approach will turn out to be a compelling way to do controllable generation at scale.

Acknowledgments and Disclosure of Funding

We thank Yang Song, Jason Eisner, Tianyi Zhang, Rohan Taori, Xuechen Li, Niladri Chatterji, and the members of p-lambda group for early discussions and feedbacks. We gratefully acknowledge the support of a PECASE award. Xiang Lisa Li is supported by a Stanford Graduate Fellowship.

References

Appendix A Diffusion Noise Schedule

Because a diffusion model shares parameters for all diffusion steps, the noise schedule (parametrized by αˉ1:T\bar{\alpha}_{1:T}) is an important hyperparameter that determines how much weight we assign to each denoising problem. We find that standard noise schedules for continuous diffusions are not robust for text data. We hypothesize that the discrete nature of text and the rounding step make the model insensitive to noise near t=0t=0. Concretely, adding small amount of Gaussian noise to a word embedding is unlikely to change its nearest neighbor in the embedding space, making denoising an easy task near t=0t=0.

To address this, we introduce a new sqrt noise schedule that is better suited for text, shown in Figure 5 defined by αˉt=1t/T+s\bar{\alpha}_{t}=1-\sqrt{t/T+s}, where ss is a small constant that corresponds to the starting noise levelWe set s=s=1e-4, and T=2000T=2000, which sets the initial standard deviation to 0.10.1. . Compared to standard linear and cosine schedules, our sqrt schedule starts with a higher noise level and increase noise rapidly for the first 50 steps. Then sqrt slows down injecting noise to avoid spending much steps in the high-noise problems, which may be too difficult to solve well.

Appendix B Hyperparameters

The hyperparameters that are specific to Diffusion-LM include the number of diffusion steps, the architecture of the Diffusion-LM, the embedding dimension, and the noise schedule, . We set the diffusion steps to be 20002000, the architecture to be BERT-base , and the sequence length to be 6464. For the embedding dimensions, we select from d{16,64,128,256}d\in\{16,64,128,256\} and select d=16d=16 for the E2E dataset and d=128d=128 for ROCStories. For the noise schedule, we design the sqrt schedule (Appendix A) that is more robust to different parametrizations and embedding dimensions as shown in Appendix H. However, once we picked the x0\mathbf{x}_{0}-parametrization (§ 4.2) the advantage of sqrt schedule is not salient.

Training hyperparameters.

We train Diffusion-LMs using AdamW optimizer and a linearly decay learning rate starting at 1e-4, dropout of 0.1, batch size of 64, and the total number of training iteration is 200K for E2E dataset, and 800K for ROCStories dataset. Our Diffusion-LMs are trained on a single GPU: NVIDIA RTX A5000, NVIDIA GeForce RTX 3090, or NVIDIA A100. It takes approximately 5 hours to train for 200K iterations on a single A100 GPU.

To stablize the training under Lvlbe2e\mathcal{L}_{\text{vlb}}^{\text{e2e}} objective, we find that we need to set gradient clipping to 1.0 and apply importance sampling to reweight each term in Lvlb\mathcal{L}_{\text{vlb}} . Both tricks are not necessary for Lsimplee2e\mathcal{L}_{\text{simple}}^{\text{e2e}} objective.

Controllable Generation hyperparameters.

To achieve controllable generation, we run gradient update on the continuous latents of Diffusion-LM. We use the AdaGrad optimizer to update the latent variables, and we tune the learning rate, lr{0.05,0.1,0.15,0.2}\text{lr}\in\{0.05,0.1,0.15,0.2\} and the trade-off parameter λ{0.1,0.01,0.001,0.0005}\lambda\in\{0.1,0.01,0.001,0.0005\}. Different plug-and-play controllable generation approaches tradeoff between fluency and control by tunning different hyperparameters: PPLM uses the number of gradient updates per token, denoted as kk, and we tune k{10,30}k\in\{10,30\}. FUDGE uses the tradeoff parameter λFUDGE\lambda_{\text{FUDGE}} and we tune this λFUDGE{16,8,4,2}\lambda_{\text{FUDGE}}\in\{16,8,4,2\}. Table 6 contains all the selected hyperparameter for each control tasks. Both PPLM and FUDGE has additional hyperparameters and we follow the instruction from the original paper to set those. For PPLM, we set the learning rate to be 0.04 and KL-scale to be 0.01. For FUDGE, we set precondition top-K to be 200, post top-K to be 10.

Appendix C Decoding Speed

Sampling from Diffusion-LMs requires iterating through the 2000 diffusion steps, yielding O(2000)O(2000) fθf_{\theta} model calls. In contrast, sampling from autoregressive LMs takes O(n)O(n) where nn is the sequence length. Therefore, decoding Diffusion-LM is slower than decoding autoregressive LMs in short and medium-length sequence regimes. Concretely, it takes around 1 minute to decode 50 sequence of length 64.

To speed up decoding, we tried skipping steps in the generative diffusion process and downsample 2000 steps to 200 steps. Concretely, we set T=200T=200 and downsample the noise schedule αˉt=αˉ10t\bar{\alpha}_{t}=\bar{\alpha}_{10t}, which is equivalent to setting each unit transition as the transition xtxt+10\mathbf{x}_{t}\rightarrow\mathbf{x}_{t+10}. We decode Diffusion-LM using this new noise schedule and discretization. We find that this naive approach doesn’t hurt sample quality for simple language modeling tasks like E2E, but it hurts sample quality for harder language modeling tasks like ROCStories.

For plug-and-play controllable generation tasks, extant approaches are even slower. PPLM takes around 80 minutes to generate 50 samples (without batching), because it needs to run 30 gradient updates for each token. FUDGE takes 50 seconds to generate 50 samples (with batching), because it needs to call the lightweight classifier for each partial sequence, requiring 200 classifier calls for each token, yielding 100×100\times sequence length calls. We can batch the classifier calls, but it sometimes limits batching across samples due to limited GPU memory. Our Diffusion-LM takes around 80 seconds to generate 50 samples (with batching). Our method downsamples the number of diffusion steps to 200, and it takes 3 classifier calls per diffusion step, yielding 600 model calls in total.

Appendix D Classifiers for Classifier-Guided Controls

Semantic Content. We train an autoregressive LM (GPT-2 small architecture) to predict the (field, value) pair conditioned on text. To parametrize logp(cxt)\log p(\mathbf{c}\mid\mathbf{x}_{t}), we compute the logprob of “value” per token.

Parts-of-speech. The classifier is parametrized by a parts-of-speech tagger, which estimates the probability of the target POS sequence conditioned on the latent variables. This tagger uses a BERT-base architecture: the input is the concantenated word embedding, and output a softmax distribution over all POS tags for each input word. logp(cxt)\log p(\mathbf{c}\mid\mathbf{x}_{t}) is the sum of POS log-probs for each word in the sequence.

Syntax Tree. We train a Transformer-based constituency parser . Our parser makes locally normalized prediction for each span, predicting either “not a constituent”, or a label for the constituent (e.g., Noun Phrase). logp(cxt)\log p(\mathbf{c}\mid\mathbf{x}_{t}) is the sum of log-probs for each labeled and non-constituency span in the sequence.

Syntax Span. We use the same parser trained for the syntax tree. logp(cxt)\log p(\mathbf{c}\mid\mathbf{x}_{t}) is the log-probability that the target span is annotated with the target label.

Appendix E End-to-end Objective Derivations

For continuous diffusion models (§ 3.3), Lsimple\mathcal{L}_{\text{simple}} is derived from the canonical objective Lvlb\mathcal{L}_{\text{vlb}} by reweighting each term. The first TT terms in Lvlb\mathcal{L}_{\text{vlb}} are all KL divergence between two Gaussian distributions, which has a closed form solution. Take the tt-th term for example:

To apply continuous diffusion to model discrete text, we design Diffusion-LM (§ 4.1) and propose a new end-to-end training objective (equation 2) that learns the diffusion model and the embedding parameters jointly. The Lvlbe2e\mathcal{L}_{\text{vlb}}^{\text{e2e}} can be written out as

We apply the same simplification which transforms LvlbLsimple\mathcal{L}_{\text{vlb}}\rightarrow\mathcal{L}_{\text{simple}} to transform Lvlbe2eLsimplee2e\mathcal{L}^{\text{e2e}}_{\text{vlb}}\rightarrow\mathcal{L}^{\text{e2e}}_{\text{simple}}:

It’s worth noting that the first term is constant if the noise schedule satisfies αˉT=0\bar{\alpha}_{T}=0, which guarantees xT\mathbf{x}_{T} is pure Gaussian noise. In contrast, if the noise schedule doesn’t go all the way such that xT\mathbf{x}_{T} is pure Gaussian noise, we need to include this regularization term to prevent the embedding from learning too large norms. Embedding with large norms is a degenerate solution, because it is impossible to sample from p(xT)p(\mathbf{x}_{T}) accurately, even though it makes all the other denoising transitions easily predictable.

Combining these terms yield Lsimplee2e\mathcal{L}^{\text{e2e}}_{\text{simple}}.

There are other parametrizations that are equivalent to μθ\mu_{\theta}-parametrization up to a scaling constant. For example in § 4.2, we can train the Transformer model to directly predict x0\mathbf{x}_{0} via fθ(xt,t)f_{\theta}(\mathbf{x}_{t},t), and use the tractable Gaussian posterior q(xt1x0,xt)q(\mathbf{x}_{t-1}\mid\mathbf{x}_{0},\mathbf{x}_{t}) to compute the mean of xt1\mathbf{x}_{t-1}, which has a closed form solution, conditioned on predicted x0\mathbf{x}_{0} and observed xt\mathbf{x}_{t}: αˉt1βt1αˉtx0+αt(1αˉt1)1αˉtxt\frac{\sqrt{\bar{\alpha}_{t-1}}\beta_{t}}{1-\bar{\alpha}_{t}}\mathbf{x}_{0}+\frac{\sqrt{\alpha_{t}}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_{t}}\mathbf{x}_{t}.

These two parametrizations differ by a constant scaling, and we apply the x0\mathbf{x}_{0}-parametrization to all terms in Lsimplee2e\mathcal{L}^{\text{e2e}}_{\text{simple}} to reduce rounding errors as discussed in § 4.2:

To generate samples from a Diffusion-LM with x0\mathbf{x}_{0}-parametrization, at each diffusion step, the model estimates the x0\mathbf{x}_{0} via fθ(xt,t)f_{\theta}(\mathbf{x}_{t},t) and then we sample xt1\mathbf{x}_{t-1} from q(xt1fθ(xt,t),xt)q(\mathbf{x}_{t-1}\mid f_{\theta}(\mathbf{x}_{t},t),\mathbf{x}_{t}), which is fed as input to the next diffusion step.

Appendix F Log-Likelihood Models and Results

To investigate Diffusion-LM’s log-likelihood performance, we make several departures from the training procedure of § 4. Ultimately the log-likelihood improvements described in this section did not translate into better generation quality in our experiments and therefore we focus on the original method in the rest of the paper. Our likelihood models are trained as follows:

Instead of training a diffusion model on sequences of low-dimensional token embeddings, we train a model directly sequences of on one-hot token vectors.

Following the setup of Kingma et al. , we train a continuous-time diffusion model against the log-likelihood bound and learn the noise schedule simultaneously with the rest of the model to minimize the loss variance.

Because our model predicts sequences of one-hot vectors, we use a softmax nonlinearity at its output and replace all squared-error terms in the loss function with cross-entropy terms. This choice of surrogate loss led to better optimization, even though we evaluate against the original loss with squared-error terms.

At inference time, we omit the rounding procedure in § 4.2.

For exact model architecture and training hyperparameter details, please refer to our released code.

We train these diffusion models, as well as baseline autoregressive Transformers, on E2E and ROCStories and report log-likelihoods in Table 7. We train two sizes of Transformers: “small” models with roughly 100M parameters and “medium” models with roughly 300M parameters. Both E2E and ROCstories are small enough datasets that all of our models reach their minimum test loss early in training (and overfit after that). To additionally compare model performance in a large-dataset regime, we also present “ROCStories (+GPT-J)” experiments in which we generate 8M examples of synthetic ROCStories training data by finetuning GPT-J on the original ROCStories data, pretrain our models on the synthetic dataset, and then finetune and evaluate them on the original ROCStories data.

Appendix G Qualitative Examples

We show randomly sampled outputs of Diffusion-LM both for unconditional generation and for the 55 control tasks. Table 8 shows the unconditional generation results. Table 9, Table 10, Table 12, and Table 3 show the qualitative samples from span control, POS control, semantic content control, and syntax tree control, respectively. Table 11 shows the results of length control.

Appendix H Additional Ablation Studies

In addition to the 2 ablation studies in § 7.4, we provide more ablation results in Figure 6 about architecture choices and noise schedule.

Learned v.s. Random Embeddings (§ 4.1). Learned embeddings outperform random embeddings on both ROCStories and the E2E dataset by xx percent and xx percent respectively, as shown in the first row of Figure 6.

Noise Schedule (Appendix A). We compare the sqrt schedule with cosine and linear schedules proposed for image modeling. The middle row of Figure 6 demonstrates that sqrt schedule attains consistently good and stable performance across all dimension and parametrization choices. While the sqrt schedule is less important with x0\mathbf{x}_{0}-parametrization, we see that it provides a substantially more robust noise schedule under alternative parametrizations such as ϵ\epsilon.

The U-Net architecture in Ho et al. utilizes 2D-convolutional layers, and we imitate all the model architectures except changing 2D-conv to 1D-conv which is suitable for text data. Figure 6 (last row) shows that the Transformer architecture outperforms U-Net.

Appendix I Societal Impacts

On the one hand, having strong controllability in language models will help with mitigating toxicity, making the language models more reliable to deploy. Additionally, we can also control the model to be more truthful, reducing the inaccurate information generated by the language model by carefully controlling it to be truthful. On the other hand, however, one could also imagine more powerful targeted disinformation (e.g., narrative wedging) derived from the fine-grained controllability.

Towards this end, it might be worth considering generation methods that can watermark the generated outputs without affecting its fluency, and this type of watermark could also be framed as a controllable generation problem, with distinguish-ability and fluency as the constraints.

Diffusion-LM Improves Controllable Text Generation — p7