JEIT: Joint End-to-End Model and Internal Language Model Training for Speech Recognition

Zhong Meng, Weiran Wang, Rohit Prabhavalkar, Tara N. Sainath, Tongzhou Chen, Ehsan Variani, Yu Zhang, Bo Li, Andrew Rosenberg, Bhuvana Ramabhadran

Introduction

End-to-end (E2E) models have achieved the strong performance for automatic speech recognition (ASR) by directly mapping the speech signal into word sequences. However, even trained with a large amount of audio-transcript pairs, the E2E models still performs poorly when evaluated on utterances including words that appear infrequently in the training data (rare words) . Moreover, this supervised speech obtained via human transcription is expensive. To overcome this, utilizing knowledge from large-scale unpaired text during training or inference is a promising solution since unpaired text is orders of magnitude more plentiful than audio-transcript pairs and covers a much larger vocabulary of words.

Language model (LM) fusion is a common approach to improve E2E ASR by using unpaired text. An external LM is first trained with unpaired text. In shallow fusion , a log-linear interpolation between the E2E model score and the LM score is computed at each step of the beam search. To improve shallow fusion, internal LM estimation-based fusion was proposed to estimate an internal LM (ILM) score and subtract it from the shallow fusion score. However, all these methods require an external LM during inference, increasing decoding time and computational cost.

To overcome this, various research has looked at incorporating unpaired text into the training stage of E2E models. One intuitive solution is to synthesize speech from unpaired text and use it to train the E2E model . However, training a text-to-speech (TTS) model and synthesizing speech are both computationally expensive. To circumvent this, modality matching approaches were proposed to map unpaired text to a latent space shared by speech and text, and then use latent embeddings to train the E2E model.

Alternatively, the decoder (and joint network for a transducer model) of an E2E model behaves like an ILM when we zero out the encoder output . To achieve fast text-only adaptation, unpaired text is injected into the decoder of a well-trained E2E model . These methods take one extra adaptation step of fine-tuning ILM of the E2E model using text-only data to minimize a cross-entropy loss after one or two stages of E2E training. In addition, Kullback-Leibler divergence (KLD) regularization is performed to maintain the source-domain ASR performance.

The novel contributions of this work are: (1) We propose a joint E2E model and ILM training (JEIT) that simplifies decoder text injection by combining it into a single-stage E2E training. JEIT outperforms text-only adaptation without KLD regularization. (2) We further propose a combined JEIT and JOIST training (CJJT) and demonstrate that decoder text-injection via ILM is complementary to encoder text-injection (via JOIST) and that the improvements are additive. (3) We show that all text-injection methods can facilitate a more effective LM fusion. (4) We validate our methods on Google’s large-scale streaming production task where JEIT and CJJT offer up to 10.2% and 16.4% relative reductions in WER, respectively, on rare-word test sets without affecting voice search performance.

Related Work

2 Modular HAT (MHAT)

To achieve more robust text-only adaptation, we proposed MHAT in to structurally separate the ILM score prediction from the acoustic model score or blank score predictions. As in Fig. 2, MHAT introduces a blank decoder that takes in the same previous labels as the label decoder to generate the current label embeddings below

3 ILM Training (ILMT) and ILM adaptation (ILMA)

ILMT minimizes an additional ILM loss during E2E model training. While the E2E loss is computed with audio-transcript pairs, the ILM loss is derived from only the training transcript. ILMT aims to encourage ILM to behave also like a standalone neural LM such that (1) accurate ILM scores can be estimated to improve ILME-based fusion (2) ILM can be further adapted to text-only data . ILMT makes no use of unpaired text and it does not improve the ASR performance on either source-domain or rare-word test sets . Unlike ILMT, JEIT injects unpaired text into ILM during E2E training with the goal of improving rare-word recognition.

