{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Retail Data Discovery and Processing"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Background \n",
"\n",
"This notebook was published as part of an AWS ML blog . It is the second notebook of two published as part of the blog. The first notebook presents the bulk of the data discovery and processing steps described in the blog. This second notebook takes the dataset generated by the first notebook, and demonstrates the training and model serving steps involved in a ML process to deploy a custom Tensorflow model in SageMaker.\n",
"\n",
"The dataset referenced by this notebook originates from the public UCI Machine Learning Repository: http://archive.ics.uci.edu/ml/datasets/online+retail\n",
"\n",
" Source: \n",
"\n",
"Dr Daqing Chen, Director: Public Analytics group. chend '@' lsbu.ac.uk, School of Engineering, London South Bank University, London SE1 0AA, UK.\n",
"\n",
" Data Set Information: \n",
"\n",
"This is a transnational data set which contains all the transactions occurring between 01/12/2010 and 09/12/2011 for a UK-based and registered non-store online retail.The company mainly sells unique all-occasion gifts. Many customers of the company are wholesalers.\n",
"\n",
" Attribute Information: \n",
"\n",
"InvoiceNo: Invoice number. Nominal, a 6-digit integral number uniquely assigned to each transaction. If this code starts with letter 'c', it indicates a cancellation. StockCode: Product (item) code. Nominal, a 5-digit integral number uniquely assigned to each distinct product. Description: Product (item) name. Nominal. Quantity: The quantities of each product (item) per transaction. Numeric.\n",
"InvoiceDate: Invice Date and time. Numeric, the day and time when each transaction was generated. UnitPrice: Unit price. Numeric, Product price per unit in sterling. CustomerID: Customer number. Nominal, a 5-digit integral number uniquely assigned to each customer. Country: Country name. Nominal, the name of the country where each customer resides."
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"# Setup\n",
"\n",
"This notebook was created and tested on an ml.t2.medium notebook instance running on the Sparkmagic (PySpark3) kernel, and an external multi-node EMR cluster. Follow the documention to have your Sagemaker notebook instance connect to EMR.\n",
"\n",
"Begin by...\n",
"1. Downloading the dataset .\n",
"2. Upload the data set to an S3 bucket.\n",
"3. Start up a notebook instance and ensure the notebook instance has access to the data set. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Import the required dependencies."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"from pyspark import SparkContext, SparkConf\n",
"from pyspark.sql.types import StructType\n",
"from pyspark.sql.types import StructField\n",
"from pyspark.sql.types import StringType, IntegerType, FloatType\n",
"from pyspark.sql import functions as F\n",
"from pyspark.sql.functions import col, udf, lit\n",
"from pyspark.sql import Row\n",
"from pyspark.ml.feature import StringIndexer\n",
"from time import mktime, strptime"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Set S3_BUCKET to the name of your S3 BUCKET. Set the S3_TARGET_PREFIX to the location of your dataset"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"s3://dtong-ml-datasets/raw/ecommerce-data.csv.gz"
]
}
],
"source": [
"S3_BUCKET = \"dtong-ml-datasets\"\n",
"S3_TARGET_PREFIX = \"/raw/ecommerce-data.csv.gz\"\n",
"S3_LOCATION = \"s3://\"+S3_BUCKET+S3_TARGET_PREFIX \n",
"print(S3_LOCATION)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load the dataset into a Spark dataframe."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+---------+---------+--------------------+--------+-----------+-----+----------+--------------+\n",
"|InvoiceNo|StockCode| Description|Quantity|InvoiceDate|Price|CustomerID| Country|\n",
"+---------+---------+--------------------+--------+-----------+-----+----------+--------------+\n",
"| 536365| 85123A|WHITE HANGING HEA...| 6| 1291191960| 2.55| 17850|United Kingdom|\n",
"| 536365| 71053| WHITE METAL LANTERN| 6| 1291191960| 3.39| 17850|United Kingdom|\n",
"+---------+---------+--------------------+--------+-----------+-----+----------+--------------+\n",
"only showing top 2 rows\n",
"\n",
"root\n",
" |-- InvoiceNo: integer (nullable = true)\n",
" |-- StockCode: string (nullable = true)\n",
" |-- Description: string (nullable = true)\n",
" |-- Quantity: integer (nullable = true)\n",
" |-- InvoiceDate: integer (nullable = true)\n",
" |-- Price: float (nullable = true)\n",
" |-- CustomerID: integer (nullable = true)\n",
" |-- Country: string (nullable = true)"
]
}
],
"source": [
"dateToTsUdf = udf(lambda date: int(mktime(strptime(date,\"%m/%d/%Y %H:%M\"))) if date is not None else None)\n",
"\n",
"invoiceSchema = StructType([StructField('InvoiceNo', IntegerType(), False),\n",
" StructField('StockCode', StringType(), False),\n",
" StructField('Description', StringType(), True),\n",
" StructField('Quantity', IntegerType(), False),\n",
" StructField('InvoiceDate', StringType(), False),\n",
" StructField('Price', FloatType(), False),\n",
" StructField('CustomerID', IntegerType(), True),\n",
" StructField('Country', StringType(), False)])\n",
"\n",
"invoicesDf= spark.read.options(header=True).schema(invoiceSchema).csv(S3_LOCATION) \n",
"invoicesDf=invoicesDf.withColumn(\"InvoiceDate\", dateToTsUdf(invoicesDf.InvoiceDate).cast(IntegerType()))\n",
"\n",
"invoicesDf.show(2)\n",
"invoicesDf.printSchema()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The data set contains dirty data that we want to include in our training and test sets. This include transcations related to refunds, discounts, postage and such. After cleansing, the data set contains 396,485 transcations."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"396485"
]
}
],
"source": [
"invoicesDf.registerTempTable(\"invoices\")\n",
"invoicesCleanDf = sqlContext.sql(\n",
" \"select * \\\n",
" from invoices \\\n",
" where Description is not null \\\n",
" and Description != 'Manual' \\\n",
" and Description != 'POSTAGE' \\\n",
" and Description != 'Discount' \\\n",
" and CustomerID is not null \\\n",
" and Quantity > 0 \\\n",
" and Price > 0 \\\n",
" and Description != 'DOTCOM POSTAGE' \\\n",
" and Description != 'AMAZON FEE'\"\n",
")\n",
"invoicesCleanDf.count()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\"Description\" can be used as the unique indentifer for products in this dataset. We need to convert the text description to integer values, so that we have an index for a vector of products that can be used as inputs and labels for a ML model."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-----------+---------+----------+--------------------+------------+--------+\n",
"|InvoiceDate|InvoiceNo|CustomerID| Description|ProductIndex|Quantity|\n",
"+-----------+---------+----------+--------------------+------------+--------+\n",
"| 1291191960| 536365| 17850|WHITE HANGING HEA...| 0.0| 6|\n",
"| 1291191960| 536365| 17850| WHITE METAL LANTERN| 434.0| 6|\n",
"| 1291191960| 536365| 17850|CREAM CUPID HEART...| 452.0| 8|\n",
"+-----------+---------+----------+--------------------+------------+--------+\n",
"only showing top 3 rows"
]
}
],
"source": [
"productIndexer = StringIndexer(inputCol=\"Description\", outputCol=\"ProductIndex\") \\\n",
".fit(invoicesCleanDf)\n",
"\n",
"timeSeriesDF = productIndexer.transform(invoicesCleanDf) \\\n",
".select([\"InvoiceDate\",\"InvoiceNo\",\"CustomerID\",\"Description\",\"ProductIndex\",\"Quantity\"])\n",
"timeSeriesDF.show(3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The time range for the dataset is from December 01, 2010 00:26:00 to December 09, 2011 04:50:00"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+----------------+----------------+\n",
"|max(InvoiceDate)|min(InvoiceDate)|\n",
"+----------------+----------------+\n",
"| 1323435000| 1291191960|\n",
"+----------------+----------------+"
]
}
],
"source": [
"timeSeriesDF.select(\"InvoiceDate\").agg(F.max(\"InvoiceDate\"),F.min(\"InvoiceDate\")).show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This dataset has 4335 customers"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4335"
]
}
],
"source": [
"timeSeriesDF.select(\"CustomerID\").distinct().count()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"There are a total of 3874 products"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3874"
]
}
],
"source": [
"timeSeriesDF.select(\"ProductIndex\").distinct().count()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"There are 4335 customers with an order history ranging from 1 to 207. The average number of transactions per customer is 4.26 with a standard deviation of 7.67."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+---------------+---+---+------------------+-----------------+\n",
"|total customers|max|min| avg| std|\n",
"+---------------+---+---+------------------+-----------------+\n",
"| 4335|207| 1|4.2551326412918105|7.668065002405116|\n",
"+---------------+---+---+------------------+-----------------+"
]
}
],
"source": [
"txnPerCustomer = timeSeriesDF.groupBy(\"CustomerID\") \\\n",
".agg(F.countDistinct(col(\"InvoiceDate\"),col(\"InvoiceNo\")).alias(\"TxnPerCustomer\")).sort(col(\"TxnPerCustomer\").desc())\n",
"\n",
"txnPerCustomer.agg(F.countDistinct(\"CustomerID\").alias(\"total customers\"), \n",
" F.max(\"TxnPerCustomer\").alias(\"max\"),\n",
" F.min(\"TxnPerCustomer\").alias(\"min\"),\n",
" F.avg(\"TxnPerCustomer\").alias(\"avg\"),\n",
" F.stddev(\"TxnPerCustomer\").alias(\"std\")).show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"I chose to trim the dataset, so we only consider customers with some minimum amount of order history. I also chose to cap the maximum number of orders, so outliers are removed. RNNs with long sequences suffer from the vanishing gradient problem . Keeping outliers in the dataset isn't worthwhile. The maximum and minimum of 50 and 5 was selected with little testing. I leave it up to the reader to search for more optimial parameters if desired.\n",
"\n",
"We are now left with 1090 customers with order history between 50 and 5. The average number of transactions within this group of customers is about 10 with a standard deviation of 6.76."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+---------------+---+---+----------------+-----------------+\n",
"|total customers|max|min| avg| std|\n",
"+---------------+---+---+----------------+-----------------+\n",
"| 1090| 50| 5|9.93302752293578|6.762405598972053|\n",
"+---------------+---+---+----------------+-----------------+"
]
}
],
"source": [
"MAX_TXNS = 50\n",
"MIN_TXNS = 5\n",
"\n",
"#transactions per customer statistics\n",
"#txnPerCustomer = timeSeriesDF.groupBy(\"CustomerID\") \\\n",
"#.agg(F.countDistinct(col(\"InvoiceDate\"),col(\"InvoiceNo\")).alias(\"TxnPerCustomer\"))\n",
"\n",
"txnPerCustomer = txnPerCustomer \\\n",
".filter(txnPerCustomer.TxnPerCustomer <= MAX_TXNS) \\\n",
".filter((txnPerCustomer.TxnPerCustomer >= MIN_TXNS))\n",
"\n",
"customerBasket = txnPerCustomer.select(col(\"CustomerID\").astype(\"float\")).collect()\n",
"customerBasket = np.array(customerBasket)\n",
"N = customerBasket.shape[0]\n",
"customerBasket= customerBasket.reshape(N)\n",
"\n",
"txnStats = txnPerCustomer.agg( \\\n",
" F.countDistinct(\"CustomerID\").alias(\"total customers\"), \n",
" F.max(\"TxnPerCustomer\").alias(\"max\"),\n",
" F.min(\"TxnPerCustomer\").alias(\"min\"),\n",
" F.avg(\"TxnPerCustomer\").alias(\"avg\"),\n",
" F.stddev(\"TxnPerCustomer\").alias(\"std\")).show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we filter the dataset to only include the orders from customers who have an order history of 5 to 50. We are left with 223881 orders."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-----------+---------+----------+--------------------+------------+--------+\n",
"|InvoiceDate|InvoiceNo|CustomerID| Description|ProductIndex|Quantity|\n",
"+-----------+---------+----------+--------------------+------------+--------+\n",
"| 1291191960| 536365| 17850|WHITE HANGING HEA...| 0.0| 6|\n",
"| 1291191960| 536365| 17850| WHITE METAL LANTERN| 434.0| 6|\n",
"| 1291191960| 536365| 17850|CREAM CUPID HEART...| 452.0| 8|\n",
"| 1291191960| 536365| 17850|KNITTED UNION FLA...| 276.0| 6|\n",
"| 1291191960| 536365| 17850|RED WOOLLY HOTTIE...| 268.0| 6|\n",
"+-----------+---------+----------+--------------------+------------+--------+\n",
"only showing top 5 rows\n",
"\n",
"223881"
]
}
],
"source": [
"basketTimeSeries = timeSeriesDF \\\n",
".filter(timeSeriesDF.CustomerID.isin(*customerBasket) == True)\n",
"\n",
"basketTimeSeries.show(5)\n",
"basketTimeSeries.count()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The products need to be re-index after filtering the dataset."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-----------+---------+----------+--------------------+------------+--------+--------+\n",
"|InvoiceDate|InvoiceNo|CustomerID| Description|ProductIndex|NewIndex|Quantity|\n",
"+-----------+---------+----------+--------------------+------------+--------+--------+\n",
"| 1291191960| 536365| 17850|WHITE HANGING HEA...| 0.0| 0.0| 6|\n",
"| 1291191960| 536365| 17850| WHITE METAL LANTERN| 434.0| 391.0| 6|\n",
"| 1291191960| 536365| 17850|CREAM CUPID HEART...| 452.0| 344.0| 8|\n",
"+-----------+---------+----------+--------------------+------------+--------+--------+\n",
"only showing top 3 rows\n",
"\n",
"223881"
]
}
],
"source": [
"productReIndexer = StringIndexer(inputCol=\"ProductIndex\", outputCol=\"NewIndex\") \\\n",
".fit(basketTimeSeries)\n",
"\n",
"basketTimeSeriesReIndexed = productReIndexer.transform(basketTimeSeries) \\\n",
".select([\"InvoiceDate\",\"InvoiceNo\",\"CustomerID\",\"Description\",\"ProductIndex\",\"NewIndex\",\"Quantity\"])\n",
"basketTimeSeriesReIndexed.show(3)\n",
"basketTimeSeriesReIndexed.count()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The most popular product by quantity bought is product 138: \"WORLD WAR 2 GLIDERS ASSTD DESIGNS.\" This is a useful datapoint that is used later when evaluating our predictive model."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+------------+----------------------------------+--------+\n",
"|ProductIndex|Description |TotalQty|\n",
"+------------+----------------------------------+--------+\n",
"|138.0 |WORLD WAR 2 GLIDERS ASSTD DESIGNS |36890 |\n",
"|2.0 |JUMBO BAG RED RETROSPOT |29850 |\n",
"|63.0 |POPCORN HOLDER |26280 |\n",
"|3.0 |ASSORTED COLOUR BIRD ORNAMENT |24673 |\n",
"|186.0 |PACK OF 12 LONDON TISSUES |23371 |\n",
"|0.0 |WHITE HANGING HEART T-LIGHT HOLDER|22667 |\n",
"|8.0 |PACK OF 72 RETROSPOT CAKE CASES |19052 |\n",
"|56.0 |PACK OF 60 PINK PAISLEY CAKE CASES|17889 |\n",
"|279.0 |MINI PAINT SET VINTAGE |17291 |\n",
"|30.0 |VICTORIAN GLASS HANGING T-LIGHT |16588 |\n",
"+------------+----------------------------------+--------+"
]
}
],
"source": [
"basketTimeSeriesReIndexed.select(\"ProductIndex\",\"Description\",\"Quantity\") \\\n",
".groupBy(\"ProductIndex\",\"Description\").agg(F.sum(col(\"Quantity\")).alias(\"TotalQty\")) \\\n",
".orderBy(col(\"TotalQty\").desc()).limit(10).show(truncate=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The most popular product by times bought is product 0: \"WHITE HANGING HEART T-LIGHT HOLDER.\" This is a useful datapoint that is used later when evaluating our predictive model."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+------------+----------------------------------+-----------+\n",
"|ProductIndex|Description |TimesBought|\n",
"+------------+----------------------------------+-----------+\n",
"|0.0 |WHITE HANGING HEART T-LIGHT HOLDER|1227 |\n",
"|2.0 |JUMBO BAG RED RETROSPOT |1054 |\n",
"|1.0 |REGENCY CAKESTAND 3 TIER |995 |\n",
"|5.0 |LUNCH BAG RED RETROSPOT |924 |\n",
"|4.0 |PARTY BUNTING |852 |\n",
"|3.0 |ASSORTED COLOUR BIRD ORNAMENT |825 |\n",
"|7.0 |LUNCH BAG BLACK SKULL. |743 |\n",
"|6.0 |SET OF 3 CAKE TINS PANTRY DESIGN |690 |\n",
"|11.0 |LUNCH BAG SPACEBOY DESIGN |646 |\n",
"|17.0 |LUNCH BAG SUKI DESIGN |640 |\n",
"+------------+----------------------------------+-----------+"
]
}
],
"source": [
"basketTimeSeriesReIndexed.select(\"ProductIndex\",\"Description\",\"Quantity\") \\\n",
".groupBy(\"ProductIndex\",\"Description\") \\\n",
".agg(F.count(lit(1)).alias(\"TimesBought\")) \\\n",
".orderBy(col(\"TimesBought\").desc()).limit(10).show(truncate=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The number of products bought per order ranges from 1 to 540 within our reduced dataset. On average, 20 products are bought with a standard deviation of 24.5. "
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+----------+---+---+-----------------+------------------+\n",
"|total txns|max|min| avg| std|\n",
"+----------+---+---+-----------------+------------------+\n",
"| 10807|540| 1|20.24032511314307|24.492032572420484|\n",
"+----------+---+---+-----------------+------------------+"
]
}
],
"source": [
"MAX_PRODUCTS = 600\n",
"productsPerTxn = basketTimeSeriesReIndexed \\\n",
".groupBy(\"InvoiceDate\", \"CustomerID\", \"InvoiceNo\") \\\n",
".agg(F.countDistinct(col(\"NewIndex\")).alias(\"ProductsPerTxn\")) \\\n",
".orderBy(col(\"CustomerID\").desc(), col(\"InvoiceNo\").desc(), col(\"InvoiceDate\").desc())\n",
"\n",
"productsPerTxn= productsPerTxn.filter(productsPerTxn.ProductsPerTxn <= MAX_PRODUCTS)\n",
"\n",
"productStats = productsPerTxn.agg( \\\n",
" F.countDistinct(\"InvoiceNo\").alias(\"total txns\"),\n",
" F.max(\"ProductsPerTxn\").alias(\"max\"),\n",
" F.min(\"ProductsPerTxn\").alias(\"min\"),\n",
" F.avg(\"ProductsPerTxn\").alias(\"avg\"),\n",
" F.stddev(\"ProductsPerTxn\").alias(\"std\")).show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following will serve as a product name lookup table for the remaining products in our filtered dataset. 3648 products remain."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------------------+-----+\n",
"| Description|Index|\n",
"+--------------------+-----+\n",
"|WHITE HANGING HEA...| 0.0|\n",
"|JUMBO BAG RED RET...| 1.0|\n",
"|REGENCY CAKESTAND...| 2.0|\n",
"|LUNCH BAG RED RET...| 3.0|\n",
"| PARTY BUNTING| 4.0|\n",
"|ASSORTED COLOUR B...| 5.0|\n",
"|LUNCH BAG BLACK ...| 6.0|\n",
"|SET OF 3 CAKE TIN...| 7.0|\n",
"|LUNCH BAG SPACEBO...| 8.0|\n",
"|LUNCH BAG SUKI DE...| 9.0|\n",
"|LUNCH BAG PINK PO...| 10.0|\n",
"|PACK OF 72 RETROS...| 11.0|\n",
"|ALARM CLOCK BAKEL...| 12.0|\n",
"| LUNCH BAG CARS BLUE| 13.0|\n",
"| SPOTTY BUNTING| 14.0|\n",
"|WOODEN PICTURE FR...| 15.0|\n",
"|LUNCH BAG APPLE D...| 16.0|\n",
"|JUMBO BAG PINK PO...| 17.0|\n",
"|ALARM CLOCK BAKEL...| 18.0|\n",
"|WOODEN FRAME ANTI...| 19.0|\n",
"+--------------------+-----+\n",
"only showing top 20 rows\n",
"\n",
"3648"
]
}
],
"source": [
"releventProducts = basketTimeSeriesReIndexed \\\n",
".select(\"Description\", col(\"NewIndex\").alias(\"Index\")).distinct() \\\n",
".sort(col(\"Index\").asc())\n",
"\n",
"releventProducts.show()\n",
"releventProducts.count()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's persist this lookup table for use later. Set S3_TARGET_PREFIX to where you want to write out the lookup table. "
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"s3://dtong-ml-datasets/processed/ecomm/customer_basket_list.csv.gz"
]
}
],
"source": [
"S3_TARGET_PREFIX = \"/processed/ecomm/customer_basket_list.csv.gz\"\n",
"S3_LOCATION = \"s3://\"+S3_BUCKET+S3_TARGET_PREFIX\n",
"\n",
"print(S3_LOCATION)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After the file is written to the S3 location you have provided, you may want to rename the file. Spark will write out the file into a folder with the name of the provided prefix. The actual file will be named by Spark in the form part-xxxx-xxxxxxx."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"'path s3://dtong-ml-datasets/processed/ecomm/customer_basket_list.csv.gz already exists.;'\n",
"Traceback (most recent call last):\n",
" File \"/usr/lib/spark/python/lib/pyspark.zip/pyspark/sql/readwriter.py\", line 766, in csv\n",
" self._jwrite.csv(path)\n",
" File \"/usr/lib/spark/python/lib/py4j-0.10.4-src.zip/py4j/java_gateway.py\", line 1133, in __call__\n",
" answer, self.gateway_client, self.target_id, self.name)\n",
" File \"/usr/lib/spark/python/lib/pyspark.zip/pyspark/sql/utils.py\", line 69, in deco\n",
" raise AnalysisException(s.split(': ', 1)[1], stackTrace)\n",
"pyspark.sql.utils.AnalysisException: 'path s3://dtong-ml-datasets/processed/ecomm/customer_basket_list.csv.gz already exists.;'\n",
"\n"
]
}
],
"source": [
"releventProducts.coalesce(1).write.option(\"compression\",\"gzip\") \\\n",
".csv(S3_LOCATION)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Below is an intermediate process step. our goal is to prepare a dataset for a RNN where we have a sequence of orders for each customer as input and targets. This intermediate step involves rolling up all the order line items into orders where we have a \"cart\" column that contains a dense vector of product index values. This vector stores the information about the products that were purchased in the order."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+----------+-----------+---------+--------------------+\n",
"|CustomerID|InvoiceDate|InvoiceNo| Cart|\n",
"+----------+-----------+---------+--------------------+\n",
"| 18283| 1294323240| 540350|[238.0, 703.0, 6....|\n",
"| 18283| 1295794680| 541854|[238.0, 1193.0, 1...|\n",
"| 18283| 1298889000| 545079|[0.0, 261.0, 1213...|\n",
"| 18283| 1303403820| 550957|[238.0, 621.0, 34...|\n",
"| 18283| 1306150380| 554157|[1570.0, 6.0, 10....|\n",
"| 18283| 1308051660| 556731|[0.0, 1455.0, 130...|\n",
"| 18283| 1308856800| 557956|[0.0, 293.0, 6.0,...|\n",
"| 18283| 1310648400| 560025|[0.0, 621.0, 127....|\n",
"| 18283| 1310649600| 560032| [444.0]|\n",
"| 18283| 1315226100| 565579|[1147.0, 392.0, 1...|\n",
"| 18283| 1319726280| 573093|[293.0, 1537.0, 2...|\n",
"| 18283| 1320937140| 575668|[293.0, 306.0, 66...|\n",
"| 18283| 1320937620| 575675| [290.0]|\n",
"| 18283| 1322054820| 578262|[238.0, 1537.0, 6...|\n",
"| 18283| 1322657940| 579673|[662.0, 318.0, 6....|\n",
"| 18283| 1323172920| 580872|[1144.0, 621.0, 6...|\n",
"| 18272| 1302168900| 549185|[947.0, 21.0, 119...|\n",
"| 18272| 1304014260| 551507|[2151.0, 1764.0, ...|\n",
"| 18272| 1310485320| 559813|[119.0, 21.0, 161...|\n",
"| 18272| 1313669160| 563680|[375.0, 816.0, 22...|\n",
"+----------+-----------+---------+--------------------+\n",
"only showing top 20 rows\n",
"\n",
"10827"
]
}
],
"source": [
"carts= basketTimeSeriesReIndexed \\\n",
".groupBy(\"InvoiceDate\", \"CustomerID\", \"InvoiceNo\") \\\n",
".agg(F.collect_set(\"NewIndex\").alias(\"Cart\")) \\\n",
".orderBy(col(\"CustomerID\").desc(),col(\"InvoiceDate\").asc(), col(\"InvoiceNo\").desc()) \\\n",
".select(\"CustomerID\",\"InvoiceDate\",\"InvoiceNo\",\"Cart\")\n",
"\n",
"carts.show()\n",
"carts.count()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"RNNs requires a fixed size input per timeslice, so we need to convert the dense vectors above to a sparse vector. The sparse vector's length is equal to the number of products, and is made up of zeros and ones. This allows us to use each position in the vector to represent whether a product was bought in the order. For instance, a one in position N means that the product with index value N was bought in the order. Zero indicates the product wasn't bought.\n",
"\n",
"This isn't an efficient representation of the data, and could be problematic for very large product catalogs. In such a situation, implementing an autoencoder maybe worthwhile. In this case, we have 3648 products, and is manageable. "
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+----------+---------+-----------+----------------------------------------------------------------------------------------------------+\n",
"|CustomerID|InvoiceNo|InvoiceDate| Cart|\n",
"+----------+---------+-----------+----------------------------------------------------------------------------------------------------+\n",
"| 18283| 540350| 1294323240|[0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, ...|\n",
"| 18283| 541854| 1295794680|[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ...|\n",
"| 18283| 545079| 1298889000|[1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, ...|\n",
"| 18283| 550957| 1303403820|[0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, ...|\n",
"| 18283| 554157| 1306150380|[0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, ...|\n",
"+----------+---------+-----------+----------------------------------------------------------------------------------------------------+\n",
"only showing top 5 rows"
]
}
],
"source": [
"nProducts= releventProducts.count()\n",
"\n",
"def encodeCart(cart) :\n",
" encoding = [\"0\"]*nProducts\n",
" \n",
" for idx in cart : \n",
" encoding[int(idx)] = \"1\"\n",
" \n",
" return encoding\n",
"\n",
"cartsSparseVecs = carts.rdd.map(lambda r: \\\n",
" Row(InvoiceNo=r[\"InvoiceNo\"], InvoiceDate=r[\"InvoiceDate\"], \\\n",
" CustomerID=r[\"CustomerID\"], Cart=encodeCart(r[\"Cart\"]))) \\\n",
".toDF().select(\"CustomerID\",\"InvoiceNo\",\"InvoiceDate\",\"Cart\")\n",
"cartsSparseVecs.show(5,truncate=100)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's write out this processed dataset. Update S3_TARGET_PREFIX to the location where you want to write out this dataset."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"s3://dtong-ml-datasets/processed/ecomm/cart_sparse_vecs.parquet.gz"
]
}
],
"source": [
"S3_TARGET_PREFIX = \"/processed/ecomm/cart_sparse_vecs.parquet.gz\"\n",
"S3_LOCATION = \"s3://\"+S3_BUCKET+S3_TARGET_PREFIX\n",
"\n",
"print(S3_LOCATION)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Write out the dataset in parquet format. Again, you probably want to rename the file manually afterwards."
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"'path s3://dtong-ml-datasets/processed/ecomm/cart_sparse_vecs.parquet.gz already exists.;'\n",
"Traceback (most recent call last):\n",
" File \"/usr/lib/spark/python/lib/pyspark.zip/pyspark/sql/readwriter.py\", line 691, in parquet\n",
" self._jwrite.parquet(path)\n",
" File \"/usr/lib/spark/python/lib/py4j-0.10.4-src.zip/py4j/java_gateway.py\", line 1133, in __call__\n",
" answer, self.gateway_client, self.target_id, self.name)\n",
" File \"/usr/lib/spark/python/lib/pyspark.zip/pyspark/sql/utils.py\", line 69, in deco\n",
" raise AnalysisException(s.split(': ', 1)[1], stackTrace)\n",
"pyspark.sql.utils.AnalysisException: 'path s3://dtong-ml-datasets/processed/ecomm/cart_sparse_vecs.parquet.gz already exists.;'\n",
"\n"
]
}
],
"source": [
"cartsSparseVecs.coalesce(1).write.option(\"compression\",\"gzip\") \\\n",
".parquet(\"s3://dtong-ml-datasets/processed/ecomm/cart_sparse_vecs.parquet.gz\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Sparkmagic (PySpark3)",
"language": "",
"name": "pyspark3kernel"
},
"language_info": {
"codemirror_mode": {
"name": "python",
"version": 3
},
"mimetype": "text/x-python",
"name": "pyspark3",
"pygments_lexer": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}