# Amazon SageMaker Real-Time Hosting with Whisper Transcription

This notebook show's how to use [SageMaker's real-time inference endpoints](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html) to host [OpenAI's Whisper](https://github.com/openai/whisper) model for audio-to-text transcription in real time. In this notebook you will...

1. Install the whisper library
2. Load a whisper model
3. Run inference locally on an example audio dataset
4. Serialize the whisper model to S3
5. Create a SageMaker model
6. Deploy the SageMaker model to a real-time endpoint
7. Run inference on the SageMaker endpoint
8. Tear down the SageMaker endpoint

# Install Whisper and Import Libraries

The cells below will install the Python packages needed to use Whisper models and evaluate the transcription results.

In [None]:
%pip install openai-whisper --quiet
%pip install torchaudio --quiet

In [None]:
import os
import multiprocessing
import numpy as np
import torch
import pandas as pd
import whisper
import torchaudio
import sagemaker
import time
from tqdm.notebook import tqdm

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Loading the LibriSpeech dataset

Once the libraries have been installed, we can use `torchaudio` to load a dataset which will provide example audio inputs for transcription.

In [None]:
class LibriSpeech(torch.utils.data.Dataset):
    """
    A simple class to wrap LibriSpeech and trim/pad the audio to 30 seconds.
    It will drop the last few seconds of a very small portion of the utterances.
    """
    def __init__(self, split="test-clean", device=DEVICE):
        self.dataset = torchaudio.datasets.LIBRISPEECH(
            root=os.path.expanduser("~/.cache"),
            url=split,
            download=True,
        )
        self.device = device

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        audio, sample_rate, text, _, _, _ = self.dataset[item]
        assert sample_rate == 16000
        audio = whisper.pad_or_trim(audio.flatten()).to(self.device)
        mel = whisper.log_mel_spectrogram(audio)
        
        return (mel, text)

In [None]:
dataset = LibriSpeech("test-clean")
loader = torch.utils.data.DataLoader(dataset, batch_size=16)

# Run example inference locally using a base Whisper model

Now that the dataset has been created, we can download a whisper model using the `whisper.load_model` function. In this example we will be using the `base.en` model, but there are larger models also available for download. Once the model is downloaded, you will use it to run an illustrative example inference call locally on the notebook before you deploy this model to a SageMaker endpoint.

In [None]:
model = whisper.load_model("base.en")
print(
    f"Model is {'multilingual' if model.is_multilingual else 'English-only'} "
    f"and has {sum(np.prod(p.shape) for p in model.parameters()):,} parameters."
)

In [None]:
audio, sample_rate, text, _, _, _ = dataset.dataset[0]
audio = whisper.pad_or_trim(audio.flatten()).to(DEVICE)
mel = whisper.log_mel_spectrogram(audio)
options = whisper.DecodingOptions(language="en", without_timestamps=True, fp16 = False)
out = model.decode(mel, options)
print(f'Example Transcription: \n{out.text}')

# SageMaker Inference

In this section, you will deploy the whisper model from the previous section to a real time API endpoint on Amazon SageMaker. You start this section by instantiating a sagemaker session and defining a path in Amazon S3 for your model artifacts shown below.

In [None]:
sess = sagemaker.session.Session()
bucket = sess.default_bucket()
prefix = 'whisper-demo-deploy/'
s3_uri = f's3://{bucket}/{prefix}'

## Create Model Artifacts in S3

You can now take the whisper model which was loaded previously and save it using PyTorch. Make sure you save both a model state as well as model dimensions to be compatible with the whisper library.

In [None]:
torch.save(
    {
        'model_state_dict': model.state_dict(),
        'dims': model.dims.__dict__,
    },
    'base.en.pt'
)

Once the model has been saved, you will package the model into a tar.gz file and upload it to Amazon S3. This serialized model will be the model artifact which is referenced for real-time inference.

In [None]:
!mkdir -p model
!mv base.en.pt model
!cd model && tar -czvf model.tar.gz base.en.pt
!mv model/model.tar.gz .
!tar -tvf model.tar.gz
model_uri = sess.upload_data('model.tar.gz', bucket = bucket, key_prefix=f"{prefix}model")
!rm model.tar.gz
!rm -rf model

## Create SageMaker Model Object

Once the model artifact has been uploaded to S3, you will use the SageMaker SDK to create a `model` object which references the model artifact in S3, one of SageMaker's PyTorch inference containers, and the inference code stored in the `src` directory in this repository. The `inference.py` is the code which is executed at runtime while the `requirements.txt` tells SageMaker to install the `whisper` library inside its Docker container.

In [None]:
image = sagemaker.image_uris.retrieve(
    framework='pytorch',
    region='us-west-2',
    image_scope='inference',
    version='1.12',
    instance_type='ml.g4dn.xlarge',
)

model_name = f'whisper-model-{int(time.time())}'
whisper_model_sm = sagemaker.model.Model(
    model_data=model_uri,
    image_uri=image,
    role=sagemaker.get_execution_role(),
    entry_point="inference.py",
    source_dir='src',
    name=model_name,
)

## Deploy to a Real Time Endpoint

Deploying the `model` object to sagemaker can be done with the `deploy` function. Notice that you will be using a `ml.g4dn.xlarge` instance type in order to take advantage of a AWS's low cost GPU instances for accelerated inference.

In [None]:
endpoint_name = f'whisper-endpoint-{int(time.time())}'
whisper_model_sm.deploy(
    initial_instance_count=1,
    instance_type="ml.g4dn.xlarge",
    endpoint_name=endpoint_name,
    wait=True,
)

## Test Inference

Once the model has deployed, you can connect to the endpoint using the `Predictor` class in the SageMaker SDK. This connection can then use the `predict` method in order to transcribe the same audio signal used previously in this notebook. Notice how the results are consistent across the local execution and the API call.

In [None]:
whisper_endpoint = sagemaker.predictor.Predictor(endpoint_name)
whisper_endpoint.serializer = sagemaker.serializers.NumpySerializer()

assert whisper_endpoint.endpoint_context().properties['Status'] == 'InService'

In [None]:
inp = audio.numpy()
out = whisper_endpoint.predict(inp)
print(f'Example Transcription: \n{out}')

## Sequential Latency Test

You can also run a latency test to see how fast the g4dn instance is able to process single input requests. The first cell will ensure the instance is warmed and the next cell will time the requests coming into the endpoint.

In [None]:
# warm up the instance
for i in range(10):
    out = whisper_endpoint.predict(inp)

In [None]:
%%timeit
out = whisper_endpoint.predict(inp)

## Optional: Clean Up Endpoint

Once you have finished testing you endpoint, you have the option to delete your SageMaker endpoint. This is a good practice as experimental endpoints can be removed in order to decrease your SageMaker costs when they are not in use.

In [None]:
whisper_endpoint.delete_endpoint()