{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Textract Visualization with Neptune\n", "\n", "[Amazon Textract](https://aws.amazon.com/textract/) is a fully managed machine learning service that automatically extracts text and data from scanned documents that goes beyond simple optical character recognition (OCR) to identify, understand, and extract data from forms and tables.\n", "\n", "The raw output from Textract is a series of JSON blocks representing pages, lines, words, tables, and cells in tables. When you are first exploring a PDF document, it's useful to visualize the relationship between these blocks to help you interpret the output.\n", "\n", "In this notebook, we show how to take the raw JSON output from a sample PDF file, insert it into [Amazon Neptune](https://aws.amazon.com/neptune/), a managed graph database, and then use Neptune to visualize part of the data.\n", "\n", "Some parts of the Python code are taken from the [Textract samples](https://github.com/aws-samples/amazon-textract-code-samples) and the [Textract documentation](https://docs.aws.amazon.com/textract/latest/dg/async-analyzing-with-sqs.html).\n", "\n", "## License\n", "\n", "Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n", "\n", "SPDX-License-Identifier: MIT-0\n", "\n", "## Prerequisites\n", "\n", "* Install the `wget` package into the python kernel\n", "* Notebook role needs permission to create SNS topics and SQS queues\n", "* Notebook role needs permission to use Textract\n", "* Notebook role needs permission to upload to S3 bucket\n", "* Create S3 bucket to hold the PDF file\n", "* Create [Textract service role](https://docs.aws.amazon.com/textract/latest/dg/api-async-roles.html)\n", "* Create a Neptune database (see the [quick start](https://docs.aws.amazon.com/neptune/latest/userguide/intro.html) guide)\n", "* Run this notebook in the [Neptune Workbench](https://docs.aws.amazon.com/neptune/latest/userguide/notebooks.html)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Set these values to reflect your S3 bucket, Textract service role ARN, and Neptune database endpoint\n", "\n", "bucket = ''\n", "in_prefix = 'in'\n", "role_arn = ''\n", "neptune_endpoint = ''" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Download PDF\n", "\n", "In this example we'll process a PDF file published in 2013 on Reinforcement Learning. The official source citation is:\n", "\n", " @misc{mnih2013playing,\n", " title={Playing Atari with Deep Reinforcement Learning}, \n", " author={Volodymyr Mnih and Koray Kavukcuoglu and David Silver and Alex Graves and Ioannis Antonoglou and Daan Wierstra and Martin Riedmiller},\n", " year={2013},\n", " eprint={1312.5602},\n", " archivePrefix={arXiv},\n", " primaryClass={cs.LG}\n", " }\n", " \n", "I downloaded the PDF file from:\n", "\n", " https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf\n", " \n", "You can also find the file at:\n", "\n", " https://arxiv.org/abs/1312.5602" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install wget" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pdf_url = 'https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import wget\n", "filename = wget.download(pdf_url)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "filename" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create SQS queue and SNS topic for job notifications\n", "\n", "When processing PDFs, Textract runs asynchronously. It sends job status notifications to SNS, and we'll forward those to SQS so we can poll SQS for job status." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import boto3\n", "textract = boto3.client('textract')\n", "sqs = boto3.client('sqs')\n", "sns = boto3.client('sns')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import time\n", "millis = str(int(round(time.time() * 1000)))\n", "snsTopicName=\"AmazonTextractTopic\" + millis\n", "topicResponse=sns.create_topic(Name=snsTopicName)\n", "snsTopicArn = topicResponse['TopicArn']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sqsQueueName=\"AmazonTextractQueue\" + millis\n", "sqs.create_queue(QueueName=sqsQueueName)\n", "sqsQueueUrl = sqs.get_queue_url(QueueName=sqsQueueName)['QueueUrl']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "attribs = sqs.get_queue_attributes(QueueUrl=sqsQueueUrl,\n", " AttributeNames=['QueueArn'])['Attributes']\n", " \n", "sqsQueueArn = attribs['QueueArn']\n", "sns.subscribe(\n", " TopicArn=snsTopicArn,\n", " Protocol='sqs',\n", " Endpoint=sqsQueueArn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "policy = \"\"\"{{\n", " \"Version\":\"2012-10-17\",\n", " \"Statement\":[\n", " {{\n", " \"Sid\":\"MyPolicy\",\n", " \"Effect\":\"Allow\",\n", " \"Principal\" : {{\"AWS\" : \"*\"}},\n", " \"Action\":\"SQS:SendMessage\",\n", " \"Resource\": \"{}\",\n", " \"Condition\":{{\n", " \"ArnEquals\":{{\n", " \"aws:SourceArn\": \"{}\"\n", " }}\n", " }}\n", " }}\n", " ]\n", "}}\"\"\".format(sqsQueueArn, snsTopicArn)\n", " \n", "response = sqs.set_queue_attributes(\n", " QueueUrl = sqsQueueUrl,\n", " Attributes = {\n", " 'Policy' : policy\n", " })" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Upload PDF to bucket" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "s3 = boto3.client('s3')\n", "s3.upload_file(filename, bucket, \"{0}/{1}\".format(in_prefix, filename))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run detection job" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def GetResults(processType, jobId):\n", " maxResults = 1000\n", " paginationToken = None\n", " finished = False\n", "\n", " blocks_to_save = []\n", " while finished == False:\n", "\n", " response=None\n", "\n", " if processType=='analysis':\n", " if paginationToken==None:\n", " response = textract.get_document_analysis(JobId=jobId,\n", " MaxResults=maxResults)\n", " else: \n", " response = textract.get_document_analysis(JobId=jobId,\n", " MaxResults=maxResults,\n", " NextToken=paginationToken) \n", "\n", " if processType=='detect':\n", " if paginationToken==None:\n", " response = textract.get_document_text_detection(JobId=jobId,\n", " MaxResults=maxResults)\n", " else: \n", " response = textract.get_document_text_detection(JobId=jobId,\n", " MaxResults=maxResults,\n", " NextToken=paginationToken) \n", "\n", " blocks=response['Blocks'] \n", " \n", " # Display block information\n", " for block in blocks:\n", " blocks_to_save.append(block)\n", "\n", " if 'NextToken' in response:\n", " paginationToken = response['NextToken']\n", " else:\n", " finished = True\n", "\n", " return blocks_to_save" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "jobFound = False\n", "docname = \"{0}/{1}\".format(in_prefix, filename)\n", "response = textract.start_document_text_detection(DocumentLocation={'S3Object': {'Bucket': bucket, \n", " 'Name': docname}},\n", " NotificationChannel={'RoleArn': role_arn, 'SNSTopicArn': snsTopicArn})\n", " \n", "\n", "print('Start Job Id: ' + response['JobId'])\n", "detect_blocks = []\n", "while jobFound == False:\n", " sqsResponse = sqs.receive_message(QueueUrl=sqsQueueUrl, MessageAttributeNames=['ALL'],\n", " MaxNumberOfMessages=10)\n", "\n", " if sqsResponse:\n", "\n", " if 'Messages' not in sqsResponse:\n", " print(\"Waiting...\")\n", " time.sleep(10)\n", " continue\n", "\n", " for message in sqsResponse['Messages']:\n", " notification = json.loads(message['Body'])\n", " textMessage = json.loads(notification['Message'])\n", " print(textMessage['JobId'])\n", " print(textMessage['Status'])\n", " if str(textMessage['JobId']) == response['JobId']:\n", " print('Matching Job Found:' + textMessage['JobId'])\n", " jobFound = True\n", " detect_blocks = GetResults('detect', textMessage['JobId'])\n", " sqs.delete_message(QueueUrl=sqsQueueUrl,\n", " ReceiptHandle=message['ReceiptHandle'])\n", " else:\n", " print(\"Job didn't match:\" +\n", " str(textMessage['JobId']) + ' : ' + str(response['JobId']))\n", " sqs.delete_message(QueueUrl=sqsQueueUrl,\n", " ReceiptHandle=message['ReceiptHandle'])\n", "\n", " print('Done!')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(f\"Found {len(detect_blocks)} blocks\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run analysis job" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "jobFound = False\n", "response = textract.start_document_analysis(DocumentLocation={'S3Object': {'Bucket': bucket, \n", " 'Name': docname}},\n", " FeatureTypes=[\"TABLES\", \"FORMS\"],\n", " NotificationChannel={'RoleArn': role_arn, 'SNSTopicArn': snsTopicArn})\n", " \n", "\n", "print('Start Job Id: ' + response['JobId'])\n", "analysis_blocks = []\n", "while jobFound == False:\n", " sqsResponse = sqs.receive_message(QueueUrl=sqsQueueUrl, MessageAttributeNames=['ALL'],\n", " MaxNumberOfMessages=10)\n", "\n", " if sqsResponse:\n", "\n", " if 'Messages' not in sqsResponse:\n", " print(\"Waiting...\")\n", " time.sleep(10)\n", " continue\n", "\n", " for message in sqsResponse['Messages']:\n", " notification = json.loads(message['Body'])\n", " textMessage = json.loads(notification['Message'])\n", " print(textMessage['JobId'])\n", " print(textMessage['Status'])\n", " if str(textMessage['JobId']) == response['JobId']:\n", " print('Matching Job Found:' + textMessage['JobId'])\n", " jobFound = True\n", " analysis_blocks = GetResults('analysis', textMessage['JobId'])\n", " sqs.delete_message(QueueUrl=sqsQueueUrl,\n", " ReceiptHandle=message['ReceiptHandle'])\n", " else:\n", " print(\"Job didn't match:\" +\n", " str(textMessage['JobId']) + ' : ' + str(response['JobId']))\n", " sqs.delete_message(QueueUrl=sqsQueueUrl,\n", " ReceiptHandle=message['ReceiptHandle'])\n", "\n", " print('Done!')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(f\"Found {len(analysis_blocks)} blocks\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Save Textract output to Disk\n", "\n", "If you want to reprocess the output later on, you can just reload the data from the Pickle files." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pickle\n", "\n", "pickle.dump( detect_blocks, open( \"blocks-detect.pkl\", \"wb\" ) )\n", "pickle.dump( analysis_blocks, open( \"blocks-analysis.pkl\", \"wb\" ) )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pickle\n", "with open( \"blocks-detect.pkl\", \"rb\" ) as F:\n", " detect_blocks = pickle.load(F)\n", "with open( \"blocks-analysis.pkl\", \"rb\" ) as F:\n", " analysis_blocks = pickle.load(F)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load into Neptune\n", "\n", "You only need to publish the data into Neptune once." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install gremlinpython" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from gremlin_python import statics\n", "from gremlin_python.structure.graph import Graph\n", "from gremlin_python.process.graph_traversal import __\n", "from gremlin_python.process.strategies import *\n", "from gremlin_python.driver.driver_remote_connection import DriverRemoteConnection\n", "\n", "graph = Graph()\n", "\n", "remoteConn = DriverRemoteConnection('wss://' + neptune_endpoint + ':8182/gremlin','g')\n", "g = graph.traversal().withRemote(remoteConn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g.V().drop().iterate()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from gremlin_python.process.traversal import T, P, Operator\n", "map_block_id = {}\n", "cnt = 0\n", "for block in (detect_blocks + analysis_blocks):\n", " \n", " btype = block['BlockType']\n", " bid = block['Id']\n", " uuid = str(cnt)\n", " g.addV(btype).property(T.id,uuid).property('block_id', bid).iterate()\n", " \n", " for attr in ['Text', 'RowIndex', 'ColumnIndex', 'Page']:\n", " if attr in block:\n", " g.V(uuid).property(attr, block[attr]).iterate()\n", " if 'Text' in block:\n", " tableprops = block['Text']\n", " elif 'RowIndex' in block and 'ColumnIndex' in block:\n", " tableprops = \"{0},{1}\".format(str(block['RowIndex']), str(block['ColumnIndex']))\n", " else:\n", " tableprops = ''\n", " g.V(uuid).property('tableprops', tableprops).iterate()\n", " bbox = block['Geometry']['BoundingBox']\n", " g.V(uuid).property('top', bbox['Top']).iterate()\n", " g.V(uuid).property('left', bbox['Left']).iterate()\n", " map_block_id[bid] = uuid\n", " cnt = cnt + 1\n", " if cnt % 100 == 0:\n", " print(f\"Cnt = {cnt}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g.V('0').toList()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_v(v_id):\n", " l = g.V(v_id).toList()\n", " return l[-1]\n", "\n", "for block in (detect_blocks + analysis_blocks):\n", " bid = block['Id']\n", " v1 = get_v(map_block_id[bid])\n", " if 'Relationships' in block:\n", " for r in block['Relationships']:\n", " rtype = r['Type']\n", " rlist = r['Ids']\n", " for rid in rlist: \n", " v2 = get_v(map_block_id[rid])\n", " g.V(v1).addE(rtype).to(v2).next()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "remoteConn.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualize\n", "\n", "Here are a few example queries to look at tables, which have a hierarchical relationship (table to cell to words).\n", "\n", "The first one looks at all tables and their cells." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%gremlin -p v,oute,inv\n", "\n", "g.V().hasLabel('TABLE').outE().inV().path()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![All Tables](images/all_tables.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This next example drills into a single table." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%gremlin -p v,oute,inv,ine\n", "\n", "g.V().hasLabel('TABLE').has(id, '12427').outE().inV().path()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![One Table](images/one_table.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we'll look at a single table and go out to the words included in the cells. You can use Gremlin syntax to further refine the query to focus on specific rows, for example." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%gremlin -p v,oute,inv,ine\n", "\n", "g.V().hasLabel('TABLE').has(id, '12427').outE().inV().outE().inV().path()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![Table with words](images/one_table_with_words.png)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }