{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!mkdir src" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!echo sagemaker-experiments==0.1.31 > ./src/requirements.txt" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%writefile ./src/xgboost_customer_churn.py\n", "import pandas as pd\n", "import argparse\n", "import json\n", "import os\n", "import pickle\n", "import random\n", "import tempfile\n", "import urllib.request\n", "\n", "import xgboost\n", "from smdebug import SaveConfig\n", "from smdebug.xgboost import Hook\n", "from smexperiments.tracker import Tracker\n", "\n", "\n", "def parse_args():\n", "\n", " parser = argparse.ArgumentParser()\n", "\n", " parser.add_argument(\"--max_depth\", type=int, default=5)\n", " parser.add_argument(\"--eta\", type=float, default=0.2)\n", " parser.add_argument(\"--gamma\", type=int, default=4)\n", " parser.add_argument(\"--min_child_weight\", type=int, default=6)\n", " parser.add_argument(\"--subsample\", type=float, default=0.8)\n", " parser.add_argument(\"--verbosity\", type=int, default=0)\n", " parser.add_argument(\"--objective\", type=str, default=\"binary:logistic\")\n", " parser.add_argument(\"--num_round\", type=int, default=50)\n", " parser.add_argument(\"--smdebug_path\", type=str, default=None)\n", " parser.add_argument(\"--smdebug_frequency\", type=int, default=1)\n", " parser.add_argument(\"--smdebug_collections\", type=str, default='metrics')\n", " parser.add_argument(\"--output_uri\", type=str, default=\"/opt/ml/output/tensors\",\n", " help=\"S3 URI of the bucket where tensor data will be stored.\")\n", "\n", " parser.add_argument('--train', type=str, default=os.environ.get('SM_CHANNEL_TRAIN'))\n", " parser.add_argument('--validation', type=str, default=os.environ.get('SM_CHANNEL_VALIDATION'))\n", " parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])\n", " \n", " args = parser.parse_args()\n", "\n", " return args\n", "\n", "\n", "def create_smdebug_hook(out_dir, train_data=None, validation_data=None, frequency=1, collections=None,):\n", "\n", " save_config = SaveConfig(save_interval=frequency)\n", " hook = Hook(\n", " out_dir=out_dir,\n", " train_data=train_data,\n", " validation_data=validation_data,\n", " save_config=save_config,\n", " include_collections=collections,\n", " )\n", "\n", " return hook\n", "\n", "\n", "def main(tracker):\n", " \n", " args = parse_args()\n", "\n", " train, validation = args.train, args.validation\n", " parse_csv = \"?format=csv&label_column=0\"\n", " dtrain = xgboost.DMatrix(train+parse_csv)\n", " dval = xgboost.DMatrix(validation+parse_csv)\n", "\n", " watchlist = [(dtrain, \"train\"), (dval, \"validation\")]\n", "\n", " params = {\n", " \"max_depth\": args.max_depth,\n", " \"eta\": args.eta,\n", " \"gamma\": args.gamma,\n", " \"min_child_weight\": args.min_child_weight,\n", " \"subsample\": args.subsample,\n", " \"verbosity\": args.verbosity,\n", " \"objective\": args.objective}\n", "\n", " # The output_uri is a the URI for the s3 bucket where the metrics will be\n", " # saved.\n", " output_uri = (\n", " args.smdebug_path\n", " if args.smdebug_path is not None\n", " else args.output_uri\n", " )\n", "\n", " collections = (\n", " args.smdebug_collections.split(',')\n", " if args.smdebug_collections is not None\n", " else None\n", " )\n", "\n", " hook = create_smdebug_hook(\n", " out_dir=output_uri,\n", " frequency=args.smdebug_frequency,\n", " collections=collections,\n", " train_data=dtrain,\n", " validation_data=dval,\n", " )\n", "\n", " bst = xgboost.train(\n", " params=params,\n", " dtrain=dtrain,\n", " evals=watchlist,\n", " num_boost_round=args.num_round,\n", " callbacks=[hook])\n", " \n", " if not os.path.exists(args.model_dir):\n", " os.makedirs(args.model_dir)\n", "\n", " model_location = os.path.join(args.model_dir, 'xgboost-model')\n", " pickle.dump(bst, open(model_location, 'wb'))\n", " \n", " print(\"Performing predictions against test data.\")\n", " predictions_probs = bst.predict(dval)\n", " predictions = predictions_probs.round()\n", " \n", " print(\"Creating and logging plots to Studio\")\n", " val_files = [ os.path.join(validation, file) for file in os.listdir(validation) ]\n", " if len(val_files) == 0:\n", " raise ValueError(('There are no files in {}.\\n' +\n", " 'This usually indicates that the channel ({}) was incorrectly specified,\\n' +\n", " 'the data specification in S3 was incorrectly specified or the role specified\\n' +\n", " 'does not have permission to access the data.').format(val_files, \"validation\"))\n", " raw_data = [ pd.read_csv(file, header=None) for file in val_files ]\n", " df_val = pd.concat(raw_data)\n", " y_val = df_val.iloc[:, 0].to_numpy()\n", " \n", " tracker.log_precision_recall(y_val, predictions_probs, title=\"Precision-recall for predicting Churn\", output_artifact=True)\n", " tracker.log_roc_curve(y_val, predictions_probs, title=\"ROC Curve for predicting Churn\", output_artifact=True)\n", " tracker.log_confusion_matrix(y_val, predictions, title=\"Confusion matrix for predicting Churn\", output_artifact=True)\n", "\n", "\n", "if __name__ == \"__main__\":\n", " # Instantiate SM Experiment Tracker\n", " tracker = Tracker.load()\n", " \n", " main(tracker)\n", "\n", "\n", "def model_fn(model_dir):\n", " \"\"\"Load a model. For XGBoost Framework, a default function to load a model is not provided.\n", " Users should provide customized model_fn() in script.\n", " Args:\n", " model_dir: a directory where model is saved.\n", " Returns:\n", " A XGBoost model.\n", " XGBoost model format type.\n", " \"\"\"\n", " model_files = (file for file in os.listdir(model_dir) if os.path.isfile(os.path.join(model_dir, file)))\n", " model_file = next(model_files)\n", " try:\n", " booster = pickle.load(open(os.path.join(model_dir, model_file), 'rb'))\n", " format = 'pkl_format'\n", " except Exception as exp_pkl:\n", " try:\n", " booster = xgboost.Booster()\n", " booster.load_model(os.path.join(model_dir, model_file))\n", " format = 'xgb_format'\n", " except Exception as exp_xgb:\n", " raise ModelLoadInferenceError(\"Unable to load model: {} {}\".format(str(exp_pkl), str(exp_xgb)))\n", " booster.set_param('nthread', 1)\n", " return booster, format\n" ] } ], "metadata": { "kernelspec": { "display_name": "", "name": "" }, "language_info": { "name": "" } }, "nbformat": 4, "nbformat_minor": 4 }