{ "cells": [ { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "import os\n", "import csv\n", "import pandas as pd\n", "import hashlib\n", "from io import BytesIO\n", "import pickle, gzip\n", "import random as rand\n", "\n", "from PIL import Image\n", "import numpy as np\n", "\n", "import torch\n", "from torch.utils.data import Dataset, DataLoader\n", "from torchvision import transforms, utils\n", "\n", "import sagemaker\n", "\n", "WORKING_DIR = os.getcwd()\n", "DATA_DIR = WORKING_DIR+'/ut-zap50k-images-square'\n", " \n", "ZAPPOS50K_INDEX = WORKING_DIR+'/zappos50k-index.csv'\n", "DOWNLOAD_S3URI = \"s3://reinvent2018-sagemaker-pytorch\"\n", "\n", "WEIGHT_SAME_IMG = 0.0\n", "WEIGHT_DIFF_IMG = 1.0\n", "PARAM_SAME_CATEGORY_WEIGHTING = 0.05\n", "PARAM_SAME_SUBCATEGORY_WEIGHTING = 0.01 \n", "\n", "ZAPPOS50K_INDEX = WORKING_DIR+'/zappos50k-index.csv'\n", "ZAPPOS50K_INDEX_TRAIN = WORKING_DIR+'/zappos50k-index-train.csv'\n", "ZAPPOS50K_INDEX_TEST = WORKING_DIR+'/zappos50k-index-test.csv'\n", "\n", "ZAPPOS50K_TUPLES_INDEX_TRAIN = WORKING_DIR+'/zappos50k-tuples-index-train.csv'\n", "ZAPPOS50K_TUPLES_INDEX_TEST = WORKING_DIR+'/zappos50k-tuples-index-test.csv'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Download Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%bash -s \"$DOWNLOAD_S3URI\"\n", "aws s3 cp $1/ut-zap50k-images-square.zip . --quiet\n", "unzip -nq ut-zap50k-images-square.zip" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate Sample Indices" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "TRAIN_IMG_PATHS = [\"Boots/Knee High/Anne Klein\",\n", " \"Boots/Knee High/Ariat\",\n", " \"Boots/Mid-Calf/UGG\",\n", " \"Sandals/Athletic/Keen Kids\",\n", " \"Sandals/Heel/Annie\",\n", " \"Sandals/Heel/Fly Flot\",\n", " \"Sandals/Heel/Onex\",\n", " \"Shoes/Oxfords/Calvin Klein\",\n", " \"Shoes/Oxfords/Rockport\"]\n", "\n", "TEST_IMG_PATHS = ['Boots/Knee High/Tommy Hilfiger Kids/',\n", " 'Boots/Over the Knee/Calvin Klein Collection/',\n", " 'Shoes/Oxfords/Bass']\n", "\n", "def getImageTensor(img_path, transform):\n", " \n", " image = Image.open(img_path)\n", " image_tensor = transform(image)\n", " \n", " return image_tensor\n", "\n", "def get_categories(img_loc) :\n", "\n", " path, file = os.path.split(img_loc)\n", " path_parts = path.split(os.sep)\n", " category = path_parts[0]\n", " subcategory = path_parts[1]\n", "\n", " return {'category': category, 'sub': subcategory}\n", " \n", "def generate_sample_index(idxFile, img_paths) : \n", " \n", " with open(idxFile, 'w') as csvfile:\n", "\n", " try:\n", "\n", " csvwriter = csv.writer(csvfile)\n", " for paths in img_paths:\n", " \n", " c = get_categories(paths)\n", " cid = int(hashlib.sha256(c['category'].encode('utf-8')).hexdigest(), 16) % 10**9\n", " scid = int(hashlib.sha256(c['sub'].encode('utf-8')).hexdigest(), 16) % 10**9\n", " \n", " files = os.listdir(os.path.join(DATA_DIR,paths))\n", "\n", " row = []\n", " for f in files:\n", " csvwriter.writerow([os.path.join(paths,f),cid,scid])\n", "\n", " except csv.Error as e:\n", " print(e)\n", "\n", " finally:\n", " csvfile.close()" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "ZAPPOS50K_PARTIAL_INDEX = WORKING_DIR+'/zappos50k-partial-index.csv'\n", "ZAPPOS50K_PARTIAL_INDEX_TRAIN = WORKING_DIR+'/zappos50k-partial-index-train.csv'\n", "ZAPPOS50K_PARTIAL_INDEX_TEST = WORKING_DIR+'/zappos50k-partial-index-test.csv'\n", "\n", "generate_sample_index(ZAPPOS50K_PARTIAL_INDEX_TRAIN, TRAIN_IMG_PATHS)\n", "generate_sample_index(ZAPPOS50K_PARTIAL_INDEX_TEST, TEST_IMG_PATHS)\n", "generate_sample_index(ZAPPOS50K_PARTIAL_INDEX, TRAIN_IMG_PATHS+TEST_IMG_PATHS)" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "def generate_tuples_sample_index(idxFile, tuplesIdxFile) :\n", " \n", " indexDF = pd.read_csv(idxFile, header=None, names=['img1','cat','sub_cat'])\n", " tuplesDF = None\n", "\n", " for (idx, row) in indexDF.iterrows() :\n", "\n", " df = (indexDF[idx:]).copy().reset_index(drop=True)\n", " \n", " sim_cat_weight = WEIGHT_DIFF_IMG-(((row['cat'] == df['cat']) * PARAM_SAME_CATEGORY_WEIGHTING) + \\\n", " ((row['sub_cat'] == df['sub_cat']) * PARAM_SAME_SUBCATEGORY_WEIGHTING))\n", " \n", " sim_cat_weight[0] = WEIGHT_SAME_IMG\n", " df['img2'] = pd.Series((row['img1'] for x in range(idx, indexDF.shape[0])))\n", " df['label'] = sim_cat_weight \n", " \n", " df= df.drop(columns= ['cat','sub_cat']) \n", " tuplesDF = df if (tuplesDF is None) else tuplesDF.append(df)\n", " \n", " tuplesDF.to_csv(tuplesIdxFile, sep=',', index=False, header=None)\n", " \n", " return tuplesDF.reset_index(drop=True)" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | img1 | \n", "img2 | \n", "label | \n", "
---|---|---|---|
0 | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "0.00 | \n", "
1 | \n", "Boots/Knee High/Tommy Hilfiger Kids/8047638.3.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "0.94 | \n", "
2 | \n", "Boots/Over the Knee/Calvin Klein Collection/80... | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "0.95 | \n", "
3 | \n", "Shoes/Oxfords/Bass/7563706.226012.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
4 | \n", "Shoes/Oxfords/Bass/7563706.371938.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
5 | \n", "Shoes/Oxfords/Bass/7616146.278640.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
6 | \n", "Shoes/Oxfords/Bass/7616146.372724.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
7 | \n", "Shoes/Oxfords/Bass/7616146.372725.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
8 | \n", "Shoes/Oxfords/Bass/8028830.372729.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
9 | \n", "Shoes/Oxfords/Bass/7616146.244.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
10 | \n", "Shoes/Oxfords/Bass/7956255.10788.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
11 | \n", "Shoes/Oxfords/Bass/8026675.3241.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
12 | \n", "Shoes/Oxfords/Bass/8028830.372728.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
13 | \n", "Shoes/Oxfords/Bass/8098601.36035.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
14 | \n", "Shoes/Oxfords/Bass/7616146.337753.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
15 | \n", "Shoes/Oxfords/Bass/7505665.585.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
16 | \n", "Shoes/Oxfords/Bass/7698965.278640.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
17 | \n", "Shoes/Oxfords/Bass/7976075.9041.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
18 | \n", "Shoes/Oxfords/Bass/7616146.128.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
19 | \n", "Shoes/Oxfords/Bass/7635940.691.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
20 | \n", "Shoes/Oxfords/Bass/7563706.184651.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
21 | \n", "Shoes/Oxfords/Bass/7976075.59601.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
22 | \n", "Shoes/Oxfords/Bass/7505581.4082.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
23 | \n", "Shoes/Oxfords/Bass/8098601.4082.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
24 | \n", "Shoes/Oxfords/Bass/7505665.876.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
25 | \n", "Shoes/Oxfords/Bass/7505665.260224.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
26 | \n", "Shoes/Oxfords/Bass/8026675.4418.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
27 | \n", "Shoes/Oxfords/Bass/7626932.1184.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
28 | \n", "Shoes/Oxfords/Bass/7563706.226011.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
29 | \n", "Shoes/Oxfords/Bass/7505665.401.jpg | \n", "Boots/Knee High/Tommy Hilfiger Kids/8027756.40... | \n", "1.00 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "
1146 | \n", "Shoes/Oxfords/Bass/7670028.21224.jpg | \n", "Shoes/Oxfords/Bass/7563706.310217.jpg | \n", "0.94 | \n", "
1147 | \n", "Shoes/Oxfords/Bass/7505665.691.jpg | \n", "Shoes/Oxfords/Bass/7563706.310217.jpg | \n", "0.94 | \n", "
1148 | \n", "Shoes/Oxfords/Bass/7587764.3.jpg | \n", "Shoes/Oxfords/Bass/7587764.3.jpg | \n", "0.00 | \n", "
1149 | \n", "Shoes/Oxfords/Bass/7976075.86183.jpg | \n", "Shoes/Oxfords/Bass/7587764.3.jpg | \n", "0.94 | \n", "
1150 | \n", "Shoes/Oxfords/Bass/8046243.7492.jpg | \n", "Shoes/Oxfords/Bass/7587764.3.jpg | \n", "0.94 | \n", "
1151 | \n", "Shoes/Oxfords/Bass/7505581.16583.jpg | \n", "Shoes/Oxfords/Bass/7587764.3.jpg | \n", "0.94 | \n", "
1152 | \n", "Shoes/Oxfords/Bass/8125248.43856.jpg | \n", "Shoes/Oxfords/Bass/7587764.3.jpg | \n", "0.94 | \n", "
1153 | \n", "Shoes/Oxfords/Bass/7670028.21224.jpg | \n", "Shoes/Oxfords/Bass/7587764.3.jpg | \n", "0.94 | \n", "
1154 | \n", "Shoes/Oxfords/Bass/7505665.691.jpg | \n", "Shoes/Oxfords/Bass/7587764.3.jpg | \n", "0.94 | \n", "
1155 | \n", "Shoes/Oxfords/Bass/7976075.86183.jpg | \n", "Shoes/Oxfords/Bass/7976075.86183.jpg | \n", "0.00 | \n", "
1156 | \n", "Shoes/Oxfords/Bass/8046243.7492.jpg | \n", "Shoes/Oxfords/Bass/7976075.86183.jpg | \n", "0.94 | \n", "
1157 | \n", "Shoes/Oxfords/Bass/7505581.16583.jpg | \n", "Shoes/Oxfords/Bass/7976075.86183.jpg | \n", "0.94 | \n", "
1158 | \n", "Shoes/Oxfords/Bass/8125248.43856.jpg | \n", "Shoes/Oxfords/Bass/7976075.86183.jpg | \n", "0.94 | \n", "
1159 | \n", "Shoes/Oxfords/Bass/7670028.21224.jpg | \n", "Shoes/Oxfords/Bass/7976075.86183.jpg | \n", "0.94 | \n", "
1160 | \n", "Shoes/Oxfords/Bass/7505665.691.jpg | \n", "Shoes/Oxfords/Bass/7976075.86183.jpg | \n", "0.94 | \n", "
1161 | \n", "Shoes/Oxfords/Bass/8046243.7492.jpg | \n", "Shoes/Oxfords/Bass/8046243.7492.jpg | \n", "0.00 | \n", "
1162 | \n", "Shoes/Oxfords/Bass/7505581.16583.jpg | \n", "Shoes/Oxfords/Bass/8046243.7492.jpg | \n", "0.94 | \n", "
1163 | \n", "Shoes/Oxfords/Bass/8125248.43856.jpg | \n", "Shoes/Oxfords/Bass/8046243.7492.jpg | \n", "0.94 | \n", "
1164 | \n", "Shoes/Oxfords/Bass/7670028.21224.jpg | \n", "Shoes/Oxfords/Bass/8046243.7492.jpg | \n", "0.94 | \n", "
1165 | \n", "Shoes/Oxfords/Bass/7505665.691.jpg | \n", "Shoes/Oxfords/Bass/8046243.7492.jpg | \n", "0.94 | \n", "
1166 | \n", "Shoes/Oxfords/Bass/7505581.16583.jpg | \n", "Shoes/Oxfords/Bass/7505581.16583.jpg | \n", "0.00 | \n", "
1167 | \n", "Shoes/Oxfords/Bass/8125248.43856.jpg | \n", "Shoes/Oxfords/Bass/7505581.16583.jpg | \n", "0.94 | \n", "
1168 | \n", "Shoes/Oxfords/Bass/7670028.21224.jpg | \n", "Shoes/Oxfords/Bass/7505581.16583.jpg | \n", "0.94 | \n", "
1169 | \n", "Shoes/Oxfords/Bass/7505665.691.jpg | \n", "Shoes/Oxfords/Bass/7505581.16583.jpg | \n", "0.94 | \n", "
1170 | \n", "Shoes/Oxfords/Bass/8125248.43856.jpg | \n", "Shoes/Oxfords/Bass/8125248.43856.jpg | \n", "0.00 | \n", "
1171 | \n", "Shoes/Oxfords/Bass/7670028.21224.jpg | \n", "Shoes/Oxfords/Bass/8125248.43856.jpg | \n", "0.94 | \n", "
1172 | \n", "Shoes/Oxfords/Bass/7505665.691.jpg | \n", "Shoes/Oxfords/Bass/8125248.43856.jpg | \n", "0.94 | \n", "
1173 | \n", "Shoes/Oxfords/Bass/7670028.21224.jpg | \n", "Shoes/Oxfords/Bass/7670028.21224.jpg | \n", "0.00 | \n", "
1174 | \n", "Shoes/Oxfords/Bass/7505665.691.jpg | \n", "Shoes/Oxfords/Bass/7670028.21224.jpg | \n", "0.94 | \n", "
1175 | \n", "Shoes/Oxfords/Bass/7505665.691.jpg | \n", "Shoes/Oxfords/Bass/7505665.691.jpg | \n", "0.00 | \n", "
1176 rows × 3 columns
\n", "