# Causal Inference with Bayesian Networks

**Jupyter Kernel**:


* If you are in SageMaker Studio, make sure that you use the **PyTorch 1.10 Python 3.8 CPU Optimized** environment.
* Make sure that you are using `ml.g4dn.xlarge` or `ml.m5.large` as an instance type.

**Run All**: 

* If you are in SageMaker Studio, you can choose the **Run All Cells** from the **Run** tab dropdown menu to run the entire notebook at once.

In [None]:
# Install dependencies that will be used in this notebook.
!pip3 install -r ./utils/requirements.in -q

In [None]:
!conda install -c conda-forge pygraphviz -y

This solution relies on a config file to run the provisioned AWS resources. Run the cells below to generate that file.

In [None]:
import boto3
import os
import json

In [None]:
client = boto3.client('servicecatalog')
cwd = os.getcwd().split('/')
i= cwd.index('S3Downloads')
pp_name = cwd[i + 1]
pp = client.describe_provisioned_product(Name=pp_name)
record_id = pp['ProvisionedProductDetail']['LastSuccessfulProvisioningRecordId']
record = client.describe_record(Id=record_id)

keys = [ x['OutputKey'] for x in record['RecordOutputs'] if 'OutputKey' and 'OutputValue' in x]
values = [ x['OutputValue'] for x in record['RecordOutputs'] if 'OutputKey' and 'OutputValue' in x]
stack_output = dict(zip(keys, values))

with open(f'/root/S3Downloads/{pp_name}/stack_outputs.json', 'w') as f:
 json.dump(stack_output, f)

In [None]:
sagemaker_config = json.load(open("stack_outputs.json"))

SOLUTION_BUCKET = sagemaker_config["SolutionS3Bucket"]
AWS_REGION = sagemaker_config["AWSRegion"]
SOLUTION_NAME = sagemaker_config["SolutionName"]
SOLUTION_PREFIX = sagemaker_config["SolutionPrefix"]
AWS_S3_BUCKET = sagemaker_config["S3Bucket"]

KEY_YIELD_CURVE = "data/raw/yield_curve_field_dt.csv"
SPATIAL_FILES_KEY = "data/spatial-files"
FIPS_STATS_KEY = "data/fips-stats/fips_county_stats.csv"
FIPS_POLYGONS_KEY = "data/fips-stats/geojson-counties-fips.json"
SENTINEL_2_SHAPEFILE_KEY = "data/sentinel-2-shapefiles"
CROPS_MASK_KEY = "data/crop_mask/raw"
REQUEST_MANIFESTS_KEY = "request_manifests/"

DAG_PATH = 'models/bn_structure.gml'
MODEL_PATH = 'models/bayesian_model.bif'
STATES_PATH = 'models/node_states.json'
NUMERICAL_SPLIT_POINTS_PATH = "models/numerical_split_points.json"

if not os.path.exists('models'):
 os.makedirs('models')

### Set up the environment

In [None]:
import pandas as pd
import numpy as np
import json
import datetime
import matplotlib.pyplot as plt
import boto3
import io
import os
import s3fs
import itertools as it
import networkx as nx
from time import time
import geopandas as gpd
import copy
import bisect
from typing import Dict
import warnings
import base64
from PIL import Image
import datetime
from time import gmtime, strftime
import urllib

import sagemaker
import boto3
from botocore.exceptions import ClientError

# from utils.plot_functions import visualize_structure
from utils.causalnex_helpers import (
 quantile_discretiser,
 generate_dag_constraints,
 discretiser_inverse_transform,
 format_inference_output
)

from utils.plot_functions import (
 plot_pretty_structure
)

from utils.helper_functions import download_s3_folder

warnings.simplefilter('ignore')

%matplotlib inline

In [None]:
# Define a few variables to use throughout the notebook

EPSG = 'epsg:4326' # using the WGS84 latitude-longitude projection: "EPSG:4326"
CROP_REGION = '2-Central' # Illinois region
YEAR = 2018 # crop year

In [None]:
# Get the SageMaker session, SageMaker execution role, Region name, and S3 resource
boto_session = boto3.session.Session()
sm_session = sagemaker.session.Session()
region = boto_session.region_name
sm_role = sagemaker.get_execution_role()
runtime = boto3.Session().client('sagemaker-runtime')
s3 = boto3.resource('s3')

