{ "cells": [ { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "# Demo -- Causal Inference Engine" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "**Jupyter Kernel**:\n", "\n", "\n", "* If you are in SageMaker Studio, make sure that you use the **PyTorch 1.10 Python 3.8 CPU Optimized** environment.\n", "* Make sure that you are using one of the following instance types: `ml.m5.large`, `ml.c5.large`, or `ml.g4dn.xlarge`.\n", "\n", "**Run All**: \n", "\n", "* If you are in SageMaker Studio, you can choose the **Run All Cells** from the **Run** tab dropdown menu to run the entire notebook at once." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Install dependencies for this notebook.\n", "!pip3 install -r ./utils/requirements.in -q" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "This solution relies on a config file to run the provisioned AWS resources. Run the following cells to generate that file." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import boto3\n", "import os\n", "import json" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "client = boto3.client('servicecatalog')\n", "cwd = os.getcwd().split('/')\n", "i= cwd.index('S3Downloads')\n", "pp_name = cwd[i + 1]\n", "pp = client.describe_provisioned_product(Name=pp_name)\n", "record_id = pp['ProvisionedProductDetail']['LastSuccessfulProvisioningRecordId']\n", "record = client.describe_record(Id=record_id)\n", "\n", "keys = [ x['OutputKey'] for x in record['RecordOutputs'] if 'OutputKey' and 'OutputValue' in x]\n", "values = [ x['OutputValue'] for x in record['RecordOutputs'] if 'OutputKey' and 'OutputValue' in x]\n", "stack_output = dict(zip(keys, values))\n", "\n", "with open(f'/root/S3Downloads/{pp_name}/stack_outputs.json', 'w') as f:\n", " json.dump(stack_output, f)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "sagemaker_config = json.load(open(\"stack_outputs.json\"))\n", "\n", "SOLUTION_BUCKET = sagemaker_config[\"SolutionS3Bucket\"]\n", "AWS_REGION = sagemaker_config[\"AWSRegion\"]\n", "SOLUTION_NAME = sagemaker_config[\"SolutionName\"]\n", "AWS_S3_BUCKET = sagemaker_config[\"S3Bucket\"]\n", "LIBRARY_VERSION = sagemaker_config[\"LibraryVersion\"]\n", "ENDPOINT_NAME = sagemaker_config[\"SolutionPrefix\"] + \"-demo-endpoint\"\n", "\n", "KEY_YIELD_CURVE = \"data/raw/yield_curve_field_dt.csv\"\n", "SPATIAL_FILES_KEY = \"data/spatial-files\"\n", "FIPS_STATS_KEY = \"data/fips-stats/fips_county_stats.csv\"\n", "FIPS_POLYGONS_KEY = \"data/fips-stats/geojson-counties-fips.json\"\n", "SENTINEL_2_SHAPEFILE_KEY = \"data/sentinel-2-shapefiles\"\n", "CROPS_MASK_KEY = \"data/crop_mask/raw\"\n", "REQUEST_MANIFESTS_KEY = \"request_manifests/\"\n", "\n", "DAG_PATH = 'model/models/bn_structure.gml'\n", "MODEL_PATH = 'model/models/bayesian_model.bif'\n", "STATES_PATH = 'model/models/node_states.json'\n", "NUMERICAL_SPLIT_POINTS_PATH = \"model/models/numerical_split_points.json\"\n", "\n", "if not os.path.exists('model'):\n", " os.makedirs('model')" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Copy simulated data to S3" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "This solution uses both geospatial data and ground-level observations. We use ground-level observations from a publicly available [simulated dataset](https://data.mendeley.com/datasets/xs5nbm4w55/) of corn response to Nitrogen over thousands of fields and multiple years in Illinois.\n", "\n", "For ease of access, we made the datasets available in an Amazon S3 bucket. Download the dataset from S3 in the following cells. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "from sagemaker.s3 import S3Downloader\n", "\n", "original_bucket = f\"s3://{SOLUTION_BUCKET}-{AWS_REGION}/{LIBRARY_VERSION}/{SOLUTION_NAME}\"\n", "original_data = f\"{original_bucket}/artifacts/data/\"\n", "current_location = f\"s3://{AWS_S3_BUCKET}/data/\"\n", "print(\"original data:\")\n", "S3Downloader.list(original_data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "if not S3Downloader.list(current_location):\n", " !aws s3 cp --recursive $original_data $current_location" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Set up the environment" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import pandas as pd\n", "import numpy\n", "import json\n", "import datetime\n", "import matplotlib.pyplot as plt\n", "import boto3\n", "import io\n", "import os\n", "import s3fs\n", "import itertools as it\n", "import networkx as nx\n", "from time import time\n", "import geopandas as gpd\n", "import copy\n", "import bisect\n", "from typing import Dict\n", "import warnings\n", "import base64\n", "from PIL import Image\n", "import datetime\n", "from time import gmtime, strftime\n", "import urllib\n", "import time\n", "\n", "import sagemaker\n", "import boto3\n", "from botocore.exceptions import ClientError\n", "\n", "# from utils.plot_functions imports visualize_structure\n", "from utils.causalnex_helpers import (\n", " discretiser_inverse_transform,\n", " format_inference_output\n", ")\n", "\n", "from utils.helper_functions import download_s3_folder\n", "\n", "warnings.simplefilter('ignore')\n", "\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Define a few variables to use throughout the notebook\n", "EPSG = 'epsg:4326'\n", "TARGETS = [\"Y_corn\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Get the SageMaker session, SageMaker execution role, Region name, and S3 resource\n", "boto_session = boto3.session.Session()\n", "sm_session = sagemaker.session.Session()\n", "region = boto_session.region_name\n", "sm_role = sagemaker.get_execution_role()\n", "runtime = boto3.Session().client('sagemaker-runtime')\n", "s3 = boto3.resource('s3')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Download spatial files locally." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "download_s3_folder(AWS_S3_BUCKET,SPATIAL_FILES_KEY, \"tmp/spatial-files\")\n", "download_s3_folder(AWS_S3_BUCKET,SENTINEL_2_SHAPEFILE_KEY, \"tmp/Sentinel-2-Shapefile-Index\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Read the dataset and crop staging mapping file\n", "\n", "`Note` Load files produced in notebook 2 `01 Feature Engineering.ipynb`" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# read enhanced dataset\n", "df_full = pd.read_csv(\n", " f\"s3://{original_data}enhanced/\"\n", " f\"enhanced_dataset_filtered_2018_2_Central.csv\",\n", ")\n", "\n", "# read crop staging mapping file\n", "df_mapping = pd.read_csv(\n", " f\"s3://{original_data}enhanced/\"\n", " f\"stage_mapping_filtered_2018_2_Central.csv\",\n", ")\n", "\n", "# read spatial files\n", "gpd_cells = gpd.read_file(\"tmp/spatial-files/cells_sf.shp\")\n", "gpd_cells = gpd_cells.to_crs(EPSG)\n", "\n", "# for the DAG setup remove the identifiers\n", "df = df_full.drop(columns=['FIPS','id_field','id_10','LAI_max','n_uptake'])\n", "df_mapping = df_mapping[df_mapping.variable.isin(df.columns)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "model_artifact = f\"{original_bucket}/artifacts/models/model.tar.gz\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can view the inference script by uncommenting the line in the following cell:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "#!pygmentize src-inference/inference.py" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Observational and counterfactuals inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Copy model artifacts locally\n", "!aws s3 cp {model_artifact} ./\n", "!tar -C ./model -zxvf model.tar.gz" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Read the numerical split points\n", "with open(NUMERICAL_SPLIT_POINTS_PATH, 'r') as fp:\n", " map_thresholds= json.load(fp)\n", "\n", "# Load the DAG structure\n", "g = nx.read_gml(DAG_PATH)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Querying marginal distributions of the target node (yield) given some observations" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### Prepare the request payload" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Sample cell_id / id_field(s)\n", "query_node = 'N_fert'\n", "yield_target = 'Y_corn'\n", "samples_number = 2\n", "requests = []\n", "samples = []\n", "\n", "sample_features = list(g.nodes)\n", "\n", "df_query = df_full[sample_features + ['id_10','id_field','FIPS']]\n", "\n", "for i in range(samples_number):\n", "\n", " sample = df_query.sample(1)\n", " samples.append(sample)\n", "\n", " # Add all observations\n", " request_nodes = [(feat, sample[feat].values[0]) for feat in sample_features]\n", "\n", " # Discretise the request\n", " request = discretiser_inverse_transform(map_thresholds,\n", " request=True,\n", " request_nodes=request_nodes,\n", " response_nodes=[])\n", " \n", " request = dict(request)\n", " \n", " # Remove target node from the request\n", " request.pop(yield_target)\n", " \n", " requests.append(request)\n", " \n", "df_samples = pd.concat(samples)\n", "df_samples = df_samples.drop_duplicates()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Prepare the payload\n", "payload = {\n", " \"method\": \"query\",\n", " \"observations\": requests,\n", " \"target\": yield_target\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Dump the payload into a local JSON file\n", "with open(\"tmp/request_payload_query.json\", 'w') as fp:\n", " json.dump(payload, fp)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### Upload the request payload" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def upload_file(input_location):\n", " prefix = f\"{AWS_S3_BUCKET}/inference/input\"\n", " return sm_session.upload_data(\n", " input_location,\n", " bucket=sm_session.default_bucket(),\n", " key_prefix=prefix,\n", " extra_args={\"ContentType\": \"application/json\"},\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Upload request to S3\n", "input_s3_location = upload_file(\"tmp/request_payload_query.json\")" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### Invoke endpoint" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Invoke endpoint\n", "response_endpoint = runtime.invoke_endpoint_async(\n", " EndpointName=ENDPOINT_NAME, \n", " InputLocation=input_s3_location,\n", ")\n", "\n", "output_location =response_endpoint['OutputLocation']" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### Get inference outputs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def get_output(output_location):\n", " output_url = urllib.parse.urlparse(output_location)\n", " bucket = output_url.netloc\n", " key = output_url.path[1:]\n", " while True:\n", " try:\n", " return sm_session.read_s3_file(bucket=output_url.netloc, key_prefix=output_url.path[1:])\n", " except ClientError as e:\n", " if e.response[\"Error\"][\"Code\"] == \"NoSuchKey\":\n", " print(\"waiting for the inference query\")\n", " time.sleep(20)\n", " continue\n", " raise" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Get inference outputs\n", "output = json.loads(get_output(output_location))\n", "print(f\"\\n Output: {output}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Format output by converting the marginals probabilities into buckets\n", "resp, _, _ = format_inference_output(output)\n", "\n", "# Convert buckets into real number ranges\n", "resp_transformed = discretiser_inverse_transform(map_thresholds,\n", " request=False,\n", " request_nodes=[],\n", " response_nodes=resp)\n", "\n", "# Collect marginals from the reponse\n", "marginals = []\n", "\n", "for idx, out in enumerate(output):\n", " marginals_df = pd.DataFrame.from_dict(\n", " out['marginals'], orient='index', columns=[f'marginals_{idx}'])\n", " marginals.append(marginals_df)\n", "\n", "marginals = pd.concat(marginals, axis=1)\n", "marginals['yield'] = df_full[yield_target].min()\n", "\n", "# Note: if target is changed add the corresponding numeric_split_points_target (from the discretiser)\n", "marginals['yield'].loc[1:] = map_thresholds[yield_target]\n", "marginals = marginals.set_index('yield')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Plot marginals for the yield node" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def plot_marginals(marginals, df_samples, resp_transformed, yield_target):\n", "\n", " plt.figure(figsize=(15, 5), dpi=120)\n", "\n", " for idx, col in enumerate(marginals):\n", " \n", " plt.plot(marginals.index, marginals[col], 'o--', label=f\"FIPS:{df_samples['FIPS'].iloc[idx]} - CELL ID: {df_samples['id_10'].iloc[idx]}\")\n", " plt.axvline(df_samples[yield_target].iloc[idx], color=plt.gca().lines[-1].get_color())\n", " plt.fill_between(marginals.index, marginals[col], alpha=0.1)\n", " \n", " plt.legend()\n", " plt.title(f\"Marginal distributions of {yield_target} target node given the observations\")\n", " plt.xlabel('Yield (kg/ha) | vertical lines represent the Yield actual values')\n", " plt.ylabel('Probability')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "plot_marginals(marginals, df_samples, resp_transformed, yield_target)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Visualize the geolocation for the selected cells IDs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Plot the sampled cells geo coordinates\n", "ax = gpd_cells[gpd_cells.region == '2-Central'].plot(cmap='Pastel2', figsize=(15,7))\n", "gpd_cells[gpd_cells.id_10.isin(df_samples['id_10'].unique())].plot(ax=ax, facecolor='none', edgecolor='red')" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Making interventions (Do-calculus)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### Prepare the Request Payload" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Sample one cell_id / id_field\n", "features = list(g.nodes)\n", "\n", "action_node = 'N_fert'\n", "yield_target = 'Y_corn'\n", "sample_features = [action_node]\n", "\n", "# Select query nodes\n", "\n", "satellite_features = [feat for feat in features if feat.startswith(\"mean_\")]\n", "\n", "sample_features.extend([feat for feat in features if 'tmean' in feat or 'rad' in feat or 'rain' in feat])\n", "sample_features.extend(satellite_features)\n", "\n", "# Pick a sample\n", "samples = df_full[sample_features + ['id_10','FIPS']]\n", "\n", "sample = samples.sample(1)\n", "\n", "# Add all observations\n", "request_nodes = [(feat , sample[feat].values[0]) for feat in sample_features]\n", "\n", "# Discretise the request\n", "request = discretiser_inverse_transform(map_thresholds,\n", " request=True,\n", " request_nodes=request_nodes,\n", " response_nodes=[])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Current value\n", "print(f\"Current value: {sample[action_node].values[0]} kg/ha\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Map thresholds action node\n", "map_thresholds[action_node]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> NOTE: select a value that differs significantly from the current value (ideally belonging to a different bucket), in order to observe the effect of the intervention." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Discretise\n", "value = 20 # ADD VALUE HERE (eg. 20 kg/ha Nitrogen)\n", "action_node_value = (action_node, value)\n", "action_node_bucket = discretiser_inverse_transform(map_thresholds,\n", " request=True,\n", " request_nodes=[action_node_value],\n", " response_nodes=[])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Remove the node we intervene on\n", "request = dict(request)\n", "action_node_before = (action_node,request.pop(action_node))\n", "action_node_after = action_node_bucket[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Prepare payload\n", "payload = {\n", " \"method\": \"do_calculus\",\n", " \"intervention_query\": request,\n", " \"interventions\": [action_node_bucket[0]],\n", " \"target\": yield_target\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Dump the payload into a local JSON file\n", "with open(\"tmp/request_payload_intervention.json\", 'w') as fp:\n", " json.dump(payload, fp)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### Upload the request payload" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def upload_file(input_location):\n", " prefix = f\"{AWS_S3_BUCKET}/inference/input\"\n", " return sm_session.upload_data(\n", " input_location,\n", " bucket=sm_session.default_bucket(),\n", " key_prefix=prefix,\n", " extra_args={\"ContentType\": \"application/json\"},\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Upload request to S3\n", "input_s3_location = upload_file(\"tmp/request_payload_intervention.json\")" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### Invoke endpoint" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Invoke endpoint\n", "response_endpoint = runtime.invoke_endpoint_async(\n", " EndpointName=ENDPOINT_NAME, \n", " InputLocation=input_s3_location,\n", ")\n", "\n", "output_location =response_endpoint['OutputLocation']" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### Get inference outputs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def get_output(output_location):\n", " output_url = urllib.parse.urlparse(output_location)\n", " bucket = output_url.netloc\n", " key = output_url.path[1:]\n", " while True:\n", " try:\n", " return sm_session.read_s3_file(bucket=output_url.netloc, key_prefix=output_url.path[1:])\n", " except ClientError as e:\n", " if e.response[\"Error\"][\"Code\"] == \"NoSuchKey\":\n", " print(\"waiting for the inference do-calculus\")\n", " time.sleep(20)\n", " continue\n", " raise" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Get inference outputs\n", "output = json.loads(get_output(output_location))\n", "print(f\"\\n Output: {output}\")" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Plot counterfactuals" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# collect marginals (before and after) into a pandas frame\n", "df_marginals_before = pd.DataFrame.from_dict(output['marginals-before'], orient='index', columns=['before'])\n", "df_marginals_after = pd.DataFrame.from_dict(output['marginals-after'], orient='index', columns=['after'])\n", "\n", "counterfactuals = pd.concat([df_marginals_before,df_marginals_after],axis=1)\n", "counterfactuals['yield'] = 0\n", "\n", "# Note: if target is changed add the corresponding numeric_split_points_target\n", "counterfactuals['yield'].loc[1:] = map_thresholds[yield_target]\n", "counterfactuals = counterfactuals.set_index('yield')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def plot_counterfactuals(cf, sample, yield_target, action_node_before, action_node_after):\n", "\n", " plt.figure(figsize=(12, 5), dpi=120)\n", "\n", " plt.plot(cf.index, cf['before'], 'o--', label=f\"Nitrogen (kg/ha): {action_node_before[0]}\")\n", " plt.fill_between(cf.index, cf['before'], alpha=0.1)\n", "\n", " plt.plot(cf.index, cf['after'], 'o--', label=f\"Nitrogen (kg/ha): {action_node_after[0]}\")\n", " plt.fill_between(cf.index, cf['after'], alpha=0.1)\n", " \n", " for xl in range(cf.shape[0]):\n", " plt.axvline(x = cf.index.values[xl], color ='gray', linestyle=\"--\")\n", "\n", " plt.legend()\n", " plt.title(f\"-- FIPS:{sample['FIPS'].values[0]} - CELL ID: {sample['id_10'].values[0]} -- \")\n", " plt.suptitle(f\"Distribution of {yield_target} Yield given Nitrogen added as fertilizer\")\n", " plt.xlabel('Yield (kg/ha) | vertical lines represent the Yield discretisation')\n", " plt.ylabel('Probability')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" }, "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "action_node_before_real = discretiser_inverse_transform(map_thresholds,\n", " request=False,\n", " request_nodes=[],\n", " response_nodes=[action_node_before])\n", "\n", "action_node_after_real = discretiser_inverse_transform(map_thresholds,\n", " request=False,\n", " request_nodes=[],\n", " response_nodes=[action_node_after])\n", "\n", "plot_counterfactuals(counterfactuals, sample, yield_target, action_node_before_real, action_node_after_real)" ] } ], "metadata": { "instance_type": "ml.m5.large", "kernelspec": { "display_name": "Python 3 (PyTorch 1.10 Python 3.8 CPU Optimized)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-west-2:236514542706:image/pytorch-1.10-cpu-py38" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 4 }