# This notebook converts the AG News dataset into a format that can be used by Comprehend for custom classification.

## Install and import libraries

In [None]:
!pip install --upgrade s3fs pandas tqdm

In [None]:
import pandas as pd
import tqdm
import boto3
region_name='us-east-1'
import matplotlib

## Get our data. Our data lives in the Amazon S3 open datasets. Many times, you can stream data right from S3 without downloading.
## In this case, since its small and in a tar file, lets download and look at it.

### The messages from perssions in the untar operation can be ignored.

In [None]:
! wget https://s3.amazonaws.com/fast-ai-nlp/ag_news_csv.tgz
! tar xvzf ag_news_csv.tgz

#### Read in the files in to "Pandas to see what is happening"

In [None]:
train=pd.read_csv("ag_news_csv/train.csv", names=['category','title','text'])

#### This is our training dataset. it has 3 columns, a label, title and text.

In [None]:
train

#### To reduce the training time to a reasonable amount for the excercise, we'll limit the data to just 1000 rows.

In [93]:
train = train.sample(axis='index',n=1000,random_state=100)

#### In order to make things prettier, let's change our labels from a number to a string. The dataset provider told us what the data looks like in the classes.txt file

In [None]:
labeldict={'1': 'WORLD', '2' : 'SPORTS', '3' : 'BUSINESS', '4': 'SCI_TECH'}
trainstr=train.astype(str)
trainstr['label']=trainstr['category'].replace(labeldict)

#### Put the title and the text in one column for our training. Normally this might be the result of some experimentation on our data. But it is generally the best practice to start to give a text classifier "all" relevant data to start.

#### Now, only write out our label and text, because that's what Comprehend expects as input.

In [None]:
dfout=trainstr[["label", 'text']] 

In [None]:
dfout

#### Let's look at a quick histogram and see what our labels look like. They are very balanced.

In [None]:
dfout['label'].value_counts()

### Copy the data to an S3 bucket


In [None]:
# Get the account ID from STS so we can all have unique bucket names
client = boto3.client("sts")
account_id = client.get_caller_identity()["Account"]
bucket_name = "comprehend-labs" + account_id + "-2"
print ("Bucket name used is " + bucket_name )

In [None]:
s3 = boto3.resource('s3')
s3_client = boto3.client('s3')

if (s3.Bucket(bucket_name).creation_date is None):
 #location = {'LocationConstraint': region_name}
 s3_client.create_bucket(Bucket=bucket_name)#, CreateBucketConfiguration=location)
 print ("Created bucket " + bucket_name)
else:
 print ("Bucket Exists")

In [None]:
file_name="s3://" + bucket_name + "/custom_news_classification.csv"

In [None]:
dfout.to_csv(file_name, header=False, index=False )

### Copy the below to Comprehend to use for a classifier!

In [None]:
print(file_name)