Whisper is an automatic speech recognition (ASR) model that has been trained using 680,000 hours of supervised data from the web, spanning a variety of languages and tasks. One of its limitations is poor performance in low-resource languages such as Marathi and Dravidian languages, which can be remedied with adjustments. However, tuning a Whisper model has become a considerable challenge, both in terms of computational resources and storage requirements. Five to ten full tuning runs for Whisper models require approximately 100 hours of A100 GPU (40 GB SXM4) (varies depending on model sizes and parameters), and each tuned checkpoint requires about 7 GB of storage space . This combination of high computational and storage demands can pose significant obstacles, especially in resource-constrained environments, often making it exceptionally difficult to achieve meaningful results.
Low-rank adaptation, also known as lora, takes a unique approach to model fitting. It keeps the weights of the pre-trained model in a static state and introduces trainable rank decomposition matrices into each layer of the Transformer structure. This method can decrease the number of trainable parameters required for subsequent tasks by 10,000 times and reduce the GPU memory requirement by 3 times. In terms of model quality, LoRA has been shown to match or even exceed the performance of traditional fine-tuning methods, despite operating with fewer trainable parameters (see results of the original study). LoRA paper). It also offers the benefit of increased training performance. Unlike adapter methods, LoRA does not introduce additional latency during inference, thus maintaining model efficiency during the deployment phase. Fine-tuning Whisper using LoRA has shown promising results. Take Whisper-Large-v2 as an example: running 3 epochs with a common 12 hour speech dataset on a GPU with 8GB memory takes 6-8 hourswhich is 5 times faster than full tuning with comparable performance.
Amazon SageMaker is an ideal platform for implementing Whisper’s LoRA tuning. Amazon SageMaker lets you build, train, and deploy machine learning models for any use case with fully managed infrastructure, tools, and workflows. Additional benefits of model training can include lower training costs with Managed Spot Training, distributed training libraries for splitting models and training data sets on AWS GPU instances, and more. Trained SageMaker models can be easily deployed for inference directly in SageMaker. In this post, we present a step-by-step guide to implementing LoRA fine-tuning in SageMaker. The source code associated with this implementation can be found at GitHub.
Prepare the data set for adjustments
We used the low-resource Marathi language for the adjustment task. Using the Hugging Faces Datasets library, you can download and split the Common Voice data set into training and testing data sets. See the following code:
The Whisper speech recognition model requires that audio inputs be 16kHz Mono 16bit Signed Integer WAV Files. Because the Common Voice dataset has a 48K sample rate, you will first need to downsample the audio files. Next, you need to apply Whisper’s feature extractor to the audio to extract features from the log-mel spectrogram and apply Whisper’s tokenizer to the framed features to convert each sentence in the transcript into a symbolic ID. See the following code:
After you have processed all the training samples, upload the processed data to Amazon S3, so that when you use the processed training data in the tuning stage, you can use FastFile to mount the S3 file directly instead of copy it to local disk:
Train the model
For the demo, we use Whisper-large-v2 as a pre-trained model (whisper v3 is now available), which can be imported via the Hugging Face transformer library. You can use 8-bit quantization to further improve training efficiency. 8-bit quantization provides memory optimization by rounding from floating point to 8-bit integers. It is a model compression technique commonly used to save reduced memory without sacrificing too much accuracy during inference.
To load the pre-trained model in 8-bit quantized format, we simply add the load_in_8bit=True argument when instantiating the model, as shown in the following code. This will load the quantized model weights to 8 bits, reducing the memory footprint.
We use Hugging Face’s LoRA implementation peft package. There are four steps to fitting a model using LoRA:
- Create an instance of a base model (as we did in the last step).
- Create a configuration (
LoraConfig
) where LoRA-specific parameters are defined. - Wrap the base model with
get_peft_model()
to get a trainablePeftModel
. - Train the
PeftModel
as a base model.
See the following code:
To run a SageMaker training job, we bring our own Docker container. You can download the Docker image from GitHub, where ffmpeg4 and git-lfs are packaged together with other Python requirements. For more information about adapting your own Docker container to work with SageMaker, see Adapting your own training container. You can then use Hugging Face Estimator and start a training job in SageMaker:
The LoRA implementation allowed us to run the large Whisper fine-tuning task on a single GPU instance (e.g. ml.g5.2xlarge). In comparison, Whisper’s large full tuning task requires multiple GPUs (e.g., ml.p4d.24xlarge) and a much longer training time. More specifically, our experiment showed that the full tuning task requires 24 times more GPU hours compared to the LoRA approach.
Evaluate model performance
To evaluate the performance of the tuned Whisper model, we calculated the word error rate (WER) on a held-out test set. WER measures the difference between the predicted transcription and the actual transcription. A lower WER indicates better performance. You can run the following script on the pre-trained model and the fitted model and compare their WER difference:
Conclusion
In this post, we demonstrate how to tune Whisper, a next-generation speech recognition model. In particular, we use Hugging Face’s PEFT LoRA and enable 8-bit quantization for efficient training. We also demonstrate how to run the training job in SageMaker.
Although this is an important first step, there are several ways to leverage this work to further improve the whisper model. In the future, consider using SageMaker Distributed Training to scale training on a much larger data set. This will allow the model to be trained with more varied and complete data, improving accuracy. You can also optimize latency by serving the Whisper model to enable real-time speech recognition. Additionally, you could extend the work to handle longer audio transcripts, which requires changes to the model architecture and training schemes.
Recognition
The authors extend their gratitude to Paras Mehra, John Sol and Evandro Franco for their valuable comments and review of the publication.
About the authors
Jun Shi is a Senior Solutions Architect at Amazon Web Services (AWS). His current focus areas are ai/ML infrastructure and applications. He has over a decade of experience in the FinTech industry as a software engineer.
Dr. Changsha Ma is an ai/ML specialist at AWS. She is a technologist with a PhD in Computer Science, a master’s degree in Educational Psychology, and years of experience in data science and independent ai/ML consulting. She is passionate about researching methodological approaches to human and machine intelligence. Outside of her job, she loves hiking, cooking, hunting, and spending time with friends and family.