In [None]:
# Initialize sagemaker session and get the training data s3 uri
import json
import time
import boto3
import numpy as np
import sagemaker
import sagemaker.huggingface
import os

#BUCKET="[BUCKET_NAME]" # please use your bucket name if you are not using the default bucket
ROLE = sagemaker.get_execution_role()
sess = sagemaker.Session()
BUCKET = sess.default_bucket()
PREFIX = "whisper/data/zhtw-common-voice-processed"
s3uri = os.path.join("s3://", BUCKET, PREFIX)
print(f"sagemaker role arn: {ROLE}")
print(f"sagemaker bucket: {BUCKET}")
print(f"sagemaker session region: {sess.boto_region_name}")
print(f"data uri: {s3uri}")


In [None]:
# Some training parameters
# For distributed training
# distribution = {'smdistributed':{'dataparallel':{ 'enabled': True }}}
# instance_type = 'ml.p3.16xlarge'
# training_batch_size  = 4
# eval_batch_size = 2

# For single instance training
distribution = None
instance_type = 'ml.p3.2xlarge'
training_batch_size  = 16
eval_batch_size = 8

In [None]:
from sagemaker.huggingface import HuggingFace

# Create an unique id to tag training job and model name. 
id = int(time.time())

TRAINING_JOB_NAME = f"whisper-zhtw-{id}"
print('Training job name: ', TRAINING_JOB_NAME)

hyperparameters = {'max_steps':1000, # you can increase the max steps to improve model accuracy
                   'train_batch_size': training_batch_size,
                   'eval_batch_size': eval_batch_size,
                   'model_name': "openai/whisper-small",
                   'language': "Chinese",
                   'dataloader_num_workers': 16,
                  }

# Define metrics definitions, such metrics will be extracted from training script's printed logs and send to cloudwatch
metric_definitions=[
        {'Name': 'eval_loss', 'Regex': "'eval_loss': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'eval_wer', 'Regex': "'eval_wer': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'eval_runtime', 'Regex': "'eval_runtime': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'eval_samples_per_second', 'Regex': "'eval_samples_per_second': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'epoch', 'Regex': "'epoch': ([0-9]+(.|e\-)[0-9]+),?"}]

In [None]:
# Point the training data to the s3 uri. Use FastFile to "mount" the s3 files directly instead of copying to local disk
from sagemaker.inputs import TrainingInput
training_input_path=s3uri

training = TrainingInput(
    s3_data_type='S3Prefix', # Available Options: S3Prefix | ManifestFile | AugmentedManifestFile
    s3_data=training_input_path,
    distribution='FullyReplicated', # Available Options: FullyReplicated | ShardedByS3Key 
    input_mode='FastFile'
)

In [None]:
# Create the HuggingFace Estimator and kick off the training with "fit". Note that as of the writing, the latest hugging face training image has version of transformers_version='4.17.0' and pytorch_version='1.10.2', the transformer version can be upgraded in the requirements.txt.
# More details on training images, see https://github.com/aws/deep-learning-containers/blob/master/available_images.md
OUTPUT_PATH= f's3://{BUCKET}/{PREFIX}/{TRAINING_JOB_NAME}/output/'

huggingface_estimator = HuggingFace(entry_point='train.py',
                                    source_dir='./scripts',
                                    output_path= OUTPUT_PATH, 
                                    instance_type=instance_type,
                                    instance_count=1,
                                    transformers_version='4.17.0',
                                    pytorch_version='1.10.2',
                                    py_version='py38',
                                    role=ROLE,
                                    hyperparameters = hyperparameters,
                                    metric_definitions = metric_definitions,
                                    volume_size=200,
                                    distribution=distribution,
                                   )

#Starts the training job using the fit function, training takes approximately 2 hours to complete.
huggingface_estimator.fit({'train': training}, job_name=TRAINING_JOB_NAME)