<div style="font-size:200%;font-weight:bold">Amazon SageMaker RL Result Evaluation</div>

This notebook evaluates a completed SageMaker training job.

In [None]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

import os
import sagemaker

from energy_storage_system.envs import SimpleBattery
from energy_storage_system.utils import evaluate_episode, plot_analysis
from smnb_utils.misc import wait_for_s3_object

# Global config
sage_session = sagemaker.session.Session()
s3_bucket = sage_session.default_bucket()  
s3_output_path = "s3://{}/".format(s3_bucket)
print("S3 bucket path: {}".format(s3_output_path))
# load the previous job
%store -r job_name

In [21]:
try:
    job_name
except NameError:
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
    print("[ERROR] Please run the notebook 01_battery_sm_on_sm before you continue.")
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

# Get Model Checkpoint

In [None]:
print("Job name: {}".format(job_name))

s3_url = "s3://{}/{}".format(s3_bucket,job_name)

intermediate_folder_key = "{}/output/intermediate/".format(job_name)
intermediate_url = "s3://{}/{}".format(s3_bucket, intermediate_folder_key)

print("S3 job path: {}".format(s3_url))
print("Intermediate folder path: {}".format(intermediate_url))
    
tmp_dir = "/tmp/{}".format(job_name)
os.system("mkdir {}".format(tmp_dir))
print("Create local folder {}".format(tmp_dir))

Download model checkpoint from s3 into `/tmp`

In [None]:
model_tar_key = "{}/output/model.tar.gz".format(job_name)
    
local_checkpoint_dir = "{}/model".format(tmp_dir)

wait_for_s3_object(s3_bucket, model_tar_key, tmp_dir, training_job_name=job_name)  

if not os.path.isfile("{}/model.tar.gz".format(tmp_dir)):
    raise FileNotFoundError("File model.tar.gz not found")
    
os.system("mkdir -p {}".format(local_checkpoint_dir))
os.system("tar -xvzf {}/model.tar.gz -C {}".format(tmp_dir, local_checkpoint_dir))

print("Checkpoint directory {}".format(local_checkpoint_dir))

checkpoint_path = f"{local_checkpoint_dir}/checkpoint"
print("checkpoint_path",checkpoint_path)

# Deserialize checkpoint to an rllib agent

In [32]:
import ray
from ray import tune
from ray.rllib.agents import dqn
import warnings

def get_agent(checkpoint_path, file_path= "data/sample-data.csv"):
    """Deserialize a checkpoint into an in-memory agent."""
    def register_env_creator(env_name):
        tune.register_env(env_name, lambda env_config: SimpleBattery(env_config))

    # Alternatively to register custom env and pass to trainer, DQNTrainer(config=config, env=env_class)
    # env_class = "battery"
    # register_env_creator(env_class)

    config = dqn.DEFAULT_CONFIG.copy()
    config["num_workers"] = 1
    config["explore"] = False
    config["evaluation_config"] = {"explore": False}

    ray.shutdown()
    ray.init(local_mode=True)

    # Instantiate agent. Agent need env to be registered as it will be using tune behind the scene.
    # env: can pass in MyEnv(gym), or a registered environment (e.g. env_class)
    # region: HAHA begin HACK
    warnings.warn('HAHA: Hack to pass custom env_config to DQNTrainer')
    env_config = {"MAX_STEPS_PER_EPISODE": 168, "LOCAL": True, "FILEPATH": file_path}
    class _SimpleBattery(SimpleBattery):
        def __init__(self, *args, **kwargs):
            super().__init__(env_config=env_config)
    agent = dqn.DQNTrainer(config=config, env=_SimpleBattery)
    # endregion: HAHA end HACK

    # Load trained model
    agent.restore(checkpoint_path)

    return agent

# Policy evaluation and observation

In [None]:
agent = get_agent(checkpoint_path)

In [None]:
import numpy as np
np.random.seed(2)
env_config = {"MAX_STEPS_PER_EPISODE": 168, "LOCAL": True, "FILEPATH": "data/sample-data.csv"}
env = SimpleBattery(env_config)
df_eval = evaluate_episode(agent, env)
fig = plot_analysis(df_eval)

In [None]:
!mkdir -p ../data/streamlit_input
df_eval.to_csv("../data/streamlit_input/result_dqn.csv", index=False)