{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# This notebook converts the AG News dataset into a format that can be used by Comprehend for custom classification." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install and import libraries" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install --upgrade s3fs pandas tqdm" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import tqdm\n", "import boto3\n", "region_name='us-east-1'\n", "import matplotlib" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Get our data. Our data lives in the Amazon S3 open datasets. Many times, you can stream data right from S3 without downloading.\n", "## In this case, since its small and in a tar file, lets download and look at it.\n", "\n", "### The messages from perssions in the untar operation can be ignored." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "! wget https://s3.amazonaws.com/fast-ai-nlp/ag_news_csv.tgz\n", "! tar xvzf ag_news_csv.tgz" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Read in the files in to \"Pandas to see what is happening\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train=pd.read_csv(\"ag_news_csv/train.csv\", names=['category','title','text'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### This is our training dataset. it has 3 columns, a label, title and text." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### To reduce the training time to a reasonable amount for the excercise, we'll limit the data to just 1000 rows." ] }, { "cell_type": "code", "execution_count": 93, "metadata": {}, "outputs": [], "source": [ "train = train.sample(axis='index',n=1000,random_state=100)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### In order to make things prettier, let's change our labels from a number to a string. The dataset provider told us what the data looks like in the classes.txt file" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "labeldict={'1': 'WORLD', '2' : 'SPORTS', '3' : 'BUSINESS', '4': 'SCI_TECH'}\n", "trainstr=train.astype(str)\n", "trainstr['label']=trainstr['category'].replace(labeldict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Put the title and the text in one column for our training. Normally this might be the result of some experimentation on our data. But it is generally the best practice to start to give a text classifier \"all\" relevant data to start." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Now, only write out our label and text, because that's what Comprehend expects as input." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dfout=trainstr[[\"label\", 'text']] " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dfout" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Let's look at a quick histogram and see what our labels look like. They are very balanced." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dfout['label'].value_counts()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Copy the data to an S3 bucket\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Get the account ID from STS so we can all have unique bucket names\n", "client = boto3.client(\"sts\")\n", "account_id = client.get_caller_identity()[\"Account\"]\n", "bucket_name = \"comprehend-labs\" + account_id + \"-2\"\n", "print (\"Bucket name used is \" + bucket_name )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "s3 = boto3.resource('s3')\n", "s3_client = boto3.client('s3')\n", "\n", "if (s3.Bucket(bucket_name).creation_date is None):\n", " #location = {'LocationConstraint': region_name}\n", " s3_client.create_bucket(Bucket=bucket_name)#, CreateBucketConfiguration=location)\n", " print (\"Created bucket \" + bucket_name)\n", "else:\n", " print (\"Bucket Exists\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "file_name=\"s3://\" + bucket_name + \"/custom_news_classification.csv\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dfout.to_csv(file_name, header=False, index=False )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Copy the below to Comprehend to use for a classifier!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(file_name)" ] } ], "metadata": { "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "conda_python3", "language": "python", "name": "conda_python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" } }, "nbformat": 4, "nbformat_minor": 4 }