vq-wav2vec: Self-Supervised Learning of Discrete Speech Representations
Alexei Baevski, Steffen Schneider, Michael Auli
Introduction
Learning discrete representations of speech has gathered much recent interest (Versteegh et al., 2016; Dunbar et al., 2019). A popular approach to discover discrete units is via autoencoding (Tjandra et al., 2019; Eloff et al., 2019; Chorowski et al., 2019) sometimes coupled with an autoregressive model (Chung et al., 2019). Another line of research is to learn continuous speech representations in a self-supervised way via predicting context information (Chung & Glass, 2018; van den Oord et al., 2018; Schneider et al., 2019).
In this paper, we combine these two lines of research by learning discrete representations of speech via a context prediction task instead of reconstructing the input. This enables us to directly apply well performing NLP algorithms to speech data (Figure 1(a)).
Our new discretization algorithm, vq-wav2vec, learns discrete representations of fixed length segments of audio signal by utilizing the wav2vec loss and architecture (Schneider et al, 2019; §2). To choose the discrete variables, we consider a Gumbel-Softmax approach (Jang et al., 2016) as well as online k-means clustering, similar to VQ-VAE (Oord et al., 2017; Eloff et al., 2019; §3).
We then train a Deep Bidirectional Transformer (BERT; Devlin et al., 2018; Liu et al., 2019) on the discretized unlabeled speech data and input these representations to a standard acoustic model (Figure 1(b); §4). Our experiments show that BERT representations perform better than log-mel filterbank inputs as well as dense wav2vec representations on both TIMIT and WSJ benchmarks. Discretization of audio enables the direct application of a whole host of algorithms from the NLP literature to speech data. For example, we show that a standard sequence to sequence model from the NLP literature can be used to perform speech recognition over discrete audio tokens (§5, §6).
Background
where is the sequence length, , and where is the probability of being the true sample. We consider a step-specific affine transformation that is applied to (van den Oord et al., 2018). We optimize the loss , summing (1) over different step sizes. After training, the representations produced by the context network are input to the acoustic model instead of log-mel filterbank features.
2 BERT
BERT (Devlin et al., 2018) is a pre-training approach for NLP tasks, which uses a transformer encoder model to build a representation of text. Transformers uses self-attention to encode the input sequence as well as an optional source sequence (Vaswani et al., 2017). The original BERT model combined two tasks for training: first, masked language modeling randomly removes some of the input tokens and the model has to predict those missing tokens. Second, next sentence prediction splices two different text passages together into a single example and the model needs to predict whether the passages are from the same document.
VQ-Wav2Vec
Our approach, vq-wav2vec, learns vector quantized (VQ) representations of audio data using a future time-step prediction task. We follow the same architectual choices as wav2vec (§2.1) with two convolutional networks and for feature extraction and aggregation, as well as a new quantization module to build discrete representations (Figure 1(a)).
We first map 30ms segments of raw speech to a dense feature representation at a stride of 10ms using the encoder network . Next, the quantizer () turns these dense representations into discrete indices which are mapped to a reconstruction of the original representation . We feed into the aggregator and optimize the same context prediction task as wav2vec outlined in §2.1.
where and are uniform samples from . During the forward pass, and in the backward pass, the true gradient of the Gumbel-Softmax outputs is used.
2 K-Means
The vector quantization approach of van den Oord et al. (2017) is an alternative to making the index selection procedure fully differentiable. Different to their setup, we optimize a future time step prediction loss instead of the reconstruction loss of an autoencoder.
We choose the codebook variable representation by finding the closest variable to the input features in terms of the Euclidean distance, yielding . During the forward pass, we select by choosing the corresponding variable from the codebook. We obtain gradients for the encoder network by back-propagating (van den Oord et al., 2017). The final loss has two additional terms:
where is the stop gradient operator and is a hyperparameter. The first term is the future prediction task and gradients do not change the codebook because of the straight-through gradient estimation of mapping to . The second term moves the codebook vectors closer to the encoder output, and the third term makes sure that the encoder outputs are close to a centroid (codeword).
3 Vector Quantization with multiple variable groups
So far, we considered replacing the encoder feature vector by a single entry in the codebook. This is prone to mode collapse where only some of the codewords are actually used. Previously, this problem has been mitigated by workarounds such as re-initializing codewords or applying additional regularizers to the loss function (Caron et al., 2019). In the following, we describe another strategy where we independently quantize partitions of , similar to product quantization (Jegou et al., 2011). This results in larger dictionaries and increased downstream performance (Appendix A).
BERT Pre-Training on Quantized Speech
Once we trained a vq-wav2vec model we can discretize audio data and make it applicable to algorithms that require discrete inputs. One possibility is to use the discretized training data and apply BERT pre-training where the task is to predict masked input tokens based on an encoding of the surrounding context (Devlin et al., 2018). Once the BERT model is trained, we can use it to build representations and feed them into an acoustic model to improve speech recognition. We follow recent advances in BERT training which only use the masked input token prediction (Liu et al., 2019).
Since each of the discretized tokens represents around 10 ms of audio it is likely too easy to predict a single masked input token. We therefore change BERT training by masking spans of consecutive discretized speech tokens, similar to Joshi et al. (2019). To mask the input sequence, we randomly sample of all tokens to be a starting index, without replacement, and mask consecutive tokens from every sampled index; spans may overlap. This makes the masked token prediction harder and we show later that it improves accuracy over masking individual tokens (§6.5).
Experimental Setup
We generally pre-train vq-wav2vec and BERT on the full 960h of Librispeech (Panayotov et al., 2015) and after vq-wav2vec training it is discretized to 345M tokens. Where indicated we perform ablations on a clean 100h subset which is discretized to 39.9M tokens. We evaluate models on two benchmarks: TIMIT (Garofolo et al., 1993b) is a 5h dataset with phoneme labels and Wall Street Journal (WSJ; Garofolo et al. 1993a) is a 81h dataset for speech recognition. For TIMIT, we apply the standard evaluation protocol and consider 39 different phonemes. For WSJ, we train acoustic models directly on 31 graphemes, including the English alphabet, the apostrophe, the silence token and tokens for repeating characters.
2 vq-wav2vec
We adapt the fairseq implementation of wav2vec (Schneider et al., 2019; Ott et al., 2019) and use vq-wav2vec/wav2vec models with parameters. The encoder has 8 layers with 512 channels each, kernel sizes (10,8,4,4,4,1,1,1) and strides (5,4,2,2,2,1,1,1), yielding a total stride of 160. Each layer contains a convolution, followed by dropout, group normalization with a single group (Wu & He, 2018) and a ReLU non-linearity. The aggregator is composed of 12 layers, with 512 channels, stride 1, and kernel sizes starting at 2 and increasing by 1 for every subsequent layer. The block structure is the same as for the encoder network, except we introduce skip connections between each subsequent block.
We train with the wav2vec context prediction loss (Equation 1) for 400k updates, predicting steps into the future and sample 10 negatives from the same audio example. Training is warmed up for 500 steps where the learning rate is increased from to , and then annealed to using a cosine schedule (Loshchilov & Hutter, 2016). The batch size is 10, and we crop a random section of 150k frames for each example (approximately 9.3 seconds for 16kHz sampling rate). All models are trained on 8 GPUs.
For ablations and experiments on the 100h Librispeech subset, we use a smaller model with kernels (10,8,4,4,4) and strides (5,4,2,2,2) in the encoder and seven convolutional layers with stride one and kernel size three in the aggregator. This model is trained for 40k updates.
We use groups and latents per group and the linear layer projects the features produced by the encoder into logits. The Gumbel-Softmax produces a one-hot vector for each group . The temperature is linearly annealed from 2 to 0.5 over the first 70% of updates and then kept constant at 0.5. This enables the model to learn which latents work best for each input before committing to a single latent. After training this model on 960h of Librispeech and quantizing the training dataset, we are left with 13.5k unique codewords combinations (out of = 102k possible codewords).
k-means Models.
We use groups and variables per group. vq-wav2vec on full Librispeech yields 23k unique codewords. Following van den Oord et al. (2017), we found to be a robust choice for balancing the VQ auxiliary loss.
3 BERT
BERT small. For ablations we use a smaller setup with model dimension 512, FFN size 2048, 8 attention heads and dropout 0.05. Models are trained for 250k updates with a batch size of 2 examples per GPU.
4 Acoustic Model
Results
We first evaluate on the WSJ speech recognition benchmark. We train a vq-wav2vec model on the unlabeled version of Librispeech, then discretize the same data with the resulting model to estimate a BERT model. Finally, we train a wav2letter acoustic model on WSJ by inputting either the BERT or vq-wav2vec representations instead of log-mel filterbanks.For vq-wav2vec we input the dense representations corresponding to the learned discrete units.
We compare to various results from the literature, including wav2vec (Schneider et al., 2019) and we consider three setups: performance without any language model (No LM), with an n-gram LM (4-gram LM) and with a character convolutional LM (Char ConvLM). We report the accuracy of wav2letter with log-mel filterbanks as input (Baseline) and wav2vec. For vq-wav2vec we first experiment with the Gumbel-Softmax, with and without a BERT base model (§5.3).
Table 1 shows that vq-wav2vec together with BERT training can achieve a new state of the art of 2.34 WER on nov92. Gains are largest when no language model is used which is the fastest setting. vq-wav2vec with Gumbel-Softmax uses only 13.5k distinct codewords to represent the audio signal and this limited set of codewords is not sufficient to outperform the baseline. However, it does enable training BERT models which require a relatively small vocabulary.
Next, we compare Gumbel-Softmax to k-means for vector quantization. For this experiment we use the faster to train BERT small configuration (§5.3). We also train a vq-wav2vec k-means model with a very large number of codewords (39.9M) to test whether a more expressive model can close the gap to wav2vec. Table 2 shows that Gumbel-Softmax and k-means clustering perform relatively comparably: in the no language model setup without BERT, Gumbel-Softmax is more accurate than k-means but these differences disappear with BERT. For 4-gram LM setup, k-means is better but those differences disappear again after BERT training. Finally, the large codeword model can substantially reduce the gap to the original wav2vec model.
2 TIMIT Phoneme Recognition
Next, we experiment on the much smaller TIMIT phoneme recognition task where we also pre-train vq-wav2vec on the full Librispeech corpus. Table 3 shows that vq-wav2vec and BERT achieve a new state of the art of 11.64 PER which corresponds to a 21% reduction in error over the previous best result of wav2vec.
3 Sequence to Sequence Modeling
So far we used vq-wav2vec to train BERT on discretized speech. However, once the audio is discretized we can also train a standard sequence to sequence model to perform speech recognition. In preliminary experiments, we trained an off-the-shelf Big Transformer (Vaswani et al., 2017; Ott et al., 2019) on the vq-wav2vec Gumbel-Softmax discretized Librispeech corpus and evaluated on the Librispeech dev/test sets; we use a 4k BPE output vocabulary (Sennrich et al., 2016). Table 4 shows that results are promising, even though they are not as good as the state of the art (Park et al., 2019) which depends on data augmentation that we do not use.
4 Accuracy vs. Bitrate
Next, we investigate how well vq-wav2vec can compress the audio data. Specifically, we train models with different numbers of groups and variables to vary the size of the possible codebook size and measure accuracy on TIMIT phoneme recognition without BERT training.
We measure compression with the bitrate at sampling rate and report the trade-off between bitrate and accuracy on our phoneme recognition task. We experiment with vq-wav2vec k-means and train models with 1,2,4,8,16 and 32 groups, using 40,80,160,…,1280 variables, spanning a bitrate range from 0.53 kbit/s (G = 1, V = 40) to 33.03 kbit/s (G = 32, V = 1280). We place the quantization module after the aggregator module and train all models in the small vq-wav2vec setup (§5.2) on the 100h clean Librispeech subset.
As baselines, we consider various lossy compression algorithms applied to the TIMIT audio data and train wav2letter models on the resulting audio: Codec2https://github.com/drowe67/codec2 as a low bitrate codec, Opus (Terriberry & Vos, 2012) as a medium bitrate codec and MP3 and Ogg Vorbis (Montgomery, 2004) as high bitrate codecs. We use the whole spectrum of both variable and constant bitrate settings of the codecs; we encode and decode with ffmpeg (ffmpeg developers, 2016). Figure 3 shows the trade-off between the bitrate and TIMIT accuracy. Acoustic models on vq-wav2vec achieve the best results across most bitrate settings.
5 Ablations
Figure 4(a) shows that masking entire spans of tokens performs significantly better than individual tokens (). Furthermore, BERT training on discretized audio data is fairly robust to masking large parts of the input (Figure 4(b)).
Conclusion
vq-wav2vec is a self-supervised algorithm that quantizes unlabeled audio data which makes it amenable to algorithms requiring discrete data. This approach improves the state of the art on the WSJ and TIMIT benchmarks by leveraging BERT pre-training. In future work, we plan to apply other algorithms requiring discrete inputs to audio data and to explore self-supervised pre-training algorithms which mask part of the continuous audio input. Another future work avenue is to fine-tune the pre-trained model to output transcriptions instead of feeding the pre-trained features to a custom ASR model.
References
Appendix A Number of variables vs. Groups
We investigate the relationship between number of variables and groups . Table 6 shows that multiple groups are beneficial compared to a single group with a large number of variables. Table 7 shows that with a single group and many variables, only a small number of codewords survive.