Download spatial files locally.

In [None]:
download_s3_folder(AWS_S3_BUCKET,SPATIAL_FILES_KEY, "tmp/spatial-files")
download_s3_folder(AWS_S3_BUCKET,SENTINEL_2_SHAPEFILE_KEY, "tmp/Sentinel-2-Shapefile-Index")

### Read dataset and crop staging mapping file

> **Note**: Files produced in the `01 Feature Engineering.ipynb` notebook

In [None]:
REGION = CROP_REGION.replace("-","_")

# read enhanced dataset
df_full = pd.read_csv(
 f"s3://{AWS_S3_BUCKET}/data/enhanced/"
 f"enhanced_dataset_filtered_{YEAR}_{REGION}.csv",
)

# read crop staging mapping file
df_mapping = pd.read_csv(
 f"s3://{AWS_S3_BUCKET}/data/enhanced/"
 f"stage_mapping_filtered_{YEAR}_{REGION}.csv",
)

# read spatial files
gpd_cells = gpd.read_file("tmp/spatial-files/cells_sf.shp")
gpd_cells = gpd_cells.to_crs(EPSG)

# for the DAG setup remove the identifiers and variables that are out of scope
df = df_full.drop(columns=['FIPS','id_field','id_10','LAI_max','n_uptake','P'])
df_mapping = df_mapping[df_mapping.variable.isin(df.columns)]

Select target(s) from the following:

* Corn Yield: `"Y_corn"`
* Soybeans Yield: `"Y_soy"`
* Total N taken up by the corn crop during the season: `"n_uptake"`
* Total 2-years N leaching during corn and soybean, from April 1st year (x) to March 31st year (x+2): `"L"`

In [None]:
TARGETS = ["Y_corn"]

