{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# HIV Inhibitor prediction using GNN on Amazon SageMaker\n", "\n", "## Motivation :\n", "\n", "The human immunodeficiency virus type 1 (HIV-1) is the primary cause of the acquired immunodeficiency syndrome (AIDS), which is a slow, progressive and degenerative disease of the human immune system. The pathogenesis of HIV-1 is complex and characterized by the interplay of both viral and host factors. An intense global research effort into understanding the individual steps of the viral replication cycle and the dynamics during an infection has inspired researchers in the development of a wide spectrum of antiviral strategies.\n", "\n", "![Inhibitor](img/1.jpg)\n", "
\n", " \n", "Image Source : Biological evaluation of molecules of the azaBINOL class as antiviral agents (https://www.sciencedirect.com/science/article/abs/pii/S0968089619306704)\n", " \n", "
\n", "\n", "
\n", "
\n", "
\n", "\n", "The purpose of this effort here is to predict the target molecular properties as accurately as possible, where the molecular properties are cast as binary labels, e.g, whether a molecule inhibits HIV virus replication or not, using Graph machine learning techniques which can leveraged as a virtual screening step.\n", "\n", "\n", "### DGL :\n", "\n", "Deep Graph Library (DGL) is a Python package built for easy implementation of graph neural network model family, on top of existing DL frameworks (currently supporting PyTorch, MXNet and TensorFlow). It offers a versatile control of message passing, speed optimization via auto-batching and highly tuned sparse matrix kernels, and multi-GPU/CPU training to scale to graphs of hundreds of millions of nodes and edges.\n", "\n", "https://www.dgl.ai/\n", "\n", "\n", "### DGL Life Science : \n", "\n", "DGL Life Sciences is an abstraction library built on top of DGL with the aim of bringing Graph Neural Networks to Chemistry and Biology.\n", "It is a python package for applying graph neural networks to various tasks in chemistry and biology, on top of PyTorch and DGL. \n", "\n", "It provides:\n", "\n", "* Various utilities for data processing, training and evaluation.\n", "\n", "* Efficient and flexible model implementations.\n", "\n", "* Pre-trained models for use without training from scratch.\n", "\n", "\n", "## Notebook Overview\n", "\n", "This example notebook focuses on training multiple Graph neural network models using Deep Graph Librar and deploying it using Amazon SageMaker, which is a comprehensive and fully managed machine learning service. With SageMaker, data scientists and developers can quickly and easily build and train machine learning models, and then directly deploy them into a production-ready hosted environment. \n", "\n", "**Note:** If you are using SageMaker Studio, please select the Kernel as Python 3 (Pytorch 1.8 Python 3.6 CPU Optimized); if you are using SageMaker Notebook instance, please set the Kernel as conda_pytorch_p36.\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### setup" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's start by installing dgl libraries latest sagemaker version and importinn some Python libraries." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "%pip install --quiet dgl\n", "%pip install --quiet dgllife\n", "%pip install -U --quiet sagemaker==\"2.75.1\"\n", "%pip install --quiet rdkit-pypi\n", "%pip install -U --quiet torch" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from functools import partial\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "import pandas as pd\n", "import datetime\n", "\n", "import dgl\n", "\n", "from dgllife.model import load_pretrained\n", "from dgllife.utils import smiles_to_bigraph, EarlyStopping, Meter, CanonicalAtomFeaturizer, CanonicalBondFeaturizer\n", "from functools import partial\n", "from torch.optim import Adam\n", "from torch.utils.data import DataLoader\n", "\n", "from dgllife.data import HIV\n", "\n", "import rdkit\n", "from rdkit import Chem\n", "from rdkit.Chem import Draw" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "rdkit.__version__" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "node_featurizer = CanonicalAtomFeaturizer(atom_data_field='feat')\n", "edge_featurizer = None#CanonicalBondFeaturizer(bond_data_field='feat1')\n", "num_workers = 1\n", "split_ratio = \"0.7:0.2:0.1\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The HIV dataset was introduced by the Drug Therapeutics Program (DTP) AIDS Antiviral Screen, which tested the ability to inhibit HIV replication for over 40,000 compounds. Screening results were evaluated and placed into three categories: confirmed inactive (CI),confirmed active (CA) and confirmed moderately active (CM). We further combine the latter two labels, making it a classification task between inactive (CI) and active (CA and CM).\n", "\n", "While the row dataset can be downloaded from here (https://moleculenet.org/datasets-1) , DGL-LifeSci library provides an highlevel interface to download the dataset as part of it's Datasets package (https://lifesci.dgl.ai/api/data.html#hiv). " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "dataset = HIV(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),\n", " node_featurizer=node_featurizer,\n", " edge_featurizer=edge_featurizer,\n", " n_jobs=num_workers)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "type(dataset)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Explore the dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset.df.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset.df.head(15)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Dataset containes ~41K molecules in `smiles` format.\n", "* The `HIV_active` column (label) indicates that the molecule is a suitable HIV inhibitor or not. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, let's verify if there are any missing values in the dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset.df.info()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset.df.isnull().values.any()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, there are `no` missing values in this dataset. Then let's explore the class distribution of the dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset.df['HIV_active'].value_counts()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset.df['HIV_active'].value_counts().plot(kind='bar')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Dataset is heavily imbalaced as there are only ~3% compounts were screened as HIV inhibtiors." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Explore molecules (with RDKit)\n", "\n", "The RDKit is an open-source cheminformatics toolkit. It includes a collection of standard cheminformatics functionality for molecule I/O, substructure searching, chemical reactions, coordinate generation (2D or 3D), fingerprinting, etc.\n", "\n", "We are going to use this library to explore the molecules (presented in smiles format) in the dataset. Following are few ramdomly selected molecules from the dataset visualized using RDKit\n", "\n", "https://www.rdkit.org/docs/cppapi/index.html" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "random_molecule_start_index = 10\n", "\n", "sample_smiles = dataset.df['smiles'][random_molecule_start_index:random_molecule_start_index + 20].values\n", "sample_molecules = [Chem.MolFromSmiles(smile) for smile in sample_smiles]\n", "Draw.MolsToGridImage(sample_molecules, molsPerRow=4, subImgSize=(600, 600))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Let's look at a single molecule and explore the properties." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sample_smiles[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mol = sample_molecules[0]\n", "mol" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Using the methods here we can further explore the features of the molecules.\n", "https://www.rdkit.org/docs/cppapi/classRDKit_1_1Atom.html" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "atoms = mol.GetAtoms()\n", "print(\"Total number of atoms in the molecule : {}\".format(len(atoms)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "molecule_features = []\n", "for atom in atoms:\n", " atom_features = {}\n", " atom_features['atomic_symbol'] = atom.GetSymbol()\n", " atom_features['atomic_numbers'] = atom.GetAtomicNum()\n", " atom_features['degree'] = atom.GetDegree()\n", " atom_features['formal_charge'] = atom.GetFormalCharge()\n", " atom_features['hybridization'] = atom.GetHybridization()\n", " atom_features['is_aromatic'] = atom.GetIsAromatic()\n", " molecule_features.append(atom_features)\n", "molecule_features" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* We can apply node-fearizer directly into the molecule like below and generate above results." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "atom_featurizer = CanonicalAtomFeaturizer(atom_data_field='feat')\n", "atom_featurizer(mol)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "atom_featurizer.feat_size()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We already have these embedded to our graph since we have used the node featurizer at the top. \n", "\n", "https://lifesci.dgl.ai/generated/dgllife.utils.CanonicalAtomFeaturizer.html\n", "\n", "Let's decode the graph associated to above index" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "smiles, graphs, labels, masks = map(list, zip(*dataset))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "smiles[random_molecule_start_index]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "random_graph = graphs[random_molecule_start_index]\n", "\n", "random_graph.num_nodes()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "random_graph.num_edges()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "random_graph.ndata['feat'].shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "random_graph.ndata['feat'][0, :]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "random_graph.ndata['feat'][1, :]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Order of the atoms represented in the graph is different (This is a not an problem). " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Split the dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "DGL-LifeSci provides interfaces to split your data for train, validation and test sets based on the strategy that you prefer. \n", "More details : https://lifesci.dgl.ai/api/utils.splitters.html\n", "\n", "We are going to use the Scaffold Splitter for this project. \n", "\n", "**ScaffoldSplitter**:\n", "\n", "Group molecules based on their scaffolds and sort groups based on their sizes. The groups are then split for k-fold cross validation.\n", "\n", "Same as usual k-fold splitting methods, each molecule will appear only once in the validation set among all folds. In addition, this method ensures that molecules with a same scaffold will be collectively in either the training set or the validation set for each fol" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Split the dataset\n", "from dgllife.utils import ScaffoldSplitter, RandomSplitter\n", "\n", "train_ratio, val_ratio, test_ratio = map(float, split_ratio.split(':'))\n", "\n", "train_set, val_set, test_set = ScaffoldSplitter.train_val_test_split(\n", " dataset, frac_train=train_ratio, frac_val=val_ratio, frac_test=test_ratio,\n", " scaffold_func='smiles')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Lets check the distribution of classes of train, validation, test datasets after the split." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "len(train_set)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_set.dataset.df.loc[train_set.indices]['HIV_active'].value_counts().plot(kind=\"bar\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "len(val_set)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "val_set.dataset.df.loc[val_set.indices]['HIV_active'].value_counts().plot(kind=\"bar\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "len(test_set)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_set.dataset.df.loc[test_set.indices]['HIV_active'].value_counts().plot(kind=\"bar\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* All three datasets (train, test, validation) follows the same data distrubtions in terms of class balance." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using DGL in Amazon Sagemaker with Pytorch backend." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Set up the environment and create the session\n", "\n", "Here we specify a bucket to use and the role that will be used for working with SageMaker. The session remembers our connection parameters to SageMaker.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sagemaker\n", "from sagemaker import get_execution_role\n", "\n", "role = get_execution_role()\n", "session = sagemaker.Session()\n", "bucket = session.default_bucket()\n", "\n", "s3_prefix = \"./hiv_inhibitor_prediction/sagemaker\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Upload Data to S3\n", "In order to accomodate model training on SageMaker we need to upload the data to s3 location. We are going to use the sagemaker.Session.upload_data function to upload our datasets to an S3 location. The return value inputs identifies the location -- we will use later when we start the training job." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset.df.to_csv(\"full.csv\", index=False)\n", "pd.DataFrame(train_set.indices, columns =[\"indices\"]).to_csv(\"train.csv\", index=False)\n", "pd.DataFrame(val_set.indices, columns =[\"indices\"]).to_csv(\"validation.csv\", index=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "input_full = session.upload_data(\n", " path=\"full.csv\", bucket=bucket, key_prefix=s3_prefix\n", " )\n", "\n", "input_train = session.upload_data(\n", " path=\"train.csv\", bucket=bucket, key_prefix=s3_prefix\n", " )\n", "input_val = session.upload_data(\n", " path=\"validation.csv\", bucket=bucket, key_prefix=s3_prefix\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Explore the model architectures to be used \n", "\n", "We are going to represent each of the modelcule as a graph. Each atom in the molecule would be a node within the graph. Hence the atom properties will be considered as node features after doing some transformations. Using these features we are going to classify the whole graph/molecule whether a it inhibits HIV virus replication or not using Graph neural networks (GNNs). In GNN terms this is considered as Graph classification problem.\n", "\n", "We are going to use above prebuilt model architecures (GCN, GAT) which comes with DGL-LifeSci to train the model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "**1. GCNPredictor** : \n", "\n", "Documentation : https://lifesci.dgl.ai/_modules/dgllife/model/model_zoo/gcn_predictor.html\n", "\n", "Code : https://github.com/awslabs/dgl-lifesci/blob/master/python/dgllife/model/model_zoo/gcn_predictor.py" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from dgllife.model import GCNPredictor\n", "import torch.nn.functional as F\n", "\n", "model = GCNPredictor(\n", " in_feats=10,\n", " hidden_feats=[10, 4],\n", " activation=[F.relu, F.relu],\n", " residual=[False] * 2\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see `GCNPredictor` architecture compised of multple layers of `gnn_layers` which itself comprised of DGL `GraphConv`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training on SageMaker\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Training Script\n", "\n", "We are going to use Pytorch as the DGL backend. Our training script should save model artifacts learned during training to a file path called model_dir, as stipulated by the SageMaker PyTorch image. Upon completion of training, model artifacts saved in model_dir will be uploaded to S3 by SageMaker and will be used for deployment.\n", "\n", "We save this script in a file named train.py, and put the file in a directory named code/. The full training script can be viewed under code/.\n", " \n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pygmentize code/train.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Define hyper parameters" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, let's define hyper parameters assocated with the model. As per training script we created above here are some of the hyper parameters that we can use to tune our model(s). One highlight here is that the model architecture is also given here as the hyper parameter which allows other model architecutures like GAT, MPNN. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.pytorch import PyTorch\n", "import time\n", "\n", "hyperparameters={\n", " # Feature Engineering\n", " \"gnn-featurizer-type\": 'canonical',\n", " \n", " # Model Architecture\n", " \"gnn-model-name\" : 'GCN-p',\n", " \"gnn-residuals\" : False,\n", " \"gnn-batchnorm\" : True,\n", " \"gnn-dropout\" : 0.0013086019242321,\n", " \"gnn-predictor-hidden-feats\" : 512,\n", " \n", " # Training\n", " \"batch-size\" : 1024,\n", " \"epochs\" : 10,\n", " \"learning-rate\" : 0.000208635928951698,\n", " \"weight-decay\" : 0.0005253058161908312,\n", " \"patience\" : 20\n", "}\n", "\n", "metric_definitions =[\n", " {'Name': 'train:roc_auc_score', 'Regex': 'training:roc_auc_score\\s\\[([0-9\\\\,.]+)\\]'},\n", " {'Name': 'validation:roc_auc_score', 'Regex': ',\\svalidation:roc_auc_score\\s\\[([0-9\\\\,.]+)\\]'},\n", " {'Name': 'best validation:roc_auc_score', 'Regex': 'best\\svalidation:roc_auc_score\\s\\[([0-9\\\\,.]+)\\]'},\n", " {'Name': 'epoch', 'Regex': 'epoch\\s\\[([0-9]+)\\]'},\n", " {'Name': 'train:loss', 'Regex': 'loss\\s\\[([0-9\\\\,.]+)\\]'}\n", " ]\n", "\n", "instance_type = \"local\"\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### (SageMaker Notebook Instance Local Mode) Docker Environment Preparation\n", "This section, we will first introduce the SageMaker local mode where you can run the training job in the local notebook without lauching remote training instance. It is very important to know that this section only works in SageMaker notebook instances. \n", "\n", "Because the volume size of container may exceed the available size in the root directory of the notebook instance, we need to put the directory of docker data into the `/home/ec2-user/SageMaker/docker` directory. By default, the root directory of docker is set as `/var/lib/docker/`. We need to change the directory of docker to `/home/ec2-user/SageMaker/docker`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Note, please only run this cell when you are using SageMaker notebook instance\n", "# Only run this cell once. If you need to rerun the notebook, please comment out this cell\n", "!bash ./prepare-docker.sh" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train\n", "\n", "DGL expects a neural network runtime and supports all Tensorlow, Pytorch and MxNext as it's runtimes. For for this project we are going to use `Pytorch` as the backend.\n", "\n", "The Amazon SageMaker Python SDK makes it easier to run a PyTorch script in Amazon SageMaker using its PyTorch estimator. After that, we can use the SageMaker Python SDK to deploy the trained model and run predictions. For more information on how to use this SDK with PyTorch, see the SageMaker Python SDK documentation.\n", "\n", "To start, we use the PyTorch estimator class to train our model. When creating our estimator, we make sure to specify a few things:\n", "\n", "* `entry_point`: the name of our PyTorch script. It contains our training script, which loads data from the input channels, configures training with hyperparameters, trains a model, and saves a model. It also contains code to load and run the model during inference.\n", "\n", "* `source_dir`: the location of our training scripts and requirements.txt file. \"requirements.txt\" lists packages you want to use with your script.\n", "\n", "* `framework_version`: the PyTorch version we want to use.\n", "The PyTorch estimator supports both single-machine & multi-machine, distributed PyTorch training using SMDataParallel" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Note, please only run this cell when you are using SageMaker notebook instance\n", "training_job_name = f\"tr-pytorch-{datetime.datetime.now():%Y-%m-%d-%H-%M-%S}\" \n", "print('Training job name: ', training_job_name)\n", "\n", "estimator_local = PyTorch(\n", " entry_point = \"train.py\",\n", " source_dir = \"code\",\n", " role = role,\n", " framework_version = \"1.9.0\",\n", " py_version=\"py38\",\n", " instance_count=1,\n", " instance_type=instance_type,\n", " debugger_hook_config=False,\n", " disable_profiler=True,\n", " hyperparameters = hyperparameters,\n", " metric_definitions=metric_definitions\n", ")\n", "\n", "estimator_local.fit({\"data_full\" : input_full, \n", " \"data_train\" : input_train, \n", " \"data_val\" : input_val}, \n", " job_name = training_job_name)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Chanllenge 1\n", "Please fill in values in below cell to run the training job using SageMaker training instances. Compare the logs of these two jobs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "instance_type = \"ml.c5.2xlarge\"\n", "hyperparameters[\"epochs\"]=50\n", "\n", "training_job_name = f\"tr-pytorch-{datetime.datetime.now():%Y-%m-%d-%H-%M-%S}\" \n", "print('Training job name: ', training_job_name)\n", "\n", "estimator = PyTorch(\n", " entry_point = \"train.py\",\n", " source_dir = \"code\",\n", " role = role,\n", " framework_version = \"1.9.0\",\n", " py_version=\"py38\",\n", " instance_count=1,\n", " instance_type=instance_type,\n", " debugger_hook_config=False,\n", " disable_profiler=True,\n", " hyperparameters = hyperparameters,\n", " metric_definitions=metric_definitions\n", ")\n", "\n", "estimator.fit({\"data_full\" : input_full, \n", " \"data_train\" : input_train, \n", " \"data_val\" : input_val}, \n", " job_name = training_job_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training results :" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* As you can see with the initial training here we get about 76% AUC value. How the metrics like training loss, validation loss and roc score changed over the time can be viewed on the SageMaker experiements." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Before, we deploy the model to an endpoint, let's see the where the model trained arffacts are stored in S3.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_data = estimator.model_data\n", "print(f\"Stored {model_data} as model_data\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Deploy the model on Amazon SageMaker\n", "After training our model, we host it on an Amazon SageMaker Endpoint. To make the endpoint load the model and serve predictions, we implement a few methods in inference.py.\n", "\n", "* `model_fn()`: function defined to load the saved model and return a model object that can be used for model serving. The SageMaker PyTorch model server loads our model by invoking model_fn.\n", "* `input_fn()`: deserializes and prepares the prediction input. In this example, our request body is first serialized to JSON and within the JSON it expects the `smiles` of the molecules that needs to be predicted. The input of smiles first converts graph using DGL and then add the features to each node using the same featurizer used at the time of training. Then the graph with features is returned by this function which is the requested format by the model. \n", "* `predict_fn()`: performs the prediction and returns the result. To deploy our endpoint, we call deploy() on our PyTorch estimator object, passing in our desired number of instances and instance type:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.pytorch import PyTorchModel\n", "\n", "endpoint_name = f\"HIV-Inhibitor-Prediction-EP-{datetime.datetime.now():%Y-%m-%d-%H-%M-%S}\"\n", "print(f\"Endpoint name: {endpoint_name}\")\n", "\n", "model = PyTorchModel(model_data=model_data, \n", " source_dir='code',\n", " entry_point='inference.py', \n", " role=role, \n", " framework_version=\"1.9.0\", \n", " py_version='py38')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "predictor = model.deploy(initial_instance_count=1, \n", " instance_type=\"ml.c5.xlarge\", \n", " endpoint_name=endpoint_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Challenge 2\n", "Can you test hosting the trained model in local mode? Is it possible to perform prediction against the endpoint deployed locally?" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "local_endpoint_name = f\"HIV-Inhibitor-Prediction-EP-{datetime.datetime.now():%Y-%m-%d-%H-%M-%S}\"\n", "print(f\"Local endpoint name: {endpoint_name}\")\n", "model_local = PyTorchModel(model_data=model_data, \n", " source_dir='code',\n", " entry_point='inference.py', \n", " role=role, \n", " framework_version=\"1.9.0\", \n", " py_version='py38')\n", "\n", "predictor_local = model_local.deploy(initial_instance_count=1, \n", " instance_type=\"local\", \n", " endpoint_name=local_endpoint_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Predicting test_set with the endpoint :\n", "\n", "As our newly created endpoint expects the smiles as the input let's get the all the smiles of the test_set. We can use followng function for that." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def collate_molgraphs(data):\n", " \"\"\"Batching a list of datapoints for dataloader.\n", " Parameters\n", " ----------\n", " data : list of 4-tuples.\n", " Each tuple is for a single datapoint, consisting of\n", " a SMILES, a DGLGraph, all-task labels and optionally a binary\n", " mask indicating the existence of labels.\n", " Returns\n", " -------\n", " smiles : list\n", " List of smiles\n", " bg : DGLGraph\n", " The batched DGLGraph.\n", " labels : Tensor of dtype float32 and shape (B, T)\n", " Batched datapoint labels. B is len(data) and\n", " T is the number of total tasks.\n", " masks : Tensor of dtype float32 and shape (B, T)\n", " Batched datapoint binary mask, indicating the\n", " existence of labels.\n", " \"\"\"\n", "\n", " smiles, graphs, labels, masks = map(list, zip(*data))\n", " \n", " bg = dgl.batch(graphs)\n", " bg.set_n_initializer(dgl.init.zero_initializer)\n", " bg.set_e_initializer(dgl.init.zero_initializer)\n", " labels = torch.stack(labels, dim=0)\n", " masks = torch.stack(masks, dim=0)\n", "\n", " return smiles, bg, labels, masks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using that function let's get all the smiles in the test set." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_smiles, bg, test_labels, masks = collate_molgraphs(test_set)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Predicting for a single molecule :\n", "\n", "Let's use single molecule from the test set and predict the probability of its beeing and inhibitor using the newly created endpoint above. Here is a randomly selected molecule. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mol = Chem.MolFromSmiles(test_smiles[110])\n", "mol" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's send this molecule data as JSON format to the endpoint and get the results." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "## If you already have an endpoint running, you can uncomment below line and fill the endpoint name to create a new predictor object \n", "# from sagemaker.predictor import Predictor\n", "# predictor = Predictor(endpoint_name=endpoint_name)\n", "\n", "predictor.serializer = sagemaker.serializers.JSONSerializer()\n", "predictor.deserializer = sagemaker.deserializers.JSONDeserializer()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "json = {\n", " \"smiles\" : \n", " [test_smiles[110]]\n", "}\n", "\n", "prediction_logits = predictor.predict(json)\n", "prediction_logits" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This had returned the logit value of above molecule beeing considered as an inhibitor." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Predicting for a for whole test_set :\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "json = {\n", " \"smiles\" : \n", " test_smiles\n", "}\n", "\n", "prediction_logits = predictor.predict(json)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Examine the test results " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import roc_curve, roc_auc_score\n", "import matplotlib.pyplot as plt \n", "\n", "roc_auc_score(test_labels[:,0].numpy(), np.asarray(prediction_logits))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* We get the ROC/AUC score for test set with the initial model aroinf 0.75.\n", "* Below is the ROC Curve" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fpr, tpr, _ = roc_curve(test_labels[:,0].numpy(), np.asarray(prediction_logits))\n", "\n", "plt.plot(fpr, tpr)\n", "plt.title(\"ROC Plot\")\n", "plt.xlabel(\"False Positive Rate\")\n", "plt.ylabel(\"True Positive Rate\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## (Optional) Hyper parameter tunning \n", "\n", "So far we have trained a single model with fixed hyper parameters. Nest lets try to further optimize the model by traying out different hyper parameters. We can use Amazon SageMaker Hyper parameter tunner for this purpose." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.tuner import (\n", " IntegerParameter,\n", " CategoricalParameter,\n", " ContinuousParameter,\n", " HyperparameterTuner,\n", ")\n", "gcn_hyperparameter_ranges = {\n", " \n", " \"gnn-dropout\": ContinuousParameter(0.001 , 0.003),\n", " \"gnn-predictor-hidden-feats\" : CategoricalParameter([128, 256, 512]),\n", " \n", " \"batch-size\" : CategoricalParameter([256, 512]),\n", " \"learning-rate\" : ContinuousParameter(0.0001, 0.001),\n", " \"weight-decay\" : ContinuousParameter(0.001, 0.01)\n", " \n", "}\n", "\n", "objective_metric_name = \"best validation:roc_auc_score\"\n", "\n", "gcn_estimator = PyTorch(\n", " entry_point = \"train.py\",\n", " source_dir = \"code\",\n", " role = role,\n", " framework_version = \"1.9.0\",\n", " py_version=\"py38\",\n", " instance_count=1,\n", " instance_type=\"ml.c5.2xlarge\",\n", " debugger_hook_config=False,\n", " disable_profiler=True\n", ")\n", "\n", "gcn_tuner = HyperparameterTuner(\n", " gcn_estimator,\n", " objective_metric_name,\n", " gcn_hyperparameter_ranges,\n", " metric_definitions,\n", " max_jobs=6,\n", " max_parallel_jobs=2\n", ")\n", "\n", "hyper_parameter_job_name = \"hpo-hiv-gcn-p-{}\".format(time.strftime(\"%m-%d-%H-%M-%S\")) \n", "print(f'Training job name: {hyper_parameter_job_name}')\n", "\n", "gcn_tuner.fit({\"data_full\" : input_full, \"data_train\" : input_train, \"data_val\" : input_val}, job_name = hyper_parameter_job_name)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# if the notebook lost it's connection, you can attach the tunner job by attaching the name below.\n", "#gcn_tuner = HyperparameterTuner.attach(\"hpo-hiv-gcn-p-03-16-02-38-07\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's examine the best model and it's hyper parameters." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import boto3\n", "\n", "smclient = boto3.client(\"sagemaker\")\n", "\n", "best_overall_training_job = smclient.describe_hyper_parameter_tuning_job(\n", " HyperParameterTuningJobName=hyper_parameter_job_name\n", ")\n", "\n", "best_overall_training_job[\"BestTrainingJob\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "best_gcn_training_job = sagemaker.estimator.Estimator.attach(best_overall_training_job[\"BestTrainingJob\"][\"TrainingJobName\"])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "best_gcn_model = PyTorchModel(model_data=best_gcn_training_job.model_data, source_dir='code',\n", " entry_point='inference.py', role=role, framework_version=\"1.9.0\", py_version='py38')\n", "\n", "best_gcn_predictor = best_gcn_model.deploy(initial_instance_count=1, instance_type=\"ml.c5.xlarge\", endpoint_name=\"best-gcn-\" + endpoint_name)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "best_gcn_predictor.serializer = sagemaker.serializers.JSONSerializer()\n", "best_gcn_predictor.deserializer = sagemaker.deserializers.JSONDeserializer()\n", "\n", "json = {\n", " \"smiles\" : \n", " test_smiles\n", "}\n", "\n", "prediction_logits = best_gcn_predictor.predict(json)\n", "roc_auc_score(test_labels[:,0].numpy(), np.asarray(prediction_logits))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Clean up \n", "\n", "Lastly, please remember to delete the Amazon SageMaker endpoint to avoid charges:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "best_gcn_predictor.delete_endpoint()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predictor.delete_endpoint()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predictor_local.delete_endpoint()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "conda_pytorch_p36", "language": "python", "name": "conda_pytorch_p36" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" } }, "nbformat": 4, "nbformat_minor": 4 }