{ "cells": [ { "cell_type": "markdown", "id": "b3779c45", "metadata": {}, "source": [ "### Main Steps:\n", "\n", "1. Train and host a Sagemaker model - sagemaker\n", "2. Import the sagemaker endpoint to AFD and set up the detector - AFD\n", "3. Test the detector - GEP/Batch Prediction - AFD\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "24537cae", "metadata": {}, "outputs": [], "source": [ "import sagemaker\n", "\n", "sess = sagemaker.Session()\n", "bucket = sess.default_bucket()\n", "s3_prefix = \"sagemaker/demo-afd-sagemaker-endpoint\"\n", "version_prefix = 'v1'\n", "\n", "# Define IAM role\n", "import boto3\n", "import re\n", "from sagemaker import get_execution_role\n", "\n", "role = get_execution_role()" ] }, { "cell_type": "code", "execution_count": 2, "id": "7293d855", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import io\n", "import os\n", "import sys\n", "import time\n", "import json\n", "from IPython.display import display\n", "from time import strftime, gmtime\n", "from sagemaker.inputs import TrainingInput\n", "from sagemaker.serializers import CSVSerializer" ] }, { "cell_type": "markdown", "id": "481ec66b", "metadata": {}, "source": [ "### Step 1: Train and Host a Sagemaker model\n", "\n", "Code Reference: https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_applying_machine_learning/xgboost_customer_churn/xgboost_customer_churn.ipynb" ] }, { "cell_type": "code", "execution_count": 3, "id": "06ee4b95", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "legit 18996\n", "fraud 1004\n", "Name: EVENT_LABEL, dtype: int64" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = pd.read_csv(\"fraud_data_20K_sample.csv\")\n", "data['EVENT_LABEL'].value_counts()" ] }, { "cell_type": "code", "execution_count": 4, "id": "e0e81c2d", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EVENT_LABELEVENT_TIMESTAMPip_addressemail_addressorder_amtprev_amt
0legit10/8/2019 20:4446.41.252.160fake_acostasusan@example.org153.7158.30
1legit5/23/2020 19:44152.58.247.12fake_christopheryoung@example.com2.5711.63
2legit4/24/2020 18:2612.252.206.222fake_jeffrey09@example.org30.9652.41
3legit4/22/2020 19:07170.81.164.240fake_ncastro@example.org63.8734.21
4legit12/31/2019 17:08165.182.68.217fake_charles99@example.org70.3666.58
\n", "
" ], "text/plain": [ " EVENT_LABEL EVENT_TIMESTAMP ip_address \\\n", "0 legit 10/8/2019 20:44 46.41.252.160 \n", "1 legit 5/23/2020 19:44 152.58.247.12 \n", "2 legit 4/24/2020 18:26 12.252.206.222 \n", "3 legit 4/22/2020 19:07 170.81.164.240 \n", "4 legit 12/31/2019 17:08 165.182.68.217 \n", "\n", " email_address order_amt prev_amt \n", "0 fake_acostasusan@example.org 153.71 58.30 \n", "1 fake_christopheryoung@example.com 2.57 11.63 \n", "2 fake_jeffrey09@example.org 30.96 52.41 \n", "3 fake_ncastro@example.org 63.87 34.21 \n", "4 fake_charles99@example.org 70.36 66.58 " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.head()" ] }, { "cell_type": "code", "execution_count": 5, "id": "ac9292c3", "metadata": {}, "outputs": [], "source": [ "# prepare data for sagemaker model training\n", "model_data = pd.get_dummies(data[['order_amt', 'prev_amt', 'EVENT_LABEL']])\n", "model_data = pd.concat([model_data[\"EVENT_LABEL_fraud\"], model_data.drop([\"EVENT_LABEL_fraud\", \"EVENT_LABEL_legit\"], axis=1)], axis=1)" ] }, { "cell_type": "code", "execution_count": 6, "id": "17c0b9aa", "metadata": {}, "outputs": [], "source": [ "# split to train valid and test data\n", "train_data, validation_data, test_data = np.split(\n", " model_data.sample(frac=1, random_state=1729),\n", " [int(0.7 * len(model_data)), int(0.9 * len(model_data))],\n", ")\n", "train_data.to_csv(\"train.csv\", header=False, index=False)\n", "validation_data.to_csv(\"validation.csv\", header=False, index=False)" ] }, { "cell_type": "code", "execution_count": 7, "id": "5cda3f0d", "metadata": {}, "outputs": [], "source": [ "# upload to s3\n", "boto3.Session().resource(\"s3\").Bucket(bucket).Object(\n", " os.path.join(s3_prefix, \"train/train.csv\")\n", ").upload_file(\"train.csv\")\n", "boto3.Session().resource(\"s3\").Bucket(bucket).Object(\n", " os.path.join(s3_prefix, \"validation/validation.csv\")\n", ").upload_file(\"validation.csv\")" ] }, { "cell_type": "code", "execution_count": 8, "id": "1f0f161f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.5-1'" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# specify the locations of the XGBoost algorithm containers - \n", "container = sagemaker.image_uris.retrieve(\"xgboost\", sess.boto_region_name, \"1.5-1\")\n", "display(container)" ] }, { "cell_type": "code", "execution_count": 9, "id": "0f7a2238", "metadata": {}, "outputs": [], "source": [ "s3_input_train = TrainingInput(\n", " s3_data=\"s3://{}/{}/train\".format(bucket, s3_prefix), content_type=\"csv\"\n", ")\n", "s3_input_validation = TrainingInput(\n", " s3_data=\"s3://{}/{}/validation/\".format(bucket, s3_prefix), content_type=\"csv\"\n", ")" ] }, { "cell_type": "code", "execution_count": 10, "id": "887481bb", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2023-03-23 16:27:01 Starting - Starting the training job...ProfilerReport-1679588820: InProgress\n", "...\n", "2023-03-23 16:28:00 Starting - Preparing the instances for training.........\n", "2023-03-23 16:29:20 Downloading - Downloading input data...\n", "2023-03-23 16:30:00 Training - Downloading the training image...\n", "2023-03-23 16:30:26 Training - Training image download completed. Training in progress...\u001b[34m[2023-03-23 16:30:36.896 ip-10-0-252-124.us-west-2.compute.internal:7 INFO utils.py:28] RULE_JOB_STOP_SIGNAL_FILENAME: None\u001b[0m\n", "\u001b[34m[2023-03-23 16:30:36.977 ip-10-0-252-124.us-west-2.compute.internal:7 INFO profiler_config_parser.py:111] User has disabled profiler.\u001b[0m\n", "\u001b[34m[2023-03-23:16:30:37:INFO] Imported framework sagemaker_xgboost_container.training\u001b[0m\n", "\u001b[34m[2023-03-23:16:30:37:INFO] Failed to parse hyperparameter objective value binary:logistic to Json.\u001b[0m\n", "\u001b[34mReturning the value itself\u001b[0m\n", "\u001b[34m[2023-03-23:16:30:37:INFO] No GPUs detected (normal if no gpus installed)\u001b[0m\n", "\u001b[34m[2023-03-23:16:30:37:INFO] Running XGBoost Sagemaker in algorithm mode\u001b[0m\n", "\u001b[34m[2023-03-23:16:30:37:INFO] Determined 0 GPU(s) available on the instance.\u001b[0m\n", "\u001b[34m[2023-03-23:16:30:37:INFO] Determined delimiter of CSV input is ','\u001b[0m\n", "\u001b[34m[2023-03-23:16:30:37:INFO] Determined delimiter of CSV input is ','\u001b[0m\n", "\u001b[34m[2023-03-23:16:30:37:INFO] files path: /opt/ml/input/data/train\u001b[0m\n", "\u001b[34m[2023-03-23:16:30:37:INFO] Determined delimiter of CSV input is ','\u001b[0m\n", "\u001b[34m[2023-03-23:16:30:37:INFO] files path: /opt/ml/input/data/validation\u001b[0m\n", "\u001b[34m[2023-03-23:16:30:37:INFO] Determined delimiter of CSV input is ','\u001b[0m\n", "\u001b[34m[2023-03-23:16:30:37:INFO] Single node training.\u001b[0m\n", "\u001b[34m[2023-03-23:16:30:37:INFO] Train matrix has 14000 rows and 2 columns\u001b[0m\n", "\u001b[34m[2023-03-23:16:30:37:INFO] Validation matrix has 4000 rows\u001b[0m\n", "\u001b[34m[2023-03-23 16:30:37.445 ip-10-0-252-124.us-west-2.compute.internal:7 INFO json_config.py:92] Creating hook from json_config at /opt/ml/input/config/debughookconfig.json.\u001b[0m\n", "\u001b[34m[2023-03-23 16:30:37.445 ip-10-0-252-124.us-west-2.compute.internal:7 INFO hook.py:206] tensorboard_dir has not been set for the hook. SMDebug will not be exporting tensorboard summaries.\u001b[0m\n", "\u001b[34m[2023-03-23 16:30:37.446 ip-10-0-252-124.us-west-2.compute.internal:7 INFO hook.py:259] Saving to /opt/ml/output/tensors\u001b[0m\n", "\u001b[34m[2023-03-23 16:30:37.446 ip-10-0-252-124.us-west-2.compute.internal:7 INFO state_store.py:77] The checkpoint config file /opt/ml/input/config/checkpointconfig.json does not exist.\u001b[0m\n", "\u001b[34m[2023-03-23:16:30:37:INFO] Debug hook created from config\u001b[0m\n", "\u001b[34m[0]#011train-logloss:0.54690#011validation-logloss:0.54770\u001b[0m\n", "\u001b[34m[2023-03-23 16:30:37.471 ip-10-0-252-124.us-west-2.compute.internal:7 INFO hook.py:427] Monitoring the collections: metrics\u001b[0m\n", "\u001b[34m[2023-03-23 16:30:37.474 ip-10-0-252-124.us-west-2.compute.internal:7 INFO hook.py:491] Hook is writing from the hook with pid: 7\u001b[0m\n", "\u001b[34m[1]#011train-logloss:0.44941#011validation-logloss:0.45117\u001b[0m\n", "\u001b[34m[2]#011train-logloss:0.38077#011validation-logloss:0.38320\u001b[0m\n", "\u001b[34m[3]#011train-logloss:0.33148#011validation-logloss:0.33463\u001b[0m\n", "\u001b[34m[4]#011train-logloss:0.29524#011validation-logloss:0.29907\u001b[0m\n", "\u001b[34m[5]#011train-logloss:0.26829#011validation-logloss:0.27269\u001b[0m\n", "\u001b[34m[6]#011train-logloss:0.24807#011validation-logloss:0.25303\u001b[0m\n", "\u001b[34m[7]#011train-logloss:0.23308#011validation-logloss:0.23870\u001b[0m\n", "\u001b[34m[8]#011train-logloss:0.22192#011validation-logloss:0.22808\u001b[0m\n", "\u001b[34m[9]#011train-logloss:0.21362#011validation-logloss:0.22051\u001b[0m\n", "\u001b[34m[10]#011train-logloss:0.20752#011validation-logloss:0.21513\u001b[0m\n", "\u001b[34m[11]#011train-logloss:0.20303#011validation-logloss:0.21114\u001b[0m\n", "\u001b[34m[12]#011train-logloss:0.19992#011validation-logloss:0.20827\u001b[0m\n", "\u001b[34m[13]#011train-logloss:0.19732#011validation-logloss:0.20650\u001b[0m\n", "\u001b[34m[14]#011train-logloss:0.19554#011validation-logloss:0.20526\u001b[0m\n", "\u001b[34m[15]#011train-logloss:0.19377#011validation-logloss:0.20488\u001b[0m\n", "\u001b[34m[16]#011train-logloss:0.19267#011validation-logloss:0.20420\u001b[0m\n", "\u001b[34m[17]#011train-logloss:0.19186#011validation-logloss:0.20393\u001b[0m\n", "\u001b[34m[18]#011train-logloss:0.19143#011validation-logloss:0.20379\u001b[0m\n", "\u001b[34m[19]#011train-logloss:0.19076#011validation-logloss:0.20357\u001b[0m\n", "\u001b[34m[20]#011train-logloss:0.19033#011validation-logloss:0.20346\u001b[0m\n", "\u001b[34m[21]#011train-logloss:0.18995#011validation-logloss:0.20350\u001b[0m\n", "\u001b[34m[22]#011train-logloss:0.18977#011validation-logloss:0.20361\u001b[0m\n", "\u001b[34m[23]#011train-logloss:0.18925#011validation-logloss:0.20404\u001b[0m\n", "\u001b[34m[24]#011train-logloss:0.18906#011validation-logloss:0.20441\u001b[0m\n", "\u001b[34m[25]#011train-logloss:0.18904#011validation-logloss:0.20442\u001b[0m\n", "\u001b[34m[26]#011train-logloss:0.18874#011validation-logloss:0.20472\u001b[0m\n", "\u001b[34m[27]#011train-logloss:0.18865#011validation-logloss:0.20484\u001b[0m\n", "\u001b[34m[28]#011train-logloss:0.18838#011validation-logloss:0.20501\u001b[0m\n", "\u001b[34m[29]#011train-logloss:0.18830#011validation-logloss:0.20511\u001b[0m\n", "\u001b[34m[30]#011train-logloss:0.18797#011validation-logloss:0.20530\u001b[0m\n", "\u001b[34m[31]#011train-logloss:0.18783#011validation-logloss:0.20555\u001b[0m\n", "\u001b[34m[32]#011train-logloss:0.18768#011validation-logloss:0.20561\u001b[0m\n", "\u001b[34m[33]#011train-logloss:0.18749#011validation-logloss:0.20567\u001b[0m\n", "\u001b[34m[34]#011train-logloss:0.18722#011validation-logloss:0.20578\u001b[0m\n", "\u001b[34m[35]#011train-logloss:0.18707#011validation-logloss:0.20598\u001b[0m\n", "\u001b[34m[36]#011train-logloss:0.18707#011validation-logloss:0.20598\u001b[0m\n", "\u001b[34m[37]#011train-logloss:0.18707#011validation-logloss:0.20601\u001b[0m\n", "\u001b[34m[38]#011train-logloss:0.18707#011validation-logloss:0.20601\u001b[0m\n", "\u001b[34m[39]#011train-logloss:0.18707#011validation-logloss:0.20600\u001b[0m\n", "\u001b[34m[40]#011train-logloss:0.18707#011validation-logloss:0.20601\u001b[0m\n", "\u001b[34m[41]#011train-logloss:0.18694#011validation-logloss:0.20598\u001b[0m\n", "\u001b[34m[42]#011train-logloss:0.18685#011validation-logloss:0.20609\u001b[0m\n", "\u001b[34m[43]#011train-logloss:0.18676#011validation-logloss:0.20610\u001b[0m\n", "\u001b[34m[44]#011train-logloss:0.18661#011validation-logloss:0.20611\u001b[0m\n", "\u001b[34m[45]#011train-logloss:0.18654#011validation-logloss:0.20615\u001b[0m\n", "\u001b[34m[46]#011train-logloss:0.18655#011validation-logloss:0.20613\u001b[0m\n", "\u001b[34m[47]#011train-logloss:0.18638#011validation-logloss:0.20648\u001b[0m\n", "\u001b[34m[48]#011train-logloss:0.18625#011validation-logloss:0.20667\u001b[0m\n", "\u001b[34m[49]#011train-logloss:0.18625#011validation-logloss:0.20667\u001b[0m\n", "\u001b[34m[50]#011train-logloss:0.18602#011validation-logloss:0.20652\u001b[0m\n", "\u001b[34m[51]#011train-logloss:0.18596#011validation-logloss:0.20665\u001b[0m\n", "\u001b[34m[52]#011train-logloss:0.18580#011validation-logloss:0.20665\u001b[0m\n", "\u001b[34m[53]#011train-logloss:0.18580#011validation-logloss:0.20663\u001b[0m\n", "\u001b[34m[54]#011train-logloss:0.18580#011validation-logloss:0.20663\u001b[0m\n", "\u001b[34m[55]#011train-logloss:0.18543#011validation-logloss:0.20673\u001b[0m\n", "\u001b[34m[56]#011train-logloss:0.18530#011validation-logloss:0.20685\u001b[0m\n", "\u001b[34m[57]#011train-logloss:0.18518#011validation-logloss:0.20693\u001b[0m\n", "\u001b[34m[58]#011train-logloss:0.18502#011validation-logloss:0.20696\u001b[0m\n", "\u001b[34m[59]#011train-logloss:0.18483#011validation-logloss:0.20715\u001b[0m\n", "\u001b[34m[60]#011train-logloss:0.18471#011validation-logloss:0.20702\u001b[0m\n", "\u001b[34m[61]#011train-logloss:0.18471#011validation-logloss:0.20703\u001b[0m\n", "\u001b[34m[62]#011train-logloss:0.18470#011validation-logloss:0.20708\u001b[0m\n", "\u001b[34m[63]#011train-logloss:0.18458#011validation-logloss:0.20708\u001b[0m\n", "\u001b[34m[64]#011train-logloss:0.18448#011validation-logloss:0.20685\u001b[0m\n", "\u001b[34m[65]#011train-logloss:0.18441#011validation-logloss:0.20684\u001b[0m\n", "\u001b[34m[66]#011train-logloss:0.18441#011validation-logloss:0.20688\u001b[0m\n", "\u001b[34m[67]#011train-logloss:0.18433#011validation-logloss:0.20700\u001b[0m\n", "\u001b[34m[68]#011train-logloss:0.18433#011validation-logloss:0.20701\u001b[0m\n", "\u001b[34m[69]#011train-logloss:0.18425#011validation-logloss:0.20708\u001b[0m\n", "\u001b[34m[70]#011train-logloss:0.18425#011validation-logloss:0.20705\u001b[0m\n", "\u001b[34m[71]#011train-logloss:0.18419#011validation-logloss:0.20719\u001b[0m\n", "\u001b[34m[72]#011train-logloss:0.18412#011validation-logloss:0.20721\u001b[0m\n", "\u001b[34m[73]#011train-logloss:0.18412#011validation-logloss:0.20717\u001b[0m\n", "\u001b[34m[74]#011train-logloss:0.18402#011validation-logloss:0.20728\u001b[0m\n", "\u001b[34m[75]#011train-logloss:0.18402#011validation-logloss:0.20727\u001b[0m\n", "\u001b[34m[76]#011train-logloss:0.18402#011validation-logloss:0.20728\u001b[0m\n", "\u001b[34m[77]#011train-logloss:0.18391#011validation-logloss:0.20730\u001b[0m\n", "\u001b[34m[78]#011train-logloss:0.18392#011validation-logloss:0.20728\u001b[0m\n", "\u001b[34m[79]#011train-logloss:0.18384#011validation-logloss:0.20729\u001b[0m\n", "\u001b[34m[80]#011train-logloss:0.18384#011validation-logloss:0.20728\u001b[0m\n", "\u001b[34m[81]#011train-logloss:0.18384#011validation-logloss:0.20729\u001b[0m\n", "\u001b[34m[82]#011train-logloss:0.18384#011validation-logloss:0.20730\u001b[0m\n", "\u001b[34m[83]#011train-logloss:0.18384#011validation-logloss:0.20732\u001b[0m\n", "\u001b[34m[84]#011train-logloss:0.18379#011validation-logloss:0.20731\u001b[0m\n", "\u001b[34m[85]#011train-logloss:0.18379#011validation-logloss:0.20729\u001b[0m\n", "\u001b[34m[86]#011train-logloss:0.18379#011validation-logloss:0.20731\u001b[0m\n", "\u001b[34m[87]#011train-logloss:0.18379#011validation-logloss:0.20731\u001b[0m\n", "\u001b[34m[88]#011train-logloss:0.18379#011validation-logloss:0.20734\u001b[0m\n", "\u001b[34m[89]#011train-logloss:0.18372#011validation-logloss:0.20759\u001b[0m\n", "\u001b[34m[90]#011train-logloss:0.18364#011validation-logloss:0.20743\u001b[0m\n", "\u001b[34m[91]#011train-logloss:0.18358#011validation-logloss:0.20747\u001b[0m\n", "\u001b[34m[92]#011train-logloss:0.18356#011validation-logloss:0.20751\u001b[0m\n", "\u001b[34m[93]#011train-logloss:0.18356#011validation-logloss:0.20748\u001b[0m\n", "\u001b[34m[94]#011train-logloss:0.18346#011validation-logloss:0.20749\u001b[0m\n", "\u001b[34m[95]#011train-logloss:0.18323#011validation-logloss:0.20743\u001b[0m\n", "\u001b[34m[96]#011train-logloss:0.18323#011validation-logloss:0.20742\u001b[0m\n", "\u001b[34m[97]#011train-logloss:0.18323#011validation-logloss:0.20742\u001b[0m\n", "\u001b[34m[98]#011train-logloss:0.18323#011validation-logloss:0.20742\u001b[0m\n", "\u001b[34m[99]#011train-logloss:0.18323#011validation-logloss:0.20746\u001b[0m\n", "\n", "2023-03-23 16:31:01 Uploading - Uploading generated training model\n", "2023-03-23 16:31:01 Completed - Training job completed\n", "Training seconds: 107\n", "Billable seconds: 107\n" ] } ], "source": [ "sess = sagemaker.Session()\n", "\n", "xgb = sagemaker.estimator.Estimator(\n", " container,\n", " role,\n", " instance_count=1,\n", " instance_type=\"ml.m4.xlarge\",\n", " output_path=\"s3://{}/{}/output\".format(bucket, s3_prefix),\n", " sagemaker_session=sess,\n", ")\n", "xgb.set_hyperparameters(\n", " max_depth=5,\n", " eta=0.2,\n", " gamma=4,\n", " min_child_weight=6,\n", " subsample=0.8,\n", " verbosity=0,\n", " objective=\"binary:logistic\",\n", " num_round=100,\n", ")\n", "\n", "xgb.fit({\"train\": s3_input_train, \"validation\": s3_input_validation})" ] }, { "cell_type": "code", "execution_count": 11, "id": "b9ef4409", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-------!" ] } ], "source": [ "# deploy sagemaker endpoint\n", "xgb_predictor = xgb.deploy(\n", " initial_instance_count=1, instance_type=\"ml.m4.xlarge\", serializer=CSVSerializer(),\n", " endpoint_name = f\"sagemaker-xgb-endpoint-{version_prefix}\"\n", ")" ] }, { "cell_type": "code", "execution_count": 12, "id": "80ee7ec2", "metadata": {}, "outputs": [], "source": [ "def predict(data, rows=500):\n", " split_array = np.array_split(data, int(data.shape[0] / float(rows) + 1))\n", " predictions = \"\"\n", " for array in split_array:\n", " predictions = \"\".join([predictions, xgb_predictor.predict(array).decode(\"utf-8\")])\n", "\n", " return predictions.split(\"\\n\")[:-1]\n", "\n", "\n", "predictions = predict(test_data.to_numpy()[:, 1:])" ] }, { "cell_type": "code", "execution_count": 13, "id": "52731445", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2000 [0.0716714 0.03765393 0.02415792 ... 0.05634578 0.06239426 0.03940216]\n" ] } ], "source": [ "predictions = np.array([float(num) for num in predictions])\n", "print(len(predictions), predictions)" ] }, { "cell_type": "markdown", "id": "41b2d1dd", "metadata": {}, "source": [ "### Step 2: Import the SageMaker model to AFD and set up the detector" ] }, { "cell_type": "code", "execution_count": 14, "id": "be654be3", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/boto3/compat.py:88: PythonDeprecationWarning: Boto3 will no longer support Python 3.6 starting May 30, 2022. To continue receiving service updates, bug fixes, and security updates please upgrade to Python 3.7 or later. More information can be found here: https://aws.amazon.com/blogs/developer/python-support-policy-updates-for-aws-sdks-and-tools/\n", " warnings.warn(warning, PythonDeprecationWarning)\n" ] } ], "source": [ "fraudDetector = boto3.client('frauddetector')" ] }, { "cell_type": "code", "execution_count": 15, "id": "54455e52", "metadata": {}, "outputs": [], "source": [ "### create afd variables, entity and event type\n", "try:\n", " resp = fraudDetector.get_variables(name = 'order_amt')\n", "except:\n", " resp = fraudDetector.create_variable(name = 'order_amt', dataType = 'FLOAT', dataSource ='EVENT', defaultValue = '0.0')\n", "\n", "try:\n", " resp = fraudDetector.get_variables(name = 'prev_amt')\n", "except: \n", " resp = fraudDetector.create_variable(name = 'prev_amt', dataType = 'FLOAT', dataSource ='EVENT', defaultValue = '0.0')\n", "\n", "response = fraudDetector.put_entity_type(name = f'sagemaker-xgb-entity-{version_prefix}')\n", "\n", "response = fraudDetector.put_event_type (\n", " name = f'sagemaker-xgb-transaction-{version_prefix}',\n", " eventVariables = ['order_amt', 'prev_amt'],\n", " entityTypes = [f'sagemaker-xgb-entity-{version_prefix}'])" ] }, { "cell_type": "code", "execution_count": 16, "id": "bc438d29", "metadata": {}, "outputs": [], "source": [ "### create external model score variable\n", "resp = fraudDetector.create_variable(name = f'sagemaker_xgb_score_{version_prefix}', dataType = 'FLOAT', dataSource ='EXTERNAL_MODEL_SCORE', defaultValue = '0.0')\n" ] }, { "cell_type": "code", "execution_count": 17, "id": "ae636488", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'ResponseMetadata': {'RequestId': '45e3a103-ab54-4ce9-9656-8f34665d5bba',\n", " 'HTTPStatusCode': 200,\n", " 'HTTPHeaders': {'date': 'Thu, 23 Mar 2023 16:35:21 GMT',\n", " 'content-type': 'application/x-amz-json-1.1',\n", " 'content-length': '2',\n", " 'connection': 'keep-alive',\n", " 'x-amzn-requestid': '45e3a103-ab54-4ce9-9656-8f34665d5bba'},\n", " 'RetryAttempts': 0}}" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "### put external model\n", "# https://docs.aws.amazon.com/frauddetector/latest/ug/import-an-amazon-sagemaker-model.html\n", "fraudDetector.put_external_model(\n", " modelSource = 'SAGEMAKER',\n", " modelEndpoint = f'sagemaker-xgb-endpoint-{version_prefix}',\n", " invokeModelEndpointRoleArn = role, #'your_SagemakerExecutionRole_arn',\n", " inputConfiguration = {\n", " 'useEventVariables' : True,\n", " 'eventTypeName' : f'sagemaker-xgb-transaction-{version_prefix}',\n", " 'format' : 'TEXT_CSV',\n", " 'csvInputTemplate' : '{{order_amt}}, {{prev_amt}}' # add afd enrichment, how the config works\n", " },\n", " outputConfiguration = {\n", " 'format' : 'TEXT_CSV',\n", " 'csvIndexToVariableMap' : {\n", " '0' : f'sagemaker_xgb_score_{version_prefix}'\n", " }\n", " },\n", " modelEndpointStatus = 'ASSOCIATED'\n", ")" ] }, { "cell_type": "code", "execution_count": 18, "id": "57949ae1", "metadata": {}, "outputs": [], "source": [ "### create a detector\n", "DETECTOR_NAME = f\"afd-with-sagemaker-model-{version_prefix}\"\n", "response = fraudDetector.put_detector(\n", " detectorId = DETECTOR_NAME, \n", " eventTypeName = f'sagemaker-xgb-transaction-{version_prefix}' )" ] }, { "cell_type": "code", "execution_count": 19, "id": "fa063f79", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "creating outcome variable: fraud \n", "creating outcome variable: investigate \n", "creating outcome variable: approve \n", "creating rule: rules_0_sagemaker_xgb_score_v1: IF $sagemaker_xgb_score_v1 > 0.9 THEN fraud\n", "creating rule: rules_1_sagemaker_xgb_score_v1: IF $sagemaker_xgb_score_v1 > 0.5 THEN investigate\n", "creating rule: rules_2_sagemaker_xgb_score_v1: IF $sagemaker_xgb_score_v1 <= 0.5 THEN approve\n" ] } ], "source": [ "### Create rules\n", "\n", "def create_outcomes(outcomes):\n", " \"\"\" \n", " Create Fraud Detector Outcomes \n", " \"\"\" \n", " for outcome in outcomes:\n", " print(\"creating outcome variable: {0} \".format(outcome))\n", " response = fraudDetector.put_outcome(name = outcome, description = outcome)\n", "\n", "def create_rules(score_cuts, outcomes, MODEL_SCORE_NAME, DETECTOR_NAME):\n", " \"\"\"\n", " Creating rules \n", " \n", " Arguments:\n", " score_cuts - list of score cuts to create rules\n", " outcomes - list of outcomes associated with the rules\n", " \n", " Returns:\n", " a rule list to used when create detector\n", " \"\"\"\n", " \n", " if len(score_cuts)+1 != len(outcomes):\n", " logging.error('Your socre cuts and outcomes are not matched.')\n", " \n", " rule_list = []\n", " for i in range(len(outcomes)):\n", " # rule expression\n", " if i < (len(outcomes)-1):\n", " rule = \"${0} > {1}\".format(MODEL_SCORE_NAME,score_cuts[i])\n", " else:\n", " rule = \"${0} <= {1}\".format(MODEL_SCORE_NAME,score_cuts[i-1])\n", " \n", " # append to rule_list (used when create detector)\n", " rule_id = \"rules_{0}_{1}\".format(i, MODEL_SCORE_NAME)\n", " \n", " rule_list.append({\n", " \"ruleId\": rule_id, \n", " \"ruleVersion\" : '1',\n", " \"detectorId\" : DETECTOR_NAME\n", " })\n", " \n", " # create rules\n", " print(\"creating rule: {0}: IF {1} THEN {2}\".format(rule_id, rule, outcomes[i]))\n", " try:\n", " response = fraudDetector.create_rule(\n", " ruleId = rule_id,\n", " detectorId = DETECTOR_NAME,\n", " expression = rule,\n", " language = 'DETECTORPL',\n", " outcomes = [outcomes[i]]\n", " )\n", " except:\n", " print(\"this rule already exists in this detector\")\n", " \n", " return rule_list\n", "\n", "score_cuts = [0.9, 0.5] \n", "outcomes = ['fraud', 'investigate', 'approve'] \n", "create_outcomes(outcomes)\n", "rule_list = create_rules(score_cuts, outcomes, f'sagemaker_xgb_score_{version_prefix}', DETECTOR_NAME)" ] }, { "cell_type": "code", "execution_count": 20, "id": "d954d86a", "metadata": {}, "outputs": [], "source": [ "# -- create detector version --\n", "response =fraudDetector.create_detector_version(\n", " detectorId = DETECTOR_NAME ,\n", " rules = rule_list,\n", " externalModelEndpoints = [f'sagemaker-xgb-endpoint-{version_prefix}'],\n", " ruleExecutionMode = 'FIRST_MATCHED'\n", ")" ] }, { "cell_type": "code", "execution_count": 21, "id": "01f34935", "metadata": {}, "outputs": [], "source": [ "response = fraudDetector.update_detector_version_status(\n", " detectorId = DETECTOR_NAME,\n", " detectorVersionId = '1',\n", " status = 'ACTIVE'\n", ")" ] }, { "cell_type": "code", "execution_count": 22, "id": "9644444a", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EVENT_LABEL_fraudorder_amtprev_amt
14590156.00135.04
13935041.5889.56
6843021.40404.08
17103035.17135.47
2286091.72122.84
\n", "
" ], "text/plain": [ " EVENT_LABEL_fraud order_amt prev_amt\n", "1459 0 156.00 135.04\n", "13935 0 41.58 89.56\n", "6843 0 21.40 404.08\n", "17103 0 35.17 135.47\n", "2286 0 91.72 122.84" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_data.head()" ] }, { "cell_type": "markdown", "id": "3e60450b", "metadata": {}, "source": [ "### Step 3: Test the detector using boto3 SDK" ] }, { "cell_type": "code", "execution_count": 23, "id": "a60b2e46", "metadata": {}, "outputs": [], "source": [ "pred = fraudDetector.get_event_prediction(\n", " detectorId = f\"afd-with-sagemaker-model-{version_prefix}\",\n", " detectorVersionId = '1',\n", " eventId = '1459',\n", " eventTypeName = f'sagemaker-xgb-transaction-{version_prefix}',\n", " eventTimestamp = '2019-10-05T22:50:48Z',\n", " entities = [{\n", " 'entityType': f'sagemaker-xgb-entity-{version_prefix}', \n", " 'entityId':\"UNKNOWN\"\n", " }],\n", " eventVariables = {\n", " 'order_amt': '156',\n", " 'prev_amt':'135.04'\n", " }) " ] }, { "cell_type": "code", "execution_count": 24, "id": "41c313ad", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'modelScores': [],\n", " 'ruleResults': [{'ruleId': 'rules_2_sagemaker_xgb_score_v1',\n", " 'outcomes': ['approve']}],\n", " 'externalModelOutputs': [{'externalModel': {'modelEndpoint': 'sagemaker-xgb-endpoint-v1',\n", " 'modelSource': 'SAGEMAKER'},\n", " 'outputs': {'sagemaker_xgb_score_v1': '0.07167139649391174\\n'}}],\n", " 'ResponseMetadata': {'RequestId': '34c934b6-0030-4159-a7ff-06b7b3e5fa4d',\n", " 'HTTPStatusCode': 200,\n", " 'HTTPHeaders': {'date': 'Thu, 23 Mar 2023 16:35:28 GMT',\n", " 'content-type': 'application/x-amz-json-1.1',\n", " 'content-length': '277',\n", " 'connection': 'keep-alive',\n", " 'x-amzn-requestid': '34c934b6-0030-4159-a7ff-06b7b3e5fa4d'},\n", " 'RetryAttempts': 0}}" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pred" ] }, { "cell_type": "code", "execution_count": null, "id": "c5e4d5ad", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "68a6c60a", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "conda_python3", "language": "python", "name": "conda_python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" } }, "nbformat": 4, "nbformat_minor": 5 }