"""
Amazon Aurora Labs for MySQL
Read load generator using multiple threads. The script will randomize a set of query patterns,
including point query, range query, aggregation and expensive stored procedure.

Dependencies:
none

License:
This sample code is made available under the MIT-0 license. See the LICENSE file.
"""

# Dependencies
import sys
import argparse
import time
import threading
import _thread
import socket
import random
import pymysql
import datetime
import json
import urllib3
from os import environ

# Define parser
parser = argparse.ArgumentParser()
parser.add_argument('-e', '--endpoint', help="The database endpoint", required=True)
parser.add_argument('-p', '--password', help="The database user password", required=True)
parser.add_argument('-u', '--username', help="The database user name", required=True)
parser.add_argument('-d', '--database', help="The schema (database) to use", required=True)
parser.add_argument('-t', '--threads', help="The number of threads to use", type=int, default=64)
args = parser.parse_args()

# Global variables
query_count = 0
max_id = 2500000
query_iterations = 100
lock = threading.Lock()

# Track this lab for usage analytics, if user has explicitly or implicitly agreed
def track_analytics():
    http = urllib3.PoolManager()
    if environ["AGREETRACKING"] == 'Yes':
        # try/catch
        try:
            # build tracker payload
            payload = {
                'stack_uuid': environ["STACKUUID"],
                'stack_name': environ["STACKNAME"],
                'stack_region': environ["STACKREGION"],
                'deployed_cluster': None,
                'deployed_ml':  None,
                'deployed_gdb': None,
                'is_secondary': None,
                'event_timestamp': datetime.datetime.utcnow().isoformat() + 'Z',
                'event_scope': 'Script',
                'event_action': 'Execute',
                'event_message': 'reader_loadtest.py',
                'ee_event_id': None,
                'ee_team_id': None,
                'ee_module_id': None,
                'ee_module_version': None
            }

            # Send the tracking data
            r = http.request('POST', environ["ANALYTICSURI"], body=json.dumps(payload).encode('utf-8'), headers={'Content-Type': 'application/json'})
        except Exception as e:
            # Errors in tracker interaction should not prevent operation of the function in critical path
            print("[ERROR]", e)

# Query thread
def thread_func(endpoint, username, password, schema, max_id, iterations):
    # Specify that query_count is a global variable
    global query_count
    global lock

    # Loop Indefinitely
    while True:
        try:
            # Resolve the endpoint
            host = socket.gethostbyname(endpoint)

            # Connect to the reader endpoint
            conn = pymysql.connect(host=host, user=username, password=password, database=schema, autocommit=True)

            # Run multiple queries per connection
            for iter in range(iterations):
                # Generate a random number to use as the lookup value
                # we will arbitrarily switch between a few query types
                key_value = random.randrange(1, max_id)
                key_offset = random.randrange(1, 1000)
                query_type = random.randrange(0,5)

                # queries of multiple types
                if query_type == 0:
                    # Point query
                    sql_command = "SELECT SQL_NO_CACHE * FROM sbtest1 WHERE id= %d;" % key_value
                elif query_type == 1:
                    # Range query
                    sql_command = "SELECT SQL_NO_CACHE *, SHA2(c, 512), SQRT(k) FROM sbtest1 WHERE id BETWEEN %d AND %d ORDER BY id DESC LIMIT 10;" % (key_value, key_value + key_offset)
                elif query_type == 2:
                    # Aggregation
                    sql_command = "SELECT SQL_NO_CACHE k, COUNT(k), SQRT(SUM(k)), SQRT(AVG(k)) FROM sbtest1 WHERE id BETWEEN %d AND %d GROUP BY k ORDER BY k;" % (key_value, key_value + key_offset)
                elif query_type == 3:
                    # Point query with hashing
                    sql_command = "SELECT SQL_NO_CACHE id, SHA2(c, 512) AS token FROM sbtest1 WHERE id= %d;" % key_value
                elif query_type == 4:
                    # Point query with hashing
                    sql_command = "CALL minute_rollup(%d);" % (key_offset * 10)


                # run query
                with conn.cursor() as cursor:
                    cursor.execute(sql_command)
                    cursor.close()

                # Increment the executed query count
                with lock:
                    query_count += 1

            # Close the connection
            conn.close()
        except:
            # Display any exception information
            print(sys.exc_info()[1])


# Progress thread
def progress_func():
    # Specify that query_count is a global variable
    global query_count
    global lock

    # Start timing
    start_time = time.time()
    initial = True

    # Loop indefinitely
    while True:
        if initial != True:
            # Format an output string
            end_time = time.time()
            output = "Queries/sec: {0} (press Ctrl+C to quit)\r".format(int(query_count / (end_time-start_time)))
            start_time = end_time

            # Reset the executed query count
            with lock:
                query_count = 0

            # Write to STDOUT and flush
            sys.stdout.write(output)
            sys.stdout.flush()

            # Sleep this thread for 1 second
            time.sleep(5)

        # No longer initial pass
        initial = False


# Invoke tracking function
track_analytics()

# Start progress thread
_thread.start_new_thread(progress_func, ())

# Start readers
for thread_id in range(args.threads):
    _thread.start_new_thread(thread_func, (args.endpoint, args.username, args.password, args.database, max_id, query_iterations))

# Loop indefinitely to prevent application exit
while 1:
    pass