# Textract Visualization with Neptune

[Amazon Textract](https://aws.amazon.com/textract/) is a fully managed machine learning service that automatically extracts text and data from scanned documents that goes beyond simple optical character recognition (OCR) to identify, understand, and extract data from forms and tables.

The raw output from Textract is a series of JSON blocks representing pages, lines, words, tables, and cells in tables. When you are first exploring a PDF document, it's useful to visualize the relationship between these blocks to help you interpret the output.

In this notebook, we show how to take the raw JSON output from a sample PDF file, insert it into [Amazon Neptune](https://aws.amazon.com/neptune/), a managed graph database, and then use Neptune to visualize part of the data.

Some parts of the Python code are taken from the [Textract samples](https://github.com/aws-samples/amazon-textract-code-samples) and the [Textract documentation](https://docs.aws.amazon.com/textract/latest/dg/async-analyzing-with-sqs.html).

## License

Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

SPDX-License-Identifier: MIT-0

## Prerequisites

* Install the `wget` package into the python kernel
* Notebook role needs permission to create SNS topics and SQS queues
* Notebook role needs permission to use Textract
* Notebook role needs permission to upload to S3 bucket
* Create S3 bucket to hold the PDF file
* Create [Textract service role](https://docs.aws.amazon.com/textract/latest/dg/api-async-roles.html)
* Create a Neptune database (see the [quick start](https://docs.aws.amazon.com/neptune/latest/userguide/intro.html) guide)
* Run this notebook in the [Neptune Workbench](https://docs.aws.amazon.com/neptune/latest/userguide/notebooks.html)

In [None]:
# Set these values to reflect your S3 bucket, Textract service role ARN, and Neptune database endpoint

bucket = ''
in_prefix = 'in'
role_arn = ''
neptune_endpoint = ''

## Download PDF

In this example we'll process a PDF file published in 2013 on Reinforcement Learning. The official source citation is:

 @misc{mnih2013playing,
 title={Playing Atari with Deep Reinforcement Learning}, 
 author={Volodymyr Mnih and Koray Kavukcuoglu and David Silver and Alex Graves and Ioannis Antonoglou and Daan Wierstra and Martin Riedmiller},
 year={2013},
 eprint={1312.5602},
 archivePrefix={arXiv},
 primaryClass={cs.LG}
 }
 
I downloaded the PDF file from:

 https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf
 
You can also find the file at:

 https://arxiv.org/abs/1312.5602

In [None]:
%load_ext autoreload
%autoreload

In [None]:
!pip install wget

In [None]:
pdf_url = 'https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf'

In [None]:
import wget
filename = wget.download(pdf_url)

In [None]:
filename

## Create SQS queue and SNS topic for job notifications

When processing PDFs, Textract runs asynchronously. It sends job status notifications to SNS, and we'll forward those to SQS so we can poll SQS for job status.

In [None]:
import boto3
textract = boto3.client('textract')
sqs = boto3.client('sqs')
sns = boto3.client('sns')

In [None]:
import time
millis = str(int(round(time.time() * 1000)))
snsTopicName="AmazonTextractTopic" + millis
topicResponse=sns.create_topic(Name=snsTopicName)
snsTopicArn = topicResponse['TopicArn']

In [None]:
sqsQueueName="AmazonTextractQueue" + millis
sqs.create_queue(QueueName=sqsQueueName)
sqsQueueUrl = sqs.get_queue_url(QueueName=sqsQueueName)['QueueUrl']

In [None]:
attribs = sqs.get_queue_attributes(QueueUrl=sqsQueueUrl,
 AttributeNames=['QueueArn'])['Attributes']
 
sqsQueueArn = attribs['QueueArn']
sns.subscribe(
 TopicArn=snsTopicArn,
 Protocol='sqs',
 Endpoint=sqsQueueArn)

In [None]:
policy = """{{
 "Version":"2012-10-17",
 "Statement":[
 {{
 "Sid":"MyPolicy",
 "Effect":"Allow",
 "Principal" : {{"AWS" : "*"}},
 "Action":"SQS:SendMessage",
 "Resource": "{}",
 "Condition":{{
 "ArnEquals":{{
 "aws:SourceArn": "{}"
 }}
 }}
 }}
 ]
}}""".format(sqsQueueArn, snsTopicArn)
 
response = sqs.set_queue_attributes(
 QueueUrl = sqsQueueUrl,
 Attributes = {
 'Policy' : policy
 })

## Upload PDF to bucket

In [None]:
s3 = boto3.client('s3')
s3.upload_file(filename, bucket, "{0}/{1}".format(in_prefix, filename))

## Run detection job

In [None]:
def GetResults(processType, jobId):
 maxResults = 1000
 paginationToken = None
 finished = False

 blocks_to_save = []
 while finished == False:

 response=None

 if processType=='analysis':
 if paginationToken==None:
 response = textract.get_document_analysis(JobId=jobId,
 MaxResults=maxResults)
 else: 
 response = textract.get_document_analysis(JobId=jobId,
 MaxResults=maxResults,
 NextToken=paginationToken) 

 if processType=='detect':
 if paginationToken==None:
 response = textract.get_document_text_detection(JobId=jobId,
 MaxResults=maxResults)
 else: 
 response = textract.get_document_text_detection(JobId=jobId,
 MaxResults=maxResults,
 NextToken=paginationToken) 

 blocks=response['Blocks'] 
 
 # Display block information
 for block in blocks:
 blocks_to_save.append(block)

 if 'NextToken' in response:
 paginationToken = response['NextToken']
 else:
 finished = True

 return blocks_to_save

In [None]:
import json
jobFound = False
docname = "{0}/{1}".format(in_prefix, filename)
response = textract.start_document_text_detection(DocumentLocation={'S3Object': {'Bucket': bucket, 
 'Name': docname}},
 NotificationChannel={'RoleArn': role_arn, 'SNSTopicArn': snsTopicArn})
 

print('Start Job Id: ' + response['JobId'])
detect_blocks = []
while jobFound == False:
 sqsResponse = sqs.receive_message(QueueUrl=sqsQueueUrl, MessageAttributeNames=['ALL'],
 MaxNumberOfMessages=10)

 if sqsResponse:

 if 'Messages' not in sqsResponse:
 print("Waiting...")
 time.sleep(10)
 continue

 for message in sqsResponse['Messages']:
 notification = json.loads(message['Body'])
 textMessage = json.loads(notification['Message'])
 print(textMessage['JobId'])
 print(textMessage['Status'])
 if str(textMessage['JobId']) == response['JobId']:
 print('Matching Job Found:' + textMessage['JobId'])
 jobFound = True
 detect_blocks = GetResults('detect', textMessage['JobId'])
 sqs.delete_message(QueueUrl=sqsQueueUrl,
 ReceiptHandle=message['ReceiptHandle'])
 else:
 print("Job didn't match:" +
 str(textMessage['JobId']) + ' : ' + str(response['JobId']))
 sqs.delete_message(QueueUrl=sqsQueueUrl,
 ReceiptHandle=message['ReceiptHandle'])

 print('Done!')

In [None]:
print(f"Found {len(detect_blocks)} blocks")

## Run analysis job

In [None]:
jobFound = False
response = textract.start_document_analysis(DocumentLocation={'S3Object': {'Bucket': bucket, 
 'Name': docname}},
 FeatureTypes=["TABLES", "FORMS"],
 NotificationChannel={'RoleArn': role_arn, 'SNSTopicArn': snsTopicArn})
 

print('Start Job Id: ' + response['JobId'])
analysis_blocks = []
while jobFound == False:
 sqsResponse = sqs.receive_message(QueueUrl=sqsQueueUrl, MessageAttributeNames=['ALL'],
 MaxNumberOfMessages=10)

 if sqsResponse:

 if 'Messages' not in sqsResponse:
 print("Waiting...")
 time.sleep(10)
 continue

 for message in sqsResponse['Messages']:
 notification = json.loads(message['Body'])
 textMessage = json.loads(notification['Message'])
 print(textMessage['JobId'])
 print(textMessage['Status'])
 if str(textMessage['JobId']) == response['JobId']:
 print('Matching Job Found:' + textMessage['JobId'])
 jobFound = True
 analysis_blocks = GetResults('analysis', textMessage['JobId'])
 sqs.delete_message(QueueUrl=sqsQueueUrl,
 ReceiptHandle=message['ReceiptHandle'])
 else:
 print("Job didn't match:" +
 str(textMessage['JobId']) + ' : ' + str(response['JobId']))
 sqs.delete_message(QueueUrl=sqsQueueUrl,
 ReceiptHandle=message['ReceiptHandle'])

 print('Done!')

In [None]:
print(f"Found {len(analysis_blocks)} blocks")

## Save Textract output to Disk

If you want to reprocess the output later on, you can just reload the data from the Pickle files.

In [None]:
import pickle

pickle.dump( detect_blocks, open( "blocks-detect.pkl", "wb" ) )
pickle.dump( analysis_blocks, open( "blocks-analysis.pkl", "wb" ) )

In [None]:
import pickle
with open( "blocks-detect.pkl", "rb" ) as F:
 detect_blocks = pickle.load(F)
with open( "blocks-analysis.pkl", "rb" ) as F:
 analysis_blocks = pickle.load(F)

## Load into Neptune

You only need to publish the data into Neptune once.

In [None]:
!pip install gremlinpython

In [None]:
from gremlin_python import statics
from gremlin_python.structure.graph import Graph
from gremlin_python.process.graph_traversal import __
from gremlin_python.process.strategies import *
from gremlin_python.driver.driver_remote_connection import DriverRemoteConnection

graph = Graph()

remoteConn = DriverRemoteConnection('wss://' + neptune_endpoint + ':8182/gremlin','g')
g = graph.traversal().withRemote(remoteConn)

In [None]:
g.V().drop().iterate()

In [None]:
from gremlin_python.process.traversal import T, P, Operator
map_block_id = {}
cnt = 0
for block in (detect_blocks + analysis_blocks):
 
 btype = block['BlockType']
 bid = block['Id']
 uuid = str(cnt)
 g.addV(btype).property(T.id,uuid).property('block_id', bid).iterate()
 
 for attr in ['Text', 'RowIndex', 'ColumnIndex', 'Page']:
 if attr in block:
 g.V(uuid).property(attr, block[attr]).iterate()
 if 'Text' in block:
 tableprops = block['Text']
 elif 'RowIndex' in block and 'ColumnIndex' in block:
 tableprops = "{0},{1}".format(str(block['RowIndex']), str(block['ColumnIndex']))
 else:
 tableprops = ''
 g.V(uuid).property('tableprops', tableprops).iterate()
 bbox = block['Geometry']['BoundingBox']
 g.V(uuid).property('top', bbox['Top']).iterate()
 g.V(uuid).property('left', bbox['Left']).iterate()
 map_block_id[bid] = uuid
 cnt = cnt + 1
 if cnt % 100 == 0:
 print(f"Cnt = {cnt}")

In [None]:
g.V('0').toList()

In [None]:
def get_v(v_id):
 l = g.V(v_id).toList()
 return l[-1]

for block in (detect_blocks + analysis_blocks):
 bid = block['Id']
 v1 = get_v(map_block_id[bid])
 if 'Relationships' in block:
 for r in block['Relationships']:
 rtype = r['Type']
 rlist = r['Ids']
 for rid in rlist: 
 v2 = get_v(map_block_id[rid])
 g.V(v1).addE(rtype).to(v2).next()

In [None]:
remoteConn.close()

## Visualize

Here are a few example queries to look at tables, which have a hierarchical relationship (table to cell to words).

The first one looks at all tables and their cells.

In [None]:
%%gremlin -p v,oute,inv

g.V().hasLabel('TABLE').outE().inV().path()

![All Tables](images/all_tables.png)

This next example drills into a single table.

In [None]:
%%gremlin -p v,oute,inv,ine

g.V().hasLabel('TABLE').has(id, '12427').outE().inV().path()

![One Table](images/one_table.png)

Next we'll look at a single table and go out to the words included in the cells. You can use Gremlin syntax to further refine the query to focus on specific rows, for example.

In [None]:
%%gremlin -p v,oute,inv,ine

g.V().hasLabel('TABLE').has(id, '12427').outE().inV().outE().inV().path()

![Table with words](images/one_table_with_words.png)