#!/usr/bin/env python # A sample training component that trains a simple scikit-learn decision tree model. # This implementation works in File mode and makes no assumptions about the input file names. # Input is specified as CSV with a data point in each row and the labels in the first column. from __future__ import print_function import json import os import pickle import sys import traceback import pandas as pd from causalnex.discretiser import Discretiser import warnings from causalnex.structure import StructureModel from sklearn.model_selection import train_test_split from causalnex.network import BayesianNetwork from causalnex.evaluation import classification_report from causalnex.evaluation import roc_auc # These are the paths where SageMaker will take data from, save model to and # store outputs in your container. prefix = '/opt/ml/' input_path = prefix + 'input/data' output_path = os.path.join(prefix, 'output') model_path = os.path.join(prefix, 'model') param_path = os.path.join(prefix, 'input/config/hyperparameters.json') # This algorithm has a single channel of input data called 'training'. Since we run in # File mode, the input files are copied to the directory specified here. channel_name='training' training_path = os.path.join(input_path, channel_name) # The function to execute the training. def train(): print('Starting the training.') try: # Take the set of files and read them all into a single pandas dataframe input_files = [ os.path.join(training_path, file) for file in os.listdir(training_path) ] if len(input_files) == 0: raise ValueError(('There are no files in {}.\n' + 'This usually indicates that the channel ({}) was incorrectly specified,\n' + 'the data specification in S3 was incorrectly specified or the role specified\n' + 'does not have permission to access the data.').format(training_path, channel_name)) raw_data =[] for file in input_files: if '.csv' in file: raw_data.append(pd.read_csv(file, encoding='utf8', engine='python')) train_data = pd.concat(raw_data) # Discretize the data with causalnex functions. For example, you can use a statistical distibution # data for this e.g. quartiles. For more details, please refer to the notebook. train_data["age"] = Discretiser(method="fixed", numeric_split_points=[60]).transform(train_data["age"].values) train_data["serum_sodium"] = Discretiser(method="fixed", numeric_split_points=[136]).transform(train_data["serum_sodium"].values) train_data["serum_creatinine"] = Discretiser(method="fixed", numeric_split_points=[1.1,1.4]).transform(train_data["serum_sodium"].values) train_data["ejection_fraction"] = Discretiser(method="fixed", numeric_split_points=[30,38,42]).transform(train_data["ejection_fraction"].values) train_data["creatinine_phosphokinase"] = Discretiser(method="fixed", numeric_split_points=[120,540,670]).transform(train_data["creatinine_phosphokinase"].values) train_data["platelets"] = Discretiser(method="fixed", numeric_split_points=[263358]).transform(train_data["platelets"].values) # Start to build the graph structure warnings.filterwarnings("ignore") # silence warnings sm = StructureModel() sm.add_edges_from([ ('ejection_fraction', 'DEATH_EVENT'), ('creatinine_phosphokinase', 'DEATH_EVENT'), ('age','DEATH_EVENT'), ('smoking','high_blood_pressure'), ('age','high_blood_pressure'), ('serum_sodium','DEATH_EVENT'), ('high_blood_pressure','DEATH_EVENT'), ('anaemia','DEATH_EVENT'), ('creatinine_phosphokinase','DEATH_EVENT'), ('smoking','DEATH_EVENT') ]) train, test = train_test_split(train_data, train_size=0.8, test_size=0.2, random_state=42) bn = BayesianNetwork(sm) bn = bn.fit_node_states(train_data) bn = bn.fit_cpds(train, method="BayesianEstimator", bayes_prior="K2") roc, auc = roc_auc(bn, test, "DEATH_EVENT") print("Model AUC: " + str(auc)) print(classification_report(bn, test, "DEATH_EVENT")) # save the model with open(os.path.join(model_path, 'causal_model.pkl'), 'wb') as out: pickle.dump(bn, out) print('Training complete.') except Exception as e: # Write out an error file. This will be returned as the failureReason in the # DescribeTrainingJob result. trc = traceback.format_exc() print(str(e)) with open(os.path.join(output_path, 'failure'), 'w') as s: s.write('Exception during training: ' + str(e) + '\n' + trc) # Printing this causes the exception to be in the training job logs, as well. print('Exception during training: ' + str(e) + '\n' + trc, file=sys.stderr) # A non-zero exit code causes the training job to be marked as Failed. sys.exit(255) if __name__ == '__main__': train() # A zero exit code causes the job to be marked a Succeeded. sys.exit(0)