# Create a COVID news classifier using Amazon Comprehend

> *This notebook has been tested with the `Python 3 (Data Science)` Kernel in SageMaker Studio.*

[Amazon Comprehend](https://aws.amazon.com/comprehend/) is a natural language processing (NLP) service that uses machine learning to analyze text.

In this example notebook, we'll demonstrate how you can build and use a custom text classifier using an example news dataset. You can read more about how this works in the [Custom Classification section](https://docs.aws.amazon.com/comprehend/latest/dg/how-document-classification.html) of the [Amazon Comprehend Developer Guide](https://docs.aws.amazon.com/comprehend/latest/dg/).

## Install and import libraries

In [None]:
# Python Built-Ins:
import itertools
import json
import sys
from time import sleep

# External Dependencies:
import boto3
from botocore.exceptions import WaiterError
from botocore.waiter import create_waiter_with_client, WaiterModel
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sagemaker
from sklearn import metrics
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS

pd.set_option("display.max_colwidth", None)

## Initialize Session and Get Bucket

In [None]:
sess = sagemaker.Session()
region = sess.boto_session.region_name
print(region)

Before we start using Amazon Comprehend, ensure that your SageMaker Studio Execution role has the permissions to access Amazon S3 and Amazon Comprehend. Read more about the execution role [here](https://docs.aws.amazon.com/sagemaker/latest/dg/security_iam_service-with-iam.html).

Refer to this [link](https://docs.aws.amazon.com/comprehend/latest/dg/access-control-managing-permissions.html) for more information regarding managing access to Amazon Comprehend.

In [None]:
# Get execution role. Search for the execution role in the IAM console to modify permissions.
role = sagemaker.get_execution_role()
print(role)

In [None]:
# Get the S3 Bucket associated with this SageMaker Studio session. We will use this bucket to store and retrieve data.
bucket = sess.default_bucket()
print(bucket)

## Load, examine and transform data

This notebook uses data from various social media platforms, manually labeled as either 'fake' or 'real' news. It was initially published in a [paper](https://link.springer.com/chapter/10.1007/978-3-030-73696-5_3) and used in this [competition](https://competitions.codalab.org/competitions/26655). 
#### Dataset Reference:
Patwa P. et al. (2021) Fighting an Infodemic: COVID-19 Fake News Dataset. In: Chakraborty T., Shu K., Bernard H.R., Liu H., Akhtar M.S. (eds) Combating Online Hostile Posts in Regional Languages during Emergency Situation. CONSTRAINT 2021. Communications in Computer and Information Science, vol 1402. Springer, Cham. https://doi.org/10.1007/978-3-030-73696-5_3

In [None]:
!git clone https://github.com/diptamath/covid_fake_news/

We will combine the source training and validation datasets for Amazon Comprehend, and allow it to split an internal validation set automatically:

In [None]:
df_train_data = pd.read_csv("covid_fake_news/data/Constraint_Train.csv")
df_val_data = pd.read_csv("covid_fake_news/data/Constraint_Val.csv")
df_data = pd.concat([df_train_data,df_val_data], ignore_index=True)
df_data.head(5)

#### Examine the distribution of data

If the data is skewed, we will need to adjust the proportion of training data with 'real' and 'fake' labels.

In [None]:
df_data["label"].value_counts()

#### Shuffle data

The data is not highly skewed, so we will simply randomize it before loading it in for training.

In [None]:
df_data = df_data.sample(frac=1, random_state=1)
df_data.head(5)

#### Convert to the format that Amazon Comprehend expects for training

Read more about creating a classifier on Amazon Comprehend [here](https://docs.aws.amazon.com/comprehend/latest/dg/getting-started-console-classifier.html).

For training, the Comprehend expects the following:
1. File must contain only 2 columns: one label and one text per line
2. No header
3. Format UTF-8, line separator "\n".

We will remove the indices and headers before we upload the file to S3.

In [None]:
df_out = df_data[["label", "tweet"]]
df_out.head(5)

#### Output data to a CSV file and upload to S3 bucket

In [None]:
! mkdir -p data/processed-data/
path = "data/processed-data/processed.csv"
prefix = "COVIDcomprehend"
# Output to CSV file
df_out.to_csv(path, index=False, header=False)
# Upload to S3 Bucket
processed_data_s3 = sess.upload_data(
    path=path,
    key_prefix=prefix,
    bucket=bucket 
)
print("Training data has been uploaded to S3.")
print("Location:", processed_data_s3)

## Train your classifier

#### Create IAM role and attach policies

To authorize Amazon Comprehend to interact with your S3 buckets, you must grant Amazon Comprehend access to it by creating an IAM role in your account with the relevant permissions and trust policies. Read more about it [here](https://docs.aws.amazon.com/comprehend/latest/dg/access-control-overview.html). To better understand how to manage S3 permissions, please refer to this [documentation](https://docs.aws.amazon.com/AmazonS3/latest/userguide/s3-access-control.html).

In [None]:
comprehend_role_name = "ComprehendRole"

# Instantiate Boto3 SDK
client_iam = boto3.client("iam")

# Allow Comprehend to assume this role
assume_role_policy_document = json.dumps({
    "Version": "2012-10-17",
    "Statement": [
        {
            "Effect": "Allow",
            "Principal": {
                "Service": "comprehend.amazonaws.com",
            },
            "Action": "sts:AssumeRole",
        },
    ],
})

# Create IAM Role if it does not exist
try:
    client_iam.get_role(RoleName=comprehend_role_name)
    print("Role already exists.")
except:
    response = client_iam.create_role(
        RoleName=comprehend_role_name,
        AssumeRolePolicyDocument=assume_role_policy_document,
    )
    print("Created a new role.")
    print("Comprehend Role ID:", response["Role"]["RoleId"])
    
comprehend_role_arn = client_iam.get_role(RoleName=comprehend_role_name)["Role"]["Arn"]
print("Comprehend Role ARN:", comprehend_role_arn)

# Attach permission policy to IAM role
policy_arn = "arn:aws:iam::aws:policy/AmazonS3FullAccess"

print("Attaching IAM role policy")
attach_response = client_iam.attach_role_policy(
    RoleName=comprehend_role_name,
    PolicyArn=policy_arn,
)
sleep(60)
print("IAM policy is attached.")

#### Create Classifier and Commence Training
Read about training a classifier [here](https://docs.aws.amazon.com/comprehend/latest/dg/how-document-classification-training.html).

In [None]:
doc_classifier_name = "COVIDInfoClassifer"
version_name = "v0"

# Instantiate Boto3 SDK
client_comp = boto3.client("comprehend")

# Create a document classifier
create_response = client_comp.create_document_classifier(
    InputDataConfig={ "S3Uri": processed_data_s3 },
    DataAccessRoleArn=comprehend_role_arn,
    DocumentClassifierName=doc_classifier_name,
    VersionName=version_name,
    LanguageCode="en",
)
doc_classifier_arn = create_response["DocumentClassifierArn"]
print("Create response:", create_response["ResponseMetadata"]["HTTPStatusCode"])
print("Classifier ARN:", doc_classifier_arn)

In [None]:
describe_classifier = client_comp.describe_document_classifier(
    DocumentClassifierArn=doc_classifier_arn,
)
try:
    classifier_initial_status = describe_classifier["DocumentClassifierProperties"]["Status"]
    print("Describe classifier response:", classifier_initial_status)
except:
    print("Status error")
# Creating Waiter to manage waiting for training to complete
classifier_trained_waiter = create_waiter_with_client(
    waiter_name="ClassifierTrainedWaiter",
    waiter_model=WaiterModel({
        "version": 2,
        "waiters": {
            "ClassifierTrainedWaiter": {
                "operation": "DescribeDocumentClassifier",
                "delay": 30,
                "maxAttempts": 300,
                "acceptors": [
                    {
                        "matcher": "path",
                        "expected": "TRAINED",
                        "argument": "DocumentClassifierProperties.Status",
                        "state": "success",
                    },
                    {
                        "matcher": "path",
                        "expected": "SUBMITTED",
                        "argument": "DocumentClassifierProperties.Status",
                        "state": "retry",
                    },
                    {
                        "matcher": "path",
                        "expected": "TRAINING",
                        "argument": "DocumentClassifierProperties.Status",
                        "state": "retry",
                    },
                    {
                        "matcher": "path",
                        "expected": "IN_ERROR",
                        "argument": "DocumentClassifierProperties.Status",
                        "state": "failure",
                    },
                    {
                        "matcher": "path",
                        "expected": "STOPPED",
                        "argument": "DocumentClassifierProperties.Status",
                        "state": "failure",
                    },
                ],
            },
        },
    }),
    client=client_comp,
)
try:
    print("Waiting for training...")
    classifier_trained_waiter.wait(DocumentClassifierArn=doc_classifier_arn)
    print("TRAINING COMPLETE")
except WaiterError as e:
    print(e)
    raise e

Once your classifier has been created, go to the AWS Console for Comprehend to check out your classifier. 

Read more about model performance [here](https://docs.aws.amazon.com/comprehend/latest/dg/cer-doc-class.html).

## Deploy your classifier to an endpoint

In [None]:
endpoint_name = f"{doc_classifier_name}-{version_name}-endpoint"
# Create endpoint
create_endpoint_resp = client_comp.create_endpoint(
    EndpointName=endpoint_name,
    ModelArn=doc_classifier_arn,
    DesiredInferenceUnits=1
)
endpoint_arn = create_endpoint_resp["EndpointArn"]
print("Created endpoint.")
print("Endpoint ARN", endpoint_arn)

In [None]:
describe_endpoint = client_comp.describe_endpoint(EndpointArn=endpoint_arn)
try:
    endpoint_initial_status = describe_endpoint["EndpointProperties"]["Status"]
    print("Describe endpoint response:", endpoint_initial_status)
except:
    print("Status error")
# Creating Waiter to manage waiting for endpoint to create
endpoint_creation_waiter = create_waiter_with_client(
    waiter_name="EndpointCreationWaiter",
    waiter_model=WaiterModel({
        "version": 2,
        "waiters": {
            "EndpointCreationWaiter": {
                "operation": "DescribeEndpoint",
                "delay": 30,
                "maxAttempts": 100,
                "acceptors": [
                    {
                        "matcher": "path",
                        "expected": "IN_SERVICE",
                        "argument": "EndpointProperties.Status",
                        "state": "success",
                    },
                    {
                        "matcher": "path",
                        "expected": "UPDATING",
                        "argument": "EndpointProperties.Status",
                        "state": "retry",
                    },
                    {
                        "matcher": "path",
                        "expected": "CREATING",
                        "argument": "EndpointProperties.Status",
                        "state": "retry",
                    },
                    {
                        "matcher": "path",
                        "expected": "FAILED",
                        "argument": "EndpointProperties.Status",
                        "state": "failure",
                    },
                ],
            },
        },
    }),
    client=client_comp,
)
try:
    print("Waiting for endpoint...")
    endpoint_creation_waiter.wait(EndpointArn=endpoint_arn)
    print("ENDPOINT CREATED")
except WaiterError as e:
    print(e)
    raise e

## Predict classification

We will use the labeled test data to run some predictions and obtain some metrics.

In [None]:
# Load and examine test data
df_test = pd.read_csv("covid_fake_news/data/english_test_with_labels.csv")
df_test = df_test[["label", "tweet"]]
df_test["label"].value_counts()

In [None]:
# Split into two arrays, where X_test contains tweets and Y_test contains labels
X_test = df_test["tweet"].values
Y_test = df_test["label"].values

### Generating real-time predictions using Comprehend's endpoint
Read more about real time analysis using an endpoint [here](https://docs.aws.amazon.com/comprehend/latest/dg/cc-real-time-analysis.html).

In [None]:
test = X_test[0]
print(test)
client_comp.classify_document(Text=test, EndpointArn=endpoint_arn)

In [None]:
test1 = X_test[1]
print(test1)
client_comp.classify_document(Text=test1, EndpointArn=endpoint_arn)

### Creating an asynchronous classification job
We will create a job to perform predictions for the entire test dataset. 

Read more about the data requirements and job inputs [here](https://docs.aws.amazon.com/comprehend/latest/dg/how-class-run.html).

Comprehend's asynchronous classification job expects that for multiple documents in a single text file, each document is separated by a single line break.

We will remove all line breaks from the original dataset, then add a line feed character at the end of every document in the text file.

#### Convert to the format that Amazon Comprehend expects for asynchronous classification jobs

In [None]:
test_path = "data/processed-data/test.txt"
test_prefix = "COVIDcomprehend"
# Remove all line breaks
df_test["tweet"] = df_test["tweet"].apply(lambda s: s.replace("\n", ""))
# Output to text file while separating each document with a line feed character '\n'
df_test["tweet"].to_csv(test_path, index=False, header=False, line_terminator="\n")
# Upload to S3 Bucket
test_data_s3 = sess.upload_data(
    path=test_path,
    key_prefix=test_prefix,
    bucket=bucket,
)
print("Test data has been uploaded to S3.")
print("Location:", test_data_s3) 

#### Start classification job

In [None]:
job_name = f"{doc_classifier_name}-{version_name}-Job"
test_data_output=f"s3://{bucket}/{test_prefix}/{job_name}"
# Create classification job
create_job_desc = client_comp.start_document_classification_job(
    JobName=job_name,
    DocumentClassifierArn=doc_classifier_arn,
    InputDataConfig={
        "S3Uri": test_data_s3,
        "InputFormat": "ONE_DOC_PER_LINE",
    },
    OutputDataConfig={ "S3Uri": test_data_output },
    DataAccessRoleArn=comprehend_role_arn,
)
job_id=create_job_desc["JobId"]
print("Started classification job.")
print("Job ID is", job_id)

In [None]:
describe_job=client_comp.describe_document_classification_job(JobId=job_id)
try:
    job_initial_status=describe_job["DocumentClassificationJobProperties"]["JobStatus"]
    print("Describe classification job response:", job_initial_status)
except:
    print("Status error")
# Creating Waiter to manage waiting for job to complete
job_waiter = create_waiter_with_client(
    waiter_name="JobWaiter",
    waiter_model=WaiterModel({
        "version": 2,
        "waiters": {
            "JobWaiter": {
                "operation": "DescribeDocumentClassificationJob",
                "delay": 30,
                "maxAttempts": 300,
                "acceptors": [
                    {
                        "matcher": "path",
                        "expected": "COMPLETED",
                        "argument": "DocumentClassificationJobProperties.JobStatus",
                        "state": "success",
                    },
                    {
                        "matcher": "path",
                        "expected": "SUBMITTED",
                        "argument": "DocumentClassificationJobProperties.JobStatus",
                        "state": "retry",
                    },
                    {
                        "matcher": "path",
                        "expected": "IN_PROGRESS",
                        "argument": "DocumentClassificationJobProperties.JobStatus",
                        "state": "retry",
                    },
                    {
                        "matcher": "path",
                        "expected": "FAILED",
                        "argument": "DocumentClassificationJobProperties.JobStatus",
                        "state": "failure",
                    },
                    {
                        "matcher": "path",
                        "expected": "STOPPED",
                        "argument": "DocumentClassificationJobProperties.JobStatus",
                        "state": "failure",
                    },
                ],
            },
        },
    }),
    client=client_comp,
)
try:
    print("Waiting for classification job...")
    job_waiter.wait(JobId=job_id)
    print(f"JOB COMPLETE: ID {job_id}")
    complete_job = job_id
except WaiterError as e:
    print(e)
    raise e

#### Examine the results of the classification job

In [None]:
batch_output = describe_job["DocumentClassificationJobProperties"]["OutputDataConfig"]["S3Uri"]
print("Results from Job ID", job_id)
batch_output_bucket, _, batch_output_file = batch_output[len("s3://"):].partition("/")

# Instantiate Boto3 SDK
s3_client = boto3.client("s3")
# Download output file containing predictions
s3_client.download_file(bucket, batch_output_file, "data/output.tar.gz")
# Unzip
!cd data && tar -xvf output.tar.gz
# Load predictions
df_batch = pd.read_json("data/predictions.jsonl", lines=True)
preds = df_batch["Classes"]

preds

Next, iterate through predictions to combine with and compare against original test data, and construct a single dataframe:

In [None]:
df_data_pred = pd.DataFrame({
    "tweet": df_test["tweet"],
    "label": df_test["label"],
    "label_pred": preds.apply(lambda p: p[0]["Name"]),
    "score_pred": preds.apply(lambda p: p[0]["Score"]),
})
df_data_pred["label vs label pred"] = np.where(
    df_data_pred["label"] == df_data_pred["label_pred"],
    "same",
    "different",
)
df_data_pred.head(10)

In [None]:
def plot_confusion_matrix(
    confusion_matrix,
    class_names_list=["Class1", "Class2"],
    axis=None,
    title="Confusion matrix",
    plot_style="ggplot",
    colormap=plt.cm.Blues,
):

    if axis is None:  # for standalone plot
        plt.figure()
        ax = plt.gca()
    else:  # for plots inside a subplot
        ax = axis

    plt.style.use(plot_style)

    # normalizing matrix to [0,100%]
    confusion_matrix_norm = (
        confusion_matrix.astype("float") / confusion_matrix.sum(axis=1)[:, np.newaxis]
    )
    confusion_matrix_norm = np.round(100 * confusion_matrix_norm, 2)

    ax.imshow(
        confusion_matrix_norm,
        interpolation="nearest",
        cmap=colormap,
        vmin=0,  # to make sure colors are scaled between [0,100%]
        vmax=100,
    )

    ax.set_title(title)
    tick_marks = np.arange(len(class_names_list))
    ax.set_xticks(tick_marks)
    ax.set_xticklabels(class_names_list, rotation=0)
    ax.set_yticks(tick_marks)
    ax.set_yticklabels(class_names_list)
    
    for i, j in itertools.product(
        range(confusion_matrix.shape[0]),
        range(confusion_matrix.shape[1])
    ):
        ax.text(
            j,
            i,
            str(confusion_matrix[i, j])+'\n('+str(confusion_matrix_norm[i,j])+'%)',
            horizontalalignment="center",
            color="white" if confusion_matrix_norm[i, j] > 50 else "black"
        )

    ax.set_ylabel("True label")
    ax.set_xlabel("Predicted label")
    ax.grid(False)
    
    if axis is None:  # for standalone plots
        plt.tight_layout()
        plt.show()

In [None]:
y_real = df_data_pred["label"]
y_pred = df_data_pred["label_pred"]

plot_confusion_matrix(
    metrics.confusion_matrix(y_real, y_pred),
    class_names_list=["Fake", "Real"],
)

How well did we identify what's fake? We will calculate performance scores and plot some graphs to examine this.

In [None]:
# Take note that "fake" would be "positive" with a label of 1
metrics_ACC = metrics.accuracy_score(y_real, y_pred)
metrics_P_fake = metrics.precision_score(y_real, y_pred, average="binary", pos_label="fake")
metrics_R_fake = metrics.recall_score(y_real, y_pred, average="binary", pos_label="fake")
metrics_f1 = metrics.f1_score(y_real, y_pred, average="binary", pos_label="fake")
print(
    "Accuracy", metrics_ACC,
    "Precision (fake)", metrics_P_fake,
    "Recall (fake)", metrics_R_fake,
    "F1 score", metrics_f1,
)

In [None]:
def plot_precision_recall_curve(
    y_real,
    y_predict,
    axis=None,
    plot_style="ggplot",
):
    """Plot a nice precision/recall curve for a binary classification model"""

    if axis is None:  # for standalone plot
        plt.figure()
        ax = plt.gca()
    else:  # for plots inside a subplot
        ax = axis

    plt.style.use(plot_style)

    metrics_P, metrics_R, _ = metrics.precision_recall_curve(y_real, y_predict)
    metrics_AP = metrics.average_precision_score(y_real, y_predict)

    ax.set_aspect(aspect=0.95)
    ax.step(metrics_R, metrics_P, color="b", where="post", linewidth=0.7)
    ax.fill_between(metrics_R, metrics_P, step="post", alpha=0.2, color="b")
    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    ax.set_ylim([0.0, 1.05])
    ax.set_xlim([0.0, 1.05])
    ax.set_title("Precision-Recall curve: AP={0:0.3f}".format(metrics_AP))
    
    if axis is None:  # for standalone plots
        plt.tight_layout()
        plt.show()

In [None]:
y_real = df_data_pred["label"].replace("fake", 1).replace("real", 0)
y_pred = df_data_pred["label_pred"].replace("fake", 1).replace("real", 0)

plot_precision_recall_curve(y_real, y_pred)

In [None]:
def plot_roc_curve(
    y_real,
    y_predict,
    axis=None,
    plot_style="ggplot",
):
    """Plot a nice ROC curve for a binary classification model"""

    if axis is None:  # for standalone plot
        plt.figure()
        ax = plt.gca()
    else:  # for plots inside a subplot
        ax = axis

    plt.style.use(plot_style)

    metrics_FPR, metrics_TPR, _ = metrics.roc_curve(y_real, y_predict)
    metrics_AUC = metrics.roc_auc_score(y_real, y_predict)

    ax.set_aspect(aspect=0.95)
    ax.plot(metrics_FPR, metrics_TPR, color="b", linewidth=0.7)
    
    ax.fill_between(
        metrics_FPR,
        metrics_TPR,
        step="post",
        alpha=0.2,
        color="b",
    )
    
    ax.plot([0, 1], [0, 1], color="k", linestyle="--", linewidth=1)
    ax.set_xlim([-0.05, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")
    ax.set_title("ROC curve: AUC={0:0.3f}".format(metrics_AUC))
    
    if axis is None:  # for standalone plots
        plt.tight_layout()
        plt.show()

In [None]:
plot_roc_curve(y_real, y_pred)

## Cleanup

To clear up your environment after running through this example and avoid ongoing charges, you can un-comment and run the cells below:

In [None]:
# # Delete Endpoint
# print("Deleting endpoint...")
# endpoint_delete_response = client_comp.delete_endpoint(EndpointArn=endpoint_arn)
# try: 
#     while client_comp.describe_endpoint(
#         EndpointArn=endpoint_arn
#     )["EndpointProperties"]["Status"] == "DELETING":
#         sleep(60)
# except:
#     pass
# print("Endpoint Deleted")

In [None]:
# # Delete Classifier
# classifier_delete_response = client_comp.delete_document_classifier(DocumentClassifierArn=doc_classifier_arn)
# print("Classifier Deleted")

In [None]:
# # Delete S3 File
# s3_delete_reponse = s3_client.delete_object(Bucket=bucket, Key=prefix+"processed.csv")
# s3_delete_reponse_1 = s3_client.delete_object(Bucket=bucket, Key=prefix+"test.txt")
# print(
#     "Data in S3 has been deleted. HTTP Code:",
#     s3_delete_reponse["ResponseMetadata"]["HTTPStatusCode"],
#     ",",
#     s3_delete_reponse_1["ResponseMetadata"]["HTTPStatusCode"],
# )

In [None]:
# # Delete IAM Role
# try:
#     role_detach_response = client_iam.detach_role_policy(
#         RoleName=comprehend_role_name,
#         PolicyArn=policy_arn,
#     )
#     print("Policy is detached")
# except:
#     print("Policy could not be detached")
# role_delete_response = client_iam.delete_role(RoleName=comprehend_role_name)
# print("Role is deleted", role_delete_response)