{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# HIV Inhibitor Prediction Using Graph Neural Networks (GNN) on Amazon SageMaker\n", "\n", "**Note:** This notebook was last tested with the `Python 3 (Pytorch 1.12 Python 3.8 CPU Optimized)` environment image in Amazon SageMaker Studio.\n", "\n", "## Learning Objectives\n", "\n", "- Understand the basics of graph neural networks and how they can be applied to molecular graphs\n", "- Install and use the Deep Graph Library (DGL)\n", "- Build, train, and deploy a DGL model on SageMaker\n", "- Perform hyperparameter tuning of deep learning models\n", "- Use your own scripts to train custom models in SageMaker \n", "- Track model training and other tasks using SageMaker Experiments\n", "\n", "\n", "## Introduction\n", "\n", "Human immunodeficiency virus type 1 (HIV-1) is the most common cause of Acquired Immunodeficiency Syndrome (AIDS). One ongoing area of research is finding compounds that inhibit HIV-1 viral replication. Schematically, this is shown below as:\n", "\n", "![Inhibitor](img/1.jpg)\n", "
\n", " \n", "Image Source : Biological evaluation of molecules of the azaBINOL class as antiviral agents (https://www.sciencedirect.com/science/article/abs/pii/S0968089619306704)\n", " \n", "
\n", "\n", "
\n", "
\n", "
\n", "\n", "### Why is Deep Learning Useful for Analyzing Biological Networks?\n", "\n", "If you are familiar with classical network analysis, you may have encountered concepts such as the [betweenness centrality](https://en.wikipedia.org/wiki/Betweenness_centrality), [degree centrality](https://en.wikipedia.org/wiki/Centrality#Degree_centrality), or [random walk with restart](https://towardsdatascience.com/random-walks-with-restart-explained-77c3fe216bca). These methods are useful for calculating properties of nodes, analyzing networks, or grouping disease [genes](https://pubmed.ncbi.nlm.nih.gov/18371930/). However, these methods are **transductive,** which means that they can only generate features for a particular graph. They cannot predict edges, classify graphs, or perform other tasks where multiple graphs are needed. See [this](https://arxiv.org/abs/1706.02216) paper for further discussion of this issue.\n", "\n", "(A quick note on nomenclature: we use the term “graphs\" to refer to biological networks; we reserve the term \"network\" for a neural network. Although it is common in the computational biology field to refer to biological graphs as networks, in the deep learning field, \"network\" refers almost exclusively to a neural network).\n", "\n", "[Convolution neural networks](https://www.d2l.ai/chapter_convolutional-neural-networks/index.html), commonly used in computer vision, are also useful for analyzing graphs. Convolutions allow for **inductive** learning, whereby features are learned for different graph topologies. These convolutions transform the underlying information in the graph nodes and edges. While a single convolutional layer is generally not sufficient for most tasks, deep graph convolutional neural networks can perform graph prediction (i.e., predict the class of a network), link prediction (predict missing edges in a network), and other tasks. \n", "\n", "Deep learning models can also incorporate different edge types as well as external information about edges and nodes. This makes deep learning an attractive approach for analyzing and making predictions about complex graphs. Biological networks are frequently very heterogeneous and include diverse data types such as metabolic, biophysical, proteomic and functional assays, and information about gene regulatory networks. For example, [this](https://www.amazon.science/blog/amazon-web-services-open-sources-biological-knowledge-graph-to-fight-covid-19) blog post shows how a knowledge graph with diverse node and edge types can predict drug repurposing.\n", "\n", "While scientists can create their own convolutional layers, deep learning researchers have already built many convolutions and architectures that have proven useful in many applications. For example, [GraphSage](https://arxiv.org/pdf/1706.02216.pdf) can predict protein-protein interactions. Another commonly used approach is [Graph Attention Networks](https://arxiv.org/pdf/1710.10903.pdf) (GAT).\n", "\n", "For a more details of deep graph learning and how it can help analyze biological data, see [this](https://academic.oup.com/bib/article/22/2/1515/5964185) review paper. You may also find [this](http://snap.stanford.edu/deepnetbio-ismb/) tutorial useful.\n", "\n", "### What is the Deep Graph Library (DGL) and When Should You Use It?\n", "\n", "Deep Graph Library (DGL) allows researches and developers to easily and quickly apply deep graph learning approaches to their data by abstracting away much of the difficult deep learning work and code. The DGL library comes with a number of prebuilt layers, including [GraphSage convolutions](https://docs.dgl.ai/generated/dgl.nn.pytorch.conv.SAGEConv.html#dgl.nn.pytorch.conv.SAGEConv), [GATs](https://docs.dgl.ai/generated/dgl.nn.pytorch.conv.GATConv.html#dgl.nn.pytorch.conv.GATConv), and [others](https://docs.dgl.ai/api/python/nn-pytorch.html). Users have have the flexibility to create their own layers and architectures as well. \n", "\n", "The [DGL-LifeScience](https://lifesci.dgl.ai/index.html) Python package provides an even further abstraction of DGL, so that computational biologists, biochemists, and bioinformaticians who wish to leverage deep graph methods can easily do so for certain common use cases and performing common operations in the context of analyzing small and large molecules. If you want to learn more about how to use the DGL library, we recommend getting started with [this](https://docs.dgl.ai/en/0.6.x/guide/graph.html) tutorial.\n", "\n", "### Notebook Overview\n", "\n", "This example notebook trains multiple graph neural network models using Deep Graph Library and deploys them using Amazon SageMaker, a comprehensive and fully-managed machine learning service. With SageMaker, data scientists and developers can quickly and easily build and train machine learning models and then directly deploy them into a production-ready hosted environment. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Install Dependencies" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's start by installing the latest version of `dgl` and other Python dependencies." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "%pip install --disable-pip-version-check -U -q -r requirements.txt" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "'2022.09.3'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from functools import partial\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "import pandas as pd\n", "\n", "import dgl\n", "\n", "from dgllife.model import load_pretrained\n", "from dgllife.utils import smiles_to_bigraph, EarlyStopping, Meter, CanonicalAtomFeaturizer, CanonicalBondFeaturizer\n", "from functools import partial\n", "from torch.optim import Adam\n", "from torch.utils.data import DataLoader\n", "\n", "from dgllife.data import HIV\n", "\n", "import rdkit\n", "from rdkit import Chem\n", "from rdkit.Chem import Draw\n", "\n", "rdkit.__version__" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "tags": [] }, "outputs": [], "source": [ "node_featurizer = CanonicalAtomFeaturizer(atom_data_field='feat')\n", "edge_featurizer = None\n", "num_workers = 1\n", "split_ratio = \"0.7:0.2:0.1\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Explore the Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The [Drug Therapeutics Program (DTP) AIDS Antiviral Screen](https://wiki.nci.nih.gov/display/NCIDTPdata/AIDS+Antiviral+Screen+Data) tested the ability of 43,850 compounds to inhibit viral replication. The DGL library has a pre-processed version of this dataset where each compound is classified as either Confirmed Inactive (CI; labeled as 0) or Confirmed Moderately Active/Confirmed Active (CM,CA; labeled as 1). You can download and inspect the raw dataset from [here](https://moleculenet.org/datasets-1). Alternatively, you can download a subset of the data focused on HIV [here](https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/HIV.csv). This data is in .csv format and provides the struture of the molecule (in [SMILES](https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system) format), the type of activity, and activity against HIV.\n", "\n", "\n", "|SMILES string |activity |HIV_active |\n", "|--- |--- |--- |\n", "|CC(C)(CCC(=O)O)CCC(=O)O |CI |0 |\n", "|O=C(O)Cc1ccc(SSc2ccc(CC(=O)O)cc2)cc1 |CM |1 |\n", "|O=C(O)c1ccccc1SSc1ccccc1C(=O)O |CI |0 |\n", "|CCCCCCCCCCCC(=O)Nc1ccc(SSc2ccc(NC(=O)CCCCCCCCCCC)cc2)cc1 |CI |0 | \n", "\n", "Confirmed inactive (CI) compounds are labeled 0, while confirmed moderately active (CM)/confirmed active (CA) are labeled 1. We can use these labels to define our ML task as a **graph classification problem**. We will construct a graph to represent each molecule, considering the atoms as nodes and the chemical bonds between atoms as edges. Then, we will use GNN techniques to classify each molecule as either active or inactive." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "dataset = HIV(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),\n", " node_featurizer=node_featurizer,\n", " edge_featurizer=edge_featurizer,\n", " n_jobs=num_workers)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dgllife.data.hiv.HIV" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(dataset)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(41127, 2)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset.df.shape" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
smilesHIV_active
0CCC1=[O+][Cu-3]2([O+]=C(CC)C1)[O+]=C(CC)CC(CC)...0
1C(=Cc1ccccc1)C1=[O+][Cu-3]2([O+]=C(C=Cc3ccccc3...0
2CC(=O)N1c2ccccc2Sc2c1ccc1ccccc210
3Nc1ccc(C=Cc2ccc(N)cc2S(=O)(=O)O)c(S(=O)(=O)O)c10
4O=S(=O)(O)CCS(=O)(=O)O0
5CCOP(=O)(Nc1cccc(Cl)c1)OCC0
6O=C(O)c1ccccc1O0
7CC1=C2C(=COC(C)C2C)C(O)=C(C(=O)O)C1=O0
8O=[N+]([O-])c1ccc(SSc2ccc([N+](=O)[O-])cc2[N+]...0
9O=[N+]([O-])c1ccccc1SSc1ccccc1[N+](=O)[O-]0
10CC(C)(CCC(=O)O)CCC(=O)O0
11O=C(O)Cc1ccc(SSc2ccc(CC(=O)O)cc2)cc11
12O=C(O)c1ccccc1SSc1ccccc1C(=O)O0
13CCCCCCCCCCCC(=O)Nc1ccc(SSc2ccc(NC(=O)CCCCCCCCC...0
14Sc1cccc2c(S)cccc120
\n", "
" ], "text/plain": [ " smiles HIV_active\n", "0 CCC1=[O+][Cu-3]2([O+]=C(CC)C1)[O+]=C(CC)CC(CC)... 0\n", "1 C(=Cc1ccccc1)C1=[O+][Cu-3]2([O+]=C(C=Cc3ccccc3... 0\n", "2 CC(=O)N1c2ccccc2Sc2c1ccc1ccccc21 0\n", "3 Nc1ccc(C=Cc2ccc(N)cc2S(=O)(=O)O)c(S(=O)(=O)O)c1 0\n", "4 O=S(=O)(O)CCS(=O)(=O)O 0\n", "5 CCOP(=O)(Nc1cccc(Cl)c1)OCC 0\n", "6 O=C(O)c1ccccc1O 0\n", "7 CC1=C2C(=COC(C)C2C)C(O)=C(C(=O)O)C1=O 0\n", "8 O=[N+]([O-])c1ccc(SSc2ccc([N+](=O)[O-])cc2[N+]... 0\n", "9 O=[N+]([O-])c1ccccc1SSc1ccccc1[N+](=O)[O-] 0\n", "10 CC(C)(CCC(=O)O)CCC(=O)O 0\n", "11 O=C(O)Cc1ccc(SSc2ccc(CC(=O)O)cc2)cc1 1\n", "12 O=C(O)c1ccccc1SSc1ccccc1C(=O)O 0\n", "13 CCCCCCCCCCCC(=O)Nc1ccc(SSc2ccc(NC(=O)CCCCCCCCC... 0\n", "14 Sc1cccc2c(S)cccc12 0" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset.df.head(15)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our dataset contains around 41,000 molecules in `SMILES` format. The `HIV_active` column (label) indicates if the molecule inhibits HIV. Let's verify if there are any missing values in the dataset." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 41127 entries, 0 to 41126\n", "Data columns (total 2 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 smiles 41127 non-null object\n", " 1 HIV_active 41127 non-null int64 \n", "dtypes: int64(1), object(1)\n", "memory usage: 642.7+ KB\n" ] } ], "source": [ "dataset.df.info()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "False" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset.df.isnull().values.any()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are no missing values in this dataset. Next, let's explore the class distribution of the dataset." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 39684\n", "1 1443\n", "Name: HIV_active, dtype: int64" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset.df['HIV_active'].value_counts()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAAD1CAYAAACyaJl6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAAsTAAALEwEAmpwYAAATBUlEQVR4nO3dYYxV533n8e+vYByr2RQczyIW6IJqVhGOVJLM2qyyL7K2CoOzWqiURlirGlkodBUsJVK1a9w3bpMgxS9a71pyLNE1a1x1Q5DbyqOUlEWOoypa2WZcU2zsej3FdgERMzXYbhTVLuS/L+7D5u5khrnMwB3s+X6kqznn/zzPuc+REL+55zx3TqoKSdLc9guzPQFJ0uwzDCRJhoEkyTCQJGEYSJIwDCRJwPzZnsB03XDDDbVixYrZnoYkfaA899xzf19VA+PrH9gwWLFiBSMjI7M9DUn6QEnyxkR1LxNJkgwDSZJhIEniEsIgybwkzyf5bttfmeSZJKNJvpNkQatf2/ZHW/uKrmPc2+qvJFnfVR9qtdEkOy7j+UmSenApnwy+ArzctX8/8EBV3QicBba2+lbgbKs/0PqRZDWwGbgJGAK+1QJmHvAQsAFYDdzR+kqS+qSnMEiyDPg88N/bfoBbgcdblz3Apra9se3T2m9r/TcCe6vqvap6DRgFbm6v0ao6VlXvA3tbX0lSn/T6yeC/Av8F+Gnb/zjwdlWda/sngKVteylwHKC1v9P6/7/6uDGT1SVJfTJlGCT598DpqnquD/OZai7bkowkGRkbG5vt6UjSh0YvXzr7LPAfktwOfAT4GPDfgIVJ5rff/pcBJ1v/k8By4ESS+cAvAW911S/oHjNZ/f9TVbuAXQCDg4MfiKfyrNjx57M9hQ+N17/5+dmegvShNeUng6q6t6qWVdUKOjeAv19V/xF4CvhC67YFeKJtD7d9Wvv3q/M4tWFgc1tttBJYBTwLHAJWtdVJC9p7DF+Ws5Mk9WQmf47iHmBvkm8AzwOPtPojwB8lGQXO0PnPnao6mmQf8BJwDtheVecBktwNHADmAbur6ugM5iVJukSXFAZV9QPgB237GJ2VQOP7/CPwG5OM3wnsnKC+H9h/KXORJF0+fgNZkmQYSJIMA0kShoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJ9BAGST6S5Nkkf53kaJLfa/VHk7yW5HB7rWn1JHkwyWiSI0k+3XWsLUleba8tXfXPJHmhjXkwSa7AuUqSJtHLYy/fA26tqh8nuQb4YZLvtbb/XFWPj+u/gc7D7lcBtwAPA7ckuR64DxgECnguyXBVnW19vgQ8Q+fxl0PA95Ak9cWUnwyq48dt95r2qosM2Qg81sY9DSxMsgRYDxysqjMtAA4CQ63tY1X1dFUV8BiwafqnJEm6VD3dM0gyL8lh4DSd/9CfaU0726WgB5Jc22pLgeNdw0+02sXqJyaoS5L6pKcwqKrzVbUGWAbcnOSTwL3AJ4B/DVwP3HOlJnlBkm1JRpKMjI2NXem3k6Q545JWE1XV28BTwFBVnWqXgt4D/gdwc+t2EljeNWxZq12svmyC+kTvv6uqBqtqcGBg4FKmLkm6iF5WEw0kWdi2rwN+Dfibdq2ftvJnE/BiGzIM3NlWFa0F3qmqU8ABYF2SRUkWAeuAA63t3SRr27HuBJ64nCcpSbq4XlYTLQH2JJlHJzz2VdV3k3w/yQAQ4DDwn1r//cDtwCjwE+AugKo6k+TrwKHW72tVdaZtfxl4FLiOzioiVxJJUh9NGQZVdQT41AT1WyfpX8D2Sdp2A7snqI8An5xqLpKkK8NvIEuSDANJkmEgScIwkCRhGEiSMAwkSRgGkiQMA0kShoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJInenoH8kSTPJvnrJEeT/F6rr0zyTJLRJN9JsqDVr237o619Rdex7m31V5Ks76oPtdpokh1X4DwlSRfRyyeD94Bbq+pXgTXAUHvQ/f3AA1V1I3AW2Nr6bwXOtvoDrR9JVgObgZuAIeBbSea1Zys/BGwAVgN3tL6SpD6ZMgyq48dt95r2KuBW4PFW3wNsatsb2z6t/bYkafW9VfVeVb0GjAI3t9doVR2rqveBva2vJKlPerpn0H6DPwycBg4Cfwu8XVXnWpcTwNK2vRQ4DtDa3wE+3l0fN2ayuiSpT3oKg6o6X1VrgGV0fpP/xJWc1GSSbEsykmRkbGxsNqYgSR9Kl7SaqKreBp4C/g2wMMn81rQMONm2TwLLAVr7LwFvddfHjZmsPtH776qqwaoaHBgYuJSpS5IuopfVRANJFrbt64BfA16mEwpfaN22AE+07eG2T2v/flVVq29uq41WAquAZ4FDwKq2OmkBnZvMw5fh3CRJPZo/dReWAHvaqp9fAPZV1XeTvATsTfIN4Hngkdb/EeCPkowCZ+j8505VHU2yD3gJOAdsr6rzAEnuBg4A84DdVXX0sp2hJGlKU4ZBVR0BPjVB/Rid+wfj6/8I/MYkx9oJ7Jygvh/Y38N8JUlXgN9AliQZBpIkw0CShGEgScIwkCRhGEiSMAwkSRgGkiQMA0kShoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEr09A3l5kqeSvJTkaJKvtPrvJjmZ5HB73d415t4ko0leSbK+qz7UaqNJdnTVVyZ5ptW/056FLEnqk14+GZwDfruqVgNrge1JVre2B6pqTXvtB2htm4GbgCHgW0nmtWcoPwRsAFYDd3Qd5/52rBuBs8DWy3R+kqQeTBkGVXWqqv6qbf8D8DKw9CJDNgJ7q+q9qnoNGKXzrOSbgdGqOlZV7wN7gY1JAtwKPN7G7wE2TfN8JEnTcEn3DJKsAD4FPNNKdyc5kmR3kkWtthQ43jXsRKtNVv848HZVnRtXlyT1Sc9hkOSjwJ8AX62qd4GHgV8B1gCngN+/EhMcN4dtSUaSjIyNjV3pt5OkOaOnMEhyDZ0g+OOq+lOAqnqzqs5X1U+BP6RzGQjgJLC8a/iyVpus/hawMMn8cfWfU1W7qmqwqgYHBgZ6mbokqQe9rCYK8AjwclX9QVd9SVe3XwdebNvDwOYk1yZZCawCngUOAavayqEFdG4yD1dVAU8BX2jjtwBPzOy0JEmXYv7UXfgs8JvAC0kOt9rv0FkNtAYo4HXgtwCq6miSfcBLdFYiba+q8wBJ7gYOAPOA3VV1tB3vHmBvkm8Az9MJH0lSn0wZBlX1QyATNO2/yJidwM4J6vsnGldVx/jZZSZJUp/5DWRJkmEgSTIMJEkYBpIkDANJEoaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJGAaSJAwDSRKGgSQJw0CShGEgScIwkCTR2zOQlyd5KslLSY4m+UqrX5/kYJJX289FrZ4kDyYZTXIkyae7jrWl9X81yZau+meSvNDGPNieuyxJ6pNePhmcA367qlYDa4HtSVYDO4Anq2oV8GTbB9gArGqvbcDD0AkP4D7gFjqPuLzvQoC0Pl/qGjc081OTJPVqyjCoqlNV9Vdt+x+Al4GlwEZgT+u2B9jUtjcCj1XH08DCJEuA9cDBqjpTVWeBg8BQa/tYVT1dVQU81nUsSVIfXNI9gyQrgE8BzwCLq+pUa/oRsLhtLwWOdw070WoXq5+YoC5J6pOewyDJR4E/Ab5aVe92t7Xf6Osyz22iOWxLMpJkZGxs7Eq/nSTNGT2FQZJr6ATBH1fVn7bym+0SD+3n6VY/CSzvGr6s1S5WXzZB/edU1a6qGqyqwYGBgV6mLknqQS+riQI8ArxcVX/Q1TQMXFgRtAV4oqt+Z1tVtBZ4p11OOgCsS7Ko3TheBxxobe8mWdve686uY0mS+mB+D30+C/wm8EKSw632O8A3gX1JtgJvAF9sbfuB24FR4CfAXQBVdSbJ14FDrd/XqupM2/4y8ChwHfC99pIk9cmUYVBVPwQmW/d/2wT9C9g+ybF2A7snqI8An5xqLpKkK8NvIEuSDANJkmEgScIwkCRhGEiSMAwkSRgGkiQMA0kShoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJInenoG8O8npJC921X43yckkh9vr9q62e5OMJnklyfqu+lCrjSbZ0VVfmeSZVv9OkgWX8wQlSVPr5ZPBo8DQBPUHqmpNe+0HSLIa2Azc1MZ8K8m8JPOAh4ANwGrgjtYX4P52rBuBs8DWmZyQJOnSTRkGVfWXwJmp+jUbgb1V9V5VvQaMAje312hVHauq94G9wMYkAW4FHm/j9wCbLu0UJEkzNZN7BncnOdIuIy1qtaXA8a4+J1ptsvrHgber6ty4uiSpj6YbBg8DvwKsAU4Bv3+5JnQxSbYlGUkyMjY21o+3lKQ5YVphUFVvVtX5qvop8Id0LgMBnASWd3Vd1mqT1d8CFiaZP64+2fvuqqrBqhocGBiYztQlSROYVhgkWdK1++vAhZVGw8DmJNcmWQmsAp4FDgGr2sqhBXRuMg9XVQFPAV9o47cAT0xnTpKk6Zs/VYck3wY+B9yQ5ARwH/C5JGuAAl4Hfgugqo4m2Qe8BJwDtlfV+Xacu4EDwDxgd1UdbW9xD7A3yTeA54FHLtfJSZJ6M2UYVNUdE5Qn/Q+7qnYCOyeo7wf2T1A/xs8uM0mSZoHfQJYkGQaSJMNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJGAaSJAwDSRI9hEGS3UlOJ3mxq3Z9koNJXm0/F7V6kjyYZDTJkSSf7hqzpfV/NcmWrvpnkrzQxjyYJJf7JCVJF9fLJ4NHgaFxtR3Ak1W1Cniy7QNsAFa11zbgYeiEB51nJ99C5xGX910IkNbnS13jxr+XJOkKmzIMquovgTPjyhuBPW17D7Cpq/5YdTwNLEyyBFgPHKyqM1V1FjgIDLW2j1XV01VVwGNdx5Ik9cl07xksrqpTbftHwOK2vRQ43tXvRKtdrH5igrokqY9mfAO5/UZfl2EuU0qyLclIkpGxsbF+vKUkzQnTDYM32yUe2s/TrX4SWN7Vb1mrXay+bIL6hKpqV1UNVtXgwMDANKcuSRpvumEwDFxYEbQFeKKrfmdbVbQWeKddTjoArEuyqN04XgccaG3vJlnbVhHd2XUsSVKfzJ+qQ5JvA58Dbkhygs6qoG8C+5JsBd4Avti67wduB0aBnwB3AVTVmSRfBw61fl+rqgs3pb9MZ8XSdcD32kuS1EdThkFV3TFJ020T9C1g+yTH2Q3snqA+AnxyqnlIkq4cv4EsSTIMJEmGgSQJw0CShGEgScIwkCRhGEiSMAwkSRgGkiQMA0kShoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEjMMgySvJ3khyeEkI612fZKDSV5tPxe1epI8mGQ0yZEkn+46zpbW/9UkW2Z2SpKkS3U5Phn8u6paU1WDbX8H8GRVrQKebPsAG4BV7bUNeBg64QHcB9wC3AzcdyFAJEn9cSUuE20E9rTtPcCmrvpj1fE0sDDJEmA9cLCqzlTVWeAgMHQF5iVJmsRMw6CA/5XkuSTbWm1xVZ1q2z8CFrftpcDxrrEnWm2y+s9Jsi3JSJKRsbGxGU5dknTB/BmO/7dVdTLJPwcOJvmb7saqqiQ1w/foPt4uYBfA4ODgZTuuJM11M/pkUFUn28/TwJ/Rueb/Zrv8Q/t5unU/CSzvGr6s1SarS5L6ZNphkOQXk/yzC9vAOuBFYBi4sCJoC/BE2x4G7myritYC77TLSQeAdUkWtRvH61pNktQnM7lMtBj4syQXjvM/q+ovkhwC9iXZCrwBfLH13w/cDowCPwHuAqiqM0m+Dhxq/b5WVWdmMC9J0iWadhhU1THgVyeovwXcNkG9gO2THGs3sHu6c5EkzYzfQJYkGQaSJMNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJAnDQJKEYSBJwjCQJDHzJ51J+oBasePPZ3sKHyqvf/Pzsz2FGfGTgSTJMJAkGQaSJK6iMEgylOSVJKNJdsz2fCRpLrkqwiDJPOAhYAOwGrgjyerZnZUkzR1XRRgANwOjVXWsqt4H9gIbZ3lOkjRnXC1LS5cCx7v2TwC3jO+UZBuwre3+OMkrfZjbXHAD8PezPYmp5P7ZnoFmif8+L69/OVHxagmDnlTVLmDXbM/jwybJSFUNzvY8pIn477M/rpbLRCeB5V37y1pNktQHV0sYHAJWJVmZZAGwGRie5TlJ0pxxVVwmqqpzSe4GDgDzgN1VdXSWpzWXeOlNVzP/ffZBqmq25yBJmmVXy2UiSdIsMgwkSYaBJOkquYGs/kryCTrf8F7aSieB4ap6efZmJWk2+clgjklyD50/9xHg2fYK8G3/QKCuZknumu05fJi5mmiOSfJ/gJuq6p/G1RcAR6tq1ezMTLq4JH9XVb882/P4sPIy0dzzU+BfAG+Mqy9pbdKsSXJksiZgcT/nMtcYBnPPV4Enk7zKz/444C8DNwJ3z9akpGYxsB44O64e4H/3fzpzh2Ewx1TVXyT5V3T+bHj3DeRDVXV+9mYmAfBd4KNVdXh8Q5If9H02c4j3DCRJriaSJBkGkiQMA0kShoEkCcNAkgT8Xwar25swFNVjAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "dataset.df['HIV_active'].value_counts().plot(kind='bar')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are only ~3% compounds classified as HIV inhibtiors. This means that the dataset is heavily imbalaced." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Explore Molecular Properties With RDKit\n", "\n", "[RDKit](https://www.rdkit.org/docs/cppapi/index.html) is an open-source cheminformatics toolkit. It includes a collection of standard cheminformatics functions for molecule I/O, substructure searching, chemical reactions, 2D and 3D coordinate generation, fingerprinting, etc.\n", "\n", "We are going to use this library to explore the molecules in the dataset. First, let's randomly select and visualize several molecules." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "random_molecule_start_index = 10\n", "\n", "sample_smiles = dataset.df['smiles'][random_molecule_start_index:random_molecule_start_index + 8].values\n", "sample_molecules = [Chem.MolFromSmiles(smile) for smile in sample_smiles]\n", "Draw.MolsToGridImage(sample_molecules, molsPerRow=4, subImgSize=(600, 600))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, let's look at a single molecule and explore its properties." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'CC(C)(CCC(=O)O)CCC(=O)O'" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sample_smiles[0]" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mol = sample_molecules[0]\n", "mol" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can use the [RDKit::Atom](https://www.rdkit.org/docs/cppapi/classRDKit_1_1Atom.html) class to further explore the features of the molecules." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total number of atoms in the molecule : 13\n" ] } ], "source": [ "atoms = mol.GetAtoms()\n", "print(\"Total number of atoms in the molecule : {}\".format(len(atoms)))" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'atomic_symbol': 'C',\n", " 'atomic_numbers': 6,\n", " 'degree': 1,\n", " 'formal_charge': 0,\n", " 'hybridization': rdkit.Chem.rdchem.HybridizationType.SP3,\n", " 'is_aromatic': False},\n", " {'atomic_symbol': 'C',\n", " 'atomic_numbers': 6,\n", " 'degree': 4,\n", " 'formal_charge': 0,\n", " 'hybridization': rdkit.Chem.rdchem.HybridizationType.SP3,\n", " 'is_aromatic': False},\n", " {'atomic_symbol': 'C',\n", " 'atomic_numbers': 6,\n", " 'degree': 1,\n", " 'formal_charge': 0,\n", " 'hybridization': rdkit.Chem.rdchem.HybridizationType.SP3,\n", " 'is_aromatic': False},\n", " {'atomic_symbol': 'C',\n", " 'atomic_numbers': 6,\n", " 'degree': 2,\n", " 'formal_charge': 0,\n", " 'hybridization': rdkit.Chem.rdchem.HybridizationType.SP3,\n", " 'is_aromatic': False},\n", " {'atomic_symbol': 'C',\n", " 'atomic_numbers': 6,\n", " 'degree': 2,\n", " 'formal_charge': 0,\n", " 'hybridization': rdkit.Chem.rdchem.HybridizationType.SP3,\n", " 'is_aromatic': False},\n", " {'atomic_symbol': 'C',\n", " 'atomic_numbers': 6,\n", " 'degree': 3,\n", " 'formal_charge': 0,\n", " 'hybridization': rdkit.Chem.rdchem.HybridizationType.SP2,\n", " 'is_aromatic': False},\n", " {'atomic_symbol': 'O',\n", " 'atomic_numbers': 8,\n", " 'degree': 1,\n", " 'formal_charge': 0,\n", " 'hybridization': rdkit.Chem.rdchem.HybridizationType.SP2,\n", " 'is_aromatic': False},\n", " {'atomic_symbol': 'O',\n", " 'atomic_numbers': 8,\n", " 'degree': 1,\n", " 'formal_charge': 0,\n", " 'hybridization': rdkit.Chem.rdchem.HybridizationType.SP2,\n", " 'is_aromatic': False},\n", " {'atomic_symbol': 'C',\n", " 'atomic_numbers': 6,\n", " 'degree': 2,\n", " 'formal_charge': 0,\n", " 'hybridization': rdkit.Chem.rdchem.HybridizationType.SP3,\n", " 'is_aromatic': False},\n", " {'atomic_symbol': 'C',\n", " 'atomic_numbers': 6,\n", " 'degree': 2,\n", " 'formal_charge': 0,\n", " 'hybridization': rdkit.Chem.rdchem.HybridizationType.SP3,\n", " 'is_aromatic': False},\n", " {'atomic_symbol': 'C',\n", " 'atomic_numbers': 6,\n", " 'degree': 3,\n", " 'formal_charge': 0,\n", " 'hybridization': rdkit.Chem.rdchem.HybridizationType.SP2,\n", " 'is_aromatic': False},\n", " {'atomic_symbol': 'O',\n", " 'atomic_numbers': 8,\n", " 'degree': 1,\n", " 'formal_charge': 0,\n", " 'hybridization': rdkit.Chem.rdchem.HybridizationType.SP2,\n", " 'is_aromatic': False},\n", " {'atomic_symbol': 'O',\n", " 'atomic_numbers': 8,\n", " 'degree': 1,\n", " 'formal_charge': 0,\n", " 'hybridization': rdkit.Chem.rdchem.HybridizationType.SP2,\n", " 'is_aromatic': False}]" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "molecule_features = []\n", "for atom in atoms:\n", " atom_features = {}\n", " atom_features['atomic_symbol'] = atom.GetSymbol()\n", " atom_features['atomic_numbers'] = atom.GetAtomicNum()\n", " atom_features['degree'] = atom.GetDegree()\n", " atom_features['formal_charge'] = atom.GetFormalCharge()\n", " atom_features['hybridization'] = atom.GetHybridization()\n", " atom_features['is_aromatic'] = atom.GetIsAromatic()\n", " molecule_features.append(atom_features)\n", "molecule_features" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we featurize the atoms of our molecule as nodes." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'feat': tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,\n", " 1., 0.],\n", " [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,\n", " 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.,\n", " 0., 0.],\n", " [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,\n", " 1., 0.],\n", " [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1.,\n", " 0., 0.],\n", " [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1.,\n", " 0., 0.],\n", " [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,\n", " 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0.,\n", " 0., 0.],\n", " [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0.,\n", " 0., 0.],\n", " [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0.,\n", " 0., 0.],\n", " [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1.,\n", " 0., 0.],\n", " [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1.,\n", " 0., 0.],\n", " [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,\n", " 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0.,\n", " 0., 0.],\n", " [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0.,\n", " 0., 0.],\n", " [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0.,\n", " 0., 0.]])}" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "atom_featurizer = CanonicalAtomFeaturizer(atom_data_field='feat')\n", "atom_featurizer(mol)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "74" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "atom_featurizer.feat_size()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We already have these embedded in our graph since we used the [node featurizer](https://lifesci.dgl.ai/generated/dgllife.utils.CanonicalAtomFeaturizer.html) earlier in this notebook. \n", "\n", "Let's decode the graph associated to the above index." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "smiles, graphs, labels, masks = map(list, zip(*dataset))" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'CC(C)(CCC(=O)O)CCC(=O)O'" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "smiles[random_molecule_start_index]" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "13" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "random_graph = graphs[random_molecule_start_index]\n", "\n", "random_graph.num_nodes()" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "37" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "random_graph.num_edges()" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([13, 74])" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "random_graph.ndata['feat'].shape" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,\n", " 1., 0.])" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "random_graph.ndata['feat'][0, :]" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0.,\n", " 0., 0.])" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "random_graph.ndata['feat'][1, :]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice that the order of the atoms represented in the graph is different. This will not be a problem for our analysis." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Split the Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `dgl-lifesci` package provides methods to split data into training, validation and test sets based on [several strategies](https://lifesci.dgl.ai/api/utils.splitters.html).\n", "\n", "We will use the [`ScaffoldSplitter`](https://lifesci.dgl.ai/api/utils.splitters.html#dgllife.utils.ScaffoldSplitter) for this project. This method groups molecules based on their scaffolds and sorts groups based on their sizes. The groups are then split for k-fold cross validation.\n", "\n", "As with other k-fold splitting methods, each molecule will appear only once in the validation set among all folds. In addition, this method ensures that molecules with the same scaffold will be collectively in either the training set or the validation set for each fold. Scaffold splitting, rather than random splitting, is commonly used in chemoinformatics to ensure the training and testing sets include similar molecules." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from dgllife.utils import ScaffoldSplitter, RandomSplitter\n", "\n", "train_ratio, val_ratio, test_ratio = map(float, split_ratio.split(':'))\n", "\n", "train_set, val_set, test_set = ScaffoldSplitter.train_val_test_split(\n", " dataset, frac_train=train_ratio, frac_val=val_ratio, frac_test=test_ratio,\n", " scaffold_func='smiles')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lets check the distribution of classes of train, validation, test datasets after the split." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "28788" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(train_set)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAAD1CAYAAACyaJl6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAAsTAAALEwEAmpwYAAANyUlEQVR4nO3cX4id9Z3H8fdnk1rKusXYzIZsEjdSZ1liYVMbYqB74VbIH3sRC0X0ogaRptAEKvTCtDcpWkEv2oJghRQHI3RNpX8wtGmzIVhKWdSMbYhG182Q6iYhmqlJtYtQN/a7F/PL9jB7JjOZSc4ZM+8XHOac7/M85/wOBN+e5zwzqSokSXPbX/V7AZKk/jMGkiRjIEkyBpIkjIEkCWMgSQLm93sB07Vw4cJavnx5v5chSR8oL7zwwu+ramD8/AMbg+XLlzM8PNzvZUjSB0qS17vNPU0kSTIGkiRjIEnCGEiSMAaSJIyBJAljIEnCGEiS+AD/0tkHxfJtP+v3Ei4brz342X4vQbps+clAkmQMJEnGQJKEMZAkYQwkSRgDSRLGQJKEMZAkYQwkSRgDSRLGQJKEMZAkYQwkSRgDSRLGQJKEMZAkMYUYJFmW5JkkLyc5nOQrbf6NJCeSHGy3WzqO+VqSkSSvJlnXMV/fZiNJtnXMr03yXJv/IMkVF/uNSpImNpVPBmeBr1bVCmANsCXJirbtO1W1st32ALRttwPXA+uB7yaZl2Qe8AiwAVgB3NHxPA+157oOOAPcfZHenyRpCiaNQVWdrKrftPt/BF4BlpznkI3Arqr6U1X9DhgBVrfbSFUdrar3gF3AxiQBPgP8sB2/E7h1mu9HkjQNF/SdQZLlwCeB59poa5JDSYaSLGizJcCxjsOOt9lE848Bf6iqs+Pm3V5/c5LhJMOjo6MXsnRJ0nlMOQZJrgR+BNxTVe8AjwIfB1YCJ4FvXYoFdqqqHVW1qqpWDQwMXOqXk6Q5Y/5UdkryIcZC8P2q+jFAVb3Zsf17wE/bwxPAso7Dl7YZE8zfAq5KMr99OujcX5LUA1O5mijAY8ArVfXtjvnijt0+B7zU7u8Gbk/y4STXAoPA88ABYLBdOXQFY18y766qAp4BPt+O3wQ8PbO3JUm6EFP5ZPBp4AvAi0kOttnXGbsaaCVQwGvAlwCq6nCSp4CXGbsSaUtVvQ+QZCuwF5gHDFXV4fZ89wK7knwT+C1j8ZEk9cikMaiqXwPpsmnPeY55AHigy3xPt+Oq6ihjVxtJkvrA30CWJBkDSZIxkCRhDCRJGANJEsZAkoQxkCRhDCRJGANJEsZAkoQxkCRhDCRJGANJEsZAkoQxkCRhDCRJGANJEsZAkoQxkCRhDCRJGANJEsZAkoQxkCRhDCRJGANJEsZAkoQxkCRhDCRJTCEGSZYleSbJy0kOJ/lKm1+dZF+SI+3ngjZPkoeTjCQ5lOSGjufa1PY/kmRTx/xTSV5sxzycJJfizUqSupvKJ4OzwFeragWwBtiSZAWwDdhfVYPA/vYYYAMw2G6bgUdhLB7AduBGYDWw/VxA2j5f7Dhu/czfmiRpqiaNQVWdrKrftPt/BF4BlgAbgZ1tt53Are3+RuCJGvMscFWSxcA6YF9Vna6qM8A+YH3b9tGqeraqCnii47kkST1wQd8ZJFkOfBJ4DlhUVSfbpjeARe3+EuBYx2HH2+x88+Nd5pKkHplyDJJcCfwIuKeq3unc1v6Pvi7y2rqtYXOS4STDo6Ojl/rlJGnOmFIMknyIsRB8v6p+3MZvtlM8tJ+n2vwEsKzj8KVtdr750i7z/6eqdlTVqqpaNTAwMJWlS5KmYCpXEwV4DHilqr7dsWk3cO6KoE3A0x3zO9tVRWuAt9vppL3A2iQL2hfHa4G9bds7Sda017qz47kkST0wfwr7fBr4AvBikoNt9nXgQeCpJHcDrwO3tW17gFuAEeBd4C6Aqjqd5H7gQNvvvqo63e5/GXgc+Ajw83aTJPXIpDGoql8DE133f3OX/QvYMsFzDQFDXebDwCcmW4sk6dLwN5AlScZAkmQMJEkYA0kSxkCShDGQJGEMJEkYA0kSxkCShDGQJGEMJEkYA0kSxkCShDGQJGEMJEkYA0kSxkCShDGQJGEMJEkYA0kSxkCShDGQJGEMJEkYA0kSxkCShDGQJGEMJEkYA0kSU4hBkqEkp5K81DH7RpITSQ622y0d276WZCTJq0nWdczXt9lIkm0d82uTPNfmP0hyxcV8g5KkyU3lk8HjwPou8+9U1cp22wOQZAVwO3B9O+a7SeYlmQc8AmwAVgB3tH0BHmrPdR1wBrh7Jm9IknThJo1BVf0KOD3F59sI7KqqP1XV74ARYHW7jVTV0ap6D9gFbEwS4DPAD9vxO4FbL+wtSJJmaibfGWxNcqidRlrQZkuAYx37HG+zieYfA/5QVWfHzSVJPTTdGDwKfBxYCZwEvnWxFnQ+STYnGU4yPDo62ouXlKQ5YVoxqKo3q+r9qvoz8D3GTgMBnACWdey6tM0mmr8FXJVk/rj5RK+7o6pWVdWqgYGB6SxdktTFtGKQZHHHw88B56402g3cnuTDSa4FBoHngQPAYLty6ArGvmTeXVUFPAN8vh2/CXh6OmuSJE3f/Ml2SPIkcBOwMMlxYDtwU5KVQAGvAV8CqKrDSZ4CXgbOAluq6v32PFuBvcA8YKiqDreXuBfYleSbwG+Bxy7Wm5MkTc2kMaiqO7qMJ/wPdlU9ADzQZb4H2NNlfpS/nGaSJPWBv4EsSTIGkiRjIEnCGEiSMAaSJIyBJAljIEnCGEiSMAaSJIyBJAljIEnCGEiSMAaSJIyBJAljIEnCGEiSMAaSJIyBJAljIEnCGEiSMAaSJIyBJAljIEnCGEiSMAaSJIyBJAljIEnCGEiSMAaSJKYQgyRDSU4lealjdnWSfUmOtJ8L2jxJHk4ykuRQkhs6jtnU9j+SZFPH/FNJXmzHPJwkF/tNSpLObyqfDB4H1o+bbQP2V9UgsL89BtgADLbbZuBRGIsHsB24EVgNbD8XkLbPFzuOG/9akqRLbNIYVNWvgNPjxhuBne3+TuDWjvkTNeZZ4Koki4F1wL6qOl1VZ4B9wPq27aNV9WxVFfBEx3NJknpkut8ZLKqqk+3+G8Cidn8JcKxjv+Ntdr758S5zSVIPzfgL5PZ/9HUR1jKpJJuTDCcZHh0d7cVLStKcMN0YvNlO8dB+nmrzE8Cyjv2Wttn55ku7zLuqqh1VtaqqVg0MDExz6ZKk8aYbg93AuSuCNgFPd8zvbFcVrQHebqeT9gJrkyxoXxyvBfa2be8kWdOuIrqz47kkST0yf7IdkjwJ3AQsTHKcsauCHgSeSnI38DpwW9t9D3ALMAK8C9wFUFWnk9wPHGj73VdV576U/jJjVyx9BPh5u0mSemjSGFTVHRNsurnLvgVsmeB5hoChLvNh4BOTrUOSdOn4G8iSJGMgSTIGkiSMgSQJYyBJwhhIkjAGkiSMgSQJYyBJwhhIkjAGkiSMgSQJYyBJwhhIkjAGkiSMgSQJYyBJwhhIkjAGkiSMgSQJYyBJwhhIkjAGkiSMgSQJYyBJwhhIkjAGkiSMgSSJGcYgyWtJXkxyMMlwm12dZF+SI+3ngjZPkoeTjCQ5lOSGjufZ1PY/kmTTzN6SJOlCXYxPBv9SVSuralV7vA3YX1WDwP72GGADMNhum4FHYSwewHbgRmA1sP1cQCRJvXEpThNtBHa2+zuBWzvmT9SYZ4GrkiwG1gH7qup0VZ0B9gHrL8G6JEkTmGkMCvi3JC8k2dxmi6rqZLv/BrCo3V8CHOs49nibTTSXJPXI/Bke/89VdSLJ3wL7kvxH58aqqiQ1w9f4Py04mwGuueaai/W0kjTnzeiTQVWdaD9PAT9h7Jz/m+30D+3nqbb7CWBZx+FL22yiebfX21FVq6pq1cDAwEyWLknqMO0YJPnrJH9z7j6wFngJ2A2cuyJoE/B0u78buLNdVbQGeLudTtoLrE2yoH1xvLbNJEk9MpPTRIuAnyQ59zz/WlW/SHIAeCrJ3cDrwG1t/z3ALcAI8C5wF0BVnU5yP3Cg7XdfVZ2ewbokSRdo2jGoqqPAP3WZvwXc3GVewJYJnmsIGJruWiRJM+NvIEuSjIEkyRhIkjAGkiSMgSQJYyBJwhhIkjAGkiSMgSQJYyBJwhhIkjAGkiSMgSQJYyBJwhhIkjAGkiSMgSQJYyBJwhhIkjAGkiSMgSQJYyBJwhhIkjAGkiRgfr8XIKk/lm/7Wb+XcFl57cHP9nsJM+InA0mSMZAkGQNJEsZAksQsikGS9UleTTKSZFu/1yNJc8msiEGSecAjwAZgBXBHkhX9XZUkzR2zIgbAamCkqo5W1XvALmBjn9ckSXPGbPk9gyXAsY7Hx4Ebx++UZDOwuT387ySv9mBtc8FC4Pf9XsRk8lC/V6A+8d/nxfX33YazJQZTUlU7gB39XsflJslwVa3q9zqkbvz32Ruz5TTRCWBZx+OlbSZJ6oHZEoMDwGCSa5NcAdwO7O7zmiRpzpgVp4mq6mySrcBeYB4wVFWH+7ysucRTb5rN/PfZA6mqfq9BktRns+U0kSSpj4yBJMkYSJJmyRfI6q0k/8jYb3gvaaMTwO6qeqV/q5LUT34ymGOS3MvYn/sI8Hy7BXjSPxCo2SzJXf1ew+XMq4nmmCT/CVxfVf8zbn4FcLiqBvuzMun8kvxXVV3T73VcrjxNNPf8Gfg74PVx88Vtm9Q3SQ5NtAlY1Mu1zDXGYO65B9if5Ah/+eOA1wDXAVv7tSipWQSsA86Mmwf4994vZ+4wBnNMVf0iyT8w9mfDO79APlBV7/dvZRIAPwWurKqD4zck+WXPVzOH+J2BJMmriSRJxkCShDGQJGEMJEkYA0kS8L9upLbe/DmXoQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "train_set.dataset.df.loc[train_set.indices]['HIV_active'].value_counts().plot(kind=\"bar\")" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "8226" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(val_set)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD1CAYAAAC87SVQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAAsTAAALEwEAmpwYAAAR+ElEQVR4nO3df4yd1X3n8fenuKRtuopNmFrUNmtLcRuRSqHsCKiyWu3GW9uQquaPBBGtlhGy5P3D3W1WK23J/mMtBIlIq2WLtEGyindN1A1xaSOsBIWOnERVVfFjCJQGKOspCbEtwFPG0B9s0pp+9497nNy4M547eHyH+Lxf0uie53vO89zzSKPPvTr3ufdJVSFJ6sNPrPYEJEnjY+hLUkcMfUnqiKEvSR0x9CWpI4a+JHVkzWpP4Fwuv/zy2rx582pPQ5J+rDz11FN/WVUTC/W9q0N/8+bNzMzMrPY0JOnHSpKXF+tzeUeSOmLoS1JHDH1J6oihL0kdMfQlqSMjhX6S/5jkuSTfSvKFJD+VZEuSx5PMJvlikkvb2Pe07dnWv3noOJ9u9ReT7LhA5yRJWsSSoZ9kA/AfgMmq+iXgEuAW4LPAPVX1AeAUsLvtshs41er3tHEkuart9yFgJ/C5JJes7OlIks5l1OWdNcBPJ1kD/AzwCvBR4KHWfxC4qbV3tW1a/7YkafUHq+r7VfVtYBa49rzPQJI0siW/nFVVJ5L8N+C7wP8D/hB4Cnijqk63YceBDa29ATjW9j2d5E3g/a3+2NChh/f5sbb59q+s9hQuKt+5+2OrPQXpojXK8s46Bu/StwA/D7yXwfLMBZFkT5KZJDNzc3MX6mkkqUujLO/8a+DbVTVXVX8P/AHwEWBtW+4B2AicaO0TwCaA1v8+4PXh+gL7/EBV7a+qyaqanJhY8KcjJEnv0Cih/13g+iQ/09bmtwHPA18HPt7GTAEPt/bhtk3r/1oNbsR7GLilXd2zBdgKPLEypyFJGsUoa/qPJ3kI+CZwGnga2A98BXgwyWda7f62y/3A55PMAvMMrtihqp5LcojBC8ZpYG9Vvb3C5yNJOoeRfmWzqvYB+84qv8QCV99U1feATyxynLuAu5Y5R0nSCvEbuZLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktSRJUM/yS8meWbo76+SfCrJZUmmkxxtj+va+CS5N8lskmeTXDN0rKk2/miSqcWfVZJ0ISwZ+lX1YlVdXVVXA/8MeAv4EnA7cKSqtgJH2jbADQxuer4V2APcB5DkMga3XLyOwW0W9515oZAkjcdyl3e2AX9RVS8Du4CDrX4QuKm1dwEP1MBjwNokVwA7gOmqmq+qU8A0sPN8T0CSNLrlhv4twBdae31VvdLarwLrW3sDcGxon+Ottlj9RyTZk2Qmyczc3NwypydJOpeRQz/JpcCvA793dl9VFVArMaGq2l9Vk1U1OTExsRKHlCQ1y3mnfwPwzap6rW2/1pZtaI8nW/0EsGlov42ttlhdkjQmywn9T/LDpR2Aw8CZK3CmgIeH6re2q3iuB95sy0CPAtuTrGsf4G5vNUnSmKwZZVCS9wK/Cvy7ofLdwKEku4GXgZtb/RHgRmCWwZU+twFU1XySO4En27g7qmr+vM9AkjSykUK/qv4WeP9ZtdcZXM1z9tgC9i5ynAPAgeVPU5K0EvxGriR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpIyOFfpK1SR5K8udJXkjyK0kuSzKd5Gh7XNfGJsm9SWaTPJvkmqHjTLXxR5NMLf6MkqQLYdR3+r8NfLWqPgh8GHgBuB04UlVbgSNtGwY3UN/a/vYA9wEkuQzYB1wHXAvsO/NCIUkajyVDP8n7gH8B3A9QVX9XVW8Au4CDbdhB4KbW3gU8UAOPAWuTXAHsAKarar6qTgHTwM4VPBdJ0hJGeae/BZgD/leSp5P8TrtR+vqqeqWNeRVY39obgGND+x9vtcXqkqQxGSX01wDXAPdV1S8Df8sPl3KAH9wMvVZiQkn2JJlJMjM3N7cSh5QkNaOE/nHgeFU93rYfYvAi8FpbtqE9nmz9J4BNQ/tvbLXF6j+iqvZX1WRVTU5MTCznXCRJS1gy9KvqVeBYkl9spW3A88Bh4MwVOFPAw619GLi1XcVzPfBmWwZ6FNieZF37AHd7q0mSxmTNiOP+PfC7SS4FXgJuY/CCcSjJbuBl4OY29hHgRmAWeKuNparmk9wJPNnG3VFV8ytyFpKkkYwU+lX1DDC5QNe2BcYWsHeR4xwADixjfpKkFeQ3ciWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0JakjI4V+ku8k+bMkzySZabXLkkwnOdoe17V6ktybZDbJs0muGTrOVBt/NMnUYs8nSbowlvNO/19V1dVVdea2ibcDR6pqK3CkbQPcAGxtf3uA+2DwIgHsA64DrgX2nXmhkCSNx/ks7+wCDrb2QeCmofoDNfAYsDbJFcAOYLqq5qvqFDAN7DyP55ckLdOooV/AHyZ5KsmeVltfVa+09qvA+tbeABwb2vd4qy1WlySNyZoRx/3zqjqR5OeA6SR/PtxZVZWkVmJC7UVlD8CVV165EoeUJDUjvdOvqhPt8STwJQZr8q+1ZRva48k2/ASwaWj3ja22WP3s59pfVZNVNTkxMbG8s5EkndOSoZ/kvUn+yZk2sB34FnAYOHMFzhTwcGsfBm5tV/FcD7zZloEeBbYnWdc+wN3eapKkMRlleWc98KUkZ8b/n6r6apIngUNJdgMvAze38Y8ANwKzwFvAbQBVNZ/kTuDJNu6OqppfsTORJC1pydCvqpeADy9Qfx3YtkC9gL2LHOsAcGD505QkrQS/kStJHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdGTn0k1yS5OkkX27bW5I8nmQ2yReTXNrq72nbs61/89AxPt3qLybZseJnI0k6p+W80/9N4IWh7c8C91TVB4BTwO5W3w2cavV72jiSXAXcAnwI2Al8Lskl5zd9SdJyjBT6STYCHwN+p20H+CjwUBtyELiptXe1bVr/tjZ+F/BgVX2/qr7N4Mbp167AOUiSRjTqO/3/Afxn4B/a9vuBN6rqdNs+Dmxo7Q3AMYDW/2Yb/4P6AvtIksZgydBP8mvAyap6agzzIcmeJDNJZubm5sbxlJLUjVHe6X8E+PUk3wEeZLCs89vA2iRr2piNwInWPgFsAmj97wNeH64vsM8PVNX+qpqsqsmJiYlln5AkaXFLhn5VfbqqNlbVZgYfxH6tqv4N8HXg423YFPBwax9u27T+r1VVtfot7eqeLcBW4IkVOxNJ0pLWLD1kUb8FPJjkM8DTwP2tfj/w+SSzwDyDFwqq6rkkh4DngdPA3qp6+zyeX5K0TMsK/ar6BvCN1n6JBa6+qarvAZ9YZP+7gLuWO0lJ0srwG7mS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUkSVDP8lPJXkiyZ8meS7Jf231LUkeTzKb5ItJLm3197Tt2da/eehYn271F5PsuGBnJUla0Cjv9L8PfLSqPgxcDexMcj3wWeCeqvoAcArY3cbvBk61+j1tHEmuYnC/3A8BO4HPJblkBc9FkrSEJUO/Bv6mbf5k+yvgo8BDrX4QuKm1d7VtWv+2JGn1B6vq+1X1bWCWBe6xK0m6cEZa009ySZJngJPANPAXwBtVdboNOQ5saO0NwDGA1v8m8P7h+gL7SJLGYKTQr6q3q+pqYCODd+cfvFATSrInyUySmbm5uQv1NJLUpWVdvVNVbwBfB34FWJtkTevaCJxo7RPAJoDW/z7g9eH6AvsMP8f+qpqsqsmJiYnlTE+StIRRrt6ZSLK2tX8a+FXgBQbh//E2bAp4uLUPt21a/9eqqlr9lnZ1zxZgK/DECp2HJGkEa5YewhXAwXalzU8Ah6rqy0meBx5M8hngaeD+Nv5+4PNJZoF5BlfsUFXPJTkEPA+cBvZW1dsrezqSpHNZMvSr6lnglxeov8QCV99U1feATyxyrLuAu5Y/TUnSSvAbuZLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktSRUe6RuynJ15M8n+S5JL/Z6pclmU5ytD2ua/UkuTfJbJJnk1wzdKypNv5okqnFnlOSdGGM8k7/NPCfquoq4Hpgb5KrgNuBI1W1FTjStgFuYHDT863AHuA+GLxIAPuA6xjcZnHfmRcKSdJ4LBn6VfVKVX2ztf8aeAHYAOwCDrZhB4GbWnsX8EANPAasTXIFsAOYrqr5qjoFTAM7V/JkJEnntqw1/SSbGdwk/XFgfVW90rpeBda39gbg2NBux1ttsbokaUxGDv0kPwv8PvCpqvqr4b6qKqBWYkJJ9iSZSTIzNze3EoeUJDUjhX6Sn2QQ+L9bVX/Qyq+1ZRva48lWPwFsGtp9Y6stVv8RVbW/qiaranJiYmI55yJJWsIoV+8EuB94oar++1DXYeDMFThTwMND9VvbVTzXA2+2ZaBHge1J1rUPcLe3miRpTNaMMOYjwL8F/izJM632X4C7gUNJdgMvAze3vkeAG4FZ4C3gNoCqmk9yJ/BkG3dHVc2vxElIkkazZOhX1R8DWaR72wLjC9i7yLEOAAeWM0FJ0srxG7mS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUkVHukXsgyckk3xqqXZZkOsnR9riu1ZPk3iSzSZ5Ncs3QPlNt/NEkUws9lyTpwhrlnf7/BnaeVbsdOFJVW4EjbRvgBmBr+9sD3AeDFwlgH3AdcC2w78wLhSRpfJYM/ar6I+DsG5jvAg629kHgpqH6AzXwGLA2yRXADmC6quar6hQwzT9+IZEkXWDvdE1/fVW90tqvAutbewNwbGjc8VZbrC5JGqPz/iC3qgqoFZgLAEn2JJlJMjM3N7dSh5Uk8c5D/7W2bEN7PNnqJ4BNQ+M2ttpi9X+kqvZX1WRVTU5MTLzD6UmSFvJOQ/8wcOYKnCng4aH6re0qnuuBN9sy0KPA9iTr2ge421tNkjRGa5YakOQLwL8ELk9ynMFVOHcDh5LsBl4Gbm7DHwFuBGaBt4DbAKpqPsmdwJNt3B1VdfaHw5KkC2zJ0K+qTy7StW2BsQXsXeQ4B4ADy5qdJGlF+Y1cSeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOrLk7+lL+vG2+favrPYULhrfuftjqz2F8+Y7fUnqyNhDP8nOJC8mmU1y+7ifX5J6NtbQT3IJ8D+BG4CrgE8muWqcc5Ckno37nf61wGxVvVRVfwc8COwa8xwkqVvj/iB3A3BsaPs4cN3wgCR7gD1t82+SvDimufXgcuAvV3sSS8lnV3sGWgX+b66sf7pYx7vu6p2q2g/sX+15XIySzFTV5GrPQzqb/5vjM+7lnRPApqHtja0mSRqDcYf+k8DWJFuSXArcAhwe8xwkqVtjXd6pqtNJfgN4FLgEOFBVz41zDp1z2UzvVv5vjkmqarXnIEkaE7+RK0kdMfQlqSOGviR15F13nb5WTpIPMvjG84ZWOgEcrqoXVm9WklaT7/QvUkl+i8HPXAR4ov0F+II/dKd3syS3rfYcLmZevXORSvJ/gQ9V1d+fVb8UeK6qtq7OzKRzS/LdqrpytedxsXJ55+L1D8DPAy+fVb+i9UmrJsmzi3UB68c5l94Y+hevTwFHkhzlhz9ydyXwAeA3VmtSUrMe2AGcOqse4E/GP51+GPoXqar6apJfYPBz1sMf5D5ZVW+v3swkAL4M/GxVPXN2R5JvjH02HXFNX5I64tU7ktQRQ1+SOmLoS1JHDH1J6oihL0kd+f9oZv3abb7VHgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "val_set.dataset.df.loc[val_set.indices]['HIV_active'].value_counts().plot(kind=\"bar\")" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "4113" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(test_set)" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD1CAYAAAC87SVQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAAsTAAALEwEAmpwYAAARf0lEQVR4nO3df4yd113n8fcH50fRliUOmbWMbdYWNaocJNxq1smq/NFNROKkaB0kqBIhakWRDJIjtRLabcI/gRZLrQRkt1IbyWy8dRFbYxVQrOAlmDQVqlATT6hx44RsZvNjbcuNhzoNVBXZdfjuH/cYLu6M5459fSf1eb+kq3me7znPc88jWZ95fO6586SqkCT14QeWewCSpMkx9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOnLVcg/gQm644YZav379cg9Dkr6vPPvss39bVVPztb2jQ3/9+vXMzMws9zAk6ftKktcWanN6R5I6YuhLUkcMfUnqiKEvSR0ZOfSTrEjy9SSPt/0NSZ5OMpvkD5Jc0+rXtv3Z1r5+6BwPtvqLSW4f+9VIki5oKXf6HwVeGNr/NPBwVb0HeAO4r9XvA95o9YdbP5JsAu4GbgS2Ap9LsuLShi9JWoqRQj/JWuBDwH9r+wFuAb7UuuwF7mrb29o+rf3W1n8bsK+q3qqqV4BZYMsYrkGSNKJR7/T/C/CfgX9s+z8CfLuqzrb9E8Catr0GOA7Q2t9s/f+pPs8x/yTJjiQzSWbm5uZGvxJJ0qIW/XJWkp8FTlfVs0k+eLkHVFW7gd0A09PT3xdPeFn/wJ8s9xCuKK9+6kPLPQTpijXKN3I/APzHJHcC7wL+NfBfgeuSXNXu5tcCJ1v/k8A64ESSq4AfBr41VD9n+BhJ0gQsOr1TVQ9W1dqqWs/gg9gvV9UvAk8BP9+6bQcea9sH2j6t/cs1eCbjAeDutrpnA7AReGZsVyJJWtSl/O2djwP7kvwm8HXg0VZ/FPi9JLPAGQa/KKiqY0n2A88DZ4GdVfX2Jby/JGmJlhT6VfUV4Ctt+2XmWX1TVf8A/MICx+8Cdi11kJKk8fAbuZLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktSRRUM/ybuSPJPkr5McS/Ibrf75JK8kOdJem1s9ST6TZDbJ0STvHzrX9iQvtdf2Bd5SknSZjPK4xLeAW6rqO0muBr6a5H+2tv9UVV86r/8dDB56vhG4CXgEuCnJ9cBDwDRQwLNJDlTVG+O4EEnS4ha906+B77Tdq9urLnDINuAL7bivAdclWQ3cDhyqqjMt6A8BWy9t+JKkpRhpTj/JiiRHgNMMgvvp1rSrTeE8nOTaVlsDHB86/ESrLVQ//712JJlJMjM3N7e0q5EkXdBIoV9Vb1fVZmAtsCXJTwIPAu8F/h1wPfDxcQyoqnZX1XRVTU9NTY3jlJKkZkmrd6rq28BTwNaqOtWmcN4C/juwpXU7CawbOmxtqy1UlyRNyCird6aSXNe2fxD4GeBv2jw9SQLcBTzXDjkAfKSt4rkZeLOqTgFPALclWZlkJXBbq0mSJmSU1Turgb1JVjD4JbG/qh5P8uUkU0CAI8CvtP4HgTuBWeC7wL0AVXUmySeBw63fJ6rqzNiuRJK0qEVDv6qOAu+bp37LAv0L2LlA2x5gzxLHKEkaE7+RK0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0Z5Rm570ryTJK/TnIsyW+0+oYkTyeZTfIHSa5p9Wvb/mxrXz90rgdb/cUkt1+2q5IkzWuUO/23gFuq6qeAzcDW9sDzTwMPV9V7gDeA+1r/+4A3Wv3h1o8km4C7gRuBrcDn2nN3JUkTsmjo18B32u7V7VXALcCXWn0vcFfb3tb2ae23Jkmr76uqt6rqFQYPTt8yjouQJI1mpDn9JCuSHAFOA4eA/w18u6rOti4ngDVtew1wHKC1vwn8yHB9nmOG32tHkpkkM3Nzc0u+IEnSwkYK/ap6u6o2A2sZ3J2/93INqKp2V9V0VU1PTU1drreRpC4tafVOVX0beAr498B1Sa5qTWuBk237JLAOoLX/MPCt4fo8x0iSJmCU1TtTSa5r2z8I/AzwAoPw//nWbTvwWNs+0PZp7V+uqmr1u9vqng3ARuCZMV2HJGkEVy3ehdXA3rbS5geA/VX1eJLngX1JfhP4OvBo6/8o8HtJZoEzDFbsUFXHkuwHngfOAjur6u3xXo4k6UIWDf2qOgq8b576y8yz+qaq/gH4hQXOtQvYtfRhSpLGwW/kSlJHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkdGeUbuuiRPJXk+ybEkH231X09yMsmR9rpz6JgHk8wmeTHJ7UP1ra02m+SBy3NJkqSFjPKM3LPAr1bVXyX5IeDZJIda28NV9VvDnZNsYvBc3BuBHwX+PMlPtObPMniw+gngcJIDVfX8OC5EkrS4UZ6Rewo41bb/PskLwJoLHLIN2FdVbwGvtAekn3uW7mx7ti5J9rW+hr4kTciS5vSTrGfwkPSnW+n+JEeT7EmystXWAMeHDjvRagvVz3+PHUlmkszMzc0tZXiSpEWMHPpJ3g38IfCxqvo74BHgx4HNDP4n8NvjGFBV7a6q6aqanpqaGscpJUnNKHP6JLmaQeD/flX9EUBVvT7U/rvA4233JLBu6PC1rcYF6pKkCRhl9U6AR4EXqup3huqrh7r9HPBc2z4A3J3k2iQbgI3AM8BhYGOSDUmuYfBh74HxXIYkaRSj3Ol/APgl4BtJjrTarwH3JNkMFPAq8MsAVXUsyX4GH9CeBXZW1dsASe4HngBWAHuq6tjYrkSStKhRVu98Fcg8TQcvcMwuYNc89YMXOk6SdHn5jVxJ6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqyCjPyF2X5Kkkzyc5luSjrX59kkNJXmo/V7Z6knwmyWySo0neP3Su7a3/S0m2X77LkiTNZ5Q7/bPAr1bVJuBmYGeSTcADwJNVtRF4su0D3MHgYegbgR3AIzD4JQE8BNwEbAEeOveLQpI0GYuGflWdqqq/att/D7wArAG2AXtbt73AXW17G/CFGvgacF2S1cDtwKGqOlNVbwCHgK3jvBhJ0oUtaU4/yXrgfcDTwKqqOtWavgmsattrgONDh51otYXq57/HjiQzSWbm5uaWMjxJ0iJGDv0k7wb+EPhYVf3dcFtVFVDjGFBV7a6q6aqanpqaGscpJUnNSKGf5GoGgf/7VfVHrfx6m7ah/Tzd6ieBdUOHr221heqSpAkZZfVOgEeBF6rqd4aaDgDnVuBsBx4bqn+kreK5GXizTQM9AdyWZGX7APe2VpMkTchVI/T5APBLwDeSHGm1XwM+BexPch/wGvDh1nYQuBOYBb4L3AtQVWeSfBI43Pp9oqrOjOMiJEmjWTT0q+qrQBZovnWe/gXsXOBce4A9SxmgJGl8/EauJHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdWSUZ+TuSXI6yXNDtV9PcjLJkfa6c6jtwSSzSV5McvtQfWurzSZ5YPyXIklazCh3+p8Hts5Tf7iqNrfXQYAkm4C7gRvbMZ9LsiLJCuCzwB3AJuCe1leSNEGjPCP3L5KsH/F824B9VfUW8EqSWWBLa5utqpcBkuxrfZ9f+pAlSRfrUub0709ytE3/rGy1NcDxoT4nWm2h+vdIsiPJTJKZubm5SxieJOl8Fxv6jwA/DmwGTgG/Pa4BVdXuqpququmpqalxnVaSxAjTO/OpqtfPbSf5XeDxtnsSWDfUdW2rcYG6JGlCLupOP8nqod2fA86t7DkA3J3k2iQbgI3AM8BhYGOSDUmuYfBh74GLH7Yk6WIseqef5IvAB4EbkpwAHgI+mGQzUMCrwC8DVNWxJPsZfEB7FthZVW+389wPPAGsAPZU1bFxX4wk6cJGWb1zzzzlRy/Qfxewa576QeDgkkYnSRorv5ErSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHVk09JPsSXI6yXNDteuTHEryUvu5stWT5DNJZpMcTfL+oWO2t/4vJdl+eS5HknQho9zpfx7Yel7tAeDJqtoIPNn2Ae5g8DD0jcAO4BEY/JJg8Gzdm4AtwEPnflFIkiZn0dCvqr8AzpxX3gbsbdt7gbuG6l+oga8B1yVZDdwOHKqqM1X1BnCI7/1FIkm6zC52Tn9VVZ1q298EVrXtNcDxoX4nWm2h+vdIsiPJTJKZubm5ixyeJGk+l/xBblUVUGMYy7nz7a6q6aqanpqaGtdpJUlcfOi/3qZtaD9Pt/pJYN1Qv7WttlBdkjRBFxv6B4BzK3C2A48N1T/SVvHcDLzZpoGeAG5LsrJ9gHtbq0mSJuiqxTok+SLwQeCGJCcYrML5FLA/yX3Aa8CHW/eDwJ3ALPBd4F6AqjqT5JPA4dbvE1V1/ofDkqTLbNHQr6p7Fmi6dZ6+Bexc4Dx7gD1LGp0kaaz8Rq4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR15JJCP8mrSb6R5EiSmVa7PsmhJC+1nytbPUk+k2Q2ydEk7x/HBUiSRjeOO/3/UFWbq2q67T8APFlVG4En2z7AHcDG9toBPDKG95YkLcHlmN7ZBuxt23uBu4bqX6iBrwHXJVl9Gd5fkrSASw39Av4sybNJdrTaqqo61ba/Caxq22uA40PHnmi1fyHJjiQzSWbm5uYucXiSpGFXXeLxP11VJ5P8G+BQkr8ZbqyqSlJLOWFV7QZ2A0xPTy/pWEnShV3SnX5VnWw/TwN/DGwBXj83bdN+nm7dTwLrhg5f22qSpAm56NBP8q+S/NC5beA24DngALC9ddsOPNa2DwAfaat4bgbeHJoGkiRNwKVM76wC/jjJufP8j6r60ySHgf1J7gNeAz7c+h8E7gRmge8C917Ce0uSLsJFh35VvQz81Dz1bwG3zlMvYOfFvp8k6dL5jVxJ6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1JFLfVyipHe49Q/8yXIP4Yrx6qc+tNxDuGTe6UtSRwx9SerIxEM/ydYkLyaZTfLApN9fkno20dBPsgL4LHAHsAm4J8mmSY5Bkno26Tv9LcBsVb1cVf8X2Adsm/AYJKlbk169swY4PrR/ArhpuEOSHcCOtvudJC9OaGw9uAH42+UexGLy6eUegZbJO/7f5/fRv81/u1DDO27JZlXtBnYv9ziuRElmqmp6ucchzcd/n5Mx6emdk8C6of21rSZJmoBJh/5hYGOSDUmuAe4GDkx4DJLUrYlO71TV2ST3A08AK4A9VXVskmPonNNmeifz3+cEpKqWewySpAnxG7mS1BFDX5I6YuhLUkfecev0NT5J3svgG89rWukkcKCqXli+UUlaTt7pX6GSfJzBn7kI8Ex7Bfiif+hO72RJ7l3uMVzJXL1zhUryv4Abq+r/nVe/BjhWVRuXZ2TShSX5P1X1Y8s9jiuV0ztXrn8EfhR47bz66tYmLZskRxdqAlZNciy9MfSvXB8DnkzyEv/8R+5+DHgPcP9yDUpqVgG3A2+cVw/wl5MfTj8M/StUVf1pkp9g8Oeshz/IPVxVby/fyCQAHgfeXVVHzm9I8pWJj6YjzulLUkdcvSNJHTH0Jakjhr4kdcTQl6SOGPqS1JH/D5bnxSy4KcdIAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "test_set.dataset.df.loc[test_set.indices]['HIV_active'].value_counts().plot(kind=\"bar\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "All three datasets (train, test, validation) have a similar class distribution." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Upload Data to S3\n", "In order to accomodate model training on SageMaker we need to upload the data to s3 location. We are going to use the sagemaker.Session.upload_data function to upload our datasets to an S3 location. The return value inputs identifies the location -- we will use later when we start the training job." ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "import sagemaker\n", "from sagemaker import get_execution_role\n", "\n", "role = get_execution_role()\n", "session = sagemaker.Session()\n", "bucket = session.default_bucket()\n", "\n", "s3_prefix = \"./hiv_inhibitor_prediction/sagemaker\"\n", "\n", "dataset.df.to_csv(\"full.csv\", index=False)\n", "pd.DataFrame(train_set.indices, columns =[\"indices\"]).to_csv(\"train.csv\", index=False)\n", "pd.DataFrame(val_set.indices, columns =[\"indices\"]).to_csv(\"validation.csv\", index=False)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "input_full = session.upload_data(path=\"full.csv\", bucket=bucket, key_prefix=s3_prefix)\n", "input_train = session.upload_data(path=\"train.csv\", bucket=bucket, key_prefix=s3_prefix)\n", "input_val = session.upload_data(path=\"validation.csv\", bucket=bucket, key_prefix=s3_prefix)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see `GCNPredictor` architecture compised of multple layers of `gnn_layers` which itself comprised of DGL `GraphConv`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train on SageMaker\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Explore the Model Architecture\n", "\n", "We are going to represent each of the modelcule as a graph with each atom as a node. The atom properties will be the node features after doing the data transformations. We will then use these features to classify the whole graph/molecule as whether it inhibits HIV virus replication or not using graph neural networks (GNNs). In GNN terms this is considered as graph classification problem.\n", "\n", "We will use prebuilt GCN and GAT model architectures packaged with DGL-LifeSci to train the model. Please refer to the GCNPredictor [documentation](https://lifesci.dgl.ai/_modules/dgllife/model/model_zoo/gcn_predictor.html) and [code](https://github.com/awslabs/dgl-lifesci/blob/master/python/dgllife/model/model_zoo/gcn_predictor.py) for more information." ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "GCNPredictor(\n", " (gnn): GCN(\n", " (gnn_layers): ModuleList(\n", " (0): GCNLayer(\n", " (graph_conv): GraphConv(in=10, out=10, normalization=none, activation=)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (bn_layer): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (1): GCNLayer(\n", " (graph_conv): GraphConv(in=10, out=4, normalization=none, activation=)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (bn_layer): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " )\n", " (readout): WeightedSumAndMax(\n", " (weight_and_sum): WeightAndSum(\n", " (atom_weighting): Sequential(\n", " (0): Linear(in_features=4, out_features=1, bias=True)\n", " (1): Sigmoid()\n", " )\n", " )\n", " )\n", " (predict): MLPPredictor(\n", " (predict): Sequential(\n", " (0): Dropout(p=0.0, inplace=False)\n", " (1): Linear(in_features=8, out_features=128, bias=True)\n", " (2): ReLU()\n", " (3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (4): Linear(in_features=128, out_features=1, bias=True)\n", " )\n", " )\n", ")" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from dgllife.model import GCNPredictor\n", "import torch.nn.functional as F\n", "\n", "model = GCNPredictor(\n", " in_feats=10,\n", " hidden_feats=[10, 4],\n", " activation=[F.relu, F.relu],\n", " residual=[False] * 2\n", " )\n", "\n", "model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### View Training Script\n", "\n", "We are going to use Pytorch as the DGL backend. Our training script will save model training artifacts to a file path called `model_dir`, as defined by the SageMaker PyTorch image. When training is finished, SageMaker will upload the model artifacts saved in `model_dir` to S3 for later deployment.\n", "\n", "We save this script in a file named `code/train.py`. \n", " " ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\u001b[0;32mimport\u001b[0m \u001b[0margparse\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;32mimport\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;32mimport\u001b[0m \u001b[0mjson\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;32mimport\u001b[0m \u001b[0mdgl\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;32mimport\u001b[0m \u001b[0mpandas\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;32mfrom\u001b[0m \u001b[0mdgllife\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mEarlyStopping\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMeter\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;32mfrom\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdouble\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptim\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mAdam\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;32mfrom\u001b[0m \u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mload_model\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minit_featurizers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_saved_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_params_saved_path\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;32mfrom\u001b[0m \u001b[0ms3_downloaded_HIV_dataset\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mS3DownloadedHIVDataset\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;32mdef\u001b[0m \u001b[0mrun_a_train_epoch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss_criterion\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mtrain_meter\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMeter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbatch_id\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_data\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0msmiles\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmasks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch_data\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msmiles\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;31m# Avoid potential issues with batch normalization\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mcontinue\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmasks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmasks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mlogits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;31m# Mask non-existing labels\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mloss_criterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mmasks\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mtrain_meter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmasks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mbatch_id\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprint_every\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'epoch [{:d}] of [{:d}], batch {:d}/{:d}, loss [{:.4f}]'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mepochs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_id\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mtrain_score\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_meter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_metric\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetric\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'epoch [{:d}] of [{:d}], training:{} [{:.4f}]'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mepochs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetric\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_score\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;32mdef\u001b[0m \u001b[0mrun_an_eval_epoch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0meval_meter\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMeter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbatch_id\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_data\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0msmiles\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmasks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch_data\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mlogits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0meval_meter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmasks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meval_meter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_metric\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetric\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;32mdef\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Hello I am training with following args.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_available\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mdevice\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'cuda:0'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mdevice\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'cpu'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Device : [{}]\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mnode_featurizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0medge_featurizer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minit_featurizers\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgnn_featurizer_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mS3DownloadedHIVDataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_data\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mnode_featurizer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnode_featurizer\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0medge_featurizer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0medge_featurizer\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mn_jobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_workers\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_workers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mtrain_set\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msplit_dataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mtrain_loader\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrain_set\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshuffle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mcollate_fn\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcollate_molgraphs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_workers\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_workers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mval_loader\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mval_set\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mcollate_fn\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcollate_molgraphs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_workers\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_workers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode_featurizer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mloss_criterion\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mBCEWithLogitsLoss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreduction\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'none'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0moptimizer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mAdam\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlearning_rate\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mweight_decay\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight_decay\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mstopper\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mEarlyStopping\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpatience\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpatience\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmodel_saved_path\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mmetric\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetric\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mepochs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;31m# Train\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mrun_a_train_epoch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss_criterion\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;31m# Validation and early stop\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mval_score\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrun_an_eval_epoch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mearly_stop\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstopper\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_score\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'epoch [{:d}] of [{:d}], validation:{} [{:.4f}], best validation:{} [{:.4f}]'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mepochs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetric\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mval_score\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetric\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstopper\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbest_score\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mearly_stop\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0msave_model_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;32mdef\u001b[0m \u001b[0msplit_dataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mtrain_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdgl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSubset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_csv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_data\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\"/train.csv\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_numpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mval_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdgl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSubset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_csv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mval_data\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\"/validation.csv\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_numpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtrain_set\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_set\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;32mdef\u001b[0m \u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mbg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mnode_feats\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'h'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;31m#edge_feats = bg.edata.pop('e').to(args['device'])\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode_feats\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;32mdef\u001b[0m \u001b[0mcollate_molgraphs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;34m\"\"\"Batching a list of datapoints for dataloader.\u001b[0m\n", "\u001b[0;34m Parameters\u001b[0m\n", "\u001b[0;34m ----------\u001b[0m\n", "\u001b[0;34m data : list of 4-tuples.\u001b[0m\n", "\u001b[0;34m Each tuple is for a single datapoint, consisting of\u001b[0m\n", "\u001b[0;34m a SMILES, a DGLGraph, all-task labels and optionally a binary\u001b[0m\n", "\u001b[0;34m mask indicating the existence of labels.\u001b[0m\n", "\u001b[0;34m Returns\u001b[0m\n", "\u001b[0;34m -------\u001b[0m\n", "\u001b[0;34m smiles : list\u001b[0m\n", "\u001b[0;34m List of smiles\u001b[0m\n", "\u001b[0;34m bg : DGLGraph\u001b[0m\n", "\u001b[0;34m The batched DGLGraph.\u001b[0m\n", "\u001b[0;34m labels : Tensor of dtype float32 and shape (B, T)\u001b[0m\n", "\u001b[0;34m Batched datapoint labels. B is len(data) and\u001b[0m\n", "\u001b[0;34m T is the number of total tasks.\u001b[0m\n", "\u001b[0;34m masks : Tensor of dtype float32 and shape (B, T)\u001b[0m\n", "\u001b[0;34m Batched datapoint binary mask, indicating the\u001b[0m\n", "\u001b[0;34m existence of labels.\u001b[0m\n", "\u001b[0;34m \"\"\"\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0msmiles\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgraphs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmasks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mbg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdgl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraphs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mbg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_n_initializer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdgl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_initializer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mbg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_e_initializer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdgl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_initializer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mmasks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmasks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0msmiles\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmasks\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;32mdef\u001b[0m \u001b[0msave_model_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mfilename\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel_params_saved_path\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mgnn_params_keys\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfilter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0marg\u001b[0m \u001b[0;34m:\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstartswith\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"gnn\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvars\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mgnn_params\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mkey\u001b[0m \u001b[0;34m:\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__dict__\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgnn_params_keys\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mfile\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"w\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mjson\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdump\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgnn_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfile\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mfile\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;32mif\u001b[0m \u001b[0m__name__\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"__main__\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Loading Parameters\\n\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mparser\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0margparse\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mArgumentParser\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'HIV Inhibitor Binary Classification'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;31m# Feature engineering hyper-params\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'-f'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'--gnn-featurizer-type'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchoices\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'canonical'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'attentivefp'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'canonical'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mhelp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'Featurization for atoms (and bonds). This is required for models '\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;34m'other than gin_supervised_**.'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;31m# model evaluation hyper-params\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'-me'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'--metric'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchoices\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'roc_auc_score'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'pr_auc_score'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'roc_auc_score'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mhelp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'Metric for evaluation (default: roc_auc_score)'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;31m# model architecture hyper-params\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'-nw'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'--num-workers'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mhelp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'Number of processes for data loading (default: 1)'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'-mn'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'--gnn-model-name'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchoices\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'GCN-p'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'GAT-p'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'GCN-p'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mhelp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'DGL Life model implementation to be used. '\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'-hl'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'--gnn-hidden-feats'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m256\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mhelp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'No of hidden GCNLayers to to be use i.e hoe many nerighers to be considered.'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'-res'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'--gnn-residuals'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbool\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mhelp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'Whether to use residual connections in the GCNLayer or not.'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'-batchnorm'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'--gnn-batchnorm'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbool\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mhelp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'Whether to use batch norm in each GCNLayer or not.'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'-dropout'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'--gnn-dropout'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdouble\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.001\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mhelp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'Drop out percentage'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'-al'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'--gnn-alphas'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdouble\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.08\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mhelp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'Alphas'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'-nh'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'--gnn-num-heads'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdouble\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m8\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mhelp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'Number of heads'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'-predictor_hidden_feats'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'--gnn-predictor-hidden-feats'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m512\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mhelp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m''\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;31m# Training hyper-params\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'-bs'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'--batch-size'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m512\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mhelp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'Batch size for the data loaders (default : 32)'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'--epochs'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1000\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mhelp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'Maximum number of epochs for training. '\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;34m'We set a large number by default as early stopping '\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;34m'will be performed. (default: 3)'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'-lr'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'--learning-rate'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdouble\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.001\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mhelp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'Learning rate'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'-wd'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'--weight-decay'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdouble\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.001\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mhelp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'Weight decay.'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'-patience'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'--patience'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m30\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mhelp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m''\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;31m# Monitoring params\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'-pe'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'--print-every'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mhelp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'Print the training progress every X mini-batches'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'-md'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'--mode'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"local\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mhelp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'Mode of running this script [sm, local]'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;31m# Container environment\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"--model-dir\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menviron\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"SM_MODEL_DIR\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"--full-data\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menviron\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"SM_CHANNEL_DATA_FULL\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"--train-data\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menviron\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"SM_CHANNEL_DATA_TRAIN\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"--val-data\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menviron\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"SM_CHANNEL_DATA_VAL\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;31m#parser.add_argument(\"--num-gpus\", type=int, default=os.environ[\"SM_NUM_GPUS\"])\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparse_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%pycat code/train.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create an Experiment\n", "\n", "SageMaker Experiments enable us to organize related trials for later comparison." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from smexperiments.experiment import Experiment\n", "\n", "experiment_name = \"HIV-Inhibitor-Prediction-Experiment\"\n", "\n", "try :\n", " experiment = Experiment.load(experiment_name)\n", "except Exception as e:\n", " if e.response['Error']['Code'] == \"ResourceNotFound\":\n", " print(\"Experiment name [{}], does not exists. Hence creating.\".format(experiment_name))\n", " experiment = Experiment.create(experiment_name = experiment_name,\n", " description = \"Experiment to track the HIV inhibitor prediction trials.\")\n", " else:\n", " raise e\n", "experiment " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define Hyperparameters" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, let's define hyperparameters assocated with the model. As per the training script we created above here are some of the hyperparameters that we can use to tune our model(s). One highlight here is that the model architecture is also given here as the hyperparameter which allows other model architecutures like GAT or MPNN. " ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "from sagemaker.pytorch import PyTorch\n", "import time\n", "\n", "hyperparameters={\n", " # Feature Engineering\n", " \"gnn-featurizer-type\": 'canonical',\n", " \n", " # Model Architecture\n", " \"gnn-model-name\" : 'GCN-p',\n", " \"gnn-residuals\" : False,\n", " \"gnn-batchnorm\" : True,\n", " \"gnn-dropout\" : 0.0013086019242321,\n", " \"gnn-predictor-hidden-feats\" : 512,\n", " \n", " # Training\n", " \"batch-size\" : 512,\n", " \"epochs\" : 20,\n", " \"learning-rate\" : 0.000508635928951698,\n", " \"weight-decay\" : 0.0013253058161908312,\n", " \"patience\" : 30\n", "}\n", "\n", "metric_definitions =[\n", " {'Name': 'train:roc_auc_score', 'Regex': 'training:roc_auc_score\\s\\[([0-9\\\\,.]+)\\]'},\n", " {'Name': 'validation:roc_auc_score', 'Regex': ',\\svalidation:roc_auc_score\\s\\[([0-9\\\\,.]+)\\]'},\n", " {'Name': 'best validation:roc_auc_score', 'Regex': 'best\\svalidation:roc_auc_score\\s\\[([0-9\\\\,.]+)\\]'},\n", " {'Name': 'epoch', 'Regex': 'epoch\\s\\[([0-9]+)\\]'},\n", " {'Name': 'train:loss', 'Regex': 'loss\\s\\[([0-9\\\\,.]+)\\]'}\n", " ]\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create Trial\n", "\n", "From here onwards each model training that we are going to perform will be tracked under the experiement that we created above as a seperate trial." ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "trial_name = \"hiv-inhibitor-prediction-training-{}-{}\".format(hyperparameters[\"gnn-model-name\"], time.strftime(\"%m-%d-%Y-%H-%M-%S\"))\n", "trial = experiment.create_trial(trial_name=trial_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit Training Job\n", "\n", "DGL supports Tensorflow, PyTorch and MXNet runtimes. For for this project we are going to use `PyTorch` as the backend.\n", "\n", "The Amazon SageMaker Python SDK makes it easier to run a PyTorch script in Amazon SageMaker using its PyTorch estimator. We can also use the SageMaker Python SDK to deploy the trained model and run predictions. For more information on how to use this SDK with PyTorch, see the [SageMaker Python SDK documentation](https://sagemaker.readthedocs.io/en/stable/).\n", "\n", "To start, we use the PyTorch estimator class to train our model. When creating our estimator, we make sure to specify a few things:\n", "\n", "* `entry_point`: The name of our PyTorch script. It contains our training code which loads data from the input channels, configures training with hyperparameters, runs the training loop, and saves the model artifacts. It also contains code to load and run the model during inference.\n", "\n", "* `source_dir`: The location of our training scripts and `requirements.txt` file. The `requirements.txt` file lists packages you want to use with your script.\n", "\n", "* `framework_version`: The PyTorch version we want to use. The PyTorch estimator supports both single-machine & multi-machine, distributed PyTorch training using `SMDataParallel`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "training_job_name = \"tr-{}\".format(trial_name) \n", "print('Training job name: ', training_job_name)\n", "\n", "estimator = PyTorch(\n", " entry_point = \"train.py\",\n", " source_dir = \"code\",\n", " role = role,\n", " framework_version = \"1.9.0\",\n", " py_version=\"py38\",\n", " instance_count=1,\n", " instance_type=\"ml.p3.2xlarge\",\n", " debugger_hook_config=False,\n", " disable_profiler=True,\n", " hyperparameters = hyperparameters,\n", " metric_definitions=metric_definitions\n", ")\n", "\n", "estimator.fit({\"data_full\" : input_full, \"data_train\" : input_train, \"data_val\" : input_val}, \n", " job_name = training_job_name,\n", " experiment_config = {\n", " \"TrialName\" : trial.trial_name,\n", " \"TrialComponentDisplayName\" : \"TrainingJob\",\n", " })" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### View Training Results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The initial training should produce a model with a validation AUC value between 75% and 80%. You can see more metrics like the training loss, validation loss, and ROC score over time in the Experiements view in SageMaker Studio." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Before we deploy the model to an endpoint, let's see the where the trained model artifacts are stored in S3.\n" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Stored s3://sagemaker-us-west-2-167428594774/tr-hiv-inhibitor-prediction-training-GCN-p-12-15-2022-15-37-52/output/model.tar.gz as model_data\n" ] } ], "source": [ "model_data = estimator.model_data\n", "print(\"Stored {} as model_data\".format(model_data))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Deploy the Model on Amazon SageMaker\n", "\n", "After training our model, we host it as an Amazon SageMaker Endpoint. We need to implement a few methods in `inference.py` for the endpoint to load the model and serve predictions correctly.\n", "\n", "* `model_fn()`: Loads the saved model and return a model object that can be used for model serving. The SageMaker PyTorch model server loads our model by invoking `model_fn`.\n", "* `input_fn()`: Deserializes and prepares the prediction input. In this example, our request body is first serialized to a JSON object containing SMILES strings for the target molecules. The `input_fn()` function first processes the graph using DGL. It then adds the features to each node using the same featurizer used during training. Finally, the function returns the graph with features in the format required by the model.\n", "* `predict_fn()`: Performs the prediction and returns the result. To deploy our endpoint, we call `deploy()` on our PyTorch estimator object, passing in our desired number of instances and instance type." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create a Real-Time Inference Endpoint" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Endpoint name: HIV-Inhibitor-Prediction-EP-12-15-2022-15-48-22\n" ] } ], "source": [ "from sagemaker.pytorch import PyTorchModel\n", "\n", "endpoint_name = \"HIV-Inhibitor-Prediction-EP-{}\".format(time.strftime(\"%m-%d-%Y-%H-%M-%S\"))\n", "print(\"Endpoint name: \", endpoint_name)\n", "\n", "model = PyTorchModel(model_data=model_data, source_dir='code',\n", " entry_point='inference.py', role=role, framework_version=\"1.9.0\", py_version='py38')" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:sagemaker:Creating model with name: pytorch-inference-2022-12-15-15-48-24-925\n", "INFO:sagemaker:Creating endpoint-config with name HIV-Inhibitor-Prediction-EP-12-15-2022-15-48-22\n", "INFO:sagemaker:Creating endpoint with name HIV-Inhibitor-Prediction-EP-12-15-2022-15-48-22\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "------!" ] } ], "source": [ "predictor = model.deploy(initial_instance_count=1, instance_type=\"ml.c5.xlarge\", endpoint_name=endpoint_name)\n", "\n", "predictor.serializer = sagemaker.serializers.JSONSerializer()\n", "predictor.deserializer = sagemaker.deserializers.JSONDeserializer()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Test the Endpoint With a Single Target Molecule\n", "\n", "First, let's define a function to convert molecular data into SMILES format." ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [], "source": [ "def collate_molgraphs(data):\n", " \"\"\"Batching a list of datapoints for dataloader.\n", " Parameters\n", " ----------\n", " data : list of 4-tuples.\n", " Each tuple is for a single datapoint, consisting of\n", " a SMILES, a DGLGraph, all-task labels and optionally a binary\n", " mask indicating the existence of labels.\n", " Returns\n", " -------\n", " smiles : list\n", " List of SMILES strings\n", " bg : DGLGraph\n", " The batched DGLGraph.\n", " labels : Tensor of dtype float32 and shape (B, T)\n", " Batched datapoint labels. B is len(data) and\n", " T is the number of total tasks.\n", " masks : Tensor of dtype float32 and shape (B, T)\n", " Batched datapoint binary mask, indicating the\n", " existence of labels.\n", " \"\"\"\n", "\n", " smiles, graphs, labels, masks = map(list, zip(*data))\n", " \n", " bg = dgl.batch(graphs)\n", " bg.set_n_initializer(dgl.init.zero_initializer)\n", " bg.set_e_initializer(dgl.init.zero_initializer)\n", " labels = torch.stack(labels, dim=0)\n", " masks = torch.stack(masks, dim=0)\n", "\n", " return smiles, bg, labels, masks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can use the `collate_molgraphs()` function to process the test data." ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "test_smiles, bg, test_labels, masks = collate_molgraphs(test_set)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's use our deployed endpoint to classify a single molecule from the test set. First, we select a random target." ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mol = Chem.MolFromSmiles(test_smiles[110])\n", "mol" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, let's serialize the molecule data into JSON format and submit it to the endpoint." ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[-7.176329612731934]" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "json = {\n", " \"smiles\" : \n", " [test_smiles[110]]\n", "}\n", "\n", "prediction_logits = predictor.predict(json)\n", "prediction_logits" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The endpoint response returns the logit value that our target molecule is an inhibitor." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Test the Endpoint With Multiple Targets\n" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "json = {\n", " \"smiles\" : \n", " test_smiles\n", "}\n", "\n", "prediction_logits = predictor.predict(json)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Examine the Test Results " ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.722756329786207" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.metrics import roc_curve, roc_auc_score\n", "import matplotlib.pyplot as plt \n", "\n", "roc_auc_score(test_labels[:,0].numpy(), np.asarray(prediction_logits))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The ROC AUC score for our initial testing should be around 0.75. Let's plot the ROC curve." ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fpr, tpr, _ = roc_curve(test_labels[:,0].numpy(), np.asarray(prediction_logits))\n", "\n", "plt.plot(fpr, tpr)\n", "plt.title(\"ROC Plot\")\n", "plt.xlabel(\"False Positive Rate\")\n", "plt.ylabel(\"True Positive Rate\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Hyperparameter Tuning \n", "\n", "So far we have trained a single model with fixed hyperparameters. Next let's try to further optimize the model using [Amazon SageMaker Hyperparameter Tuning](https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-how-it-works.html)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit Hyperparameter Tuning Job" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training job name: hpo-hiv-gcn-p-12-15-15-51-40\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:sagemaker.estimator:No finished training job found associated with this estimator. Please make sure this estimator is only used for building workflow config\n", "INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.\n", "INFO:sagemaker:Creating hyperparameter tuning job with name: hpo-hiv-gcn-p-12-15-15-51-40\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "..............................................................................................................................................................................................................!\n" ] } ], "source": [ "from sagemaker.tuner import (\n", " IntegerParameter,\n", " CategoricalParameter,\n", " ContinuousParameter,\n", " HyperparameterTuner,\n", ")\n", "gcn_hyperparameter_ranges = {\n", " \n", " \"gnn-dropout\": ContinuousParameter(0.001 , 0.003),\n", " \"gnn-predictor-hidden-feats\" : CategoricalParameter([128, 256, 512]),\n", " \n", " \"batch-size\" : CategoricalParameter([256, 512]),\n", " \"learning-rate\" : ContinuousParameter(0.0001, 0.001),\n", " \"weight-decay\" : ContinuousParameter(0.001, 0.01)\n", " \n", "}\n", "\n", "objective_metric_name = \"best validation:roc_auc_score\"\n", "\n", "gcn_estimator = PyTorch(\n", " entry_point = \"train.py\",\n", " source_dir = \"code\",\n", " role = role,\n", " framework_version = \"1.9.0\",\n", " py_version=\"py38\",\n", " instance_count=1,\n", " instance_type=\"ml.p3.2xlarge\",\n", " debugger_hook_config=False,\n", " disable_profiler=True\n", ")\n", "\n", "gcn_tuner = HyperparameterTuner(\n", " gcn_estimator,\n", " objective_metric_name,\n", " gcn_hyperparameter_ranges,\n", " metric_definitions,\n", " max_jobs=1,\n", " max_parallel_jobs=1\n", ")\n", "\n", "hyper_parameter_job_name = \"hpo-hiv-gcn-p-{}\".format(time.strftime(\"%m-%d-%H-%M-%S\")) \n", "print('Training job name: ', hyper_parameter_job_name)\n", "\n", "gcn_tuner.fit({\"data_full\" : input_full, \"data_train\" : input_train, \"data_val\" : input_val}, job_name = hyper_parameter_job_name)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If your notebook loses its connection, you can reattach it by specifying the job name, like `gcn_tuner = HyperparameterTuner.attach(\"hpo-hiv-gcn-p-03-16-02-38-07\")`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Evaluate Tuned Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's examine the best model and its hyperparameters." ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'TrainingJobName': 'hpo-hiv-gcn-p-12-15-15-51-40-001-bcb06c6b',\n", " 'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:167428594774:training-job/hpo-hiv-gcn-p-12-15-15-51-40-001-bcb06c6b',\n", " 'CreationTime': datetime.datetime(2022, 12, 15, 15, 51, 45, tzinfo=tzlocal()),\n", " 'TrainingStartTime': datetime.datetime(2022, 12, 15, 15, 53, 59, tzinfo=tzlocal()),\n", " 'TrainingEndTime': datetime.datetime(2022, 12, 15, 16, 6, 24, tzinfo=tzlocal()),\n", " 'TrainingJobStatus': 'Completed',\n", " 'TunedHyperParameters': {'batch-size': '\"256\"',\n", " 'gnn-dropout': '0.001047059471760684',\n", " 'gnn-predictor-hidden-feats': '\"128\"',\n", " 'learning-rate': '0.0008192232031754262',\n", " 'weight-decay': '0.00485718776241423'},\n", " 'FinalHyperParameterTuningJobObjectiveMetric': {'MetricName': 'best validation:roc_auc_score',\n", " 'Value': 0.7642999887466431},\n", " 'ObjectiveStatus': 'Succeeded'}" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import boto3\n", "\n", "smclient = boto3.client(\"sagemaker\")\n", "\n", "best_overall_training_job = smclient.describe_hyper_parameter_tuning_job(\n", " HyperParameterTuningJobName=hyper_parameter_job_name\n", ")\n", "\n", "best_overall_training_job[\"BestTrainingJob\"]" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "2022-12-15 16:06:27 Starting - Preparing the instances for training\n", "2022-12-15 16:06:27 Downloading - Downloading input data\n", "2022-12-15 16:06:27 Training - Training image download completed. Training in progress.\n", "2022-12-15 16:06:27 Uploading - Uploading generated training model\n", "2022-12-15 16:06:27 Completed - Resource retained for reuse\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:sagemaker:Creating model with name: pytorch-inference-2022-12-15-16-09-04-003\n", "INFO:sagemaker:Creating endpoint-config with name best-gcn-HIV-Inhibitor-Prediction-EP-12-15-2022-15-48-22\n", "INFO:sagemaker:Creating endpoint with name best-gcn-HIV-Inhibitor-Prediction-EP-12-15-2022-15-48-22\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "-----!" ] } ], "source": [ "best_gcn_training_job = sagemaker.estimator.Estimator.attach(best_overall_training_job[\"BestTrainingJob\"][\"TrainingJobName\"])\n", "best_gcn_model = PyTorchModel(model_data=best_gcn_training_job.model_data, source_dir='code',\n", " entry_point='inference.py', role=role, framework_version=\"1.9.0\", py_version='py38')\n", "\n", "best_gcn_predictor = best_gcn_model.deploy(initial_instance_count=1, instance_type=\"ml.c5.xlarge\", endpoint_name=\"best-gcn-\" + endpoint_name)\n", "best_gcn_predictor.serializer = sagemaker.serializers.JSONSerializer()\n", "best_gcn_predictor.deserializer = sagemaker.deserializers.JSONDeserializer()" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.7037138608316114" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "json = {\n", " \"smiles\" : \n", " test_smiles\n", "}\n", "\n", "prediction_logits = best_gcn_predictor.predict(json)\n", "roc_auc_score(test_labels[:,0].numpy(), np.asarray(prediction_logits))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Clean up \n", "\n", "Lastly, please remember to delete the Amazon SageMaker endpoint to avoid charges:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "best_gcn_predictor.delete_endpoint()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "predictor.delete_endpoint()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "import os\n", "for x in ['full.csv', 'train.csv', 'validation.csv', 'hiv_dglgraph.bin']: os.remove(x) " ] } ], "metadata": { "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "Python 3 (PyTorch 1.12 Python 3.8 CPU Optimized)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/pytorch-1.12-cpu-py38" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" }, "vscode": { "interpreter": { "hash": "a8534c14445fc6cdc3039d8140510d6736e5b4960d89f445a45d8db6afd8452b" } } }, "nbformat": 4, "nbformat_minor": 4 }