\n",
" \n",
"This notebook shows how to implement a basic spam classifier for SMS messages using Amazon SageMaker built-in linear learner algorithm.\n",
"The idea is to use the SMS spam collection dataset available at https://archive.ics.uci.edu/ml/datasets/sms+spam+collection to train and deploy a binary classification model by leveraging on the built-in Linear Learner algoirithm available in Amazon SageMaker.\n",
"\n",
"Amazon SageMaker's Linear Learner algorithm extends upon typical linear models by training many models in parallel, in a computationally efficient manner. Each model has a different set of hyperparameters, and then the algorithm finds the set that optimizes a specific criteria. This can provide substantially more accurate models than typical linear algorithms at the same, or lower, cost."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's get started by setting some configuration variables and getting the Amazon SageMaker session and the current execution role, using the Amazon SageMaker high-level SDK for Python."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sagemaker import get_execution_role\n",
"\n",
"bucket_name = ''\n",
"\n",
"role = get_execution_role()\n",
"bucket_key_prefix = 'sms-spam-classifier'\n",
"vocabulary_length = 9013\n",
"\n",
"print(role)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We now download the spam collection dataset, unzip it and read the first 10 rows."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!mkdir -p dataset\n",
"!curl https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip -o dataset/smsspamcollection.zip\n",
"!unzip -o dataset/smsspamcollection.zip -d dataset\n",
"!head -10 dataset/SMSSpamCollection"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We now load the dataset into a Pandas dataframe and execute some data preparation.\n",
"More specifically we have to:\n",
"
\n",
"
replace the target column values (ham/spam) with numeric values (0/1)
\n",
"
tokenize the sms messages and encode based on word counts
\n",
"
split into train and test sets
\n",
"
upload to a S3 bucket for training
\n",
"
"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import pickle\n",
"from sms_spam_classifier_utilities import one_hot_encode\n",
"from sms_spam_classifier_utilities import vectorize_sequences\n",
"\n",
"df = pd.read_csv('dataset/SMSSpamCollection', sep='\\t', header=None)\n",
"df[df.columns[0]] = df[df.columns[0]].map({'ham': 0, 'spam': 1})\n",
"\n",
"targets = df[df.columns[0]].values\n",
"messages = df[df.columns[1]].values\n",
"\n",
"# one hot encoding for each SMS message\n",
"one_hot_data = one_hot_encode(messages, vocabulary_length)\n",
"encoded_messages = vectorize_sequences(one_hot_data, vocabulary_length)\n",
"\n",
"df2 = pd.DataFrame(encoded_messages)\n",
"df2.insert(0, 'spam', targets)\n",
"\n",
"# Split into training and validation sets (80%/20% split)\n",
"split_index = int(np.ceil(df.shape[0] * 0.8))\n",
"train_set = df2[:split_index]\n",
"val_set = df2[split_index:]\n",
"\n",
"train_set.to_csv('dataset/sms_train_set.csv', header=False, index=False)\n",
"val_set.to_csv('dataset/sms_val_set.csv', header=False, index=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We have to upload the two files back to Amazon S3 in order to be accessed by the Amazon SageMaker training cluster."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import boto3\n",
"\n",
"s3 = boto3.resource('s3')\n",
"target_bucket = s3.Bucket(bucket_name)\n",
"\n",
"with open('dataset/sms_train_set.csv', 'rb') as data:\n",
" target_bucket.upload_fileobj(data, '{0}/train/sms_train_set.csv'.format(bucket_key_prefix))\n",
" \n",
"with open('dataset/sms_val_set.csv', 'rb') as data:\n",
" target_bucket.upload_fileobj(data, '{0}/val/sms_val_set.csv'.format(bucket_key_prefix))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"
Training the model with Linear Learner
\n",
"\n",
"We are now ready to run the training using the Amazon SageMaker Linear Learner built-in algorithm. First let's get the linear larner container."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import boto3\n",
"\n",
"from sagemaker.amazon.amazon_estimator import get_image_uri\n",
"container = get_image_uri(boto3.Session().region_name, 'linear-learner', repo_version=\"latest\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next we'll kick off the base estimator, making sure to pass in the necessary hyperparameters. Notice:\n",
"\n",
"
\n",
"
feature_dim is set to the same dimension of the vocabulary.
\n",
"
predictor_type is set to 'binary_classifier' since we are trying to predict whether a SMS message is spam or not.