{ "cells": [ { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "# Causal Inference with Bayesian Networks" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "**Jupyter Kernel**:\n", "\n", "\n", "* If you are in SageMaker Studio, make sure that you use the **PyTorch 1.10 Python 3.8 CPU Optimized** environment.\n", "* Make sure that you are using `ml.g4dn.xlarge` or `ml.m5.large` as an instance type.\n", "\n", "**Run All**: \n", "\n", "* If you are in SageMaker Studio, you can choose the **Run All Cells** from the **Run** tab dropdown menu to run the entire notebook at once." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "tags": [], "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Install dependencies that will be used in this notebook.\n", "!pip3 install -r ./utils/requirements.in -q" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "tags": [], "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "!conda install -c conda-forge pygraphviz -y" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "This solution relies on a config file to run the provisioned AWS resources. Run the cells below to generate that file." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" }, "tags": [], "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import boto3\n", "import os\n", "import json" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "tags": [], "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "client = boto3.client('servicecatalog')\n", "cwd = os.getcwd().split('/')\n", "i= cwd.index('S3Downloads')\n", "pp_name = cwd[i + 1]\n", "pp = client.describe_provisioned_product(Name=pp_name)\n", "record_id = pp['ProvisionedProductDetail']['LastSuccessfulProvisioningRecordId']\n", "record = client.describe_record(Id=record_id)\n", "\n", "keys = [ x['OutputKey'] for x in record['RecordOutputs'] if 'OutputKey' and 'OutputValue' in x]\n", "values = [ x['OutputValue'] for x in record['RecordOutputs'] if 'OutputKey' and 'OutputValue' in x]\n", "stack_output = dict(zip(keys, values))\n", "\n", "with open(f'/root/S3Downloads/{pp_name}/stack_outputs.json', 'w') as f:\n", " json.dump(stack_output, f)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "sagemaker_config = json.load(open(\"stack_outputs.json\"))\n", "\n", "SOLUTION_BUCKET = sagemaker_config[\"SolutionS3Bucket\"]\n", "AWS_REGION = sagemaker_config[\"AWSRegion\"]\n", "SOLUTION_NAME = sagemaker_config[\"SolutionName\"]\n", "SOLUTION_PREFIX = sagemaker_config[\"SolutionPrefix\"]\n", "AWS_S3_BUCKET = sagemaker_config[\"S3Bucket\"]\n", "\n", "KEY_YIELD_CURVE = \"data/raw/yield_curve_field_dt.csv\"\n", "SPATIAL_FILES_KEY = \"data/spatial-files\"\n", "FIPS_STATS_KEY = \"data/fips-stats/fips_county_stats.csv\"\n", "FIPS_POLYGONS_KEY = \"data/fips-stats/geojson-counties-fips.json\"\n", "SENTINEL_2_SHAPEFILE_KEY = \"data/sentinel-2-shapefiles\"\n", "CROPS_MASK_KEY = \"data/crop_mask/raw\"\n", "REQUEST_MANIFESTS_KEY = \"request_manifests/\"\n", "\n", "DAG_PATH = 'models/bn_structure.gml'\n", "MODEL_PATH = 'models/bayesian_model.bif'\n", "STATES_PATH = 'models/node_states.json'\n", "NUMERICAL_SPLIT_POINTS_PATH = \"models/numerical_split_points.json\"\n", "\n", "if not os.path.exists('models'):\n", " os.makedirs('models')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Set up the environment" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import json\n", "import datetime\n", "import matplotlib.pyplot as plt\n", "import boto3\n", "import io\n", "import os\n", "import s3fs\n", "import itertools as it\n", "import networkx as nx\n", "from time import time\n", "import geopandas as gpd\n", "import copy\n", "import bisect\n", "from typing import Dict\n", "import warnings\n", "import base64\n", "from PIL import Image\n", "import datetime\n", "from time import gmtime, strftime\n", "import urllib\n", "\n", "import sagemaker\n", "import boto3\n", "from botocore.exceptions import ClientError\n", "\n", "# from utils.plot_functions import visualize_structure\n", "from utils.causalnex_helpers import (\n", " quantile_discretiser,\n", " generate_dag_constraints,\n", " discretiser_inverse_transform,\n", " format_inference_output\n", ")\n", "\n", "from utils.plot_functions import (\n", " plot_pretty_structure\n", ")\n", "\n", "from utils.helper_functions import download_s3_folder\n", "\n", "warnings.simplefilter('ignore')\n", "\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Define a few variables to use throughout the notebook\n", "\n", "EPSG = 'epsg:4326' # using the WGS84 latitude-longitude projection: \"EPSG:4326\"\n", "CROP_REGION = '2-Central' # Illinois region\n", "YEAR = 2018 # crop year" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Get the SageMaker session, SageMaker execution role, Region name, and S3 resource\n", "boto_session = boto3.session.Session()\n", "sm_session = sagemaker.session.Session()\n", "region = boto_session.region_name\n", "sm_role = sagemaker.get_execution_role()\n", "runtime = boto3.Session().client('sagemaker-runtime')\n", "s3 = boto3.resource('s3')" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "Download spatial files locally." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "download_s3_folder(AWS_S3_BUCKET,SPATIAL_FILES_KEY, \"tmp/spatial-files\")\n", "download_s3_folder(AWS_S3_BUCKET,SENTINEL_2_SHAPEFILE_KEY, \"tmp/Sentinel-2-Shapefile-Index\")" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Read dataset and crop staging mapping file\n", "\n", "> **Note**: Files produced in the `01 Feature Engineering.ipynb` notebook" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "REGION = CROP_REGION.replace(\"-\",\"_\")\n", "\n", "# read enhanced dataset\n", "df_full = pd.read_csv(\n", " f\"s3://{AWS_S3_BUCKET}/data/enhanced/\"\n", " f\"enhanced_dataset_filtered_{YEAR}_{REGION}.csv\",\n", ")\n", "\n", "# read crop staging mapping file\n", "df_mapping = pd.read_csv(\n", " f\"s3://{AWS_S3_BUCKET}/data/enhanced/\"\n", " f\"stage_mapping_filtered_{YEAR}_{REGION}.csv\",\n", ")\n", "\n", "# read spatial files\n", "gpd_cells = gpd.read_file(\"tmp/spatial-files/cells_sf.shp\")\n", "gpd_cells = gpd_cells.to_crs(EPSG)\n", "\n", "# for the DAG setup remove the identifiers and variables that are out of scope\n", "df = df_full.drop(columns=['FIPS','id_field','id_10','LAI_max','n_uptake','P'])\n", "df_mapping = df_mapping[df_mapping.variable.isin(df.columns)]" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "Select target(s) from the following:\n", "\n", "* Corn Yield: `\"Y_corn\"`\n", "* Soybeans Yield: `\"Y_soy\"`\n", "* Total N taken up by the corn crop during the season: `\"n_uptake\"`\n", "* Total 2-years N leaching during corn and soybean, from April 1st year (x) to March 31st year (x+2): `\"L\"`" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "TARGETS = [\"Y_corn\"]" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "`Setting`:\n", " * The crop phenology graph (DAG) is a collection of nodes and edges, where the `nodes` are indicators of crop growth, soil characteristics, atmospheric conditions, and the `edges` between them represent temporal-causal relationships. `Parent nodes` are the field-related parameters (incl. the day of sowing and area planted), whereas the `child nodes` are the yield, nitrogen uptake and nitrogen leaching targets.\n", " * A `crop phenology DAG (Directed Acyclic Graph)` structure is learned from data (with domain knowledge assisted constraints) and human inputs:\n", " * The graphical model incorporates crop phenology dynamics extracted from ground-level indicators and spectral vegetation indices \n", " * Continuous features are discretised based on the split thresholds of a decision tree regressor (crop yield is used as a target)\n", " * Once the graph has been determined, the conditional probability distributions of the variables are learned from the data, using Bayesian parameter estimation.\n", "\n", " * Please find the [vocabulary](https://www.sciencedirect.com/science/article/pii/S2352340921010283#tbl0001) for the ground-level variables, and the [guide](https://crops.extension.iastate.edu/encyclopedia/corn-growth-stages) for identifying the corn growth stages.\n", " * Nodes starting with `mean_{spectral vegetation indices}_corn_{isoweek}` are corn growth indicators extracted from the satellite multi-spectral imagery, representing the 10 x 10 km cell mean value of the following spectral vegetation indices (for each satellite visit):\n", " * `EVI2` : Two-Band Enhanced Vegetation Index\n", " * `GDVI` : Generalized Difference Vegetation Index\n", " * `NDMI` : Normalized Difference Moisture Index\n", " * `NDVI` : Normalized Difference Vegetation Index\n", " * `NDWI` : Normalized Difference Water Index\n", " \n", " * `Corn response to nitrogen` is studied by querying the model and making interventions.\n", " * Firstly, undertake inference in order to gain insights about different response curves.\n", " * Secondly, use the inference insights and observation of evidence, in order to take actions for the amount of Nitrogen added as fertilizer, while observing the effect of these actions on the crop yield, the Nitrogen leaching and the total Nitrogen uptake." ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Prepare constraints for the DAG learning\n", "\n", "Use the mapping file with the crop phenology staging and return constraints for the NOTEARS algorithm.\n", "\n", "\n", "1. list of nodes banned from being a child of any other nodes\n", "2. list of nodes banned from being a parent of any other nodes\n", "3. list of edges(from, to) not to be included in the graph.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Remove satellite indicators for now (they will be added later to the DAG wit assistance)\n", "sattelite_images = [feat for feat in df.columns if feat.startswith('mean_')]\n", "mapping = df_mapping[~df_mapping.variable.isin(sattelite_images)]\n", "\n", "# learn the DAG structure up to layer 4\n", "n_stage = 4\n", "\n", "mapping = mapping[mapping.level.isin([i for i in range(n_stage + 1)])]\n", "\n", "# Eliminate atmospheric nodes from level 0\n", "mapping = mapping[~(mapping.variable.str.startswith((\"tmean\",\"rain\",\"rad\")) & (mapping.level == 0))]\n", "\n", "tabu_edges, tabu_child, tabu_parents, nodes_list, nodes_matrix = generate_dag_constraints(mapping)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### causalnex imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "from causalnex.structure import StructureModel\n", "from causalnex.structure.notears import from_pandas\n", "from causalnex.network import BayesianNetwork\n", "from causalnex.plots import plot_structure, NODE_STYLE, EDGE_STYLE\n", "from causalnex.discretiser.discretiser_strategy import (\n", " DecisionTreeSupervisedDiscretiserMethod,\n", " MDLPSupervisedDiscretiserMethod\n", ")\n", "from causalnex.discretiser import Discretiser\n", "from causalnex.network import BayesianNetwork\n", "from causalnex.evaluation import classification_report\n", "from causalnex.inference import InferenceEngine\n", "\n", "from sklearn.model_selection import train_test_split\n", "from causalnex.evaluation import roc_auc\n", "\n", "\n", "import warnings\n", "from IPython.display import Image\n", "\n", "warnings.filterwarnings(\"ignore\") # silence warnings" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## DAG learning from structure\n", "\n", "https://papers.nips.cc/paper/8157-dags-with-no-tears-continuous-optimization-for-structure-learning.pdf\n", "\n", "1. Imposing edges that are not allowed in the causal model\n", "2. Imposing parent nodes that are not allowed in the causal model\n", "3. Imposing child nodes that are not allowed in the causal model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "from time import time\n", "\n", "t0 = time()\n", "\n", "g_learned = from_pandas(df[nodes_list],\n", " tabu_edges=tabu_edges,\n", " tabu_parent_nodes=tabu_parents,\n", " tabu_child_nodes=tabu_child,\n", " max_iter=100\n", " )\n", "\n", "\n", "print(f'Running NOTEARS algorithm takes {time() - t0} seconds')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "g = g_learned.copy()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "g = g.get_largest_subgraph()\n", "\n", "print(f\"Learned DAG Edges: {len(g.edges)}\")\n", "print(f\"Learned DAG Nodes: {len(g.nodes)}\")\n", "print(f\"Learned DAG Degree View \\n: {g.degree} \\n\")\n", "\n", "bn = BayesianNetwork(g)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "viz = plot_pretty_structure(bn.structure, edges_to_highlight=[])\n", "Image(viz.draw(format='png'))" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## DAG knowledge assisted" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "Next, we will enhance the learned DAG structure with domain knowledge extracted from the [Simulated dataset of corn response to nitrogen over thousands of fields and multiple years in Illinois](https://www.sciencedirect.com/science/article/pii/S2352340921010283) paper." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "nodes_list_all = list(df_mapping.variable.unique())\n", "nodes_matrix_all = sorted([(df_mapping[df_mapping.variable == node]\n", " ['level'].values[0], node) for node in nodes_list_all])" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### N fertilizer edges\n", "\n", "Add edges between the N fertilizer and the target nodes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Add direct links to the targets\n", "g.add_edges_from([(\"N_fert\", node, {\"weight\": 1.0}) for node in TARGETS])" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" }, "tags": [] }, "source": [ "### Water stress indicators\n", "\n", "1. Add edges between Mean water stress indicators and the parent nodes\n", "2. Add edges between Mean water stress indicators and the soil indicators" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "water_stress_features = df.columns[df.columns.str.contains('_fw')]\n", "\n", "water_stress_matrix = sorted([(df_mapping[df_mapping.variable == node]\n", " ['level'].values[0], node) for node in water_stress_features])\n", "\n", "water_stress_edges = [\n", " (node_i, node_j, {\"weight\": 1.0}) for idx, node_i in water_stress_matrix for node_j in TARGETS]\n", "\n", "# add edges between Mean water stress indicators and the soil indicators\n", "\n", "g_in_degree = [node[0] for node in sorted(\n", " g.in_degree, key=lambda x: x[1], reverse=True) if node[0] in g.nodes and node[1] > 2]\n", "\n", "\n", "g.add_edges_from(water_stress_edges, origin=\"expert\")\n", "\n", "water_stress_edges" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Geospatial indicators\n", "\n", "1. Add edges between geospatial data and the targets (level 5 variables)\n", "2. Add edges between geospatial consecutive observations (consecutive isoweeks, aka satellite visits)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "satellite_features = df.columns[df.columns.str.contains('_NDVI|_NDMI|_EVI2')]\n", "\n", "satellite_matrix = sorted([(df_mapping[df_mapping.variable == node]\n", " ['level'].values[0], node) for node in satellite_features])\n", "\n", "satellite_edges = [\n", " (node_i, node_j, {\"weight\": 1.0}) for idx, node_i in satellite_matrix for idy, node_j in nodes_matrix_all\n", " if idy == 5 and idx == 4]\n", "\n", "\n", "consecutive_satellite_edges_target = [\n", " (node_i, node_j, {\"weight\": 1.0}) for idx, node_i in satellite_matrix for idy, node_j in satellite_matrix\n", " if idx == idy - 1 and node_i.split(\"_\")[0:2] == node_j.split(\"_\")[0:2]]\n", "\n", "\n", "satellite_edges.extend(consecutive_satellite_edges_target)\n", "\n", "g.add_edges_from(satellite_edges, origin=\"expert\")\n", "\n", "satellite_edges" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Soil indicators\n", "\n", "Add soil nitrogen, biomass and water content links with the targets" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Add v5 Soil and Water content links with the targets\n", "v5_edges = [ (node_i, node_j, {\"weight\": 1.0}) for node_i in g.nodes for node_j in TARGETS if node_i in ['n_deep_v5','biomass_v5','sw_dep_v5']]\n", "\n", "g.add_edges_from(v5_edges, origin=\"expert\")\n", "\n", "v5_edges" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Rebase the graph\n", "\n", "Overwrite the learned weights for the edges in order to maintain consistency." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "g_edges = [(edge[0],edge[1], {\"weight\": 1.0}) for edge in list(g.edges)] \n", "\n", "g = StructureModel()\n", "g.add_edges_from(\n", " g_edges,\n", " origin=\"expert\",\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Get the largest subgraph of the Structure Model\n", "g = g.get_largest_subgraph()\n", "\n", "# Base class for Bayesian Network (BN), a probabilistic weighted DAG\n", "# Nodes represent variables, \n", "# Edges represent the causal relationships between variables.\n", "bn = BayesianNetwork(g)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "viz = plot_pretty_structure(bn.structure, edges_to_highlight=[])\n", "Image(viz.draw(format='png'))" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Discretise the data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "features = list(g.nodes)\n", "\n", "# You can use the Decision Tree Supervised Discretiser with the Corn Yield or the Soy Yield \n", "# Note: Use Corn if the subsequent studies are concering the Corn, and otherwise for Soybeans\n", "\n", "target = 'Y_corn'\n", "\n", "# ====================================================================\n", "# Decision Tree Supervised Discretiser Method\n", "# ====================================================================\n", " \n", "features.remove(target)\n", "\n", "# Discretisation of continuous features based on the split thresholds of a Decision Tree Regressor\n", "discretiser = DecisionTreeSupervisedDiscretiserMethod(\n", " mode=\"single\", \n", " tree_params={\"max_depth\": 2, \"random_state\": 2022},\n", ")\n", "discretiser.fit(\n", " feat_names=features, \n", " dataframe=df, \n", " target_continuous=True,\n", " target=target,\n", ")\n", "\n", "discretised_data = discretiser.transform(df[features])\n", "discretised_data.loc[:,target] = df[target].values\n", "\n", "print(f\"discretiser map thresholds: {discretiser.map_thresholds}\")\n", "\n", "# Discretisation of target (quantiles-based)\n", "discretised_data[target], numeric_split_points_target = quantile_discretiser(discretised_data[target], num_buckets=4)\n", "\n", "train, test = train_test_split(discretised_data, train_size=0.8, random_state=42)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Fitting and evaluating the Bayesian Network" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "bn = BayesianNetwork(g)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "bn = bn.fit_node_states(discretised_data)\n", "bn = bn.fit_cpds(\n", " train, \n", " method=\"BayesianEstimator\",\n", " bayes_prior=\"K2\",\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "classification_report(bn, test, 'Y_corn')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# It is recommended to update the model using the complete dataset for the following type of queries\n", "bn = bn.fit_cpds(\n", " discretised_data, \n", " method=\"BayesianEstimator\", \n", " bayes_prior=\"K2\",\n", ")" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Save model artifacts\n", "\n", "Upload model artifacts to Amazon S3. This is where the inference endpoint will collect them later." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Save the numerical split point\n", "map_thresholds = [{f\"{var}\": list(discretiser.map_thresholds[var])} for var in discretiser.map_thresholds]\n", "map_thresholds.extend([{f\"{target}\": list(numeric_split_points_target)}])\n", "map_thresholds = {key:val for d in map_thresholds for key,val in d.items()}\n", "\n", "with open(NUMERICAL_SPLIT_POINTS_PATH, 'w') as fp:\n", " json.dump(map_thresholds, fp)\n", "\n", "# Save structure\n", "nx.write_gml(g, DAG_PATH)\n", "\n", "# Save model artifact after fitting the cpds\n", "bn._model.save(MODEL_PATH, filetype='bif')\n", "\n", "# Save the node states\n", "node_states_dict = {c: dict([(int(el), int(el)) for el in sorted(discretised_data[c].unique())]) for c in discretised_data.columns}\n", "with open(STATES_PATH, 'w') as fp:\n", " json.dump(node_states_dict, fp)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import tarfile\n", "\n", "tar = tarfile.open(\"model.tar.gz\", \"w:gz\")\n", "for file in [DAG_PATH, MODEL_PATH, STATES_PATH, NUMERICAL_SPLIT_POINTS_PATH]:\n", " tar.add(file)\n", "tar.close()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "!aws s3 cp model.tar.gz s3://{AWS_S3_BUCKET}/models/" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## SageMaker asynchronous inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "instance_type = \"ml.m5.2xlarge\"\n", "\n", "model_artifact = f\"s3://{AWS_S3_BUCKET}/models/model.tar.gz\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# We use a PyTorch inference DLC image that ships with sagemaker-pytorch-inference-toolkit \n", "image_uri = sagemaker.image_uris.retrieve(\n", " framework=\"pytorch\",\n", " region=region,\n", " py_version=\"py38\",\n", " image_scope=\"inference\",\n", " version=\"1.10\",\n", " instance_type=instance_type,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "#!pygmentize src-inference/inference.py" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Create a SageMaker model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# SAGEMAKER_TS_BATCH_SIZE (int): This is the maximum batch size in ms that a model is expected to handle\n", "# SAGEMAKER_TS_STARTUP_TIMEOUT (int): Time delay after which inference will timeout if model initialization fails\n", "# SAGEMAKER_TS_RESPONSE_TIMEOUT (int): Time delay after which inference will timeout in absence of a response\n", "\n", "env_variables_dict = {\n", " \"SAGEMAKER_TS_BATCH_SIZE\": \"10000000\",\n", " \"SAGEMAKER_TS_STARTUP_TIMEOUT\": \"1200\",\n", " \"SAGEMAKER_TS_RESPONSE_TIMEOUT\": \"600\",\n", " 'TS_MAX_REQUEST_SIZE': '655350000',\n", " 'TS_MAX_RESPONSE_SIZE': '655350000',\n", " 'TS_DEFAULT_RESPONSE_TIMEOUT': '2000',\n", " \n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "from sagemaker.model import Model\n", "from sagemaker.predictor import Predictor\n", "\n", "model_name = f\"{SOLUTION_PREFIX}-bn-model\"\n", "\n", "model_predictor = Model(\n", " name=model_name,\n", " image_uri=image_uri,\n", " model_data=model_artifact,\n", " role=sm_role,\n", " source_dir=\"src-inference\",\n", " entry_point=\"inference.py\",\n", " predictor_cls=Predictor,\n", " env=env_variables_dict,\n", ")\n", "model_name" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Create AsyncInferenceConfig" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig\n", "\n", "async_config = AsyncInferenceConfig(\n", " output_path=f\"s3://{AWS_S3_BUCKET}/models/output\",\n", " max_concurrent_invocations_per_instance=4,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Create endpoint" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import time\n", "\n", "ENDPOINT_NAME = f\"{SOLUTION_PREFIX}-bn-endpoint\"\n", "\n", "async_predictor = model_predictor.deploy(\n", " async_inference_config=async_config,\n", " instance_type=instance_type,\n", " initial_instance_count=1,\n", " endpoint_name=ENDPOINT_NAME,\n", " serializer=sagemaker.serializers.JSONSerializer(),\n", " deserializer=sagemaker.deserializers.JSONDeserializer(),\n", ")\n", "\n", "# Waiting for the inference engine to be initialized\n", "time.sleep(90)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Observational and counterfactuals inference" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Querying marginal distributions of the target node (yield) given some observations" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Sample cell_id / id_field(s)\n", "query_node = 'N_fert'\n", "yield_target = 'Y_corn'\n", "samples_number = 4\n", "requests = []\n", "samples = []\n", "\n", "sample_features = list(g.nodes)\n", "\n", "df_query = df_full[sample_features + ['id_10','id_field','FIPS']]\n", "\n", "for i in range(samples_number):\n", "\n", " sample = df_query.sample(1)\n", " samples.append(sample)\n", "\n", " # Add all observations\n", " request_nodes = [(feat, sample[feat].values[0]) for feat in sample_features]\n", "\n", " # Discretise the request\n", " request = discretiser_inverse_transform(map_thresholds,\n", " request=True,\n", " request_nodes=request_nodes,\n", " response_nodes=[])\n", " \n", " request = dict(request)\n", " \n", " # Remove target node form the request\n", " request.pop(yield_target)\n", " \n", " requests.append(request)\n", " \n", "df_samples = pd.concat(samples)\n", "df_samples = df_samples.drop_duplicates()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Prepare the paylod\n", "payload = {\n", " \"method\": \"query\",\n", " \"observations\": requests,\n", " \"target\": yield_target\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Dump the payload into a local JSON file\n", "with open(\"tmp/request_payload_query.json\", 'w') as fp:\n", " json.dump(payload, fp)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### Upload the request payload" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def upload_file(input_location):\n", " prefix = f\"{AWS_S3_BUCKET}/inference/input\"\n", " return sm_session.upload_data(\n", " input_location,\n", " bucket=sm_session.default_bucket(),\n", " key_prefix=prefix,\n", " extra_args={\"ContentType\": \"application/json\"},\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# upload request to S3\n", "input_s3_location = upload_file(\"tmp/request_payload_query.json\")" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### Invoke endpoint" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Invoke endpoint\n", "response_endpoint = runtime.invoke_endpoint_async(\n", " EndpointName=ENDPOINT_NAME, \n", " InputLocation=input_s3_location,\n", ")\n", "\n", "output_location =response_endpoint['OutputLocation']" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### Get inference outputs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def get_output(output_location):\n", " output_url = urllib.parse.urlparse(output_location)\n", " bucket = output_url.netloc\n", " key = output_url.path[1:]\n", " while True:\n", " try:\n", " return sm_session.read_s3_file(bucket=output_url.netloc, key_prefix=output_url.path[1:])\n", " except ClientError as e:\n", " if e.response[\"Error\"][\"Code\"] == \"NoSuchKey\":\n", " print(\"waiting for the inference query\")\n", " time.sleep(20)\n", " continue\n", " raise" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# get inference outputs\n", "output = json.loads(get_output(output_location))\n", "print(f\"\\n Output: {output}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Format ouptut by converting the marginals probabilities into buckets\n", "resp, _, _ = format_inference_output(output)\n", "\n", "# Convert buckets into real number ranges\n", "resp_transformed = discretiser_inverse_transform(map_thresholds,\n", " request=False,\n", " request_nodes=[],\n", " response_nodes=resp)\n", "\n", "# collect marginals from the reponse\n", "marginals = []\n", "\n", "for idx, out in enumerate(output):\n", " marginals_df = pd.DataFrame.from_dict(\n", " out['marginals'], orient='index', columns=[f'marginals_{idx}'])\n", " marginals.append(marginals_df)\n", "\n", "marginals = pd.concat(marginals, axis=1)\n", "marginals['yield'] = df_full[yield_target].min()\n", "\n", "# Note: if target is changed add the corresponding numeric_split_points_target (from the discretiser)\n", "marginals['yield'].loc[1:] = map_thresholds[yield_target]\n", "marginals = marginals.set_index('yield')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Plot marginals for the yield node" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def plot_marginals(marginals, df_samples, resp_transformed, yield_target):\n", "\n", " plt.figure(figsize=(15, 5), dpi=120)\n", "\n", " for idx, col in enumerate(marginals):\n", " \n", " plt.plot(marginals.index, marginals[col], 'o--', label=f\"FIPS:{df_samples['FIPS'].iloc[idx]} - CELL ID: {df_samples['id_10'].iloc[idx]}\")\n", " plt.axvline(df_samples[yield_target].iloc[idx], color=plt.gca().lines[-1].get_color())\n", " plt.fill_between(marginals.index, marginals[col], alpha=0.1)\n", " \n", " plt.legend()\n", " plt.title(f\"Marginal distributions of {yield_target} target node given the observations\")\n", " plt.xlabel('Yield (kg/ha) | vertical lines represent the Yield actual values')\n", " plt.ylabel('Probability')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "plot_marginals(marginals, df_samples, resp_transformed, yield_target)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### Visualize the geolocation for the selected cell IDs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Plot the sampled cell geo coordinates\n", "ax = gpd_cells[gpd_cells.region == CROP_REGION].plot(cmap='Pastel2', figsize=(15,7))\n", "gpd_cells[gpd_cells.id_10.isin(df_samples['id_10'].unique())].plot(ax=ax, facecolor='none', edgecolor='red')" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Making interventions (Do-calculus)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Sample one cell_id / id_field\n", "features = list(g.nodes)\n", "\n", "action_node = 'N_fert'\n", "yield_target = 'Y_corn'\n", "sample_features = [action_node]\n", "\n", "# Select query nodes\n", "\n", "sample_features.extend([feat for feat in features if 'tmean' in feat or 'rad' in feat or 'rain' in feat])\n", "sample_features.extend(satellite_features)\n", "\n", "# Pick a sample\n", "samples = df_full[sample_features + ['id_10','FIPS']]\n", "\n", "sample = samples.sample(1)\n", "\n", "# Add all observations\n", "request_nodes = [(feat , sample[feat].values[0]) for feat in sample_features]\n", "\n", "# Discretise the request\n", "request = discretiser_inverse_transform(map_thresholds,\n", " request=True,\n", " request_nodes=request_nodes,\n", " response_nodes=[])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Map thresholds action node\n", "map_thresholds[action_node]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Current value\n", "print(f\"Current value: {sample[action_node].values[0]} kg/ha\")" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "* NOTE: select a value which differs significantly to the current value (ideally belonging to a different bucket), in order to observe the effect of the intervention" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Discretise\n", "value = 30 # ADD VALUE HERE (eg. X kg/ha Nitrogen)\n", "action_node_value = (action_node, value)\n", "action_node_bucket = discretiser_inverse_transform(map_thresholds,\n", " request=True,\n", " request_nodes=[action_node_value],\n", " response_nodes=[])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Remove the node we intervene on\n", "request = dict(request)\n", "action_node_before = (action_node,request.pop(action_node))\n", "action_node_after = action_node_bucket[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Prepare payload\n", "payload = {\n", " \"method\": \"do_calculus\",\n", " \"intervention_query\": request,\n", " \"interventions\": [action_node_bucket[0]],\n", " \"target\": yield_target\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Dump the payload into a local JSON file\n", "with open(\"tmp/request_payload_intervention.json\", 'w') as fp:\n", " json.dump(payload, fp)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### Uploading the Request Payload" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def upload_file(input_location):\n", " prefix = f\"{AWS_S3_BUCKET}/inference/input\"\n", " return sm_session.upload_data(\n", " input_location,\n", " bucket=sm_session.default_bucket(),\n", " key_prefix=prefix,\n", " extra_args={\"ContentType\": \"application/json\"},\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Upload request to S3\n", "input_s3_location = upload_file(\"tmp/request_payload_intervention.json\")" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### Invoke endpoint" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Invoke endpoint\n", "response_endpoint = runtime.invoke_endpoint_async(\n", " EndpointName=ENDPOINT_NAME, \n", " InputLocation=input_s3_location,\n", ")\n", "\n", "output_location =response_endpoint['OutputLocation']" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### Get inference outputs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def get_output(output_location):\n", " output_url = urllib.parse.urlparse(output_location)\n", " bucket = output_url.netloc\n", " key = output_url.path[1:]\n", " while True:\n", " try:\n", " return sm_session.read_s3_file(bucket=output_url.netloc, key_prefix=output_url.path[1:])\n", " except ClientError as e:\n", " if e.response[\"Error\"][\"Code\"] == \"NoSuchKey\":\n", " print(\"waiting for the inference do-calculus\")\n", " time.sleep(20)\n", " continue\n", " raise" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Get inference outputs\n", "output = json.loads(get_output(output_location))\n", "print(f\"\\n Output: {output}\")" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Plot counterfactuals" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Collect marginals (before and after) into a pandas frame\n", "df_marginals_before = pd.DataFrame.from_dict(output['marginals-before'], orient='index', columns=['before'])\n", "df_marginals_after = pd.DataFrame.from_dict(output['marginals-after'], orient='index', columns=['after'])\n", "\n", "counterfactuals = pd.concat([df_marginals_before,df_marginals_after],axis=1)\n", "counterfactuals['yield'] = 0\n", "\n", "# Note: if target is changed add the corresponding numeric_split_points_target\n", "counterfactuals['yield'].loc[1:] = map_thresholds[yield_target]\n", "counterfactuals = counterfactuals.set_index('yield')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def plot_counterfactuals(cf, sample, yield_target, action_node_before, action_node_after):\n", "\n", " plt.figure(figsize=(12, 5), dpi=120)\n", "\n", " plt.plot(cf.index, cf['before'], 'o--', label=f\"Nitrogen (kg/ha): {action_node_before[0]}\")\n", " plt.fill_between(cf.index, cf['before'], alpha=0.1)\n", "\n", " plt.plot(cf.index, cf['after'], 'o--', label=f\"Nitrogen (kg/ha): {action_node_after[0]}\")\n", " plt.fill_between(cf.index, cf['after'], alpha=0.1)\n", " \n", " for xl in range(cf.shape[0]):\n", " plt.axvline(x = cf.index.values[xl], color ='gray', linestyle=\"--\")\n", "\n", " plt.legend()\n", " plt.title(f\"-- FIPS:{sample['FIPS'].values[0]} - CELL ID: {sample['id_10'].values[0]} -- \")\n", " plt.suptitle(f\"Distribution of {yield_target} Yield given Nitrogen added as fertilizer\")\n", " plt.xlabel('Yield (kg/ha) | vertical lines represent the Yield discretisation')\n", " plt.ylabel('Probability')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "action_node_before_real = discretiser_inverse_transform(map_thresholds,\n", " request=False,\n", " request_nodes=[],\n", " response_nodes=[action_node_before])\n", "\n", "action_node_after_real = discretiser_inverse_transform(map_thresholds,\n", " request=False,\n", " request_nodes=[],\n", " response_nodes=[action_node_after])\n", "\n", "plot_counterfactuals(counterfactuals, sample, yield_target, action_node_before_real, action_node_after_real)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Clean Up" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Delete the SageMaker endpoint\n", "async_predictor.delete_endpoint()" ] } ], "metadata": { "instance_type": "ml.m5.large", "kernelspec": { "display_name": "Python 3 (PyTorch 1.10 Python 3.8 CPU Optimized)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-west-2:236514542706:image/pytorch-1.10-cpu-py38" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 4 }