ILMA performs fast text-only adaptation of an E2E model to improve rare-word ASR. In ILMA, we first conduct ILMT of E2E model and then fine-tune ILM to minimize a cross-entropy ILM loss using unpaired text. To prevent the source-domain ASR performance from degrading, we minimize an additional KLD between the output distributions of the unadapted and adapted ILMs during ILMA. To simplify ILMA, JEIT combines two stages of ILMT and ILMA into one training stage and obviates the need for KLD regularization.

4 Joint Speech and Text Modeling (JOIST)

JOIST incorporates unpaired text into E2E training and significantly improves rare-word recognition. It injects unpaired text through the encoder so that text data can benefit the entire E2E model. In JOIST, unpaired text is first tokenized to word-piece or phoneme sequences and is then upsampled by replicating each token a fixed or random number of times. The upsampled text is masked and then fed into a text encoder to generate token embeddings which are further passed to the decoder input or a layer of the encoder. JOIST minimizes a weighted sum of two E2E losses derived from audio-transcript pairs DP\mathcal{D}_{\text{P}} and unpaired text DUP\mathcal{D}_{\text{UP}}, respectively

where F()F(\cdot) is a function that tokenizes, unsamples and masks an unpaired sentence in DUP\mathcal{D}_{\text{UP}}. α>0\alpha>0 is the weight of unpaired E2E losses. In this work, we incorporate ILM loss into JOIST to further improve ASR performance.

Joint E2E and ILM Training (JEIT)

ILM probability can be estimated by the E2E model output after zeroing out the encoder output . ILM is the decoder and the joint network of HAT, is the label decoder and output projection W4\mathbf{W}_{4} of MHAT, and is the decoder of an AED model.

Our goal is to improve the ASR accuracy on rare-word test sets by making use of large-scale unpaired text while maintaining WER on source-domain task (e.g., voice search). In this work, we propose JEIT, a joint training of E2E model and ILM that injects unpaired text into ILM during E2E training.

As shown in Figs. 1 and 2, ILM is trained with unpaired text to minimize an ILM loss while the entire E2E model is trained with audio-transcript pairs to minimize an E2E loss. The ILM loss minimization makes ILM a strong neural LM in the target domain while the E2E loss serves as a regularization to ensure ILM can work well with the other E2E model components to predict accurate E2E scores. Specifically, JEIT minimizes a weighted sum of the E2E loss and ILM loss below

where β>0\beta>0 is the weight of the ILM loss. The ILM loss is the summed negative log probability of all label sequences predicted by ILM on the unpaired text DUP\mathcal{D}_{\text{UP}} as follows

where θILMθE2E\theta_{\text{ILM}}\subseteq\theta_{\text{E2E}} denotes ILM parameters.

Compared to text-only adaptation , JEIT significantly simplifies the entire learning process: 1) JEIT reduces two steps of audio-transcript training and unpaired text adaptation to one step of joint training, decreasing the computational cost and training/adaptation time. 2) JEIT avoids the need for KLD regularization of the ILM output distribution.

To improve rare-word recognition, JEIT injects unpaired text into the label decoder of an E2E model to minimize LE2E(DP;θE2E)\mathcal{L}_{\text{E2E}}(\mathcal{D}_{\text{P}};\theta_{\text{E2E}}) and LILM(DUP;θILM)\mathcal{L}_{\text{ILM}}(\mathcal{D}_{\text{UP}};\theta_{\text{ILM}}) while JOIST injects it through the encoder to minimize LE2E(DP;θE2E)\mathcal{L}_{\text{E2E}}(\mathcal{D}_{\text{P}};\theta_{\text{E2E}}) and LE2E(DUP;θE2E)\mathcal{L}_{\text{E2E}}(\mathcal{D}_{\text{UP}};\theta_{\text{E2E}}). To benefit from both methods, we proposed a combined JEIT and JOIST training (CJJT) to minimize a weighted sum of an E2E loss derived from audio-transcript pairs, an E2E loss derived from unpaired text and an ILM loss derived from unpaired text as follows

