{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "20207a4e", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import sagemaker\n", "import pandas as pd\n", "from io import StringIO, BytesIO\n", "from sagemaker.pytorch import PyTorch\n", "from sagemaker.pytorch.model import PyTorchModel\n", "from sagemaker.serializers import CSVSerializer, IdentitySerializer\n", "from sklearn.metrics import roc_auc_score" ] }, { "cell_type": "markdown", "id": "957e581f", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "# Train RGCN model using SageMaker\n", "\n", "Please make sure you download the data by running [`01-Prepare-Data.ipynb`](01-Prepare-Data.ipynb) notebook first." ] }, { "cell_type": "code", "execution_count": 2, "id": "31d2f961", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "sagemaker_session = sagemaker.Session()\n", "\n", "bucket = sagemaker_session.default_bucket()\n", "prefix = \"sagemaker/ieee-fraud-detection-train\"\n", "\n", "role = sagemaker.get_execution_role()" ] }, { "cell_type": "code", "execution_count": 3, "id": "052bc551", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "### upload training data to S3\n", "inputs = sagemaker_session.upload_data(path=\"./data/train.parquet\", bucket=bucket, key_prefix=prefix)" ] }, { "cell_type": "code", "execution_count": 4, "id": "e0eae030", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "### create SageMaker's PyTorch estimator with custom training script\n", "estimator = PyTorch(\n", " entry_point=\"smtrain.py\",\n", " role=role,\n", " py_version=\"py38\",\n", " framework_version=\"1.11.0\",\n", " instance_count=1,\n", " instance_type=\"ml.m5.4xlarge\",\n", " source_dir='fgnn',\n", " volume_size=100,\n", " hyperparameters={\n", " 'embedding_size': 64,\n", " 'n_layers': 2,\n", " 'n_epochs': 150,\n", " 'n_hidden': 16,\n", " 'dropout': 0.2,\n", " 'weight_decay': 5e-05,\n", " 'lr': 0.01,\n", " },\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "53eb410b", "metadata": { "scrolled": true, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "### fit SM estimator\n", "estimator.fit({\"training\": inputs})" ] }, { "cell_type": "markdown", "id": "857ee78d", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "# Deploy trained RGCN model to SageMaker endpoint " ] }, { "cell_type": "code", "execution_count": 6, "id": "24a61053", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "## create sm model from model data + source code\n", "model = PyTorchModel(model_data=estimator.model_data,\n", " role=role,\n", " entry_point='smtrain.py', \n", " source_dir='fgnn',\n", " py_version=\"py38\",\n", " framework_version=\"1.11.0\",\n", " model_server_workers=2)\n", "\n", "## alternatively, use fitted estimator object to create sm model\n", "# model = estimator.create_model(model_server_workers=2)" ] }, { "cell_type": "code", "execution_count": 7, "id": "e336595c", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-------!" ] } ], "source": [ "## deploy sm model to an endpoint that will accept payload in (serialized) parquet format\n", "predictor = model.deploy(initial_instance_count=1, instance_type=\"ml.m5.4xlarge\", \n", " serializer=IdentitySerializer(content_type='application/x-parquet'))" ] }, { "cell_type": "code", "execution_count": 8, "id": "9c36eb1f", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "### alternatively, deploy sm model to an endpoint that will accept payload in CSV format\n", "# predictor_csv = model.deploy(initial_instance_count=1, instance_type=\"ml.m5.4xlarge\", \n", "# serializer=IdentitySerializer(content_type='text/csv'))" ] }, { "cell_type": "markdown", "id": "a64e398e", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "# Invoke endpoint with test transactions" ] }, { "cell_type": "code", "execution_count": 9, "id": "fffbe2ea", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "### load test transactions\n", "df_test = pd.read_parquet('./data/test.parquet')" ] }, { "cell_type": "code", "execution_count": 14, "id": "119b8af5", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "### sample batch of 1000 transaction\n", "df_batch=df_test.sample(1000)" ] }, { "cell_type": "code", "execution_count": 15, "id": "1564b72c", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 67.5 ms, sys: 201 µs, total: 67.7 ms\n", "Wall time: 7.98 s\n" ] } ], "source": [ "%%time\n", "### serialize parquet table with test transactions\n", "buffer = BytesIO()\n", "df_batch.drop(columns=['isFraud']).to_parquet(buffer)\n", "### invoke model endpoint with serialized parquet payload\n", "response = predictor.predict(buffer.getvalue())" ] }, { "cell_type": "code", "execution_count": 16, "id": "8df8af0b", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "### invoke model endpoint with CSV payload\n", "### note that using CSV format may result in prediction error because CSV serialization will loose column type information:\n", "### e.g., when all rows for a string/object column has NaN values in a batch, this column will be deserialized as type float on the endpoint side,\n", "### and one-hot-encoding of this column will fail inside fraud_detector.py (line `self._cat_transformer.transform(test_transactions)`)\n", "#\n", "# response = predictor_csv.predict(df_batch.to_csv(index=False))" ] }, { "cell_type": "code", "execution_count": 17, "id": "f7db2e30", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/plain": [ "0.8737550277724574" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "### compute roc-auc score for the batch\n", "roc_auc_score(df_batch.isFraud, response)" ] }, { "cell_type": "markdown", "id": "8b56edc2", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "# Delete SageMaker endpoint" ] }, { "cell_type": "code", "execution_count": 18, "id": "5855599d", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "sagemaker_session.delete_endpoint(endpoint_name=predictor.endpoint_name)\n", "# sagemaker_session.delete_endpoint(endpoint_name=predictor_csv.endpoint_name)" ] }, { "cell_type": "code", "execution_count": null, "id": "75cba214", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "conda_pytorch_p38", "language": "python", "name": "conda_pytorch_p38" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 5 }