{ "cells": [ { "cell_type": "markdown", "id": "9e9ea171-e1c5-4877-a71d-4adeac97b89e", "metadata": {}, "source": [ "# Build and Deploy Many Models Leveraging Cancer Gene Expression Data With SageMaker Pipelines and SageMaker Multi-Model Endpoints\n", "\n", "When building machine learning models that leverage genomic data, a key problem is how to allow users to select which features should be used when querying models. To address this,data scientists will sometimes build multiple models to handle specific sub-problems within the dataset. In the context of survival analysis for cancer, a common approach is to analyze gene signatures, and to predict the survival of patients based on the gene expression signatures. See [here](https://www.nature.com/articles/s41598-021-84787-5) for a an example of such an approach in the context of a number of different cancer types. See also [this](https://pubmed.ncbi.nlm.nih.gov/31296308/) review, which discusses different techniques to perform survival analysis.\n", "\n", "A problem that may occur is that, should an application require publishing models based on many hundreds or thousands of gene signatures, managing and deploying all such models may become difficult to maintain and thus unweildly. In this blog post, we show how you can leverage SageMaker Pipelines and SageMaker MultiModel Endpoints to build and deploy many such models. \n", "\n", "To give a specific example, we will leverage the sample cancer RNA expression dataset discussed in the paper [Non-Small Cell Lung Cancer Radiogenomics Map Identifies Relationships between Molecular and Imaging Phenotypes with Prognostic Implications](https://pubmed.ncbi.nlm.nih.gov/28727543/). To simpify the use case, we will focus on 21 co-expressed groups that have been found in this paper to be clicially significant in NSCLC (see that paper, Table 2). These groups of genes, which the authors term metagenes, are annotated in different cellcular pathways. For example, the first group of genes LRIG1, HPGD and GDF15 are relate to the EGFR signaling pathay, while CIM,LMO2 and EFR2 all are involved in cell hypoxia/inflaation. Thus, each cancer patient (row) has gene expression values (columns). In addtion, each of the 199 patients is annoted by their survival status; each described by their Survival Status (1 for deceased; 0 for alive at time of collection of the dataset. We followed the preprocessing [this blog post](https://aws.amazon.com/blogs/industries/building-scalable-machine-learning-pipelines-for-multimodal-health-data-on-aws/) for preprocessing the data. As described more fully in that blog post, the final dataset is 119 patients where each cancer patient (row) has gene expression values (columns). If you run the pipeline described in that blog post, you will get the entire gene expression profile based on the raw FASTQ files, or you can also access the entire gene expression at [GEO](https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE103584). \n", "\n", "The architecture for this approach is as follows:\n", "\n", "\n", "\n", "As can be seen in the diagram, we first start with data that is located in S3. We then create a [SageMaker Pipeline](https://sagemaker-examples.readthedocs.io/en/latest/sagemaker-pipelines/index.html). SageMaker Pipelines is a powerful feature that allows data scientists to wrap different components of their workload as a pipeline. This allows for a deployment strategy whereby each step of the analysis is automatically kicked off after the previous job finishes. See the associate code repository ?? for the specific syntax for creating a SageMaker Pipeline.\n", "The pipeline consists of:\n", "\n", "* A SageMaker Processing job for preprocessing the data\n", "\n", "* A SageMaker Training job for training the model. \n", "\n", "* A SageMaker Processing job for evaluating and registering the model in SageMaker Model Registry.\n", "\n", "* A seperate SageMaker Processing job for deploying the model on SageMaker Multi Model Endpoint (MME)\n", "\n", "\n", "\n", "\n", "Before we begin lets verify SageMaker version" ] }, { "cell_type": "code", "execution_count": 3, "id": "cdf9dfab-fc7f-438c-9212-49d299274320", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'2.109.0'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import sagemaker\n", "sagemaker.__version__" ] }, { "cell_type": "code", "execution_count": 2, "id": "445a345a-cf12-47e9-9ffe-9490309ad20a", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Keyring is skipped due to an exception: 'keyring.backends'\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "pytest-astropy 0.8.0 requires pytest-cov>=2.0, which is not installed.\n", "pytest-astropy 0.8.0 requires pytest-filter-subpackage>=0.1, which is not installed.\n", "docker-compose 1.29.2 requires PyYAML<6,>=3.10, but you have pyyaml 6.0 which is incompatible.\n", "aiobotocore 2.4.1 requires botocore<1.27.60,>=1.27.59, but you have botocore 1.29.24 which is incompatible.\u001b[0m\u001b[31m\n", "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "%pip install --upgrade --quiet sagemaker==2.109.0" ] }, { "cell_type": "markdown", "id": "f97fd49d-5a55-4246-a7b5-edaa09b357e5", "metadata": {}, "source": [ "* Please restart the kernel after the sagemaker update. You can do that by following the options on the menu Kernel->Restart Kernel.\n", "* After restarting execute the from below. Make sure that the version of the sagemaker is updated '>=2.94.0'." ] }, { "cell_type": "code", "execution_count": 3, "id": "22d4d10f-725b-4b3a-b52c-900ed7a2c8bb", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "'2.109.0'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import sagemaker\n", "sagemaker.__version__" ] }, { "cell_type": "markdown", "id": "a2e8125b-bd7d-4be7-877b-58c8e5445849", "metadata": {}, "source": [ "Then let's import rest of the packages needed." ] }, { "cell_type": "code", "execution_count": 4, "id": "1925c3c0-3193-4536-8e8d-f51e2d80e2cd", "metadata": { "tags": [] }, "outputs": [], "source": [ "import time\n", "import pandas as pd\n", "import numpy as np\n", "from sklearn.preprocessing import StandardScaler, MinMaxScaler\n", "from sklearn.model_selection import train_test_split\n", "from sagemaker import get_execution_role\n", "\n", "from sagemaker.multidatamodel import MultiDataModel\n", "\n", "from sagemaker.pytorch import PyTorch\n", "from sagemaker.pytorch.model import PyTorchModel\n", "\n", "from sagemaker.workflow.pipeline_context import PipelineSession\n", "from sagemaker.workflow.fail_step import FailStep\n", "from sagemaker.workflow.functions import Join\n", "from sagemaker.model_metrics import MetricsSource, ModelMetrics\n", "from sagemaker.workflow.functions import Join\n", "from sagemaker.workflow.model_step import ModelStep\n", "from sagemaker.workflow.conditions import ConditionLessThanOrEqualTo\n", "from sagemaker.workflow.condition_step import ConditionStep\n", "from sagemaker.workflow.functions import JsonGet\n", "\n", "from sagemaker.predictor import Predictor\n", "\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "id": "ce9b88dd-56db-4c23-baa3-6ae83fc8aa90", "metadata": {}, "source": [ "### Read the data \n", "\n", "Data related to the project is available in the `data` folder. Lets read the and do some exploratory analysis of it and basic pre-processing." ] }, { "cell_type": "code", "execution_count": 5, "id": "0cf392c9-76d1-4db1-ad4a-2c85083c6c48", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", " | Case_ID | \n", "LRIG1 | \n", "HPGD | \n", "GDF15 | \n", "CDH2 | \n", "POSTN | \n", "VCAN | \n", "PDGFRA | \n", "VCAM1 | \n", "CD44 | \n", "... | \n", "CD37 | \n", "VIM | \n", "LMO2 | \n", "EGR2 | \n", "BGN | \n", "COL4A1 | \n", "COL5A1 | \n", "COL5A2 | \n", "SurvivalStatus | \n", "PathologicalMstage | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "R01-005 | \n", "25.438505 | \n", "2.434890 | \n", "4.83280 | \n", "0.000000 | \n", "14.80001 | \n", "13.816834 | \n", "3.791663 | \n", "3.903575 | \n", "44.62658 | \n", "... | \n", "1.590920 | \n", "70.972300 | \n", "0.00000 | \n", "0.000000 | \n", "6.22329 | \n", "3.268430 | \n", "9.394910 | \n", "11.402356 | \n", "1 | \n", "0 | \n", "
1 | \n", "R01-012 | \n", "23.192370 | \n", "8.997890 | \n", "43.86000 | \n", "2.576560 | \n", "57.67060 | \n", "22.893300 | \n", "6.526380 | \n", "8.768820 | \n", "32.54150 | \n", "... | \n", "6.934750 | \n", "158.976000 | \n", "3.14894 | \n", "2.828400 | \n", "28.85350 | \n", "22.221980 | \n", "12.066000 | \n", "22.500950 | \n", "0 | \n", "0 | \n", "
2 | \n", "R01-013 | \n", "77.553001 | \n", "190.529190 | \n", "7.76279 | \n", "0.000000 | \n", "11.79248 | \n", "26.137328 | \n", "6.983622 | \n", "3.591960 | \n", "152.81587 | \n", "... | \n", "5.048800 | \n", "115.247000 | \n", "2.26644 | \n", "3.317050 | \n", "5.77029 | \n", "5.490720 | \n", "7.749760 | \n", "3.532330 | \n", "0 | \n", "0 | \n", "
3 | \n", "R01-014 | \n", "63.164522 | \n", "21.455770 | \n", "16.58980 | \n", "0.000000 | \n", "20.94288 | \n", "7.123850 | \n", "18.156200 | \n", "2.076650 | \n", "57.05985 | \n", "... | \n", "4.651934 | \n", "101.178000 | \n", "1.32881 | \n", "1.462162 | \n", "21.01470 | \n", "6.839590 | \n", "2.572220 | \n", "36.834306 | \n", "0 | \n", "0 | \n", "
4 | \n", "R01-017 | \n", "82.543949 | \n", "66.726003 | \n", "15.27520 | \n", "1.817656 | \n", "17.53631 | \n", "5.260631 | \n", "7.652980 | \n", "2.341880 | \n", "227.60731 | \n", "... | \n", "2.631933 | \n", "206.286100 | \n", "1.39513 | \n", "1.568070 | \n", "21.19790 | \n", "16.872651 | \n", "4.875322 | \n", "9.060370 | \n", "1 | \n", "0 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
114 | \n", "R01-156 | \n", "23.141790 | \n", "8.542742 | \n", "0.00000 | \n", "0.684995 | \n", "119.58400 | \n", "25.027800 | \n", "3.129100 | \n", "2.322870 | \n", "143.23510 | \n", "... | \n", "4.717760 | \n", "58.163200 | \n", "2.62162 | \n", "0.976505 | \n", "27.42970 | \n", "29.316800 | \n", "35.525100 | \n", "31.737300 | \n", "1 | \n", "0 | \n", "
115 | \n", "R01-157 | \n", "6.565445 | \n", "0.829140 | \n", "4.94241 | \n", "0.000000 | \n", "13.29370 | \n", "26.306770 | \n", "3.463640 | \n", "4.649160 | \n", "187.18417 | \n", "... | \n", "1.333220 | \n", "49.906300 | \n", "1.87628 | \n", "1.833738 | \n", "15.31520 | \n", "17.690250 | \n", "21.050000 | \n", "21.977010 | \n", "0 | \n", "0 | \n", "
116 | \n", "R01-158 | \n", "26.088220 | \n", "21.237700 | \n", "4.81202 | \n", "0.000000 | \n", "16.14340 | \n", "10.719810 | \n", "4.399728 | \n", "1.937700 | \n", "126.17160 | \n", "... | \n", "1.286370 | \n", "57.849300 | \n", "1.97280 | \n", "0.659567 | \n", "7.73836 | \n", "29.922400 | \n", "7.174080 | \n", "9.495350 | \n", "0 | \n", "0 | \n", "
117 | \n", "R01-159 | \n", "20.813240 | \n", "3.629507 | \n", "0.00000 | \n", "0.000000 | \n", "19.68492 | \n", "24.987579 | \n", "2.346060 | \n", "9.746220 | \n", "371.70458 | \n", "... | \n", "22.198750 | \n", "102.767831 | \n", "3.39946 | \n", "2.585660 | \n", "8.66216 | \n", "85.256870 | \n", "6.234190 | \n", "15.785505 | \n", "0 | \n", "0 | \n", "
118 | \n", "R01-160 | \n", "18.136591 | \n", "2.894350 | \n", "0.00000 | \n", "0.000000 | \n", "35.57201 | \n", "29.780610 | \n", "6.185470 | \n", "3.071740 | \n", "108.12872 | \n", "... | \n", "1.865994 | \n", "74.261921 | \n", "0.00000 | \n", "2.334370 | \n", "25.06410 | \n", "23.004220 | \n", "10.038100 | \n", "26.491740 | \n", "0 | \n", "0 | \n", "
119 rows × 24 columns
\n", "\n", " | LRIG1 | \n", "HPGD | \n", "GDF15 | \n", "CDH2 | \n", "POSTN | \n", "VCAN | \n", "PDGFRA | \n", "VCAM1 | \n", "CD44 | \n", "CD48 | \n", "... | \n", "LYL1 | \n", "SPI1 | \n", "CD37 | \n", "VIM | \n", "LMO2 | \n", "EGR2 | \n", "BGN | \n", "COL4A1 | \n", "COL5A1 | \n", "COL5A2 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "0.186033 | \n", "0.011379 | \n", "0.058076 | \n", "0.000000 | \n", "0.044392 | \n", "0.043852 | \n", "0.031756 | \n", "0.077312 | \n", "0.096727 | \n", "0.000000 | \n", "... | \n", "0.000000 | \n", "0.041144 | \n", "0.071667 | \n", "0.161186 | \n", "0.000000 | \n", "0.000000 | \n", "0.036592 | \n", "0.000000 | \n", "0.037805 | \n", "0.048075 | \n", "
1 | \n", "0.168126 | \n", "0.042051 | \n", "0.527072 | \n", "0.014627 | \n", "0.188186 | \n", "0.080102 | \n", "0.058723 | \n", "0.173671 | \n", "0.068072 | \n", "0.208182 | \n", "... | \n", "0.111333 | \n", "0.264893 | \n", "0.312394 | \n", "0.384688 | \n", "0.132416 | \n", "0.254566 | \n", "0.260631 | \n", "0.098762 | \n", "0.050754 | \n", "0.099377 | \n", "
2 | \n", "0.601524 | \n", "0.890415 | \n", "0.093287 | \n", "0.000000 | \n", "0.034304 | \n", "0.093057 | \n", "0.063232 | \n", "0.071140 | \n", "0.353251 | \n", "0.058441 | \n", "... | \n", "0.061315 | \n", "0.425206 | \n", "0.227436 | \n", "0.273630 | \n", "0.095306 | \n", "0.298546 | \n", "0.032107 | \n", "0.011580 | \n", "0.029830 | \n", "0.011697 | \n", "
3 | \n", "0.486809 | \n", "0.100271 | \n", "0.199362 | \n", "0.000000 | \n", "0.064996 | \n", "0.017122 | \n", "0.173404 | \n", "0.041129 | \n", "0.126207 | \n", "0.089822 | \n", "... | \n", "0.064753 | \n", "0.177943 | \n", "0.209558 | \n", "0.237899 | \n", "0.055878 | \n", "0.131600 | \n", "0.183027 | \n", "0.018608 | \n", "0.004730 | \n", "0.165631 | \n", "
4 | \n", "0.641315 | \n", "0.311836 | \n", "0.183564 | \n", "0.010318 | \n", "0.053570 | \n", "0.009681 | \n", "0.069832 | \n", "0.046382 | \n", "0.530587 | \n", "0.075721 | \n", "... | \n", "0.000000 | \n", "0.109945 | \n", "0.118562 | \n", "0.504841 | \n", "0.058666 | \n", "0.141132 | \n", "0.184841 | \n", "0.070888 | \n", "0.015895 | \n", "0.037249 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
114 | \n", "0.167722 | \n", "0.039923 | \n", "0.000000 | \n", "0.003889 | \n", "0.395853 | \n", "0.088626 | \n", "0.025222 | \n", "0.046006 | \n", "0.330534 | \n", "0.023121 | \n", "... | \n", "0.196905 | \n", "0.061541 | \n", "0.212524 | \n", "0.128655 | \n", "0.110241 | \n", "0.087889 | \n", "0.246535 | \n", "0.135732 | \n", "0.164479 | \n", "0.142071 | \n", "
115 | \n", "0.035565 | \n", "0.003875 | \n", "0.059394 | \n", "0.000000 | \n", "0.039339 | \n", "0.093734 | \n", "0.028521 | \n", "0.092079 | \n", "0.434741 | \n", "0.036051 | \n", "... | \n", "0.088405 | \n", "0.078723 | \n", "0.060058 | \n", "0.107685 | \n", "0.078899 | \n", "0.165043 | \n", "0.126602 | \n", "0.075148 | \n", "0.094307 | \n", "0.096955 | \n", "
116 | \n", "0.191213 | \n", "0.099252 | \n", "0.057827 | \n", "0.000000 | \n", "0.048898 | \n", "0.031484 | \n", "0.037752 | \n", "0.038377 | \n", "0.290076 | \n", "0.026981 | \n", "... | \n", "0.208866 | \n", "0.090146 | \n", "0.057948 | \n", "0.127858 | \n", "0.082958 | \n", "0.059363 | \n", "0.051591 | \n", "0.138887 | \n", "0.027039 | \n", "0.039260 | \n", "
117 | \n", "0.149158 | \n", "0.016962 | \n", "0.000000 | \n", "0.000000 | \n", "0.060777 | \n", "0.088466 | \n", "0.017501 | \n", "0.193029 | \n", "0.872251 | \n", "0.424727 | \n", "... | \n", "0.343994 | \n", "0.162622 | \n", "1.000000 | \n", "0.241937 | \n", "0.142950 | \n", "0.232718 | \n", "0.060737 | \n", "0.427221 | \n", "0.022482 | \n", "0.068335 | \n", "
118 | \n", "0.127818 | \n", "0.013526 | \n", "0.000000 | \n", "0.000000 | \n", "0.114064 | \n", "0.107608 | \n", "0.055361 | \n", "0.060837 | \n", "0.247295 | \n", "0.024145 | \n", "... | \n", "0.150692 | \n", "0.158571 | \n", "0.084059 | \n", "0.169541 | \n", "0.000000 | \n", "0.210101 | \n", "0.223116 | \n", "0.102838 | \n", "0.040923 | \n", "0.117824 | \n", "
119 rows × 21 columns
\n", "