We show in the experiments that JEIT and JOIST are complementary to each other and CJJT achieves better ASR performance than either method alone.

During inference, we can integrate an external LM into the E2E model after JEIT or CJJT to further improve the rare-word ASR. We show that LM fusion is complementary to both JEIT and CJJT even if the external LM is trained with same unpaired text DUP\mathcal{D}_{\text{UP}} as in JEIT.

Experiments

We use \sim650M multi-domain English audio-transcript pairs as supervised training data . It covers multiple domains including Voice Search, Dictation, YouTube, Telephony and etc. YouTube transcripts are generated in a semi-supervised fashion while other data is anonymized and hand-transcribed . In addition, multi-condition training , random 8kHz down-sampling and SpecAug are applied to augment and diversify the data.

The unpaired text used in training or adaptation consists of 100B anonymized sentences across the domains of Maps, Google Play, Web, and YouTube, and is more than two orders of magnitude larger than audio-transcript pairs. The external LM is trained with 50% transcripts of the paired data and 50% unpaired text to ensure the quality on base Voice Search task does not degrade.

We evaluate our models on a Voice Search (VS) test set containing \sim12K anonymized and hand-transcribed voice search utterances with an average duration of 5.5 s. To evaluate ASR performance on long-tail words, we construct rare-word test sets for each of the 5 domains: Maps, Google Play, Web and YouTube (YT) domains. All test sets include rare proper nouns that appear fewer than 5 times in the training set and are synthesized by a TTS system . Our goal is to improve the ASR accuracy on 4 rare-word test sets without degrading the WER on Voice Search.

2 Modeling

We train HAT and MHAT with 2-pass cascaded encoders and separate decoders as in . They share the same front-end and encoder architecture. Specifically, 128-dim log Mel filterbanks are extracted from speech signal and are subsampled to form a 512-dim feature every 30 ms. Each speech feature is appended with a 16-dim domain ID . The causal encoder is a 7-layer conformer with causal convolution and left-context attention. The non-causal encoder is a 10-layer conformer with right-context attention that processes 900 ms of speech into the future. Each conformer layer uses a 512-dim 8-head self-attention and a convolution kernel of size 15.

The causal and non-causal decoders of HAT or MHAT decode using the outputs of the causal and non-causal encoders, respectively. The label decoders of HAT and MHAT are 2-layer LSTMs with 2048 hidden units in each layer. In HAT, the label decoder output passes through a 640-dim feedforward joint network before projected to 4096 output units representing word pieces . In MHAT, the label decoder output is directly projected to the output layer of the same size. ILMs of HAT and MHAT have 30.7M and 30M parameters, respectively. The blank decoder of MHAT is a 320-dim V2V^{2} embedding decoder with a look-up table shared between the last 2 tokens and has 1.5M parameters. Overall, HAT and MHAT have in total 205M and 210M model parameters, respectively. We report only the 2nd pass WER in this paper. We train baselines with only audio-transcript pairs and show their WERs in Table 1.

Moreover, we train a 12-layer conformer LM with 384-dim self-attention and 3072-dim feedforward layer . The external LM has left attention context of 31 and has in total 70M parameters.

3 ILMA of HAT and MHAT

We first train an ILMT model with an ILM loss weight of 0.1 and use it as the seed for ILMA . For both ILMA and JEIT, we adopt minibatch sizes of 4,096 and 32,768 for paired audio-transcript data and unpaired text, respectively. During ILMA, a KLD regularization with a weight of 0.50.5 is applied for both HAT and MHAT. In Fig. 3, we plot the WERs of HAT ILMA and MHAT ILMA with respect to number of training steps. WER of HAT ILMA sharply increases after reaching its best one at 5K training step while the WER of MHAT gradually decreases until after 200K step. This is because without a structural factorization, HAT is not able to work with an increasingly stronger ILM and will lose its functionality of performing E2E ASR. This shows that MHAT is superior to HAT for ILMA because its structurally independent ILM allows MHAT to constantly improve its ASR capability as ILM becomes stronger. We list the best WERs of ILMA in Table 1.

