BEST-RQ Walkthrough

by ryan | May 18, 2024

This post is a summarizes my work during the first four months as a PhD Student which got published at the IEEE ICASSP 2024 SASB Workshop called Open Implementation and Study of BEST-RQ for Speech Processing (link to paper here). Code here (this is a pull request and I'm currently working on cleaning it up).

Outline of the walkthrough:

What is BEST-RQ?

BEST-RQ is a recent model for automatic speech recognition (ASR) i.e. where a computer takes in an audio and generates a transcription. It was developed by researchers from Google (2022) and is the model behind Google USM (2023).

Why BEST-RQ?

My PhD topic is about Efficient and Effective Self-Supervised Learning Models for Speech.

If this doesn’t make sense to you thats ok, just know my topic involves efficiency (i.e. the speed of these models, the time it takes to train them, the amount of data these models need, …).

So when you start a PhD what do you do? I don’t know what others do, but I decided to use these smart people called my advisors and ask them for advice. They recommended looking at the recent research in this field and start from there.

So after discussions with my advisors and going through around 100 papers (most of them I just looked at the title and abstract, so this didn’t take too too long), I thought Google’s BEST-RQ would be a good place to start.

Heres why:

  • It looked fast/efficient
  • Had really good performance (i.e. the model produced accurate transcripts).
  • It looked simple (good for someone new to the field and starting a PhD right?? It actually turned out to be a lot harder that what I thought)

I struggled quite a bit with getting this to work, and in my paper I couldn’t go into much detail. So, I thought this walkthrough could be an easy-to-read and a more-in-depth description of my work.

The Walkthrough

In this section I give an overview of the architecture of BEST-RQ along with how I implemented each part. For my experiments, I compare my implementation of BEST-RQ with wav2vec 2.0 (a widely used model invented by Meta researchers). After describing the implementation, I conclude with a bullet-point summary of my experiments and results.

Architecture and my implementation in SpeechBrain

Mel-Spectrograms:

  • Unlike wav2vec 2.0 which takes in the raw audio signal as input, BEST-RQ starts off with Mel-Spectrograms (for more info on Mel-Specgrograms see this post).
  • In SpeechBrain, this is pretty easy to do and can be done with the Fbank class.

Normalization:

  • In my implementation I use SpeechBrain’s InputNormalization class.
  • The input data is normalized. This is important to prevent the random projection from collapsing to only use a small set of the codebook (more on this in the Random Projection Quantizer Section below).
  • I use the global type, which calculated a moving mean and standard deviation of the whole data.
  • Now, writing this post and reflect on this, I realize that this might not have been the best option (although I did get descent results), and I would like to try it with sentence as the norm_type.
    • The reason why I think this is that when you try to do transcriptions on data from a different dataset, chances are is that the mean and standard devotion will be different, yet with the global you will try to normalize with the mean and standard deviation of the training dataset.
  • After the input is converted to mel-spectrograms and normalized, this new normalized input will be sent down two different paths:
    • 1) the Random Projection Quantizer (which generates targets for the model)
    • 2) the mask and the model (the output of which is used to attempt to predict the targets of the quantizer)
    • Here is a figure below to help see these two paths and how they fit into the whole architecture:

  • We can define both the mel-spectrogram and normalization functions in a .yaml file which will be used by Speechbrain to create python objects.
# in yaml file

# define variables
sample_rate: 16000
n_fft: 400
n_mels: 80

# for calculating mel-spectrogram
compute_features: !new:speechbrain.lobes.features.Fbank
    sample_rate: !ref <sample_rate>
    n_fft: !ref <n_fft>
    n_mels: !ref <n_mels>
    
normalize: !new:speechbrain.processing.features.InputNormalization
    norm_type: global
    update_until_epoch: 4
    
    
  • Then this is used in the train.py file in the following way.
# in train.py

feats = self.hparams.compute_features(wavs)
# get current epoch
current_epoch = self.hparams.epoch_counter.current
# normalize input
feats = self.modules.normalize(feats, wav_lens, epoch=current_epoch)

Masking:

In BEST-RQ, the authors describe a ‘masking’ strategy where they randomly choose one percent of frames to be a ‘starting’ frame for a mask and that frame and the following three frames are selected to be ‘masked’.

In the case of BEST-RQ, ‘to mask’ means to replace the frames selected with random noise from a normal distribution (mean of 0 and standard deviation of 0.1). Basically, the goal of the pre-training will be to try to get the model to use the unmasked audio sections reconstruct the masked sections. The idea with this is that if a model can reconstruct or ‘guess’ masked section then the model must have learned good representations of audio that can then be used for transcribing or other audio tasks.

Although conceptually simple to understand, the masking part was one of the hardest parts for me to code. The trouble for me was that dimensions get reduced 4x with CNN layers. My biggest question was if we randomly mask frames and then reduce the dimensionality by 4x, there is a chance that the mask will fall between indices.

