\n",
" \n",
"This notebook shows how to implement a basic spam classifier for SMS messages using Apache MXNet as deep learning framework.\n",
"The idea is to use the SMS spam collection dataset available at https://archive.ics.uci.edu/ml/datasets/sms+spam+collection to train and deploy a neural network model by leveraging on the built-in open-source container for Apache MXNet available in Amazon SageMaker."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's get started by setting some configuration variables and getting the Amazon SageMaker session and the current execution role, using the Amazon SageMaker high-level SDK for Python."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sagemaker import get_execution_role\n",
"\n",
"bucket_name = ''\n",
"\n",
"role = get_execution_role()\n",
"bucket_key_prefix = 'sms-spam-classifier'\n",
"vocabulary_length = 9013\n",
"\n",
"print(role)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We now download the spam collection dataset, unzip it and read the first 10 rows."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!mkdir -p dataset\n",
"!curl https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip -o dataset/smsspamcollection.zip\n",
"!unzip -o dataset/smsspamcollection.zip -d dataset\n",
"!head -10 dataset/SMSSpamCollection"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We now load the dataset into a Pandas dataframe and execute some data preparation.\n",
"More specifically we have to:\n",
"
\n",
"
replace the target column values (ham/spam) with numeric values (0/1)
\n",
"
tokenize the sms messages and encode based on word counts
\n",
"
split into train and test sets
\n",
"
upload to a S3 bucket for training
\n",
"
"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import pickle\n",
"from sms_spam_classifier_utilities import one_hot_encode\n",
"from sms_spam_classifier_utilities import vectorize_sequences\n",
"\n",
"df = pd.read_csv('dataset/SMSSpamCollection', sep='\\t', header=None)\n",
"df[df.columns[0]] = df[df.columns[0]].map({'ham': 0, 'spam': 1})\n",
"\n",
"targets = df[df.columns[0]].values\n",
"messages = df[df.columns[1]].values\n",
"\n",
"# one hot encoding for each SMS message\n",
"one_hot_data = one_hot_encode(messages, vocabulary_length)\n",
"encoded_messages = vectorize_sequences(one_hot_data, vocabulary_length)\n",
"\n",
"df2 = pd.DataFrame(encoded_messages)\n",
"df2.insert(0, 'spam', targets)\n",
"\n",
"# Split into training and validation sets (80%/20% split)\n",
"split_index = int(np.ceil(df.shape[0] * 0.8))\n",
"train_set = df2[:split_index]\n",
"val_set = df2[split_index:]\n",
"\n",
"train_set.to_csv('dataset/sms_train_set.gz', header=False, index=False, compression='gzip')\n",
"val_set.to_csv('dataset/sms_val_set.gz', header=False, index=False, compression='gzip')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We have to upload the two files back to Amazon S3 in order to be accessed by the Amazon SageMaker training cluster."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import boto3\n",
"\n",
"s3 = boto3.resource('s3')\n",
"target_bucket = s3.Bucket(bucket_name)\n",
"\n",
"with open('dataset/sms_train_set.gz', 'rb') as data:\n",
" target_bucket.upload_fileobj(data, '{0}/train/sms_train_set.gz'.format(bucket_key_prefix))\n",
" \n",
"with open('dataset/sms_val_set.gz', 'rb') as data:\n",
" target_bucket.upload_fileobj(data, '{0}/val/sms_val_set.gz'.format(bucket_key_prefix))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"
Training the model with MXNet
\n",
"\n",
"We are now ready to run the training using the Amazon SageMaker MXNet built-in container. First let's have a look at the script defining our neural network."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!cat 'sms_spam_classifier_mxnet_script.py'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We are now ready to run the training using the MXNet estimator object of the SageMaker Python SDK."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sagemaker.mxnet import MXNet\n",
"\n",
"output_path = 's3://{0}/{1}/output'.format(bucket_name, bucket_key_prefix)\n",
"code_location = 's3://{0}/{1}/code'.format(bucket_name, bucket_key_prefix)\n",
"\n",
"m = MXNet('sms_spam_classifier_mxnet_script.py',\n",
" role=role,\n",
" train_instance_count=1,\n",
" train_instance_type='ml.c5.2xlarge',\n",
" output_path=output_path,\n",
" base_job_name='sms-spam-classifier-mxnet',\n",
" framework_version=1.2,\n",
" code_location = code_location,\n",
" hyperparameters={'batch_size': 100,\n",
" 'epochs': 20,\n",
" 'learning_rate': 0.01})\n",
"\n",
"inputs = {'train': 's3://{0}/{1}/train/'.format(bucket_name, bucket_key_prefix),\n",
" 'val': 's3://{0}/{1}/val/'.format(bucket_name, bucket_key_prefix)}\n",
"\n",
"m.fit(inputs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"
THE FOLLOWING STEPS ARE NOT MANDATORY IF YOU PLAN TO DEPLOY TO AWS LAMBDA AND ARE INCLUDED IN THIS NOTEBOOK FOR EDUCATIONAL PURPOSES.
\n",
"\n",
"Let's deploy the trained model to a real-time inference endpoint fully-managed by Amazon SageMaker."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mxnet_pred = m.deploy(initial_instance_count=1,\n",
" instance_type='ml.m5.large')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"
Executing Inferences
\n",
"\n",
"Now, we can invoke the Amazon SageMaker real-time endpoint to execute some inferences, by providing SMS messages and getting the predicted label (SPAM = 1, HAM = 0) and the related probability."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sagemaker.mxnet.model import MXNetPredictor\n",
"from sms_spam_classifier_utilities import one_hot_encode\n",
"from sms_spam_classifier_utilities import vectorize_sequences\n",
"\n",
"# Uncomment the following line to connect to an existing endpoint.\n",
"# mxnet_pred = MXNetPredictor('')\n",
"\n",
"test_messages = [\"FreeMsg: Txt: CALL to No: 86888 & claim your reward of 3 hours talk time to use from your phone now! ubscribe6GBP/ mnth inc 3hrs 16 stop?txtStop\"]\n",
"one_hot_test_messages = one_hot_encode(test_messages, vocabulary_length)\n",
"encoded_test_messages = vectorize_sequences(one_hot_test_messages, vocabulary_length)\n",
"\n",
"result = mxnet_pred.predict(encoded_test_messages)\n",
"print(result)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"