4 JEIT of HAT and MHAT

We perform JEIT with the same minibatch sizes as ILMA and using ILM loss weights β\beta of 0.2 and 4.0 for HAT and MHAT, respectively. Significantly larger optimal ILM loss weight for MHAT signifies its advantage over HAT due to factorization - MHAT can work with increasingly stronger ILM to perform better ASR while HAT cannot. In Table 1, MHAT JEIT performs the best among all methods, achieving 4.8%–10.2% relative WER reduction from the baseline HAT on rare-word test sets. For MHAT, JEIT gets better WERs than ILMA on all test sets. MHAT JEIT consistently outperforms HAT JEIT by up to 3.5% relatively in terms of lower WER. As JEIT goes on, WERs of both HAT and MHAT reduce continuously without any sudden increase. This implies that E2E loss in JEIT serves as a much better regularization than KLD in ILMA. Overall, we show for the first time that joint training of ILM is better than adaptation.

5 JEIT with Different Decoders

We vary the type and size of the MHAT label decoders while keeping cascaded encoders in Section 4.2 unchanged. Besides LSTM, we explore simpler and smaller label decoders: V2V^{2} embedding and V4V^{4} embedding which have 640-dim embeddings and condition on the last 2 and 4 tokens, respectively. Each previous token has a separate look-up table. The same blank decoder in Section 4.2 is used. MHATs with V2V^{2} and V4V^{4} embedding decoders have 8.6M, 14.6M parameters for their ILMs and have in total 169M, 182M parameters, respectively. In Tables 2 and 1, JEIT of MHATs with V2V^{2} embedding, V4V^{4} embedding and LSTM decoders achieve 1.6%–4.1%, 1.6%-4.9% and 4.3%–8.8% relative WER reductions from the baseline MHAT on Maps, Play, Web and YT, respectively. For all 3 decoders, JEIT obtains no WER degradation on rare-word test sets. This shows that JEIT is beneficial to label decoders of various types and sizes. The effectiveness of JEIT increases as the ILM size grows and also as the label decoder’s conditioning history extends.

6 Combining JEIT with Other Text Injection Methods

We train JOIST MHAT with phoneme-based unpaired text following the setup in . The text encoder output is fed to the 3rd conformer layer of causal encoder. Unpaired E2E loss weight α\alpha is 0.25. We conduct combined JEIT and JOIST training (CJJT) with an ILM loss weight β\beta of 1.5. We subtract ILM scores during LM fusion.

CJJT consistently outperforms both JOIST and JEIT, indicating text injection into the encoder and decoder are complementary and the gains are additive. It is worth noting that CJJT achieves similar or even better WER than LM fusion with base MHAT on rare-word test sets, despite having 70M fewer model parameters. LM fusion with JEIT/CJJT MHAT achieves 2.3%–12.8% additional gains relatively, so conducting JEIT or CJJT early on is extremely beneficial to LM fusion. LM fusion with CJJT performs better than with JEIT, suggesting JEIT, JOIST and LM fusion are complementary to each other. Finally, we perform CJJT, minimum word error rate (MWER) training and LM fusion, and obtain the best WER over all systems with 17.9%–31.3% relative WER reductions from the baseline.

Conclusion

We propose JEIT to inject unpaired text into ILM via a single-stage joint training. JEIT simplifies two-stage ILMA and eliminates KLD regularization, achieving up to 10.2% relative WER reductions from baseline on rare-word test sets. MHAT performs better than HAT after JEIT, and is much more robust than HAT during ILMA. Text injection into encoder and decoder are complementary, combining them (CJJT) achieves up to 16.4% relative gain. LM fusion further improves all text-injection methods by up to 12.8% relatively.

References