{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Introduction to Basic Functionality of NTM\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. \n", "\n", "![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-2/introduction_to_amazon_algorithms|ntm_synthetic|ntm_synthetic.ipynb)\n", "\n", "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "_**Finding Topics in Synthetic Document Data with the Neural Topic Model**_\n", "\n", "---\n", "\n", "---\n", "\n", "## Contents\n", "***\n", "\n", "1. [Introduction](#Introduction)\n", "1. [Setup](#Setup)\n", "1. [Data](#Data)\n", "1. [Train](#Train)\n", "1. [Host](#Host)\n", "1. [Extensions](#Extensions)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Introduction\n", "***\n", "\n", "Amazon SageMaker NTM (Neural Topic Model) is an unsupervised learning algorithm that attempts to describe a set of observations as a mixture of distinct categories. NTM is most commonly used to discover a user-specified number of topics shared by documents within a text corpus. Here each observation is a document, the features are the presence (or occurrence count) of each word, and the categories are the topics. Since the method is unsupervised, the topics are not specified up front, and are not guaranteed to align with how a human may naturally categorize documents. The topics are learned as a probability distribution over the words that occur in each document. Each document, in turn, is described as a mixture of topics.\n", "\n", "In this notebook we will use the Amazon SageMaker NTM algorithm to train a model on some example synthetic data. We will then use this model to classify (perform inference on) the data. The main goals of this notebook are to,\n", "\n", "* learn how to obtain and store data for use in Amazon SageMaker,\n", "* create an AWS SageMaker training job on a data set to produce a NTM model,\n", "* use the model to perform inference with an Amazon SageMaker endpoint." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup\n", "***\n", "\n", "_This notebook was tested in Amazon SageMaker Studio on a ml.t3.medium instance with Python 3 (Data Science) kernel._\n", "\n", "Let's start by specifying:\n", "\n", "- The S3 bucket and prefix that you want to use for training and model data. This should be within the same region as the Notebook Instance, training, and hosting.\n", "- The IAM role arn used to give training and hosting access to your data. See the documentation for how to create these. Note, if more than one role is required for notebook instances, training, and/or hosting, please replace the boto regexp with a the appropriate full IAM role arn string(s)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "isConfigCell": true, "tags": [ "parameters" ] }, "outputs": [], "source": [ "# Define IAM role\n", "import sagemaker\n", "import boto3\n", "import re\n", "from sagemaker import get_execution_role\n", "\n", "sess = sagemaker.Session()\n", "bucket = sess.default_bucket()\n", "prefix = \"ntm_demo\"\n", "role = get_execution_role()\n", "region = boto3.Session().region_name" ] }, { "cell_type": "markdown", "metadata": { "isConfigCell": true, "tags": [ "parameters" ] }, "source": [ "Next we download the synthetic data generation module if it does not exist." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "isConfigCell": true, "tags": [ "parameters" ] }, "outputs": [], "source": [ "from os import path\n", "\n", "data_generation_file = \"generate_example_data.py\" # Synthetic data generation module\n", "\n", "# Uncomment if you don't have the repository cloned with the generate_example_data.py file\n", "# if not path.exists(data_generation_file):\n", "# tools_bucket = f\"jumpstart-cache-prod-{region}\" # Bucket containing the data generation module.\n", "# tools_prefix = \"1p-algorithms-assets/ntm\" # Prefix for the data generation module\n", "# s3 = boto3.client(\"s3\")\n", "# s3.download_file(tools_bucket, f\"{tools_prefix}/{data_generation_file}\", data_generation_file)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we'll import the libraries we'll need throughout the remainder of the notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from generate_example_data import generate_griffiths_data, plot_topic_data\n", "import io\n", "import os\n", "import time\n", "import json\n", "import sys\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "from IPython.display import display\n", "import scipy\n", "import sagemaker.amazon.common as smac\n", "from sagemaker.serializers import CSVSerializer\n", "from sagemaker.deserializers import JSONDeserializer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data\n", "***\n", "\n", "We generate some example synthetic document data. For the purposes of this notebook we will omit the details of this process. All we need to know is that each piece of data, commonly called a \"document\", is a vector of integers representing \"word counts\" within the document. In this particular example there are a total of 25 words in the \"vocabulary\"." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# generate the sample data\n", "num_documents = 5000\n", "num_topics = 5\n", "vocabulary_size = 25\n", "known_alpha, known_beta, documents, topic_mixtures = generate_griffiths_data(\n", " num_documents=num_documents, num_topics=num_topics, vocabulary_size=vocabulary_size\n", ")\n", "\n", "# separate the generated data into training and tests subsets\n", "num_documents_training = int(0.8 * num_documents)\n", "num_documents_test = num_documents - num_documents_training\n", "\n", "documents_training = documents[:num_documents_training]\n", "documents_test = documents[num_documents_training:]\n", "\n", "topic_mixtures_training = topic_mixtures[:num_documents_training]\n", "topic_mixtures_test = topic_mixtures[num_documents_training:]\n", "\n", "data_training = (documents_training, np.zeros(num_documents_training))\n", "data_test = (documents_test, np.zeros(num_documents_test))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Inspect Example Data\n", "\n", "*What does the example data actually look like?* Below we print an example document as well as its corresponding *known* topic mixture. Later, when we perform inference on the training data set we will compare the inferred topic mixture to this known one.\n", "\n", "As we can see, each document is a vector of word counts from the 25-word vocabulary" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(f\"First training document = {documents[0]}\")\n", "print(f\"\\nVocabulary size = {vocabulary_size}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "np.set_printoptions(precision=4, suppress=True)\n", "\n", "print(f\"Known topic mixture of first training document = {topic_mixtures_training[0]}\")\n", "print(f\"\\nNumber of topics = {num_topics}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Because we are visual creatures, let's try plotting the documents. In the below plots, each pixel of a document represents a word. The greyscale intensity is a measure of how frequently that word occurs. Below we plot the first tes documents of the training set reshaped into 5x5 pixel grids." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "\n", "fig = plot_topic_data(documents_training[:10], nrows=2, ncols=5, cmap=\"gray_r\", with_colorbar=False)\n", "fig.suptitle(\"Example Documents\")\n", "fig.set_dpi(160)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Store Data on S3\n", "\n", "A SageMaker training job needs access to training data stored in an S3 bucket. Although training can accept data of various formats recordIO wrapped protobuf is most performant.\n", "\n", "_Note, since NTM is an unsupervised learning algorithm, we simple put 0 in for all label values._" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "buf = io.BytesIO()\n", "smac.write_numpy_to_dense_tensor(buf, data_training[0].astype(\"float32\"))\n", "buf.seek(0)\n", "\n", "key = \"ntm.data\"\n", "boto3.resource(\"s3\").Bucket(bucket).Object(os.path.join(prefix, \"train\", key)).upload_fileobj(buf)\n", "s3_train_data = f\"s3://{bucket}/{prefix}/train/{key}\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training\n", "\n", "***\n", "\n", "Once the data is preprocessed and available in a recommended format the next step is to train our model on the data. There are number of parameters required by the NTM algorithm to configure the model and define the computational environment in which training will take place. The first of these is to point to a container image which holds the algorithms training and hosting code." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker import image_uris\n", "\n", "container = image_uris.retrieve(region=boto3.Session().region_name, framework=\"ntm\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "An NTM model uses the following hyperparameters:\n", "\n", "* **`num_topics`** - The number of topics or categories in the NTM model. This has been pre-defined in our synthetic data to be 5.\n", "\n", "* **`feature_dim`** - The size of the *\"vocabulary\"*, in topic modeling parlance. In this case, this has been set to 25 by `generate_griffiths_data()`.\n", "\n", "In addition to these NTM model hyperparameters, we provide additional parameters defining things like the EC2 instance type on which training will run, the S3 bucket containing the data, and the AWS access role." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sess = sagemaker.Session()\n", "\n", "ntm = sagemaker.estimator.Estimator(\n", " container,\n", " role,\n", " instance_count=1,\n", " instance_type=\"ml.c4.xlarge\",\n", " output_path=f\"s3://{bucket}/{prefix}/output\",\n", " sagemaker_session=sess,\n", ")\n", "ntm.set_hyperparameters(num_topics=num_topics, feature_dim=vocabulary_size)\n", "\n", "ntm.fit({\"train\": s3_train_data})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference\n", "\n", "***\n", "\n", "A trained model does nothing on its own. We now want to use the model to perform inference. For this example, that means predicting the topic mixture representing a given document.\n", "\n", "This is simplified by the deploy function provided by the Amazon SageMaker Python SDK." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ntm_predictor = ntm.deploy(initial_instance_count=1, instance_type=\"ml.m4.xlarge\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Perform Inference\n", "\n", "With this real-time endpoint at our fingertips we can finally perform inference on our training and test data. We should first discuss the meaning of the SageMaker NTM inference output.\n", "\n", "For each document we wish to compute its corresponding `topic_weights`. Each set of topic weights is a probability distribution over the number of topics, which is 5 in this example. Of the 5 topics discovered during NTM training each element of the topic weights is the proportion to which the input document is represented by the corresponding topic.\n", "\n", "For example, if the topic weights of an input document $\\mathbf{w}$ is,\n", "\n", "$$\\theta = \\left[ 0.3, 0.2, 0, 0.5, 0 \\right]$$\n", "\n", "then $\\mathbf{w}$ is 30% generated from Topic #1, 20% from Topic #2, and 50% from Topic #4. Below, we compute the topic mixtures for the first ten traning documents.\n", "\n", "First, we setup our serializes and deserializers which allow us to convert NumPy arrays to CSV strings which we can pass into our HTTP POST request to our hosted endpoint." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ntm_predictor.serializer = CSVSerializer()\n", "ntm_predictor.deserializer = JSONDeserializer()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, let's check results for a small sample of records." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "results = ntm_predictor.predict(documents_training[:10], initial_args={\"ContentType\": \"text/csv\"})\n", "print(results)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see the output format of SageMaker NTM inference endpoint is a Python dictionary with the following format.\n", "\n", "```\n", "{\n", " 'predictions': [\n", " {'topic_weights': [ ... ] },\n", " {'topic_weights': [ ... ] },\n", " {'topic_weights': [ ... ] },\n", " ...\n", " ]\n", "}\n", "```\n", "\n", "We extract the topic weights, themselves, corresponding to each of the input documents." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predictions = np.array([prediction[\"topic_weights\"] for prediction in results[\"predictions\"]])\n", "\n", "print(predictions)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you decide to compare these results to the known topic weights generated above keep in mind that SageMaker NTM discovers topics in no particular order. That is, the approximate topic mixtures computed above may be (approximate) permutations of the known topic mixtures corresponding to the same documents." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(topic_mixtures_training[0]) # known topic mixture\n", "print(predictions[0]) # computed topic mixture" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With that said, let's look at how our learned topic weights map to known topic mixtures for the entire training set. Because NTM inherently creates a soft clustering (meaning that documents can sometimes belong partially to multiple topics), we'll evaluate correlation of topic weights. This gives us a more relevant picture than just selecting the single topic for each document that happens to have the highest probability.\n", "\n", "To do this, we'll first need to generate predictions for all of our training data. Because our endpoint has a ~6MB per POST request limit, let's break the training data up into mini-batches and loop over them, creating a full dataset of predictions." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def predict_batches(data, rows=1000):\n", " split_array = np.array_split(data, int(data.shape[0] / float(rows) + 1))\n", " predictions = []\n", " for array in split_array:\n", " results = ntm_predictor.predict(array, initial_args={\"ContentType\": \"text/csv\"})\n", " predictions += [r[\"topic_weights\"] for r in results[\"predictions\"]]\n", " return np.array(predictions)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predictions = predict_batches(documents_training)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we'll look at how the actual and predicted topics correlate." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = pd.DataFrame(\n", " np.concatenate([topic_mixtures_training, predictions], axis=1),\n", " columns=[f\"actual_{i}\" for i in range(5)] + [f\"predictions_{i}\" for i in range(5)],\n", ")\n", "display(data.corr())\n", "pd.plotting.scatter_matrix(\n", " pd.DataFrame(np.concatenate([topic_mixtures_training, predictions], axis=1)), figsize=(12, 12)\n", ")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we can see:\n", "- The upper left quadrant of 5 * 5 cells illustrates that the data are synthetic as the correlations are all slightly negative, but too perfectly triangular to occur naturally.\n", "- The upper right quadrant, which tells us about our model fit, shows some similarities, with many correlations having very near triangular shape, and negative correlations of a similar magnitude.\n", " - Notice, actual topic #2 maps to predicted topic #2. Similarly actual topic #3 maps to predicted topic #3, and #4 to #4. However, there's a slight bit of uncertainty in topics #0 and #1. Actual topic #0 appears to map to predicted topic #1, but actual topic #1 also correlates most highly with predicted topic #1. This is not unexpected given that we're working with manufactured data and unsupervised algorithms. The important part is that NTM is picking up aggregate structure well and with increased tuning of hyperparameters may fit the data even more closely.\n", "\n", "_Note, specific results may differ due to randomized steps in the data generation and algorithm, but the general story should remain unchanged._" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Stop / Close the Endpoint\n", "\n", "Finally, we should delete the endpoint before we close the notebook.\n", "\n", "To restart the endpoint you can follow the code above using the same `endpoint_name` we created or you can navigate to the \"Endpoints\" tab in the SageMaker console, select the endpoint with the name stored in the variable `endpoint_name`, and select \"Delete\" from the \"Actions\" dropdown menu. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ntm_predictor.delete_model()\n", "ntm_predictor.delete_endpoint()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Extensions\n", "\n", "***\n", "\n", "This notebook was a basic introduction to the NTM . It was applied on a synthetic dataset merely to show how the algorithm functions and represents data. Obvious extensions would be to train the algorithm utilizing real data. We skipped the important step of qualitatively evaluating the outputs of NTM. Because it is an unsupervised model, we want our topics to make sense. There is a great deal of subjectivity involved in this, and whether or not NTM is more suitable than another topic modeling algorithm like Amazon SageMaker LDA will depend on your use case." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Notebook CI Test Results\n", "\n", "This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n", "\n", "![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-1/introduction_to_amazon_algorithms|ntm_synthetic|ntm_synthetic.ipynb)\n", "\n", "![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-2/introduction_to_amazon_algorithms|ntm_synthetic|ntm_synthetic.ipynb)\n", "\n", "![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-1/introduction_to_amazon_algorithms|ntm_synthetic|ntm_synthetic.ipynb)\n", "\n", "![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ca-central-1/introduction_to_amazon_algorithms|ntm_synthetic|ntm_synthetic.ipynb)\n", "\n", "![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/sa-east-1/introduction_to_amazon_algorithms|ntm_synthetic|ntm_synthetic.ipynb)\n", "\n", "![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-1/introduction_to_amazon_algorithms|ntm_synthetic|ntm_synthetic.ipynb)\n", "\n", "![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-2/introduction_to_amazon_algorithms|ntm_synthetic|ntm_synthetic.ipynb)\n", "\n", "![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-3/introduction_to_amazon_algorithms|ntm_synthetic|ntm_synthetic.ipynb)\n", "\n", "![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-central-1/introduction_to_amazon_algorithms|ntm_synthetic|ntm_synthetic.ipynb)\n", "\n", "![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-north-1/introduction_to_amazon_algorithms|ntm_synthetic|ntm_synthetic.ipynb)\n", "\n", "![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-1/introduction_to_amazon_algorithms|ntm_synthetic|ntm_synthetic.ipynb)\n", "\n", "![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-2/introduction_to_amazon_algorithms|ntm_synthetic|ntm_synthetic.ipynb)\n", "\n", "![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-1/introduction_to_amazon_algorithms|ntm_synthetic|ntm_synthetic.ipynb)\n", "\n", "![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-2/introduction_to_amazon_algorithms|ntm_synthetic|ntm_synthetic.ipynb)\n", "\n", "![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-south-1/introduction_to_amazon_algorithms|ntm_synthetic|ntm_synthetic.ipynb)\n" ] } ], "metadata": { "celltoolbar": "Tags", "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "Python 3 (Data Science 3.0)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/sagemaker-data-science-310-v1" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" }, "notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License." }, "nbformat": 4, "nbformat_minor": 4 }