{ "cells": [ { "cell_type": "markdown", "id": "9e9ea171-e1c5-4877-a71d-4adeac97b89e", "metadata": {}, "source": [ "# Build and Deploy Many Models Leveraging Cancer Gene Expression Data With SageMaker Pipelines and SageMaker Multi-Model Endpoints\n", "\n", "When building machine learning models that leverage genomic data, a key problem is how to allow users to select which features should be used when querying models. To address this,data scientists will sometimes build multiple models to handle specific sub-problems within the dataset. In the context of survival analysis for cancer, a common approach is to analyze gene signatures, and to predict the survival of patients based on the gene expression signatures. See [here](https://www.nature.com/articles/s41598-021-84787-5) for a an example of such an approach in the context of a number of different cancer types. See also [this](https://pubmed.ncbi.nlm.nih.gov/31296308/) review, which discusses different techniques to perform survival analysis.\n", "\n", "A problem that may occur is that, should an application require publishing models based on many hundreds or thousands of gene signatures, managing and deploying all such models may become difficult to maintain and thus unweildly. In this blog post, we show how you can leverage SageMaker Pipelines and SageMaker MultiModel Endpoints to build and deploy many such models. \n", "\n", "To give a specific example, we will leverage the sample cancer RNA expression dataset discussed in the paper [Non-Small Cell Lung Cancer Radiogenomics Map Identifies Relationships between Molecular and Imaging Phenotypes with Prognostic Implications](https://pubmed.ncbi.nlm.nih.gov/28727543/). To simpify the use case, we will focus on 21 co-expressed groups that have been found in this paper to be clicially significant in NSCLC (see that paper, Table 2). These groups of genes, which the authors term metagenes, are annotated in different cellcular pathways. For example, the first group of genes LRIG1, HPGD and GDF15 are relate to the EGFR signaling pathay, while CIM,LMO2 and EFR2 all are involved in cell hypoxia/inflaation. Thus, each cancer patient (row) has gene expression values (columns). In addtion, each of the 199 patients is annoted by their survival status; each described by their Survival Status (1 for deceased; 0 for alive at time of collection of the dataset. We followed the preprocessing [this blog post](https://aws.amazon.com/blogs/industries/building-scalable-machine-learning-pipelines-for-multimodal-health-data-on-aws/) for preprocessing the data. As described more fully in that blog post, the final dataset is 119 patients where each cancer patient (row) has gene expression values (columns). If you run the pipeline described in that blog post, you will get the entire gene expression profile based on the raw FASTQ files, or you can also access the entire gene expression at [GEO](https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE103584). \n", "\n", "The architecture for this approach is as follows:\n", "\n", "![](images/Architecture.jpeg)\n", "\n", "As can be seen in the diagram, we first start with data that is located in S3. We then create a [SageMaker Pipeline](https://sagemaker-examples.readthedocs.io/en/latest/sagemaker-pipelines/index.html). SageMaker Pipelines is a powerful feature that allows data scientists to wrap different components of their workload as a pipeline. This allows for a deployment strategy whereby each step of the analysis is automatically kicked off after the previous job finishes. See the associate code repository ?? for the specific syntax for creating a SageMaker Pipeline.\n", "The pipeline consists of:\n", "\n", "* A SageMaker Processing job for preprocessing the data\n", "\n", "* A SageMaker Training job for training the model. \n", "\n", "* A SageMaker Processing job for evaluating and registering the model in SageMaker Model Registry.\n", "\n", "* A seperate SageMaker Processing job for deploying the model on SageMaker Multi Model Endpoint (MME)\n", "\n", "\n", "\n", "\n", "Before we begin lets verify SageMaker version" ] }, { "cell_type": "code", "execution_count": 3, "id": "cdf9dfab-fc7f-438c-9212-49d299274320", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'2.109.0'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import sagemaker\n", "sagemaker.__version__" ] }, { "cell_type": "code", "execution_count": 2, "id": "445a345a-cf12-47e9-9ffe-9490309ad20a", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Keyring is skipped due to an exception: 'keyring.backends'\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "pytest-astropy 0.8.0 requires pytest-cov>=2.0, which is not installed.\n", "pytest-astropy 0.8.0 requires pytest-filter-subpackage>=0.1, which is not installed.\n", "docker-compose 1.29.2 requires PyYAML<6,>=3.10, but you have pyyaml 6.0 which is incompatible.\n", "aiobotocore 2.4.1 requires botocore<1.27.60,>=1.27.59, but you have botocore 1.29.24 which is incompatible.\u001b[0m\u001b[31m\n", "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "%pip install --upgrade --quiet sagemaker==2.109.0" ] }, { "cell_type": "markdown", "id": "f97fd49d-5a55-4246-a7b5-edaa09b357e5", "metadata": {}, "source": [ "* Please restart the kernel after the sagemaker update. You can do that by following the options on the menu Kernel->Restart Kernel.\n", "* After restarting execute the from below. Make sure that the version of the sagemaker is updated '>=2.94.0'." ] }, { "cell_type": "code", "execution_count": 3, "id": "22d4d10f-725b-4b3a-b52c-900ed7a2c8bb", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "'2.109.0'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import sagemaker\n", "sagemaker.__version__" ] }, { "cell_type": "markdown", "id": "a2e8125b-bd7d-4be7-877b-58c8e5445849", "metadata": {}, "source": [ "Then let's import rest of the packages needed." ] }, { "cell_type": "code", "execution_count": 4, "id": "1925c3c0-3193-4536-8e8d-f51e2d80e2cd", "metadata": { "tags": [] }, "outputs": [], "source": [ "import time\n", "import pandas as pd\n", "import numpy as np\n", "from sklearn.preprocessing import StandardScaler, MinMaxScaler\n", "from sklearn.model_selection import train_test_split\n", "from sagemaker import get_execution_role\n", "\n", "from sagemaker.multidatamodel import MultiDataModel\n", "\n", "from sagemaker.pytorch import PyTorch\n", "from sagemaker.pytorch.model import PyTorchModel\n", "\n", "from sagemaker.workflow.pipeline_context import PipelineSession\n", "from sagemaker.workflow.fail_step import FailStep\n", "from sagemaker.workflow.functions import Join\n", "from sagemaker.model_metrics import MetricsSource, ModelMetrics\n", "from sagemaker.workflow.functions import Join\n", "from sagemaker.workflow.model_step import ModelStep\n", "from sagemaker.workflow.conditions import ConditionLessThanOrEqualTo\n", "from sagemaker.workflow.condition_step import ConditionStep\n", "from sagemaker.workflow.functions import JsonGet\n", "\n", "from sagemaker.predictor import Predictor\n", "\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "id": "ce9b88dd-56db-4c23-baa3-6ae83fc8aa90", "metadata": {}, "source": [ "### Read the data \n", "\n", "Data related to the project is available in the `data` folder. Lets read the and do some exploratory analysis of it and basic pre-processing." ] }, { "cell_type": "code", "execution_count": 5, "id": "0cf392c9-76d1-4db1-ad4a-2c85083c6c48", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Case_IDLRIG1HPGDGDF15CDH2POSTNVCANPDGFRAVCAM1CD44...CD37VIMLMO2EGR2BGNCOL4A1COL5A1COL5A2SurvivalStatusPathologicalMstage
0R01-00525.4385052.4348904.832800.00000014.8000113.8168343.7916633.90357544.62658...1.59092070.9723000.000000.0000006.223293.2684309.39491011.40235610
1R01-01223.1923708.99789043.860002.57656057.6706022.8933006.5263808.76882032.54150...6.934750158.9760003.148942.82840028.8535022.22198012.06600022.50095000
2R01-01377.553001190.5291907.762790.00000011.7924826.1373286.9836223.591960152.81587...5.048800115.2470002.266443.3170505.770295.4907207.7497603.53233000
3R01-01463.16452221.45577016.589800.00000020.942887.12385018.1562002.07665057.05985...4.651934101.1780001.328811.46216221.014706.8395902.57222036.83430600
4R01-01782.54394966.72600315.275201.81765617.536315.2606317.6529802.341880227.60731...2.631933206.2861001.395131.56807021.1979016.8726514.8753229.06037010
..................................................................
114R01-15623.1417908.5427420.000000.684995119.5840025.0278003.1291002.322870143.23510...4.71776058.1632002.621620.97650527.4297029.31680035.52510031.73730010
115R01-1576.5654450.8291404.942410.00000013.2937026.3067703.4636404.649160187.18417...1.33322049.9063001.876281.83373815.3152017.69025021.05000021.97701000
116R01-15826.08822021.2377004.812020.00000016.1434010.7198104.3997281.937700126.17160...1.28637057.8493001.972800.6595677.7383629.9224007.1740809.49535000
117R01-15920.8132403.6295070.000000.00000019.6849224.9875792.3460609.746220371.70458...22.198750102.7678313.399462.5856608.6621685.2568706.23419015.78550500
118R01-16018.1365912.8943500.000000.00000035.5720129.7806106.1854703.071740108.12872...1.86599474.2619210.000002.33437025.0641023.00422010.03810026.49174000
\n", "

119 rows × 24 columns

\n", "
" ], "text/plain": [ " Case_ID LRIG1 HPGD GDF15 CDH2 POSTN VCAN \\\n", "0 R01-005 25.438505 2.434890 4.83280 0.000000 14.80001 13.816834 \n", "1 R01-012 23.192370 8.997890 43.86000 2.576560 57.67060 22.893300 \n", "2 R01-013 77.553001 190.529190 7.76279 0.000000 11.79248 26.137328 \n", "3 R01-014 63.164522 21.455770 16.58980 0.000000 20.94288 7.123850 \n", "4 R01-017 82.543949 66.726003 15.27520 1.817656 17.53631 5.260631 \n", ".. ... ... ... ... ... ... ... \n", "114 R01-156 23.141790 8.542742 0.00000 0.684995 119.58400 25.027800 \n", "115 R01-157 6.565445 0.829140 4.94241 0.000000 13.29370 26.306770 \n", "116 R01-158 26.088220 21.237700 4.81202 0.000000 16.14340 10.719810 \n", "117 R01-159 20.813240 3.629507 0.00000 0.000000 19.68492 24.987579 \n", "118 R01-160 18.136591 2.894350 0.00000 0.000000 35.57201 29.780610 \n", "\n", " PDGFRA VCAM1 CD44 ... CD37 VIM LMO2 \\\n", "0 3.791663 3.903575 44.62658 ... 1.590920 70.972300 0.00000 \n", "1 6.526380 8.768820 32.54150 ... 6.934750 158.976000 3.14894 \n", "2 6.983622 3.591960 152.81587 ... 5.048800 115.247000 2.26644 \n", "3 18.156200 2.076650 57.05985 ... 4.651934 101.178000 1.32881 \n", "4 7.652980 2.341880 227.60731 ... 2.631933 206.286100 1.39513 \n", ".. ... ... ... ... ... ... ... \n", "114 3.129100 2.322870 143.23510 ... 4.717760 58.163200 2.62162 \n", "115 3.463640 4.649160 187.18417 ... 1.333220 49.906300 1.87628 \n", "116 4.399728 1.937700 126.17160 ... 1.286370 57.849300 1.97280 \n", "117 2.346060 9.746220 371.70458 ... 22.198750 102.767831 3.39946 \n", "118 6.185470 3.071740 108.12872 ... 1.865994 74.261921 0.00000 \n", "\n", " EGR2 BGN COL4A1 COL5A1 COL5A2 SurvivalStatus \\\n", "0 0.000000 6.22329 3.268430 9.394910 11.402356 1 \n", "1 2.828400 28.85350 22.221980 12.066000 22.500950 0 \n", "2 3.317050 5.77029 5.490720 7.749760 3.532330 0 \n", "3 1.462162 21.01470 6.839590 2.572220 36.834306 0 \n", "4 1.568070 21.19790 16.872651 4.875322 9.060370 1 \n", ".. ... ... ... ... ... ... \n", "114 0.976505 27.42970 29.316800 35.525100 31.737300 1 \n", "115 1.833738 15.31520 17.690250 21.050000 21.977010 0 \n", "116 0.659567 7.73836 29.922400 7.174080 9.495350 0 \n", "117 2.585660 8.66216 85.256870 6.234190 15.785505 0 \n", "118 2.334370 25.06410 23.004220 10.038100 26.491740 0 \n", "\n", " PathologicalMstage \n", "0 0 \n", "1 0 \n", "2 0 \n", "3 0 \n", "4 0 \n", ".. ... \n", "114 0 \n", "115 0 \n", "116 0 \n", "117 0 \n", "118 0 \n", "\n", "[119 rows x 24 columns]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "genomic_data_with_label = pd.read_csv(\"data/Genomic-data-119patients.csv\")\n", "genomic_data_with_label" ] }, { "cell_type": "markdown", "id": "17944acb-d627-47bb-9fb6-5014b0f260e9", "metadata": {}, "source": [ "You can see that for each patient (`Case_ID`) we have all gene expression levels, as well as SurvivalStatus. Note that this dataset also contains a pathological label for the patient. We will not be leveraging this column, but you can read more about the histopathology data associated with this dataset [here](https://aws.amazon.com/blogs/industries/building-scalable-machine-learning-pipelines-for-multimodal-health-data-on-aws/). Thus, we remove `Case_ID` and `PathologicalMstage`" ] }, { "cell_type": "code", "execution_count": 6, "id": "e31012ee-788c-409e-a786-971cf7c3c060", "metadata": { "tags": [] }, "outputs": [], "source": [ "genomic_data_with_label.drop(columns=[\"Case_ID\", \"PathologicalMstage\"], inplace=True)" ] }, { "cell_type": "markdown", "id": "fd09bd55-e225-44ca-b1e5-6eede57aceed", "metadata": {}, "source": [ "Next, we check the Class Balanceness" ] }, { "cell_type": "code", "execution_count": 7, "id": "94ed0ed8-4c4c-449f-8503-e26d85275f4b", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "genomic_data_with_label.SurvivalStatus.value_counts().plot.bar()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "86932ff9-a393-4443-a014-7c73133b1de9", "metadata": {}, "source": [ "While class `0` is a greater proportion of cases, there is sufficient number of class `1` to proceed without rebalancing the data.\n", "\n", "Next, we will rescale the data column, by column." ] }, { "cell_type": "code", "execution_count": 8, "id": "1bafc906-ba74-4cc4-8c48-82a4fba0b925", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
LRIG1HPGDGDF15CDH2POSTNVCANPDGFRAVCAM1CD44CD48...LYL1SPI1CD37VIMLMO2EGR2BGNCOL4A1COL5A1COL5A2
00.1860330.0113790.0580760.0000000.0443920.0438520.0317560.0773120.0967270.000000...0.0000000.0411440.0716670.1611860.0000000.0000000.0365920.0000000.0378050.048075
10.1681260.0420510.5270720.0146270.1881860.0801020.0587230.1736710.0680720.208182...0.1113330.2648930.3123940.3846880.1324160.2545660.2606310.0987620.0507540.099377
20.6015240.8904150.0932870.0000000.0343040.0930570.0632320.0711400.3532510.058441...0.0613150.4252060.2274360.2736300.0953060.2985460.0321070.0115800.0298300.011697
30.4868090.1002710.1993620.0000000.0649960.0171220.1734040.0411290.1262070.089822...0.0647530.1779430.2095580.2378990.0558780.1316000.1830270.0186080.0047300.165631
40.6413150.3118360.1835640.0103180.0535700.0096810.0698320.0463820.5305870.075721...0.0000000.1099450.1185620.5048410.0586660.1411320.1848410.0708880.0158950.037249
..................................................................
1140.1677220.0399230.0000000.0038890.3958530.0886260.0252220.0460060.3305340.023121...0.1969050.0615410.2125240.1286550.1102410.0878890.2465350.1357320.1644790.142071
1150.0355650.0038750.0593940.0000000.0393390.0937340.0285210.0920790.4347410.036051...0.0884050.0787230.0600580.1076850.0788990.1650430.1266020.0751480.0943070.096955
1160.1912130.0992520.0578270.0000000.0488980.0314840.0377520.0383770.2900760.026981...0.2088660.0901460.0579480.1278580.0829580.0593630.0515910.1388870.0270390.039260
1170.1491580.0169620.0000000.0000000.0607770.0884660.0175010.1930290.8722510.424727...0.3439940.1626221.0000000.2419370.1429500.2327180.0607370.4272210.0224820.068335
1180.1278180.0135260.0000000.0000000.1140640.1076080.0553610.0608370.2472950.024145...0.1506920.1585710.0840590.1695410.0000000.2101010.2231160.1028380.0409230.117824
\n", "

119 rows × 21 columns

\n", "
" ], "text/plain": [ " LRIG1 HPGD GDF15 CDH2 POSTN VCAN PDGFRA \\\n", "0 0.186033 0.011379 0.058076 0.000000 0.044392 0.043852 0.031756 \n", "1 0.168126 0.042051 0.527072 0.014627 0.188186 0.080102 0.058723 \n", "2 0.601524 0.890415 0.093287 0.000000 0.034304 0.093057 0.063232 \n", "3 0.486809 0.100271 0.199362 0.000000 0.064996 0.017122 0.173404 \n", "4 0.641315 0.311836 0.183564 0.010318 0.053570 0.009681 0.069832 \n", ".. ... ... ... ... ... ... ... \n", "114 0.167722 0.039923 0.000000 0.003889 0.395853 0.088626 0.025222 \n", "115 0.035565 0.003875 0.059394 0.000000 0.039339 0.093734 0.028521 \n", "116 0.191213 0.099252 0.057827 0.000000 0.048898 0.031484 0.037752 \n", "117 0.149158 0.016962 0.000000 0.000000 0.060777 0.088466 0.017501 \n", "118 0.127818 0.013526 0.000000 0.000000 0.114064 0.107608 0.055361 \n", "\n", " VCAM1 CD44 CD48 ... LYL1 SPI1 CD37 \\\n", "0 0.077312 0.096727 0.000000 ... 0.000000 0.041144 0.071667 \n", "1 0.173671 0.068072 0.208182 ... 0.111333 0.264893 0.312394 \n", "2 0.071140 0.353251 0.058441 ... 0.061315 0.425206 0.227436 \n", "3 0.041129 0.126207 0.089822 ... 0.064753 0.177943 0.209558 \n", "4 0.046382 0.530587 0.075721 ... 0.000000 0.109945 0.118562 \n", ".. ... ... ... ... ... ... ... \n", "114 0.046006 0.330534 0.023121 ... 0.196905 0.061541 0.212524 \n", "115 0.092079 0.434741 0.036051 ... 0.088405 0.078723 0.060058 \n", "116 0.038377 0.290076 0.026981 ... 0.208866 0.090146 0.057948 \n", "117 0.193029 0.872251 0.424727 ... 0.343994 0.162622 1.000000 \n", "118 0.060837 0.247295 0.024145 ... 0.150692 0.158571 0.084059 \n", "\n", " VIM LMO2 EGR2 BGN COL4A1 COL5A1 COL5A2 \n", "0 0.161186 0.000000 0.000000 0.036592 0.000000 0.037805 0.048075 \n", "1 0.384688 0.132416 0.254566 0.260631 0.098762 0.050754 0.099377 \n", "2 0.273630 0.095306 0.298546 0.032107 0.011580 0.029830 0.011697 \n", "3 0.237899 0.055878 0.131600 0.183027 0.018608 0.004730 0.165631 \n", "4 0.504841 0.058666 0.141132 0.184841 0.070888 0.015895 0.037249 \n", ".. ... ... ... ... ... ... ... \n", "114 0.128655 0.110241 0.087889 0.246535 0.135732 0.164479 0.142071 \n", "115 0.107685 0.078899 0.165043 0.126602 0.075148 0.094307 0.096955 \n", "116 0.127858 0.082958 0.059363 0.051591 0.138887 0.027039 0.039260 \n", "117 0.241937 0.142950 0.232718 0.060737 0.427221 0.022482 0.068335 \n", "118 0.169541 0.000000 0.210101 0.223116 0.102838 0.040923 0.117824 \n", "\n", "[119 rows x 21 columns]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "genomic_data = genomic_data_with_label.drop(columns=[\"SurvivalStatus\"])\n", "labels = genomic_data_with_label[\"SurvivalStatus\"]\n", "\n", "scaler = MinMaxScaler()\n", "genomic_data[genomic_data.columns] = scaler.fit_transform(genomic_data.to_numpy())\n", "genomic_data " ] }, { "cell_type": "markdown", "id": "95de8282-8aa1-4739-87b1-44677da0294a", "metadata": {}, "source": [ "### Split the data Train/Test\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "e80edd4d-12a8-43ef-96f0-a544ecd80854", "metadata": { "tags": [] }, "outputs": [], "source": [ "X_train, X_val, y_train, y_val = train_test_split(genomic_data, labels, test_size = 0.2)\n" ] }, { "cell_type": "markdown", "id": "8dc3e8d6-8b74-4e11-ae97-d629483273bb", "metadata": {}, "source": [ "After spliting the data lets visually verify that the class distributions follow the same both in `train` and `validation` data." ] }, { "cell_type": "code", "execution_count": 10, "id": "3b30c33b-6466-4c9e-84d9-569a18a0b3b7", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_train.value_counts().plot.bar()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 11, "id": "dcf9096b-008a-40c8-82f2-c701314d1c70", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_val.value_counts().plot.bar()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "ac1ed187-56ad-4d6e-9db0-c8de76cb9302", "metadata": {}, "source": [ "### Save data" ] }, { "cell_type": "code", "execution_count": 12, "id": "076a2760-b7be-47dd-9eee-1cf9bdc3c6df", "metadata": { "tags": [] }, "outputs": [], "source": [ "X_train.insert(0, \"SurvivalStatus\", y_train)\n", "X_train.to_csv(\"./data/train_data.csv\", index = False, header=True)" ] }, { "cell_type": "code", "execution_count": 13, "id": "1c05e717-5f08-4e84-946a-16b17b886d17", "metadata": { "tags": [] }, "outputs": [], "source": [ "X_val.insert(0, \"SurvivalStatus\", y_val)\n", "X_val.to_csv(\"./data/validation_data.csv\", index = False, header=True)" ] }, { "cell_type": "markdown", "id": "302812ca-daa7-411b-844a-1202037df084", "metadata": {}, "source": [ "### Prepare for SageMaker Training" ] }, { "cell_type": "code", "execution_count": 14, "id": "96ea6495-a011-4432-81f4-83ac82b7b002", "metadata": { "tags": [] }, "outputs": [], "source": [ "role = get_execution_role()\n", "session = sagemaker.Session()\n", "bucket = session.default_bucket()\n", "\n", "s3_prefix = \"genome-survival-classification/data\"" ] }, { "cell_type": "markdown", "id": "8ec97077-df78-4012-8f5a-ed7fe83c0340", "metadata": {}, "source": [ "### Upload to S3" ] }, { "cell_type": "code", "execution_count": 15, "id": "a6f51cc3-2646-4812-bf62-374217c01a2b", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train data : [s3://sagemaker-us-east-1-421720712360/genome-survival-classification/data/train/train_data.csv]\n", "Val data : [s3://sagemaker-us-east-1-421720712360/genome-survival-classification/data/validation/validation_data.csv]\n" ] } ], "source": [ "input_train = session.upload_data(\n", " path=\"./data/train_data.csv\", bucket=bucket, key_prefix=\"{}/train\".format(s3_prefix)\n", " )\n", "\n", "input_val = session.upload_data(\n", " path=\"./data/validation_data.csv\", bucket=bucket, key_prefix=\"{}/validation\".format(s3_prefix)\n", " )\n", "\n", "print(\"Train data : [{}]\".format(input_train))\n", "print(\"Val data : [{}]\".format(input_val))" ] }, { "cell_type": "markdown", "id": "b380194b-f835-46ed-8ba4-7b7045a81bbd", "metadata": {}, "source": [ "## Create the Multimodel Endpoint \n", "\n", "At this time we are creating the multi-model endpoint (one time configuration) to serve the models that are going to be delivered by the SageMaker piplines. Note that for now we are deploying a MME model that points to an empty collection of models; we will populate the collection of models later in the SageMaker Pipeline step. We also specify a custom inference.py script, which will allow users to choose which model to invoke. \n" ] }, { "cell_type": "code", "execution_count": 16, "id": "8da29b31-7086-4143-97e5-ce388ec0634b", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-----!" ] } ], "source": [ "FRAMEWORK_VERSION = \"1.12.0\"\n", "\n", "mme_model_data_location = \"s3://{}/{}/mme-models-location\".format(bucket, s3_prefix)\n", "\n", "endpoint_name = \"Genome-Survival-Prediction-MultiModel-Endpoint-{}\".format(time.strftime(\"%H-%M-%S\"))\n", "\n", "model = PyTorchModel(model_data=\"./model/model.tar.gz\", \n", " source_dir='src', \n", " entry_point='inference.py', \n", " role=role, \n", " framework_version=FRAMEWORK_VERSION,\n", " py_version = \"py38\",\n", " sagemaker_session=session)\n", "\n", " \n", "mme = MultiDataModel(\n", " name = \"Genome-Survival-Prediction-MME-Model-{}\".format(time.strftime(\"%H-%M-%S\")),\n", " model_data_prefix = mme_model_data_location,\n", " model = model, # passing our model\n", " sagemaker_session=session,\n", ")\n", "\n", "mme_predictor = mme.deploy(\n", " initial_instance_count=1, \n", " instance_type=\"ml.m5.large\", \n", " endpoint_name=endpoint_name\n", ")" ] }, { "cell_type": "markdown", "id": "d6f202e8-66b1-4af0-814c-32ff00d2f210", "metadata": {}, "source": [ "#### Check for current models (First time it should be empty)" ] }, { "cell_type": "code", "execution_count": 18, "id": "267ed5dc-3397-49a1-bf74-1e55a8222d9a", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "list(mme.list_models())" ] }, { "cell_type": "markdown", "id": "d0bcec3b-94ca-4027-b8d2-c91c53907ad1", "metadata": {}, "source": [ "## Creating the pipeline \n", "\n", "At this point, the trained models are stored on S3, and the Multi-Model Enpoint can dynamically retrieve the needed model based on the user request. The user specifies not only the input data to run, but which specific model to use. \n", "\n", "Thinking back to the gene expression data, the following diagram represents an overview of the modeling process FIX:\n", "\n", "![](images/image_2.jpg)\n", "\n", "In this diagram, we first start with the original gene expression data (red indicates higher expression; blue lower expression), and then split that data into N seperate subsets of gene expression data. Model 1, for example, is built on genes 1,2,3; Model 2 on genes 4,5,6 etc. We then train multiple models, where each subsample of gene expression data is leveraged to predict survival. Note that each execution of the SageMaker Pipeline corresponds to building one model based on a gene signature. \n", "\n", "As mentioned in the introduction, we are leveraging a small data set for just 21 genes found to be signficant in predicting survival in lung cancer. However, you could do similair analysis with others groups of genes, such as those present in the [KEGG pathway database](https://www.genome.jp/kegg/pathway.html) or [Molecular Signatures Database](http://www.gsea-msigdb.org/gsea/msigdb/index.jsp)\n", "\n" ] }, { "cell_type": "code", "execution_count": 19, "id": "721dcf50-3e73-42dc-9625-29bfd945737d", "metadata": { "tags": [] }, "outputs": [], "source": [ "pipeline_session = PipelineSession()\n", "\n", "from sagemaker.workflow.parameters import (\n", " ParameterInteger,\n", " ParameterString,\n", " ParameterFloat,\n", ")\n", "\n", "input_train_data = ParameterString(\n", " name=\"InputTrainData\",\n", " default_value=input_train,\n", ")\n", "\n", "input_validation_data = ParameterString(\n", " name=\"InputValidationData\",\n", " default_value=input_val,\n", ")\n", "\n", "genome_group = ParameterString(\n", " name=\"genomeGroup\",\n", " default_value=\"ALL\",\n", ")\n", "\n", "training_instance_type = ParameterString(\n", " name=\"TrainingInstanceType\", \n", " default_value=\"ml.m5.large\"\n", ")\n", "\n", "mme_model_location = ParameterString(\n", " name=\"MMEModelsLocation\",\n", " default_value=mme_model_data_location,\n", ")\n", "\n", "from sagemaker.workflow.steps import CacheConfig\n", "\n", "cache_config = CacheConfig(enable_caching=True, expire_after=\"PT1H\")\n" ] }, { "cell_type": "markdown", "id": "626e9955-9b59-49bc-933c-78957ca4922f", "metadata": {}, "source": [ "#### Training Step" ] }, { "cell_type": "code", "execution_count": 20, "id": "bf4b182f-67ff-46fe-b8d4-04bf7cd0d7ee", "metadata": { "tags": [] }, "outputs": [], "source": [ "pytorch_estimator = PyTorch(\n", " source_dir=\"src\", \n", " entry_point=\"train.py\",\n", " framework_version = \"1.12.0\",\n", " py_version = \"py38\",\n", " instance_type= training_instance_type,\n", " instance_count=1,\n", " role = role,\n", " hyperparameters = {\n", " \"genome-group\" : genome_group\n", " },\n", " sagemaker_session = pipeline_session\n", ")\n", "\n", "#pytorch_estimator.fit({\"train_data\" : input_train, \"val_data\": input_val})" ] }, { "cell_type": "code", "execution_count": 21, "id": "e188b0fd-3804-47d2-842b-acac0f3e0cb9", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.7/site-packages/sagemaker/workflow/steps.py:444: UserWarning: Profiling is enabled on the provided estimator. The default profiler rule includes a timestamp which will change each time the pipeline is upserted, causing cache misses. If profiling is not needed, set disable_profiler to True on the estimator.\n", " warnings.warn(msg)\n" ] } ], "source": [ "from sagemaker.inputs import TrainingInput\n", "from sagemaker.workflow.steps import TrainingStep\n", "\n", "step_train = TrainingStep(\n", " name=\"Genome-Survival-Prediction-Training\",\n", " estimator=pytorch_estimator,\n", " inputs={\n", " \"train_data\": TrainingInput(\n", " s3_data=input_train_data,\n", " content_type=\"text/csv\",\n", " ),\n", " \"val_data\": TrainingInput(\n", " s3_data=input_validation_data,\n", " content_type=\"text/csv\",\n", " )\n", " },\n", " cache_config=cache_config\n", ")" ] }, { "cell_type": "markdown", "id": "b96166bd-5187-428b-9eec-33372edcca65", "metadata": {}, "source": [ "#### Model evaluation Step" ] }, { "cell_type": "code", "execution_count": 22, "id": "3c5ebc8a-6135-49ff-b170-74693991c089", "metadata": { "tags": [] }, "outputs": [], "source": [ "from sagemaker.sklearn.processing import SKLearnProcessor\n", "from sagemaker.workflow.properties import PropertyFile\n", "from sagemaker.processing import ProcessingInput, ProcessingOutput\n", "from sagemaker.workflow.steps import ProcessingStep\n", "\n", "framework_version = \"0.23-1\"\n", "\n", "sklearn_processor = SKLearnProcessor(\n", " framework_version=framework_version,\n", " instance_type=\"ml.m5.large\",\n", " instance_count=1,\n", " base_job_name=\"Genome-Survival-Prediction-Eval\",\n", " role=role,\n", " env = {\n", " \"genomeGroup\" : genome_group\n", " },\n", " sagemaker_session = pipeline_session\n", ")" ] }, { "cell_type": "code", "execution_count": 23, "id": "0ec11859-a2ff-4b0c-97a3-151d66f2764d", "metadata": { "tags": [] }, "outputs": [], "source": [ "evaluation_report = PropertyFile(\n", " name=\"EvaluationReport\", output_name=\"evaluation\", path=\"evaluation.json\"\n", ")\n", "\n", "step_eval = ProcessingStep(\n", " name=\"Genome-Survival-Prediction-Eval\",\n", " processor=sklearn_processor,\n", " inputs=[\n", " ProcessingInput(\n", " source=step_train.properties.ModelArtifacts.S3ModelArtifacts,\n", " destination=\"/opt/ml/processing/model\",\n", " ),\n", " ProcessingInput(\n", " source=input_validation_data,\n", " destination=\"/opt/ml/processing/test\",\n", " ),\n", " ProcessingInput(\n", " source=\"./src\",\n", " destination=\"/opt/ml/processing/code\",\n", " )\n", " ],\n", " outputs=[\n", " ProcessingOutput(output_name=\"evaluation\", source=\"/opt/ml/processing/evaluation\")\n", " ],\n", " code=\"src/evaluation.py\",\n", " property_files=[evaluation_report],\n", ")" ] }, { "cell_type": "code", "execution_count": 24, "id": "361f41f5-27e4-41d5-8f47-3ce07187c0f1", "metadata": { "tags": [] }, "outputs": [], "source": [ "step_fail = FailStep(\n", " name=\"Genome-Survival-Prediction-Fail\",\n", " error_message=\"Execution failed due to Obective Metric was not met\",\n", ")" ] }, { "cell_type": "markdown", "id": "35bece6c-4327-412d-b3b9-02de4407ebe8", "metadata": {}, "source": [ "#### Define a Register Model Step to Create a Model Package\n" ] }, { "cell_type": "code", "execution_count": 25, "id": "41555a87-62dc-4662-8f6d-4f1e527ddd5b", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.7/site-packages/sagemaker/workflow/pipeline_context.py:236: UserWarning: Running within a PipelineSession, there will be No Wait, No Logs, and No Job being started.\n", " UserWarning,\n" ] } ], "source": [ "model_metrics = ModelMetrics(\n", " model_statistics=MetricsSource(\n", " s3_uri=\"{}/evaluation.json\".format(\n", " step_eval.arguments[\"ProcessingOutputConfig\"][\"Outputs\"][0][\"S3Output\"][\"S3Uri\"]\n", " ),\n", " content_type=\"application/json\",\n", " )\n", ")\n", "\n", "model = PyTorchModel(\n", " model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts,\n", " role=role,\n", " entry_point=\"inference.py\",\n", " source_dir = \"src\",\n", " framework_version = \"1.12.0\",\n", " py_version = \"py38\",\n", " sagemaker_session=PipelineSession()\n", ")\n", "\n", "# in addition, we might also want to register a model to SageMaker Model Registry\n", "register_model_step_args = model.register(\n", " content_types=[\"application/json\"],\n", " response_types=[\"application/json\"],\n", " inference_instances=[\"ml.t2.medium\", \"ml.m5.xlarge\"],\n", " transform_instances=[\"ml.m5.xlarge\"],\n", " model_package_group_name='Genome-Survival-Prediction-Model-Package-Group',\n", " approval_status = \"Approved\"\n", ")\n", "\n", "step_model_registration = ModelStep(\n", " name=\"Genome-Survival-Prediction-Model-Registration\",\n", " step_args=register_model_step_args,\n", ")\n", "\n" ] }, { "cell_type": "markdown", "id": "82182449-a53b-40b7-8843-97ab1f643194", "metadata": {}, "source": [ "#### Define MME Deployment Step\n" ] }, { "cell_type": "code", "execution_count": 42, "id": "8b823816-e9c1-4fde-a9aa-858c5613831c", "metadata": { "tags": [] }, "outputs": [], "source": [ "sklearn_processor_for_mme_deployment = SKLearnProcessor(\n", " framework_version=framework_version,\n", " instance_type=\"ml.m5.xlarge\",\n", " instance_count=1,\n", " base_job_name=\"Genome-Survival-Prediction-Deployment\",\n", " role=role,\n", " env = {\n", " \"modelPackageArn\" : step_model_registration.steps[1].properties.ModelPackageArn,\n", " \"mmeModelLocation\" : mme_model_location,\n", " \"genomeGroup\" : genome_group,\n", " \"AWS_DEFAULT_REGION\": session.boto_region_name\n", " }\n", ")\n", "\n", "step_mme_deployment = ProcessingStep(\n", " name=\"Genome-Survival-Prediction-MME-Deployment\",\n", " processor=sklearn_processor_for_mme_deployment,\n", " inputs=[\n", " \n", " ],\n", " outputs=[\n", " ProcessingOutput(output_name=\"mme_model_location\", source=\"/opt/ml/processing/model/mme\")\n", " ],\n", " code=\"src/mme_deployment.py\"\n", ")" ] }, { "cell_type": "markdown", "id": "47577d1d-c3d8-40b2-b3b6-bde34ca59878", "metadata": {}, "source": [ "### Condition Step" ] }, { "cell_type": "code", "execution_count": 43, "id": "5d854028-1df9-4bae-a72b-52fa94c5e62c", "metadata": { "tags": [] }, "outputs": [], "source": [ "cond_lte = ConditionLessThanOrEqualTo(\n", " left=JsonGet(\n", " step_name=step_eval.name,\n", " property_file=evaluation_report,\n", " json_path=\"metrics.test_accuracy.value\",\n", " ),\n", " right=0.4\n", ")\n", "\n", "step_cond = ConditionStep(\n", " name=\"Genome-Survival-Prediction-Condition\",\n", " conditions=[cond_lte],\n", " if_steps=[step_fail],\n", " else_steps=[step_model_registration, step_mme_deployment],\n", ")" ] }, { "cell_type": "markdown", "id": "f0dbe662-e0c3-4b01-add3-c2f840a8e0de", "metadata": {}, "source": [ "### Create the pipeline using all the steps defined above" ] }, { "cell_type": "code", "execution_count": 44, "id": "fca12fef-dd6c-4de6-8ba7-b51fa0148d5a", "metadata": { "tags": [] }, "outputs": [], "source": [ "from sagemaker.workflow.pipeline import Pipeline\n", "\n", "pipeline_name = f\"Genome-Survival-Prediction-Pipeline\"\n", "pipeline = Pipeline(\n", " name=pipeline_name,\n", " parameters=[\n", " input_train_data,\n", " input_validation_data,\n", " training_instance_type,\n", " genome_group,\n", " mme_model_location\n", " ],\n", " steps=[step_train, step_eval, step_cond]\n", ")" ] }, { "cell_type": "code", "execution_count": 45, "id": "c25313f1-f76b-4638-9e76-65142ecda773", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The input argument instance_type of function (sagemaker.image_uris.retrieve) is a pipeline variable (), which is not allowed. The default_value of this Parameter object will be used to override it.\n" ] }, { "data": { "text/plain": [ "{'Version': '2020-12-01',\n", " 'Metadata': {},\n", " 'Parameters': [{'Name': 'InputTrainData',\n", " 'Type': 'String',\n", " 'DefaultValue': 's3://sagemaker-us-east-1-421720712360/genome-survival-classification/data/train/train_data.csv'},\n", " {'Name': 'InputValidationData',\n", " 'Type': 'String',\n", " 'DefaultValue': 's3://sagemaker-us-east-1-421720712360/genome-survival-classification/data/validation/validation_data.csv'},\n", " {'Name': 'TrainingInstanceType',\n", " 'Type': 'String',\n", " 'DefaultValue': 'ml.m5.large'},\n", " {'Name': 'genomeGroup', 'Type': 'String', 'DefaultValue': 'ALL'},\n", " {'Name': 'MMEModelsLocation',\n", " 'Type': 'String',\n", " 'DefaultValue': 's3://sagemaker-us-east-1-421720712360/genome-survival-classification/data/mme-models-location'}],\n", " 'PipelineExperimentConfig': {'ExperimentName': {'Get': 'Execution.PipelineName'},\n", " 'TrialName': {'Get': 'Execution.PipelineExecutionId'}},\n", " 'Steps': [{'Name': 'Genome-Survival-Prediction-Training',\n", " 'Type': 'Training',\n", " 'Arguments': {'AlgorithmSpecification': {'TrainingInputMode': 'File',\n", " 'TrainingImage': '763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:1.12.0-cpu-py38',\n", " 'EnableSageMakerMetricsTimeSeries': True},\n", " 'OutputDataConfig': {'S3OutputPath': 's3://sagemaker-us-east-1-421720712360/'},\n", " 'StoppingCondition': {'MaxRuntimeInSeconds': 86400},\n", " 'ResourceConfig': {'VolumeSizeInGB': 30,\n", " 'InstanceCount': 1,\n", " 'InstanceType': {'Get': 'Parameters.TrainingInstanceType'}},\n", " 'RoleArn': 'arn:aws:iam::421720712360:role/mod-ad69b4757be44dd4-SageMakerExecutionRole-1HU619OTJ7IXI',\n", " 'InputDataConfig': [{'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix',\n", " 'S3Uri': {'Get': 'Parameters.InputTrainData'},\n", " 'S3DataDistributionType': 'FullyReplicated'}},\n", " 'ContentType': 'text/csv',\n", " 'ChannelName': 'train_data'},\n", " {'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix',\n", " 'S3Uri': {'Get': 'Parameters.InputValidationData'},\n", " 'S3DataDistributionType': 'FullyReplicated'}},\n", " 'ContentType': 'text/csv',\n", " 'ChannelName': 'val_data'}],\n", " 'HyperParameters': {'genome-group': {'Get': 'Parameters.genomeGroup'},\n", " 'sagemaker_submit_directory': '\"s3://sagemaker-us-east-1-421720712360/Genome-Survival-Prediction-Training-f26f9dd2efcc5b30943ede8fa8fc1b51/source/sourcedir.tar.gz\"',\n", " 'sagemaker_program': '\"train.py\"',\n", " 'sagemaker_container_log_level': '20',\n", " 'sagemaker_region': '\"us-east-1\"'},\n", " 'DebugHookConfig': {'S3OutputPath': 's3://sagemaker-us-east-1-421720712360/',\n", " 'CollectionConfigurations': []},\n", " 'ProfilerRuleConfigurations': [{'RuleConfigurationName': 'ProfilerReport-1671204623',\n", " 'RuleEvaluatorImage': '503895931360.dkr.ecr.us-east-1.amazonaws.com/sagemaker-debugger-rules:latest',\n", " 'RuleParameters': {'rule_to_invoke': 'ProfilerReport'}}],\n", " 'ProfilerConfig': {'S3OutputPath': 's3://sagemaker-us-east-1-421720712360/'}},\n", " 'CacheConfig': {'Enabled': True, 'ExpireAfter': 'PT1H'}},\n", " {'Name': 'Genome-Survival-Prediction-Eval',\n", " 'Type': 'Processing',\n", " 'Arguments': {'ProcessingResources': {'ClusterConfig': {'InstanceType': 'ml.m5.large',\n", " 'InstanceCount': 1,\n", " 'VolumeSizeInGB': 30}},\n", " 'AppSpecification': {'ImageUri': '683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.23-1-cpu-py3',\n", " 'ContainerEntrypoint': ['python3',\n", " '/opt/ml/processing/input/code/evaluation.py']},\n", " 'RoleArn': 'arn:aws:iam::421720712360:role/mod-ad69b4757be44dd4-SageMakerExecutionRole-1HU619OTJ7IXI',\n", " 'ProcessingInputs': [{'InputName': 'input-1',\n", " 'AppManaged': False,\n", " 'S3Input': {'S3Uri': {'Get': 'Steps.Genome-Survival-Prediction-Training.ModelArtifacts.S3ModelArtifacts'},\n", " 'LocalPath': '/opt/ml/processing/model',\n", " 'S3DataType': 'S3Prefix',\n", " 'S3InputMode': 'File',\n", " 'S3DataDistributionType': 'FullyReplicated',\n", " 'S3CompressionType': 'None'}},\n", " {'InputName': 'input-2',\n", " 'AppManaged': False,\n", " 'S3Input': {'S3Uri': {'Get': 'Parameters.InputValidationData'},\n", " 'LocalPath': '/opt/ml/processing/test',\n", " 'S3DataType': 'S3Prefix',\n", " 'S3InputMode': 'File',\n", " 'S3DataDistributionType': 'FullyReplicated',\n", " 'S3CompressionType': 'None'}},\n", " {'InputName': 'input-3',\n", " 'AppManaged': False,\n", " 'S3Input': {'S3Uri': 's3://sagemaker-us-east-1-421720712360/Genome-Survival-Prediction-Eval-6aa4ea5b1214955610bed069e7f7bf56/input/input-3',\n", " 'LocalPath': '/opt/ml/processing/code',\n", " 'S3DataType': 'S3Prefix',\n", " 'S3InputMode': 'File',\n", " 'S3DataDistributionType': 'FullyReplicated',\n", " 'S3CompressionType': 'None'}},\n", " {'InputName': 'code',\n", " 'AppManaged': False,\n", " 'S3Input': {'S3Uri': 's3://sagemaker-us-east-1-421720712360/Genome-Survival-Prediction-Eval-6aa4ea5b1214955610bed069e7f7bf56/input/code/evaluation.py',\n", " 'LocalPath': '/opt/ml/processing/input/code',\n", " 'S3DataType': 'S3Prefix',\n", " 'S3InputMode': 'File',\n", " 'S3DataDistributionType': 'FullyReplicated',\n", " 'S3CompressionType': 'None'}}],\n", " 'ProcessingOutputConfig': {'Outputs': [{'OutputName': 'evaluation',\n", " 'AppManaged': False,\n", " 'S3Output': {'S3Uri': 's3://sagemaker-us-east-1-421720712360/Genome-Survival-Prediction-Eval-6aa4ea5b1214955610bed069e7f7bf56/output/evaluation',\n", " 'LocalPath': '/opt/ml/processing/evaluation',\n", " 'S3UploadMode': 'EndOfJob'}}]},\n", " 'Environment': {'genomeGroup': {'Get': 'Parameters.genomeGroup'}}},\n", " 'PropertyFiles': [{'PropertyFileName': 'EvaluationReport',\n", " 'OutputName': 'evaluation',\n", " 'FilePath': 'evaluation.json'}]},\n", " {'Name': 'Genome-Survival-Prediction-Condition',\n", " 'Type': 'Condition',\n", " 'Arguments': {'Conditions': [{'Type': 'LessThanOrEqualTo',\n", " 'LeftValue': {'Std:JsonGet': {'PropertyFile': {'Get': 'Steps.Genome-Survival-Prediction-Eval.PropertyFiles.EvaluationReport'},\n", " 'Path': 'metrics.test_accuracy.value'}},\n", " 'RightValue': 0.4}],\n", " 'IfSteps': [{'Name': 'Genome-Survival-Prediction-Fail',\n", " 'Type': 'Fail',\n", " 'Arguments': {'ErrorMessage': 'Execution failed due to Obective Metric was not met'}}],\n", " 'ElseSteps': [{'Name': 'Genome-Survival-Prediction-Model-Registration-RepackModel-0',\n", " 'Type': 'Training',\n", " 'Arguments': {'AlgorithmSpecification': {'TrainingInputMode': 'File',\n", " 'TrainingImage': '683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.23-1-cpu-py3'},\n", " 'OutputDataConfig': {'S3OutputPath': 's3://sagemaker-us-east-1-421720712360/pytorch-inference-2022-12-16-14-16-49-980'},\n", " 'StoppingCondition': {'MaxRuntimeInSeconds': 86400},\n", " 'ResourceConfig': {'VolumeSizeInGB': 30,\n", " 'InstanceCount': 1,\n", " 'InstanceType': 'ml.m5.large'},\n", " 'RoleArn': 'arn:aws:iam::421720712360:role/mod-ad69b4757be44dd4-SageMakerExecutionRole-1HU619OTJ7IXI',\n", " 'InputDataConfig': [{'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix',\n", " 'S3Uri': {'Get': 'Steps.Genome-Survival-Prediction-Training.ModelArtifacts.S3ModelArtifacts'},\n", " 'S3DataDistributionType': 'FullyReplicated'}},\n", " 'ChannelName': 'training'}],\n", " 'HyperParameters': {'inference_script': '\"inference.py\"',\n", " 'model_archive': {'Std:Join': {'On': '',\n", " 'Values': [{'Get': 'Steps.Genome-Survival-Prediction-Training.ModelArtifacts.S3ModelArtifacts'}]}},\n", " 'dependencies': 'null',\n", " 'source_dir': '\"src\"',\n", " 'sagemaker_submit_directory': '\"s3://sagemaker-us-east-1-421720712360/Genome-Survival-Prediction-Model-Registration-RepackModel-0-f26f9dd2efcc5b30943ede8fa8fc1b51/source/sourcedir.tar.gz\"',\n", " 'sagemaker_program': '\"_repack_model.py\"',\n", " 'sagemaker_container_log_level': '20',\n", " 'sagemaker_region': '\"us-east-1\"'},\n", " 'DebugHookConfig': {'S3OutputPath': 's3://sagemaker-us-east-1-421720712360/pytorch-inference-2022-12-16-14-16-49-980',\n", " 'CollectionConfigurations': []}},\n", " 'Description': 'Used to repack a model with customer scripts for a register/create model step'},\n", " {'Name': 'Genome-Survival-Prediction-Model-Registration-RegisterModel',\n", " 'Type': 'RegisterModel',\n", " 'Arguments': {'ModelPackageGroupName': 'Genome-Survival-Prediction-Model-Package-Group',\n", " 'InferenceSpecification': {'Containers': [{'Image': '763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference:1.12.0-cpu-py38',\n", " 'Environment': {'SAGEMAKER_PROGRAM': 'inference.py',\n", " 'SAGEMAKER_SUBMIT_DIRECTORY': '/opt/ml/model/code',\n", " 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20',\n", " 'SAGEMAKER_REGION': 'us-east-1'},\n", " 'ModelDataUrl': {'Get': 'Steps.Genome-Survival-Prediction-Model-Registration-RepackModel-0.ModelArtifacts.S3ModelArtifacts'},\n", " 'Framework': 'PYTORCH',\n", " 'FrameworkVersion': '1.12.0'}],\n", " 'SupportedContentTypes': ['application/json'],\n", " 'SupportedResponseMIMETypes': ['application/json'],\n", " 'SupportedRealtimeInferenceInstanceTypes': ['ml.t2.medium',\n", " 'ml.m5.xlarge'],\n", " 'SupportedTransformInstanceTypes': ['ml.m5.xlarge']},\n", " 'ModelApprovalStatus': 'Approved'}},\n", " {'Name': 'Genome-Survival-Prediction-MME-Deployment',\n", " 'Type': 'Processing',\n", " 'Arguments': {'ProcessingResources': {'ClusterConfig': {'InstanceType': 'ml.m5.xlarge',\n", " 'InstanceCount': 1,\n", " 'VolumeSizeInGB': 30}},\n", " 'AppSpecification': {'ImageUri': '683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.23-1-cpu-py3',\n", " 'ContainerEntrypoint': ['python3',\n", " '/opt/ml/processing/input/code/mme_deployment.py']},\n", " 'RoleArn': 'arn:aws:iam::421720712360:role/mod-ad69b4757be44dd4-SageMakerExecutionRole-1HU619OTJ7IXI',\n", " 'ProcessingInputs': [{'InputName': 'code',\n", " 'AppManaged': False,\n", " 'S3Input': {'S3Uri': 's3://sagemaker-us-east-1-421720712360/Genome-Survival-Prediction-MME-Deployment-feacb172574b936d55a6ffe6decc6641/input/code/mme_deployment.py',\n", " 'LocalPath': '/opt/ml/processing/input/code',\n", " 'S3DataType': 'S3Prefix',\n", " 'S3InputMode': 'File',\n", " 'S3DataDistributionType': 'FullyReplicated',\n", " 'S3CompressionType': 'None'}}],\n", " 'ProcessingOutputConfig': {'Outputs': [{'OutputName': 'mme_model_location',\n", " 'AppManaged': False,\n", " 'S3Output': {'S3Uri': 's3://sagemaker-us-east-1-421720712360/Genome-Survival-Prediction-MME-Deployment-feacb172574b936d55a6ffe6decc6641/output/mme_model_location',\n", " 'LocalPath': '/opt/ml/processing/model/mme',\n", " 'S3UploadMode': 'EndOfJob'}}]},\n", " 'Environment': {'modelPackageArn': {'Get': 'Steps.Genome-Survival-Prediction-Model-Registration-RegisterModel.ModelPackageArn'},\n", " 'mmeModelLocation': {'Get': 'Parameters.MMEModelsLocation'},\n", " 'genomeGroup': {'Get': 'Parameters.genomeGroup'},\n", " 'AWS_DEFAULT_REGION': 'us-east-1'}}}]}}]}" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import json\n", "\n", "definition = json.loads(pipeline.definition())\n", "definition" ] }, { "cell_type": "code", "execution_count": 46, "id": "29fcb44d-ea5d-465e-b1ab-b44cd6024b97", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The input argument instance_type of function (sagemaker.image_uris.retrieve) is a pipeline variable (), which is not allowed. The default_value of this Parameter object will be used to override it.\n" ] }, { "data": { "text/plain": [ "{'PipelineArn': 'arn:aws:sagemaker:us-east-1:421720712360:pipeline/genome-survival-prediction-pipeline',\n", " 'ResponseMetadata': {'RequestId': 'd46e69e6-214a-4784-983b-888c2f5a88e3',\n", " 'HTTPStatusCode': 200,\n", " 'HTTPHeaders': {'x-amzn-requestid': 'd46e69e6-214a-4784-983b-888c2f5a88e3',\n", " 'content-type': 'application/x-amz-json-1.1',\n", " 'content-length': '103',\n", " 'date': 'Fri, 16 Dec 2022 15:30:24 GMT'},\n", " 'RetryAttempts': 0}}" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipeline.upsert(role_arn=role)" ] }, { "cell_type": "markdown", "id": "c922e1c4-ab01-40a1-a5da-5f4b02e09a61", "metadata": {}, "source": [ "If you are using SageMaker Studio, you can visualize what each step of the pipeline actually looks like:\n", "\n", "![](images/image_3.jpg)" ] }, { "cell_type": "markdown", "id": "fc5254db-1638-43ee-bfb6-a00e7e620673", "metadata": {}, "source": [ "### Start the pipeline with all the Gene groups." ] }, { "cell_type": "code", "execution_count": 47, "id": "3bc37055-f5fc-4f6f-b8b2-87a3026aa4db", "metadata": { "tags": [] }, "outputs": [], "source": [ "execution = pipeline.start({\n", " \"genomeGroup\" : \"ALL\"\n", " }\n", ")" ] }, { "cell_type": "markdown", "id": "8e981eef-04c9-4187-add2-e527486138b8", "metadata": {}, "source": [ "### Pipeline Operations: Examining and Waiting for Pipeline Execution\n", "\n", "Describe the pipeline execution" ] }, { "cell_type": "code", "execution_count": 48, "id": "62ee313b-5b0b-4948-905d-d5d4d8acf4f9", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'PipelineArn': 'arn:aws:sagemaker:us-east-1:421720712360:pipeline/genome-survival-prediction-pipeline',\n", " 'PipelineExecutionArn': 'arn:aws:sagemaker:us-east-1:421720712360:pipeline/genome-survival-prediction-pipeline/execution/lxs3vcvk7d4r',\n", " 'PipelineExecutionDisplayName': 'execution-1671204625908',\n", " 'PipelineExecutionStatus': 'Executing',\n", " 'PipelineExperimentConfig': {'ExperimentName': 'genome-survival-prediction-pipeline',\n", " 'TrialName': 'lxs3vcvk7d4r'},\n", " 'CreationTime': datetime.datetime(2022, 12, 16, 15, 30, 25, 824000, tzinfo=tzlocal()),\n", " 'LastModifiedTime': datetime.datetime(2022, 12, 16, 15, 30, 25, 824000, tzinfo=tzlocal()),\n", " 'CreatedBy': {'UserProfileArn': 'arn:aws:sagemaker:us-east-1:421720712360:user-profile/d-6dy9c2r2izfc/sagemakeruser',\n", " 'UserProfileName': 'sagemakeruser',\n", " 'DomainId': 'd-6dy9c2r2izfc'},\n", " 'LastModifiedBy': {'UserProfileArn': 'arn:aws:sagemaker:us-east-1:421720712360:user-profile/d-6dy9c2r2izfc/sagemakeruser',\n", " 'UserProfileName': 'sagemakeruser',\n", " 'DomainId': 'd-6dy9c2r2izfc'},\n", " 'ResponseMetadata': {'RequestId': '87d3f008-c361-421c-91a5-a0b79cc13584',\n", " 'HTTPStatusCode': 200,\n", " 'HTTPHeaders': {'x-amzn-requestid': '87d3f008-c361-421c-91a5-a0b79cc13584',\n", " 'content-type': 'application/x-amz-json-1.1',\n", " 'content-length': '872',\n", " 'date': 'Fri, 16 Dec 2022 15:30:26 GMT'},\n", " 'RetryAttempts': 0}}" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "execution.describe()" ] }, { "cell_type": "markdown", "id": "093fe1b2-e14e-4528-89e1-e0450c29f73c", "metadata": {}, "source": [ "Wait for the execution to complete.\n" ] }, { "cell_type": "code", "execution_count": 49, "id": "488a4eaa-6249-4e02-8ed4-cc44f51816f9", "metadata": { "tags": [] }, "outputs": [], "source": [ "execution.wait()" ] }, { "cell_type": "markdown", "id": "eee962c0-06e5-4af0-86fb-fd5179bcc24a", "metadata": {}, "source": [ "### Verify how many models deploye on MME" ] }, { "cell_type": "code", "execution_count": 50, "id": "a0b06a8b-2ec5-4d23-bd0e-f4084bc367fc", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "['/model-ALL.tar.gz']" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "list(mme.list_models())" ] }, { "cell_type": "markdown", "id": "c031170a-3325-4679-8ecb-29b6c3c2b024", "metadata": {}, "source": [ "* We can see there is model suffixed with 'ALL' already in the MME location. Let's do some predictions with the test dataset. " ] }, { "cell_type": "markdown", "id": "0b901590-f32b-47ff-ad2f-d6e61d8b51e3", "metadata": {}, "source": [ "### Predict with trained models using test data\n" ] }, { "cell_type": "code", "execution_count": 51, "id": "825d7171-45c7-45e5-b924-08a001e1a2b5", "metadata": { "tags": [] }, "outputs": [], "source": [ "predictor = Predictor(endpoint_name = endpoint_name)\n", "\n", "predictor.serializer = sagemaker.serializers.JSONSerializer()\n", "predictor.deserializer = sagemaker.deserializers.CSVDeserializer()" ] }, { "cell_type": "code", "execution_count": 52, "id": "fc401ebe-fbdf-4ba6-9117-429ed1052be6", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "[['0'],\n", " ['1'],\n", " ['1'],\n", " ['0'],\n", " ['0'],\n", " ['0'],\n", " ['1'],\n", " ['1'],\n", " ['0'],\n", " ['1'],\n", " ['0'],\n", " ['1'],\n", " ['0'],\n", " ['1'],\n", " ['0'],\n", " ['0'],\n", " ['0'],\n", " ['1'],\n", " ['0'],\n", " ['0'],\n", " ['1'],\n", " ['0'],\n", " ['0'],\n", " ['0']]" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "source": [ "payload = {\n", " \"inputs\" : X_val.iloc[:, 1:].values\n", "}\n", "\n", "predictor.predict(payload, target_model=\"/model-ALL.tar.gz\")" ] }, { "cell_type": "markdown", "id": "c2811b58-f6f0-4a90-a5e9-fc7d67ddf40a", "metadata": {}, "source": [ "### Next lets start training model with the \"metagene_19\" Gene group" ] }, { "cell_type": "code", "execution_count": 53, "id": "60f9b528-8a38-4fcf-810d-725e0934448e", "metadata": { "tags": [] }, "outputs": [], "source": [ "execution = pipeline.start(\n", " parameters=dict(\n", " genomeGroup=\"metagene_19\"\n", " )\n", ")" ] }, { "cell_type": "code", "execution_count": 54, "id": "f9c6559e-d8b0-4cdf-97a2-a491fd9422d3", "metadata": { "tags": [] }, "outputs": [], "source": [ "execution.wait()" ] }, { "cell_type": "markdown", "id": "7e71b214-c9be-4558-a9b9-8ebc77419c82", "metadata": {}, "source": [ "### Verify how many models deploye on MME" ] }, { "cell_type": "code", "execution_count": 55, "id": "b4366722-06c0-4583-bce2-4bf33225baab", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "['/model-ALL.tar.gz', '/model-metagene_19.tar.gz']" ] }, "execution_count": 55, "metadata": {}, "output_type": "execute_result" } ], "source": [ "list(mme.list_models())" ] }, { "cell_type": "markdown", "id": "27c45470-b4a4-47d4-a449-7457b67885c4", "metadata": {}, "source": [ "We can see there is a new model suffixed with 'metagene_19' in the MME location. Let's do some predictions with the test dataset. " ] }, { "cell_type": "code", "execution_count": 56, "id": "5890cdf7-e859-429d-a3a2-c0cada84c269", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'inputs': array([[0.07959922, 0.03087836, 0.19559475],\n", " [0.05884028, 0.0156981 , 0.30814739],\n", " [0. , 0.00898249, 0.05668478],\n", " [0.13264582, 0.05260354, 0.08079599],\n", " [0.10139428, 0.24103913, 0.08870376]])}" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "payload = {\n", " \"inputs\" : X_val[['LRIG1', 'HPGD', 'GDF15']].iloc[0:5, :].values\n", "}\n", "payload" ] }, { "cell_type": "code", "execution_count": 57, "id": "3f5ea007-e3fb-43f3-b040-0a2c755a4693", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "[['0'], ['1'], ['0'], ['0'], ['0']]" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.predict(payload, target_model=\"/model-metagene_19.tar.gz\")" ] }, { "cell_type": "markdown", "id": "6febc0d7-65dd-440a-9016-18781d9af25c", "metadata": {}, "source": [ "## Clean up\n", "\n", "Once you are completed the work with the notebook, please delete the endpoint by uncommenting the following code." ] }, { "cell_type": "code", "execution_count": 58, "id": "331c04d8-39da-4734-94f2-bad3579f49e6", "metadata": { "tags": [] }, "outputs": [], "source": [ "#predictor.delete_endpoint()" ] } ], "metadata": { "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 5 }