For example if you randomly select frame 2 to be the starting mask, and then mask the next 3 frames, you will mask frames [2,3,4,5]. But when the dimensions get reduced, and we want to predict the masks targets (more on how these target are created in the next section).

So do we predict for index 0 (originally consisting of frames [0,1,2,3]) or index 1, (original consisting of [4,5,6,7]). We could even predict both!?

We asked the authors (#thanksauthorsforbeingsoresponsive), and they said they only predicted on indices that were complexly masked. So I decided so code up my mask in the simplest way I could think of.

  • The starting frames of the mask will randomly chosen from a subset of indices that are divisible by 4 (i.e. 0,4,8,…).
  • The indices of starting frames and next three frames are used to replace the those sections of the input with noise.
  • The starting frames are divided by 4 to create are indices of the output that need to be predicted

This way, every index that the model predicts will be a fully masked section. And to make sure that everything behaves well we pad the spectrogram, if needed, to always have a length divisible by 4.

Here is the code for the padding in the train.py file. (For the code for the mask see mask.py)

# Calculate the amount of padding needed to make the tensor divisible by 4
current_dim_size = feats.shape[dim_to_pad]
padding_needed = (4 - (current_dim_size % 4)) % 4  # Ensure positive padding

# Define the padding
padding = [0, 0, 0, 0, 0, 0]  # Initialize padding for all dimensions
padding[dim_to_pad * 2] = padding_needed  # Set padding for the chosen dimension

# add in padding to features and mask
feats = torch.nn.functional.pad(feats, padding)

Random Projection Quantizer:

  • The Random Projection Quantizer is basically the main part of the original BEST-RQ paper.
  • Like the mask, the concept was pretty simple to understand yet it was hard for me to implement. In the end, all my hard work that went into implementing the Random Projection Quantizer turned into only 24 lines of code so I’ll just paste at the end of this section.
  • The whole idea of the random projection quantizer is to create labels or targets for the model to try to predict. The input to the quantizer is the unmasked spectrogram stacked every 4 frames.
  • (Side note: I do this stacking with the pytorch .view() function and then give that to the quantizer. This is found in the train.py file).
  • The stacking of 4 frames make it so that the dimensions match (because the CNN layers reduce the dimensionality by 4). In other words, for each frame in the model’s output, the quantizer will have generated one target.
  • The ramdom quanizer has two main parts both of which are randomly initialized:
    • the projection matrix
    • and a codebook
  • The stacked spectrogram is then multiplied with the projection matrix and then there is a look up in the codebook
  • To do the codebook look up we need to take every frame of the projected spectrogram and find the row in the codebook that is closest to it. See code below for implementation.

💡 Tricky parts to remember

  • the projection is initialized with the xavier uniform distribution
  • the codebook is normalized
  • the projection is normalized (x @ self.P)

Here is the code.

# quantizer.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.linalg import vector_norm

class RandomProjectionQuantizer(nn.Module):

    def __init__(self, input_dim, cb_dim, cb_vocab):
        super().__init__()

        self.input_dim = input_dim
        self.cb_dim = cb_dim
        self.cb_vocab = cb_vocab

        # Section 3.1 "projection matrix A use Xavier initialization"
        P_init = torch.empty((input_dim, cb_dim))
        self.register_buffer("P", nn.init.xavier_uniform_(P_init))

        # normalize random matrix for codebook
        self.register_buffer("CB", F.normalize(torch.randn(cb_vocab, cb_dim)))

    def forward(self, x):
        x = F.normalize(x @ self.P, dim=2)
        return vector_norm((self.CB.unsqueeze(1) - x.unsqueeze(1)), dim=-1).argmin(dim=1)
        
        # this line of code above is very condensed and thus confusing
        # basically this is to avoid doing for loops and find the closest
        # entry in the code book for each projected frame
        # this single line of code honestly merits its own tutorial
        # let me know if you'd like me to write one :)

CNN and Conformer Layers:

  • BEST-RQ then takes the masked spectrogram and puts it through a series of 2 CNN and 24 conformer layers (for my experiments I only work with 12 Conformer layers in order to be able to run preliminary experiments faster).
  • CNN and Conformer are already implement in SpeechBrain so I use it with the following code following the same patter in SpeechBrain where you define in the .yaml file and then use in the train.py .
# Transformer parameters
d_model: 576
nhead: 8
num_encoder_layers: 12
num_decoder_layers: 0
d_ffn: 2048
transformer_dropout: 0.1
activation: !name:torch.nn.GELU
output_neurons: 5000
encoder_layerdrop: 0.05

CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd
    input_shape: (8, 10, 80)
    num_blocks: 2
    num_layers_per_block: 1
    out_channels: (128, 32)
    kernel_sizes: (3, 3)
    strides: (2, 2)
    residuals: (False, False)

Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR # yamllint disable-line rule:line-length
    input_size: 640
    tgt_vocab: !ref <output_neurons>
    d_model: !ref <d_model>
    nhead: !ref <nhead>
    num_encoder_layers: !ref <num_encoder_layers>
    num_decoder_layers: !ref <num_decoder_layers>
    d_ffn: !ref <d_ffn>
    dropout: !ref <transformer_dropout>
    activation: !ref <activation>
    encoder_module: conformer
    attention_type: RelPosMHAXL
    normalize_before: True
    causal: False
    layerdrop_prob: !ref <encoder_layerdrop>
    
# I use the following wrapper so the decoder isn't run 
# This is because by default the TransformerASR will try to run a decoder
# but we don't have any decoder layers
wrapper: !new:speechbrain.lobes.models.transformer.TransformerASR.EncoderWrapper
   transformer: !ref <Transformer>
# convolutions
src = self.modules.CNN(feats)
# conformer layers (use wrapper so that decoder isn't used)
enc_out = self.modules.wrapper(src, wav_lens)

Loss:

  • The output of the last Conformer layer is then put through a liner layer that projects the output of each time step of the Conformer into the same dimensions as the size of the codebook.
  • This is then put through a cross entropy loss function and voila we have finished with all the main parts of the architecture!
    • For those who aren’t familiar with cross entropy, basically is takes a vector (in our cass the output of the linear layer), turns it into a probability distribution, and then compares it to the target.
    • The idea is that we want the model to produce vectors that give the highest probability to the index that is the same as the target.
      • Ex. suppose the model outputs
        • output of model / logits = [2.0, 1.0, 0.1]
      • pytorch’s cross entropy performs a softmax operation, turning this into a probability distribution (i.e. all the item in the vector sum to 1)
        • probabilities = [0.6590, 0.2424, 0.0986]
      • Now suppose the random projection quantizer produced a target of 0 (i.e. the 0th index is the correct target). This would result in a target vector of [1.0, 0.0, 0.0] were the 0th index is 1 (meaning correct target) and the rest are 0 (meaning incorrect target).
      • Then you take the negative likelihood of the target index.
        • -log(0.6590) = 0.4170
linear: !new:speechbrain.nnet.linear.Linear
    input_size: !ref <d_model>
    n_neurons: !ref <cb_vocab>
# in compute_forward function

	# linear layer to get logits
	logits = self.modules.linear(enc_out)
	
	# get starting indicies of masked area
	mask_idx = mask[::divis_by] // divis_by
	
	# get logits of masked area
	logits = logits[:,mask_idx,:]
	# get targets of masked area
	targets = targets[:,mask_idx]
	B, T, C = logits.shape
	
	# reshape flattenting out batch dimension
	# this makes it so we can use the loss function easier
	# then we return these two values (the logits and targets)
	# to be passed on the compute_objectives function
	return logits.view(B * T, C), targets.view(B*T)

# in compute_objectives function
	pred, targets = predictions      
	loss = F.cross_entropy(pred, targets)	
	return loss

And voila! Those are all the main components to BEST-RQ.

Experiments and Results

The details of the experiments are described in more detail the paper but I give a list below to summarize the experiments and results.

  • Models
    • My implementation of BEST-RQ
    • wav2vec 2.0 (SpeechBrain Implementation)
  • Pre-training Setting (data, GPUs, and epochs)
    • Data: Librispeech
    • GPUs: 8 x 32GB V100
    • 42 epochs (or roughly 200k)
    • Batch Size: ~13 min
  • Preliminary Experiments
    • experimented with varying the mask percentage
    • GPUs: 4 x 11GB 2080Ti
    • 18 epochs or about 87k
  • MP3S Benchmark
    • A series of downstream tasks for benchmarking where the pre-trained model is frozen. The task include the following:
      • ASR (done with LibriSpeech dataset)
        • goal: accurate transcriptions
      • Speaker Verification (Voxceleb dataset)
        • goal: given two audio files predict if speaker is the same
      • Intent Classification (SLURP dataset)
        • goal: given an utterance predict intent (i.e. create appointment, play music, …)
      • Emotion Recognition (IEMOCAP)
        • goal: given an utterance predict emotion (angry, sad, …)
  • Finetuning
    • ASR on LibriSpeech
  • Main Results
    • Masking percentage and learning rate were very important
    • Codebook size didn’t matter too much
      • This could matter more when dealing with a lot more data or more diverse data (i.e. data with background noise or multiple language)
    • Best-RQ performs similarly yet trains 2.4x’s faster than wav2vec 2.0

Future Work and Concluding thoughts

Although conceptually simple, BEST-RQ has a lot of tricky details that if not implement probably will make it so that the model doesn’t perform well. I hope this post can help clarifiy some of the details.

In future work, I would like to scale up the size of my BEST-RQ implementation and experiment with by altering the architecture and with different datasets.

asr tutorial cs project