Minimum Word Error Rate Training for Attention-based Sequence-to-Sequence Models
Rohit Prabhavalkar, Tara N. Sainath, Yonghui Wu, Patrick Nguyen, Zhifeng Chen, Chung-Cheng Chiu, Anjuli Kannan
Introduction
There has been growing interest in the automatic speech recognition (ASR) community in building end-to-end trained, sequence-to-sequence models which directly output a word sequence given input speech frames, without requiring explicit alignments between the speech frames and labels. Examples of such approaches include the recurrent neural network transducer (RNN-T) , the recurrent neural aligner (RNA) , attention-based models , and connectionist temporal classification (CTC) with word-based targets . Such approaches are motivated by their simplicity: since these models directly output graphemes, word-pieces , or words, they do not require expertly curated pronunuciation dictionaries; since they can be trained to directly output normalized text, they do not require separate modules to map recognized text from the spoken to the written domain. In our recent work, we have shown that such approaches are comparable to traditional state-of-the-art speech recognition systems .
Most sequence-to-sequence models (e.g., ) are typically trained to optimize the cross-entropy (CE) loss function, which corresponds to improving log-likelihood of the training data. During inference, however, model performance is commonly measured using task-specific criteria, not log-likelihood: e.g., word error rate (WER) for ASR, or BLEU score for machine translation. Traditional ASR systems account for this mismatch through discriminative sequence training of neural network acoustic models (AMs) which fine-tunes a cross-entropy trained AM with criteria such as state-level minimum Bayes risk (sMBR) which are more closely related to word error rate.
In the context of sequence-to-sequence models, there have been a few previous proposals to optimize task-specific losses. In their seminal work, Graves and Jaitly minimize expected WER of an RNN-T model by approximating the expectation with samples drawn from the model. This approach is similar to the edit-based minimum Bayes risk (EMBR) approach proposed by Shannon, which was used for minimum expected WER training of conventional ASR systems and the recurrent neural aligner . An alternative approach is based on reinforcement learning, where the label output at each step can be viewed as an action, so that the task of learning consists of learning the optimal policy (i.e., optimal output label sequence) which results in the greatest expected reward (lowest expected task-specific loss). Ranzato et al. apply a variant of the REINFORCE algorithm to optimize task-specific losses for summarization and machine translation. More recently Bahdanau et al. use an actor-critic approach, which was shown to improve BLEU scores for machine translation.
In the present work, we consider techniques to optimize attention-based sequence-to-sequence models in order to directly minimize WER. Our proposed approach is similar to in that we approximate the expected WER using hypotheses from the model. We consider both the use of sampling-based approaches as well as approximating the loss over N-best lists of recognition hypotheses as is commonly done in ASR (e.g., ). However, unlike Sak et al. we find that the process is more effective if we approximate the expectation using N-best hypotheses decoded from the model using beam-search rather than sampling from the model (See section 5.1). We apply the proposed techniques on an English mobile voice-search task, to optimize grapheme-based models, with uni- and bi-directional encoders, where we find that we can improve WER by up to 8.2% relative to a CE-trained baseline model. Minimum word error rate training allows us to train grapheme-based sequence-to-sequence models which are comparable in performance to a strong state-of-the-art context-dependent (CD) phoneme-based speech recognition system .
The organization of the rest of the paper is as follows. We describe the particular attention-based model used in this work in Section 2 and describe the proposed approach for minimum WER training of attention models in Section 3. We describe our experimental setup and our results in Sections 4 and 5, respectively, before concluding in Section 6.
Attention-Based Models
An attention-based model consists of three components: an encoder network which maps input acoustic vectors into a higher-level representation, an attention model which summarizes the output of the encoder based on the current state of the decoder, and a decoder network which models an output distribution over the next target conditioned on the sequence of previous predictions: . The model is depicted in Figure 1. The encoder network consists of a deep recurrent neural network which receives as input the sequence of acoustic feature vectors, , and computes a sequence of encoded features, , and is analogous to an acoustic model in a traditional ASR system. The decoder network - which is analogous to the pronunication and language modeling components in a traditional ASR system - consists of a deep recurrent neural network which is augmented with an attention mechanism . The decoder network predicts a single label at each step, conditioned on the history of previous predictions. At each prediction step, the attention mechanism summarizes the encoded features based on the decoder state to compute a context vector, , as described in Section 2.1. The attention model thus corresponds to the component of a traditional ASR system which learns the alignments between the input acoustics and the output labels. This context vector is input to the decoder along with the previous label, . The final decoder layer produces a set of logits which are input to a softmax layer which computes a distribution over the set of output labels: .
The individual attention values are then transformed into soft attention weights through a softmax operation, and used to compute a summary of the encoder features, :
The matrices and the vector, , are parameters of the model. Finally, the overall context vector is computed by concatenating together the individual summaries: .
2 Training and Inference
Most attention-based models are trained by optimizing the cross-entropy (CE) loss function, which maximizes the the log-likelihood of the training data:
where, we always input the ground-truth label sequence during training (i.e., we do not use scheduled sampling ). Inference in the model is performed using a beam-search algorithm , where the models predictions are fed back until the model outputs the symbol which indicates that inference is complete.
Minimum Word Error Rate Training of Attention-based Models
In this section we described how an attention-based model can be trained to minimize the expected number of word errors, and thus the word error rate. We denote by the number of word errors in a hypothesis, , relative to the ground-truth sequence, . In order to minimize word error rates on test data, we consider as our loss function, the expected number of word errors over the training set:
Computing the loss in (4) exactly is intractable since it involves a summation over all possible label sequences. We therefore consider two possible approximations which ensure tractability: approximating the expectation in (4) with samples , or restricting the summation to an N-best list as is commonly done during sequence-training for ASR .
We can approximate the expectation in (4) using an empirical average over samples drawn from the model :
where, are N samples drawn from the model distribution. Critically, the gradient of the expectation in (5) can be itself be expressed as an expectation, which allows it to be approximated using samples :
2 Approximation Using N-best Lists
One of the potential disadvantages of the sampling-based approach is that a large number of samples might be required in order to approximate the expectation well. However, since the probability mass is likely to be concentrated on the top-N hypotheses, it is reasonable to approximate the loss function by restricting the sum over just the top N hypotheses. We note that this is typically done in traditional discriminative sequence training approaches as well, where the summation is restricted to paths in a lattice .
Denote by , the set of N-best hypotheses computed using beam-search decoding for the input utterance , with a beam-size, . We can then approximate the loss function in (4) by assuming that the probability mass is concentrated on just the N-best hypotheses, as follows:
Where, , represents the distribution re-normalized over just the N-best hypotheses, and is the average number of word errors over the N-best hypohtheses, which is applied as a form of variance reduction, since it does not affect the gradient.
3 Initialization and Training
Based on the two schemes for approximating the expected word error rate, we can define two possible loss functions:
In both cases, we interpolate with the CE loss function using a hyperparameter which we find is important to stabilize training (See Section 5). We note that interpolation with the CE loss function is similar to the f-smoothing approach in ASR. Training the model directly to optimize or with random initialization is hard, since the model is not directly provided with the ground-truth label sequence. Therefore, we initialize the model with the parameters obtained after CE training.
Experimental Setup
The proposed approach is evaluated by conducting experiments on a mobile voice-search task. Models are trained on the same datasets as in our previous works . The training set consists of 15M hand-transcribed anonymized utterances extracted from Google voice-search traffic (12,500 hours). In order to improve robustness to noise, multi-style training data (MTR) are constructed by artificially distorting training utterances with reverberation and noise drawn from environmental recordings of daily events and from YouTube using a room simulator, where the overall SNR ranges from 0-30dB with an average SNR of 12dB . Model hyperparameters are tuned on a development set of 12.9K utterances (63K words) and results are reported on a set of 14.8K utterances (71.6K words).
The acoustic input is parameterized into 80-dimensional log-Mel filterbank features extracted over the 16kHz frequency range, computed with a 25ms window and a 10ms frame shift. Following , three consecutive frames are stacked together, and every third stacked frame is presented as input to the encoder. The same frontend is used for all models reported in this work.
Two attention-based models are trained in this work, differing only in the structure of the encoder network: the first model (Uni-LAS) uses 5 layers of 1,400 uni-directional LSTM cells , whereas the second model (Bidi-LAS) uses 5 layers of 1,024 bi-directional LSTM cells (i.e., 1,024 cells in the forward and backward directions, for each layer). The decoder network of both models consists of two layers of 1,024 LSTM cells in each layer. Both models use multi-headed attention as described in Section 2.1 with attention heads. Models are trained to output a probability distribution over grapheme symbols: 26 lower case alphabets a-z, the numerals 0-9, punctuation symbols ,’! etc., and the special symbols , . All models are trained using the Tensorflow toolkit , with asynchronous stochastic gradient descent (ASGD) using the Adam optimizer .
Results
We investigate the impact of various hyperparameters, and the choice of approximation scheme by conducting detailed experiments on the uni-directional LAS model. Results on the bi-directional LAS model, along with a comparison to a traditional CD-phone based state-of-the-art system are deferred until Section 5.2.
Our first set of experiments evaluate the effectiveness of approximating the expected number of word errors using samples (i.e., optimizing ) versus the approximation using N-best lists (i.e., optimizing ), as described in Section 3.3. Our observations are illustrated in Figure 2, where we plot various metrics on a held-out portion of the training data.
As can be seen in Figure 2(a), optimizing the sample-based approximation, , reduces the expected number of word errors by 50% after training, with performance appearing to improve as the number of samples, , used in the approximation increases. Unlike , however, as can be seen in Figure 2(b), the WER for the top-hypothesis computed using beam search does not improve, but instead degrades as a result of training. We hypothesize that this is a result of the mis-match between the beam-search decoding procedure, which focuses on the head of the distribution during each next-label prediction, and the sampling procedure which also considers lower-probability paths .
As illustrated in Figure 2(c), optimizing (i.e., using the N-best list-based approximation) significantly improves WER by about 10.4% on the held-out portion of the training set. Further, performance seems to be similar even when just the top four hypotheses are considered during the optimization.
As a final note, we find that it is important to also interpolate with CE loss function during optimization (i.e., setting ). This is illustrated for the case where we optimize using hypotheses in the N-best list in Figure 3.
2 Improvements from Minimum WER Training for LAS Models
We present results after expected minimum WER training (MWER) of the uni- and bi-directional LAS models described in Section 4 in Table 1, where we set and . We report results after directly decoding the models to produce grapheme sequences using a beam-search decoding with 8 beams (column 2) as well as after rescoring the 8-best list using a very large 5-gram language model (column 3). For comparison, we also report results using a traditional state-of-the-art low frame rate (LFR) CD-phone based system, which uses an acoustic model composed of four layers of 1,024 uni-directional LSTM cells, followed by one layer of 768 uni-directional cells. The model is first trained to optimize the CE loss function, followed by discriminative sequence training to optimize the state-level minimum Bayes risk (sMBR) criterion . The model is decoded using a pruned, first-pass, 5-gram language model, which uses a vocabulary of millions of words, as well as an expert-curated pronunciation dictionary. As before, we report results both before and after second-pass lattice rescoring.
As can be seen in Table 1, when decoded without second-pass rescoring (i.e., end-to-end training), MWER training improves performance of the uni- and bi-directional LAS systems by 7.4% and 4.2% respectively. The gains after MWER training are even larger after second-pass rescoring, improving the baseline uni- and bi-directional LAS systems by 8.2% and 6.1%, respectively. Finally, we note that after MWER training the grapheme-based uni-directional LAS system matches the performance of a state-of-the-art traditional CD-phoneme-based ASR system.
Conclusions
We described a technique for training sequence-to-sequence systems to optmize the expected test error rate, which was applied to attention-based systems. Unlike , we find that sampling-based approximations are not as effective as approximations based on using N-best decoded hypotheses. Overall, we find that the proposed approach allows us to improve WER by up to 8.2% relative. We find that the proposed techniques allow us to train grapheme-based sequence-to-sequence models which match performance with a traditional CD-phone-based state-of-the-art system on a voice-search task, which when viewed jointly with our previous works adds further evidence to the effectiveness of sequence-to-sequence modeling approaches.