`Setting`:
 * The crop phenology graph (DAG) is a collection of nodes and edges, where the `nodes` are indicators of crop growth, soil characteristics, atmospheric conditions, and the `edges` between them represent temporal-causal relationships. `Parent nodes` are the field-related parameters (incl. the day of sowing and area planted), whereas the `child nodes` are the yield, nitrogen uptake and nitrogen leaching targets.
 * A `crop phenology DAG (Directed Acyclic Graph)` structure is learned from data (with domain knowledge assisted constraints) and human inputs:
 * The graphical model incorporates crop phenology dynamics extracted from ground-level indicators and spectral vegetation indices 
 * Continuous features are discretised based on the split thresholds of a decision tree regressor (crop yield is used as a target)
 * Once the graph has been determined, the conditional probability distributions of the variables are learned from the data, using Bayesian parameter estimation.

 * Please find the [vocabulary](https://www.sciencedirect.com/science/article/pii/S2352340921010283#tbl0001) for the ground-level variables, and the [guide](https://crops.extension.iastate.edu/encyclopedia/corn-growth-stages) for identifying the corn growth stages.
 * Nodes starting with `mean_{spectral vegetation indices}_corn_{isoweek}` are corn growth indicators extracted from the satellite multi-spectral imagery, representing the 10 x 10 km cell mean value of the following spectral vegetation indices (for each satellite visit):
 * `EVI2` : Two-Band Enhanced Vegetation Index
 * `GDVI` : Generalized Difference Vegetation Index
 * `NDMI` : Normalized Difference Moisture Index
 * `NDVI` : Normalized Difference Vegetation Index
 * `NDWI` : Normalized Difference Water Index
 
 * `Corn response to nitrogen` is studied by querying the model and making interventions.
 * Firstly, undertake inference in order to gain insights about different response curves.
 * Secondly, use the inference insights and observation of evidence, in order to take actions for the amount of Nitrogen added as fertilizer, while observing the effect of these actions on the crop yield, the Nitrogen leaching and the total Nitrogen uptake.

## Prepare constraints for the DAG learning

Use the mapping file with the crop phenology staging and return constraints for the NOTEARS algorithm.


1. list of nodes banned from being a child of any other nodes
2. list of nodes banned from being a parent of any other nodes
3. list of edges(from, to) not to be included in the graph.


In [None]:
# Remove satellite indicators for now (they will be added later to the DAG wit assistance)
sattelite_images = [feat for feat in df.columns if feat.startswith('mean_')]
mapping = df_mapping[~df_mapping.variable.isin(sattelite_images)]

# learn the DAG structure up to layer 4
n_stage = 4

mapping = mapping[mapping.level.isin([i for i in range(n_stage + 1)])]

# Eliminate atmospheric nodes from level 0
mapping = mapping[~(mapping.variable.str.startswith(("tmean","rain","rad")) & (mapping.level == 0))]

tabu_edges, tabu_child, tabu_parents, nodes_list, nodes_matrix = generate_dag_constraints(mapping)

### causalnex imports

In [None]:
from causalnex.structure import StructureModel
from causalnex.structure.notears import from_pandas
from causalnex.network import BayesianNetwork
from causalnex.plots import plot_structure, NODE_STYLE, EDGE_STYLE
from causalnex.discretiser.discretiser_strategy import (
 DecisionTreeSupervisedDiscretiserMethod,
 MDLPSupervisedDiscretiserMethod
)
from causalnex.discretiser import Discretiser
from causalnex.network import BayesianNetwork
from causalnex.evaluation import classification_report
from causalnex.inference import InferenceEngine

from sklearn.model_selection import train_test_split
from causalnex.evaluation import roc_auc


import warnings
from IPython.display import Image

warnings.filterwarnings("ignore") # silence warnings

## DAG learning from structure

https://papers.nips.cc/paper/8157-dags-with-no-tears-continuous-optimization-for-structure-learning.pdf

1. Imposing edges that are not allowed in the causal model
2. Imposing parent nodes that are not allowed in the causal model
3. Imposing child nodes that are not allowed in the causal model

In [None]:
from time import time

t0 = time()

g_learned = from_pandas(df[nodes_list],
 tabu_edges=tabu_edges,
 tabu_parent_nodes=tabu_parents,
 tabu_child_nodes=tabu_child,
 max_iter=100
 )


print(f'Running NOTEARS algorithm takes {time() - t0} seconds')

In [None]:
g = g_learned.copy()

In [None]:
g = g.get_largest_subgraph()

print(f"Learned DAG Edges: {len(g.edges)}")
print(f"Learned DAG Nodes: {len(g.nodes)}")
print(f"Learned DAG Degree View \n: {g.degree} \n")

bn = BayesianNetwork(g)

In [None]:
viz = plot_pretty_structure(bn.structure, edges_to_highlight=[])
Image(viz.draw(format='png'))

## DAG knowledge assisted

Next, we will enhance the learned DAG structure with domain knowledge extracted from the [Simulated dataset of corn response to nitrogen over thousands of fields and multiple years in Illinois](https://www.sciencedirect.com/science/article/pii/S2352340921010283) paper.

In [None]:
nodes_list_all = list(df_mapping.variable.unique())
nodes_matrix_all = sorted([(df_mapping[df_mapping.variable == node]
 ['level'].values[0], node) for node in nodes_list_all])

### N fertilizer edges

Add edges between the N fertilizer and the target nodes.

In [None]:
# Add direct links to the targets
g.add_edges_from([("N_fert", node, {"weight": 1.0}) for node in TARGETS])

### Water stress indicators

1. Add edges between Mean water stress indicators and the parent nodes
2. Add edges between Mean water stress indicators and the soil indicators

In [None]:
water_stress_features = df.columns[df.columns.str.contains('_fw')]

water_stress_matrix = sorted([(df_mapping[df_mapping.variable == node]
 ['level'].values[0], node) for node in water_stress_features])

water_stress_edges = [
 (node_i, node_j, {"weight": 1.0}) for idx, node_i in water_stress_matrix for node_j in TARGETS]

# add edges between Mean water stress indicators and the soil indicators

g_in_degree = [node[0] for node in sorted(
 g.in_degree, key=lambda x: x[1], reverse=True) if node[0] in g.nodes and node[1] > 2]


g.add_edges_from(water_stress_edges, origin="expert")

water_stress_edges

### Geospatial indicators

1. Add edges between geospatial data and the targets (level 5 variables)
2. Add edges between geospatial consecutive observations (consecutive isoweeks, aka satellite visits)


In [None]:
satellite_features = df.columns[df.columns.str.contains('_NDVI|_NDMI|_EVI2')]

satellite_matrix = sorted([(df_mapping[df_mapping.variable == node]
 ['level'].values[0], node) for node in satellite_features])

satellite_edges = [
 (node_i, node_j, {"weight": 1.0}) for idx, node_i in satellite_matrix for idy, node_j in nodes_matrix_all
 if idy == 5 and idx == 4]


consecutive_satellite_edges_target = [
 (node_i, node_j, {"weight": 1.0}) for idx, node_i in satellite_matrix for idy, node_j in satellite_matrix
 if idx == idy - 1 and node_i.split("_")[0:2] == node_j.split("_")[0:2]]


satellite_edges.extend(consecutive_satellite_edges_target)

g.add_edges_from(satellite_edges, origin="expert")

satellite_edges

### Soil indicators

Add soil nitrogen, biomass and water content links with the targets

In [None]:
# Add v5 Soil and Water content links with the targets
v5_edges = [ (node_i, node_j, {"weight": 1.0}) for node_i in g.nodes for node_j in TARGETS if node_i in ['n_deep_v5','biomass_v5','sw_dep_v5']]

g.add_edges_from(v5_edges, origin="expert")

v5_edges

### Rebase the graph

Overwrite the learned weights for the edges in order to maintain consistency.

In [None]:
g_edges = [(edge[0],edge[1], {"weight": 1.0}) for edge in list(g.edges)] 

g = StructureModel()
g.add_edges_from(
 g_edges,
 origin="expert",
)

In [None]:
# Get the largest subgraph of the Structure Model
g = g.get_largest_subgraph()

# Base class for Bayesian Network (BN), a probabilistic weighted DAG
# Nodes represent variables, 
# Edges represent the causal relationships between variables.
bn = BayesianNetwork(g)

In [None]:
viz = plot_pretty_structure(bn.structure, edges_to_highlight=[])
Image(viz.draw(format='png'))

## Discretise the data

In [None]:
features = list(g.nodes)

# You can use the Decision Tree Supervised Discretiser with the Corn Yield or the Soy Yield 
# Note: Use Corn if the subsequent studies are concering the Corn, and otherwise for Soybeans

target = 'Y_corn'

# ====================================================================
# Decision Tree Supervised Discretiser Method
# ====================================================================
 
features.remove(target)

# Discretisation of continuous features based on the split thresholds of a Decision Tree Regressor
discretiser = DecisionTreeSupervisedDiscretiserMethod(
 mode="single", 
 tree_params={"max_depth": 2, "random_state": 2022},
)
discretiser.fit(
 feat_names=features, 
 dataframe=df, 
 target_continuous=True,
 target=target,
)

discretised_data = discretiser.transform(df[features])
discretised_data.loc[:,target] = df[target].values

print(f"discretiser map thresholds: {discretiser.map_thresholds}")

# Discretisation of target (quantiles-based)
discretised_data[target], numeric_split_points_target = quantile_discretiser(discretised_data[target], num_buckets=4)

train, test = train_test_split(discretised_data, train_size=0.8, random_state=42)

## Fitting and evaluating the Bayesian Network

In [None]:
bn = BayesianNetwork(g)

In [None]:
bn = bn.fit_node_states(discretised_data)
bn = bn.fit_cpds(
 train, 
 method="BayesianEstimator",
 bayes_prior="K2",
)

In [None]:
classification_report(bn, test, 'Y_corn')

In [None]:
# It is recommended to update the model using the complete dataset for the following type of queries
bn = bn.fit_cpds(
 discretised_data, 
 method="BayesianEstimator", 
 bayes_prior="K2",
)

## Save model artifacts

Upload model artifacts to Amazon S3. This is where the inference endpoint will collect them later.

In [None]:
# Save the numerical split point
map_thresholds = [{f"{var}": list(discretiser.map_thresholds[var])} for var in discretiser.map_thresholds]
map_thresholds.extend([{f"{target}": list(numeric_split_points_target)}])
map_thresholds = {key:val for d in map_thresholds for key,val in d.items()}

with open(NUMERICAL_SPLIT_POINTS_PATH, 'w') as fp:
 json.dump(map_thresholds, fp)

# Save structure
nx.write_gml(g, DAG_PATH)

# Save model artifact after fitting the cpds
bn._model.save(MODEL_PATH, filetype='bif')

# Save the node states
node_states_dict = {c: dict([(int(el), int(el)) for el in sorted(discretised_data[c].unique())]) for c in discretised_data.columns}
with open(STATES_PATH, 'w') as fp:
 json.dump(node_states_dict, fp)

In [None]:
import tarfile

tar = tarfile.open("model.tar.gz", "w:gz")
for file in [DAG_PATH, MODEL_PATH, STATES_PATH, NUMERICAL_SPLIT_POINTS_PATH]:
 tar.add(file)
tar.close()

In [None]:
!aws s3 cp model.tar.gz s3://{AWS_S3_BUCKET}/models/

## SageMaker asynchronous inference

In [None]:
instance_type = "ml.m5.2xlarge"

model_artifact = f"s3://{AWS_S3_BUCKET}/models/model.tar.gz"

In [None]:
# We use a PyTorch inference DLC image that ships with sagemaker-pytorch-inference-toolkit 
image_uri = sagemaker.image_uris.retrieve(
 framework="pytorch",
 region=region,
 py_version="py38",
 image_scope="inference",
 version="1.10",
 instance_type=instance_type,
)

In [None]:
#!pygmentize src-inference/inference.py

### Create a SageMaker model

In [None]:
# SAGEMAKER_TS_BATCH_SIZE (int): This is the maximum batch size in ms that a model is expected to handle
# SAGEMAKER_TS_STARTUP_TIMEOUT (int): Time delay after which inference will timeout if model initialization fails
# SAGEMAKER_TS_RESPONSE_TIMEOUT (int): Time delay after which inference will timeout in absence of a response

env_variables_dict = {
 "SAGEMAKER_TS_BATCH_SIZE": "10000000",
 "SAGEMAKER_TS_STARTUP_TIMEOUT": "1200",
 "SAGEMAKER_TS_RESPONSE_TIMEOUT": "600",
 'TS_MAX_REQUEST_SIZE': '655350000',
 'TS_MAX_RESPONSE_SIZE': '655350000',
 'TS_DEFAULT_RESPONSE_TIMEOUT': '2000',
 
}

In [None]:
from sagemaker.model import Model
from sagemaker.predictor import Predictor

model_name = f"{SOLUTION_PREFIX}-bn-model"

model_predictor = Model(
 name=model_name,
 image_uri=image_uri,
 model_data=model_artifact,
 role=sm_role,
 source_dir="src-inference",
 entry_point="inference.py",
 predictor_cls=Predictor,
 env=env_variables_dict,
)
model_name

### Create AsyncInferenceConfig

In [None]:
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig

async_config = AsyncInferenceConfig(
 output_path=f"s3://{AWS_S3_BUCKET}/models/output",
 max_concurrent_invocations_per_instance=4,
)

### Create endpoint

In [None]:
import time

ENDPOINT_NAME = f"{SOLUTION_PREFIX}-bn-endpoint"

async_predictor = model_predictor.deploy(
 async_inference_config=async_config,
 instance_type=instance_type,
 initial_instance_count=1,
 endpoint_name=ENDPOINT_NAME,
 serializer=sagemaker.serializers.JSONSerializer(),
 deserializer=sagemaker.deserializers.JSONDeserializer(),
)

# Waiting for the inference engine to be initialized
time.sleep(90)

## Observational and counterfactuals inference

### Querying marginal distributions of the target node (yield) given some observations

In [None]:
# Sample cell_id / id_field(s)
query_node = 'N_fert'
yield_target = 'Y_corn'
samples_number = 4
requests = []
samples = []

sample_features = list(g.nodes)

df_query = df_full[sample_features + ['id_10','id_field','FIPS']]

for i in range(samples_number):

 sample = df_query.sample(1)
 samples.append(sample)

 # Add all observations
 request_nodes = [(feat, sample[feat].values[0]) for feat in sample_features]

 # Discretise the request
 request = discretiser_inverse_transform(map_thresholds,
 request=True,
 request_nodes=request_nodes,
 response_nodes=[])
 
 request = dict(request)
 
 # Remove target node form the request
 request.pop(yield_target)
 
 requests.append(request)
 
df_samples = pd.concat(samples)
df_samples = df_samples.drop_duplicates()

In [None]:
# Prepare the paylod
payload = {
 "method": "query",
 "observations": requests,
 "target": yield_target
}

In [None]:
# Dump the payload into a local JSON file
with open("tmp/request_payload_query.json", 'w') as fp:
 json.dump(payload, fp)

#### Upload the request payload

In [None]:
def upload_file(input_location):
 prefix = f"{AWS_S3_BUCKET}/inference/input"
 return sm_session.upload_data(
 input_location,
 bucket=sm_session.default_bucket(),
 key_prefix=prefix,
 extra_args={"ContentType": "application/json"},
 )

In [None]:
# upload request to S3
input_s3_location = upload_file("tmp/request_payload_query.json")

#### Invoke endpoint

In [None]:
# Invoke endpoint
response_endpoint = runtime.invoke_endpoint_async(
 EndpointName=ENDPOINT_NAME, 
 InputLocation=input_s3_location,
)

output_location =response_endpoint['OutputLocation']

#### Get inference outputs

In [None]:
def get_output(output_location):
 output_url = urllib.parse.urlparse(output_location)
 bucket = output_url.netloc
 key = output_url.path[1:]
 while True:
 try:
 return sm_session.read_s3_file(bucket=output_url.netloc, key_prefix=output_url.path[1:])
 except ClientError as e:
 if e.response["Error"]["Code"] == "NoSuchKey":
 print("waiting for the inference query")
 time.sleep(20)
 continue
 raise

In [None]:
# get inference outputs
output = json.loads(get_output(output_location))
print(f"\n Output: {output}")

In [None]:
# Format ouptut by converting the marginals probabilities into buckets
resp, _, _ = format_inference_output(output)

# Convert buckets into real number ranges
resp_transformed = discretiser_inverse_transform(map_thresholds,
 request=False,
 request_nodes=[],
 response_nodes=resp)

# collect marginals from the reponse
marginals = []

for idx, out in enumerate(output):
 marginals_df = pd.DataFrame.from_dict(
 out['marginals'], orient='index', columns=[f'marginals_{idx}'])
 marginals.append(marginals_df)

marginals = pd.concat(marginals, axis=1)
marginals['yield'] = df_full[yield_target].min()

# Note: if target is changed add the corresponding numeric_split_points_target (from the discretiser)
marginals['yield'].loc[1:] = map_thresholds[yield_target]
marginals = marginals.set_index('yield')

#### Plot marginals for the yield node

In [None]:
def plot_marginals(marginals, df_samples, resp_transformed, yield_target):

 plt.figure(figsize=(15, 5), dpi=120)

 for idx, col in enumerate(marginals):
 
 plt.plot(marginals.index, marginals[col], 'o--', label=f"FIPS:{df_samples['FIPS'].iloc[idx]} - CELL ID: {df_samples['id_10'].iloc[idx]}")
 plt.axvline(df_samples[yield_target].iloc[idx], color=plt.gca().lines[-1].get_color())
 plt.fill_between(marginals.index, marginals[col], alpha=0.1)
 
 plt.legend()
 plt.title(f"Marginal distributions of {yield_target} target node given the observations")
 plt.xlabel('Yield (kg/ha) | vertical lines represent the Yield actual values')
 plt.ylabel('Probability')

In [None]:
plot_marginals(marginals, df_samples, resp_transformed, yield_target)

#### Visualize the geolocation for the selected cell IDs

In [None]:
# Plot the sampled cell geo coordinates
ax = gpd_cells[gpd_cells.region == CROP_REGION].plot(cmap='Pastel2', figsize=(15,7))
gpd_cells[gpd_cells.id_10.isin(df_samples['id_10'].unique())].plot(ax=ax, facecolor='none', edgecolor='red')

### Making interventions (Do-calculus)

In [None]:
# Sample one cell_id / id_field
features = list(g.nodes)

action_node = 'N_fert'
yield_target = 'Y_corn'
sample_features = [action_node]

# Select query nodes

sample_features.extend([feat for feat in features if 'tmean' in feat or 'rad' in feat or 'rain' in feat])
sample_features.extend(satellite_features)

# Pick a sample
samples = df_full[sample_features + ['id_10','FIPS']]

sample = samples.sample(1)

# Add all observations
request_nodes = [(feat , sample[feat].values[0]) for feat in sample_features]

# Discretise the request
request = discretiser_inverse_transform(map_thresholds,
 request=True,
 request_nodes=request_nodes,
 response_nodes=[])

In [None]:
# Map thresholds action node
map_thresholds[action_node]

In [None]:
# Current value
print(f"Current value: {sample[action_node].values[0]} kg/ha")

* NOTE: select a value which differs significantly to the current value (ideally belonging to a different bucket), in order to observe the effect of the intervention

In [None]:
# Discretise
value = 30 # ADD VALUE HERE (eg. X kg/ha Nitrogen)
action_node_value = (action_node, value)
action_node_bucket = discretiser_inverse_transform(map_thresholds,
 request=True,
 request_nodes=[action_node_value],
 response_nodes=[])

In [None]:
# Remove the node we intervene on
request = dict(request)
action_node_before = (action_node,request.pop(action_node))
action_node_after = action_node_bucket[0]

In [None]:
# Prepare payload
payload = {
 "method": "do_calculus",
 "intervention_query": request,
 "interventions": [action_node_bucket[0]],
 "target": yield_target
}

In [None]:
# Dump the payload into a local JSON file
with open("tmp/request_payload_intervention.json", 'w') as fp:
 json.dump(payload, fp)

#### Uploading the Request Payload

In [None]:
def upload_file(input_location):
 prefix = f"{AWS_S3_BUCKET}/inference/input"
 return sm_session.upload_data(
 input_location,
 bucket=sm_session.default_bucket(),
 key_prefix=prefix,
 extra_args={"ContentType": "application/json"},
 )

In [None]:
# Upload request to S3
input_s3_location = upload_file("tmp/request_payload_intervention.json")

#### Invoke endpoint

In [None]:
# Invoke endpoint
response_endpoint = runtime.invoke_endpoint_async(
 EndpointName=ENDPOINT_NAME, 
 InputLocation=input_s3_location,
)

output_location =response_endpoint['OutputLocation']

#### Get inference outputs

In [None]:
def get_output(output_location):
 output_url = urllib.parse.urlparse(output_location)
 bucket = output_url.netloc
 key = output_url.path[1:]
 while True:
 try:
 return sm_session.read_s3_file(bucket=output_url.netloc, key_prefix=output_url.path[1:])
 except ClientError as e:
 if e.response["Error"]["Code"] == "NoSuchKey":
 print("waiting for the inference do-calculus")
 time.sleep(20)
 continue
 raise

In [None]:
# Get inference outputs
output = json.loads(get_output(output_location))
print(f"\n Output: {output}")

### Plot counterfactuals

In [None]:
# Collect marginals (before and after) into a pandas frame
df_marginals_before = pd.DataFrame.from_dict(output['marginals-before'], orient='index', columns=['before'])
df_marginals_after = pd.DataFrame.from_dict(output['marginals-after'], orient='index', columns=['after'])

counterfactuals = pd.concat([df_marginals_before,df_marginals_after],axis=1)
counterfactuals['yield'] = 0

# Note: if target is changed add the corresponding numeric_split_points_target
counterfactuals['yield'].loc[1:] = map_thresholds[yield_target]
counterfactuals = counterfactuals.set_index('yield')

In [None]:
def plot_counterfactuals(cf, sample, yield_target, action_node_before, action_node_after):

 plt.figure(figsize=(12, 5), dpi=120)

 plt.plot(cf.index, cf['before'], 'o--', label=f"Nitrogen (kg/ha): {action_node_before[0]}")
 plt.fill_between(cf.index, cf['before'], alpha=0.1)

 plt.plot(cf.index, cf['after'], 'o--', label=f"Nitrogen (kg/ha): {action_node_after[0]}")
 plt.fill_between(cf.index, cf['after'], alpha=0.1)
 
 for xl in range(cf.shape[0]):
 plt.axvline(x = cf.index.values[xl], color ='gray', linestyle="--")

 plt.legend()
 plt.title(f"-- FIPS:{sample['FIPS'].values[0]} - CELL ID: {sample['id_10'].values[0]} -- ")
 plt.suptitle(f"Distribution of {yield_target} Yield given Nitrogen added as fertilizer")
 plt.xlabel('Yield (kg/ha) | vertical lines represent the Yield discretisation')
 plt.ylabel('Probability')

In [None]:
action_node_before_real = discretiser_inverse_transform(map_thresholds,
 request=False,
 request_nodes=[],
 response_nodes=[action_node_before])

action_node_after_real = discretiser_inverse_transform(map_thresholds,
 request=False,
 request_nodes=[],
 response_nodes=[action_node_after])

plot_counterfactuals(counterfactuals, sample, yield_target, action_node_before_real, action_node_after_real)

### Clean Up

In [None]:
# Delete the SageMaker endpoint
async_predictor.delete_endpoint()