# Demo -- Causal Inference Engine

**Jupyter Kernel**:


* If you are in SageMaker Studio, make sure that you use the **PyTorch 1.10 Python 3.8 CPU Optimized** environment.
* Make sure that you are using one of the following instance types: `ml.m5.large`, `ml.c5.large`, or `ml.g4dn.xlarge`.

**Run All**: 

* If you are in SageMaker Studio, you can choose the **Run All Cells** from the **Run** tab dropdown menu to run the entire notebook at once.

In [None]:
# Install dependencies for this notebook.
!pip3 install -r ./utils/requirements.in -q

This solution relies on a config file to run the provisioned AWS resources. Run the following cells to generate that file.

In [None]:
import boto3
import os
import json

In [None]:
client = boto3.client('servicecatalog')
cwd = os.getcwd().split('/')
i= cwd.index('S3Downloads')
pp_name = cwd[i + 1]
pp = client.describe_provisioned_product(Name=pp_name)
record_id = pp['ProvisionedProductDetail']['LastSuccessfulProvisioningRecordId']
record = client.describe_record(Id=record_id)

keys = [ x['OutputKey'] for x in record['RecordOutputs'] if 'OutputKey' and 'OutputValue' in x]
values = [ x['OutputValue'] for x in record['RecordOutputs'] if 'OutputKey' and 'OutputValue' in x]
stack_output = dict(zip(keys, values))

with open(f'/root/S3Downloads/{pp_name}/stack_outputs.json', 'w') as f:
 json.dump(stack_output, f)

In [None]:
sagemaker_config = json.load(open("stack_outputs.json"))

SOLUTION_BUCKET = sagemaker_config["SolutionS3Bucket"]
AWS_REGION = sagemaker_config["AWSRegion"]
SOLUTION_NAME = sagemaker_config["SolutionName"]
AWS_S3_BUCKET = sagemaker_config["S3Bucket"]
LIBRARY_VERSION = sagemaker_config["LibraryVersion"]
ENDPOINT_NAME = sagemaker_config["SolutionPrefix"] + "-demo-endpoint"

KEY_YIELD_CURVE = "data/raw/yield_curve_field_dt.csv"
SPATIAL_FILES_KEY = "data/spatial-files"
FIPS_STATS_KEY = "data/fips-stats/fips_county_stats.csv"
FIPS_POLYGONS_KEY = "data/fips-stats/geojson-counties-fips.json"
SENTINEL_2_SHAPEFILE_KEY = "data/sentinel-2-shapefiles"
CROPS_MASK_KEY = "data/crop_mask/raw"
REQUEST_MANIFESTS_KEY = "request_manifests/"

DAG_PATH = 'model/models/bn_structure.gml'
MODEL_PATH = 'model/models/bayesian_model.bif'
STATES_PATH = 'model/models/node_states.json'
NUMERICAL_SPLIT_POINTS_PATH = "model/models/numerical_split_points.json"

if not os.path.exists('model'):
 os.makedirs('model')

### Copy simulated data to S3

This solution uses both geospatial data and ground-level observations. We use ground-level observations from a publicly available [simulated dataset](https://data.mendeley.com/datasets/xs5nbm4w55/) of corn response to Nitrogen over thousands of fields and multiple years in Illinois.

For ease of access, we made the datasets available in an Amazon S3 bucket. Download the dataset from S3 in the following cells. 

In [None]:
from sagemaker.s3 import S3Downloader

original_bucket = f"s3://{SOLUTION_BUCKET}-{AWS_REGION}/{LIBRARY_VERSION}/{SOLUTION_NAME}"
original_data = f"{original_bucket}/artifacts/data/"
current_location = f"s3://{AWS_S3_BUCKET}/data/"
print("original data:")
S3Downloader.list(original_data)

In [None]:
if not S3Downloader.list(current_location):
 !aws s3 cp --recursive $original_data $current_location

### Set up the environment

In [None]:
import pandas as pd
import numpy
import json
import datetime
import matplotlib.pyplot as plt
import boto3
import io
import os
import s3fs
import itertools as it
import networkx as nx
from time import time
import geopandas as gpd
import copy
import bisect
from typing import Dict
import warnings
import base64
from PIL import Image
import datetime
from time import gmtime, strftime
import urllib
import time

import sagemaker
import boto3
from botocore.exceptions import ClientError

# from utils.plot_functions imports visualize_structure
from utils.causalnex_helpers import (
 discretiser_inverse_transform,
 format_inference_output
)

from utils.helper_functions import download_s3_folder

warnings.simplefilter('ignore')

%matplotlib inline

In [None]:
# Define a few variables to use throughout the notebook
EPSG = 'epsg:4326'
TARGETS = ["Y_corn"]

In [None]:
# Get the SageMaker session, SageMaker execution role, Region name, and S3 resource
boto_session = boto3.session.Session()
sm_session = sagemaker.session.Session()
region = boto_session.region_name
sm_role = sagemaker.get_execution_role()
runtime = boto3.Session().client('sagemaker-runtime')
s3 = boto3.resource('s3')

Download spatial files locally.

In [None]:
download_s3_folder(AWS_S3_BUCKET,SPATIAL_FILES_KEY, "tmp/spatial-files")
download_s3_folder(AWS_S3_BUCKET,SENTINEL_2_SHAPEFILE_KEY, "tmp/Sentinel-2-Shapefile-Index")

### Read the dataset and crop staging mapping file

`Note` Load files produced in notebook 2 `01 Feature Engineering.ipynb`

In [None]:
# read enhanced dataset
df_full = pd.read_csv(
 f"s3://{original_data}enhanced/"
 f"enhanced_dataset_filtered_2018_2_Central.csv",
)

# read crop staging mapping file
df_mapping = pd.read_csv(
 f"s3://{original_data}enhanced/"
 f"stage_mapping_filtered_2018_2_Central.csv",
)

# read spatial files
gpd_cells = gpd.read_file("tmp/spatial-files/cells_sf.shp")
gpd_cells = gpd_cells.to_crs(EPSG)

# for the DAG setup remove the identifiers
df = df_full.drop(columns=['FIPS','id_field','id_10','LAI_max','n_uptake'])
df_mapping = df_mapping[df_mapping.variable.isin(df.columns)]

In [None]:
model_artifact = f"{original_bucket}/artifacts/models/model.tar.gz"

You can view the inference script by uncommenting the line in the following cell:

In [None]:
#!pygmentize src-inference/inference.py

## Observational and counterfactuals inference

In [None]:
# Copy model artifacts locally
!aws s3 cp {model_artifact} ./
!tar -C ./model -zxvf model.tar.gz

In [None]:
# Read the numerical split points
with open(NUMERICAL_SPLIT_POINTS_PATH, 'r') as fp:
 map_thresholds= json.load(fp)

# Load the DAG structure
g = nx.read_gml(DAG_PATH)

### Querying marginal distributions of the target node (yield) given some observations

#### Prepare the request payload

In [None]:
# Sample cell_id / id_field(s)
query_node = 'N_fert'
yield_target = 'Y_corn'
samples_number = 2
requests = []
samples = []

sample_features = list(g.nodes)

df_query = df_full[sample_features + ['id_10','id_field','FIPS']]

for i in range(samples_number):

 sample = df_query.sample(1)
 samples.append(sample)

 # Add all observations
 request_nodes = [(feat, sample[feat].values[0]) for feat in sample_features]

 # Discretise the request
 request = discretiser_inverse_transform(map_thresholds,
 request=True,
 request_nodes=request_nodes,
 response_nodes=[])
 
 request = dict(request)
 
 # Remove target node from the request
 request.pop(yield_target)
 
 requests.append(request)
 
df_samples = pd.concat(samples)
df_samples = df_samples.drop_duplicates()

In [None]:
# Prepare the payload
payload = {
 "method": "query",
 "observations": requests,
 "target": yield_target
}

In [None]:
# Dump the payload into a local JSON file
with open("tmp/request_payload_query.json", 'w') as fp:
 json.dump(payload, fp)

#### Upload the request payload

In [None]:
def upload_file(input_location):
 prefix = f"{AWS_S3_BUCKET}/inference/input"
 return sm_session.upload_data(
 input_location,
 bucket=sm_session.default_bucket(),
 key_prefix=prefix,
 extra_args={"ContentType": "application/json"},
 )

In [None]:
# Upload request to S3
input_s3_location = upload_file("tmp/request_payload_query.json")

#### Invoke endpoint

In [None]:
# Invoke endpoint
response_endpoint = runtime.invoke_endpoint_async(
 EndpointName=ENDPOINT_NAME, 
 InputLocation=input_s3_location,
)

output_location =response_endpoint['OutputLocation']

#### Get inference outputs

In [None]:
def get_output(output_location):
 output_url = urllib.parse.urlparse(output_location)
 bucket = output_url.netloc
 key = output_url.path[1:]
 while True:
 try:
 return sm_session.read_s3_file(bucket=output_url.netloc, key_prefix=output_url.path[1:])
 except ClientError as e:
 if e.response["Error"]["Code"] == "NoSuchKey":
 print("waiting for the inference query")
 time.sleep(20)
 continue
 raise

In [None]:
# Get inference outputs
output = json.loads(get_output(output_location))
print(f"\n Output: {output}")

In [None]:
# Format output by converting the marginals probabilities into buckets
resp, _, _ = format_inference_output(output)

# Convert buckets into real number ranges
resp_transformed = discretiser_inverse_transform(map_thresholds,
 request=False,
 request_nodes=[],
 response_nodes=resp)

# Collect marginals from the reponse
marginals = []

for idx, out in enumerate(output):
 marginals_df = pd.DataFrame.from_dict(
 out['marginals'], orient='index', columns=[f'marginals_{idx}'])
 marginals.append(marginals_df)

marginals = pd.concat(marginals, axis=1)
marginals['yield'] = df_full[yield_target].min()

# Note: if target is changed add the corresponding numeric_split_points_target (from the discretiser)
marginals['yield'].loc[1:] = map_thresholds[yield_target]
marginals = marginals.set_index('yield')

#### Plot marginals for the yield node

In [None]:
def plot_marginals(marginals, df_samples, resp_transformed, yield_target):

 plt.figure(figsize=(15, 5), dpi=120)

 for idx, col in enumerate(marginals):
 
 plt.plot(marginals.index, marginals[col], 'o--', label=f"FIPS:{df_samples['FIPS'].iloc[idx]} - CELL ID: {df_samples['id_10'].iloc[idx]}")
 plt.axvline(df_samples[yield_target].iloc[idx], color=plt.gca().lines[-1].get_color())
 plt.fill_between(marginals.index, marginals[col], alpha=0.1)
 
 plt.legend()
 plt.title(f"Marginal distributions of {yield_target} target node given the observations")
 plt.xlabel('Yield (kg/ha) | vertical lines represent the Yield actual values')
 plt.ylabel('Probability')

In [None]:
plot_marginals(marginals, df_samples, resp_transformed, yield_target)

#### Visualize the geolocation for the selected cells IDs

In [None]:
# Plot the sampled cells geo coordinates
ax = gpd_cells[gpd_cells.region == '2-Central'].plot(cmap='Pastel2', figsize=(15,7))
gpd_cells[gpd_cells.id_10.isin(df_samples['id_10'].unique())].plot(ax=ax, facecolor='none', edgecolor='red')

### Making interventions (Do-calculus)

#### Prepare the Request Payload

In [None]:
# Sample one cell_id / id_field
features = list(g.nodes)

action_node = 'N_fert'
yield_target = 'Y_corn'
sample_features = [action_node]

# Select query nodes

satellite_features = [feat for feat in features if feat.startswith("mean_")]

sample_features.extend([feat for feat in features if 'tmean' in feat or 'rad' in feat or 'rain' in feat])
sample_features.extend(satellite_features)

# Pick a sample
samples = df_full[sample_features + ['id_10','FIPS']]

sample = samples.sample(1)

# Add all observations
request_nodes = [(feat , sample[feat].values[0]) for feat in sample_features]

# Discretise the request
request = discretiser_inverse_transform(map_thresholds,
 request=True,
 request_nodes=request_nodes,
 response_nodes=[])

In [None]:
# Current value
print(f"Current value: {sample[action_node].values[0]} kg/ha")

In [None]:
# Map thresholds action node
map_thresholds[action_node]

> NOTE: select a value that differs significantly from the current value (ideally belonging to a different bucket), in order to observe the effect of the intervention.

In [None]:
# Discretise
value = 20 # ADD VALUE HERE (eg. 20 kg/ha Nitrogen)
action_node_value = (action_node, value)
action_node_bucket = discretiser_inverse_transform(map_thresholds,
 request=True,
 request_nodes=[action_node_value],
 response_nodes=[])

In [None]:
# Remove the node we intervene on
request = dict(request)
action_node_before = (action_node,request.pop(action_node))
action_node_after = action_node_bucket[0]

In [None]:
# Prepare payload
payload = {
 "method": "do_calculus",
 "intervention_query": request,
 "interventions": [action_node_bucket[0]],
 "target": yield_target
}

In [None]:
# Dump the payload into a local JSON file
with open("tmp/request_payload_intervention.json", 'w') as fp:
 json.dump(payload, fp)

#### Upload the request payload

In [None]:
def upload_file(input_location):
 prefix = f"{AWS_S3_BUCKET}/inference/input"
 return sm_session.upload_data(
 input_location,
 bucket=sm_session.default_bucket(),
 key_prefix=prefix,
 extra_args={"ContentType": "application/json"},
 )

In [None]:
# Upload request to S3
input_s3_location = upload_file("tmp/request_payload_intervention.json")

#### Invoke endpoint

In [None]:
# Invoke endpoint
response_endpoint = runtime.invoke_endpoint_async(
 EndpointName=ENDPOINT_NAME, 
 InputLocation=input_s3_location,
)

output_location =response_endpoint['OutputLocation']

#### Get inference outputs

In [None]:
def get_output(output_location):
 output_url = urllib.parse.urlparse(output_location)
 bucket = output_url.netloc
 key = output_url.path[1:]
 while True:
 try:
 return sm_session.read_s3_file(bucket=output_url.netloc, key_prefix=output_url.path[1:])
 except ClientError as e:
 if e.response["Error"]["Code"] == "NoSuchKey":
 print("waiting for the inference do-calculus")
 time.sleep(20)
 continue
 raise

In [None]:
# Get inference outputs
output = json.loads(get_output(output_location))
print(f"\n Output: {output}")

### Plot counterfactuals

In [None]:
# collect marginals (before and after) into a pandas frame
df_marginals_before = pd.DataFrame.from_dict(output['marginals-before'], orient='index', columns=['before'])
df_marginals_after = pd.DataFrame.from_dict(output['marginals-after'], orient='index', columns=['after'])

counterfactuals = pd.concat([df_marginals_before,df_marginals_after],axis=1)
counterfactuals['yield'] = 0

# Note: if target is changed add the corresponding numeric_split_points_target
counterfactuals['yield'].loc[1:] = map_thresholds[yield_target]
counterfactuals = counterfactuals.set_index('yield')

In [None]:
def plot_counterfactuals(cf, sample, yield_target, action_node_before, action_node_after):

 plt.figure(figsize=(12, 5), dpi=120)

 plt.plot(cf.index, cf['before'], 'o--', label=f"Nitrogen (kg/ha): {action_node_before[0]}")
 plt.fill_between(cf.index, cf['before'], alpha=0.1)

 plt.plot(cf.index, cf['after'], 'o--', label=f"Nitrogen (kg/ha): {action_node_after[0]}")
 plt.fill_between(cf.index, cf['after'], alpha=0.1)
 
 for xl in range(cf.shape[0]):
 plt.axvline(x = cf.index.values[xl], color ='gray', linestyle="--")

 plt.legend()
 plt.title(f"-- FIPS:{sample['FIPS'].values[0]} - CELL ID: {sample['id_10'].values[0]} -- ")
 plt.suptitle(f"Distribution of {yield_target} Yield given Nitrogen added as fertilizer")
 plt.xlabel('Yield (kg/ha) | vertical lines represent the Yield discretisation')
 plt.ylabel('Probability')

In [None]:
action_node_before_real = discretiser_inverse_transform(map_thresholds,
 request=False,
 request_nodes=[],
 response_nodes=[action_node_before])

action_node_after_real = discretiser_inverse_transform(map_thresholds,
 request=False,
 request_nodes=[],
 response_nodes=[action_node_after])

plot_counterfactuals(counterfactuals, sample, yield_target, action_node_before_real, action_node_after_real)