{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "
Amazon SageMaker RL Result Evaluation
\n", "\n", "This notebook evaluates a completed SageMaker training job." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "%config InlineBackend.figure_format = 'retina'\n", "\n", "import os\n", "import sagemaker\n", "\n", "from energy_storage_system.envs import SimpleBattery\n", "from energy_storage_system.utils import evaluate_episode, plot_analysis\n", "from smnb_utils.misc import wait_for_s3_object\n", "\n", "# Global config\n", "sage_session = sagemaker.session.Session()\n", "s3_bucket = sage_session.default_bucket() \n", "s3_output_path = \"s3://{}/\".format(s3_bucket)\n", "print(\"S3 bucket path: {}\".format(s3_output_path))", "\n", "# load the previous job\n", "%store -r job_name" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "try:\n", " job_name\n", "except NameError:\n", " print(\"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")\n", " print(\"[ERROR] Please run the notebook 01_battery_sm_on_sm before you continue.\")\n", " print(\"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Get Model Checkpoint" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "print(\"Job name: {}\".format(job_name))\n", "\n", "s3_url = \"s3://{}/{}\".format(s3_bucket,job_name)\n", "\n", "intermediate_folder_key = \"{}/output/intermediate/\".format(job_name)\n", "intermediate_url = \"s3://{}/{}\".format(s3_bucket, intermediate_folder_key)\n", "\n", "print(\"S3 job path: {}\".format(s3_url))\n", "print(\"Intermediate folder path: {}\".format(intermediate_url))\n", " \n", "tmp_dir = \"/tmp/{}\".format(job_name)\n", "os.system(\"mkdir {}\".format(tmp_dir))\n", "print(\"Create local folder {}\".format(tmp_dir))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Download model checkpoint from s3 into `/tmp`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_tar_key = \"{}/output/model.tar.gz\".format(job_name)\n", " \n", "local_checkpoint_dir = \"{}/model\".format(tmp_dir)\n", "\n", "wait_for_s3_object(s3_bucket, model_tar_key, tmp_dir, training_job_name=job_name) \n", "\n", "if not os.path.isfile(\"{}/model.tar.gz\".format(tmp_dir)):\n", " raise FileNotFoundError(\"File model.tar.gz not found\")\n", " \n", "os.system(\"mkdir -p {}\".format(local_checkpoint_dir))\n", "os.system(\"tar -xvzf {}/model.tar.gz -C {}\".format(tmp_dir, local_checkpoint_dir))\n", "\n", "print(\"Checkpoint directory {}\".format(local_checkpoint_dir))\n", "\n", "checkpoint_path = f\"{local_checkpoint_dir}/checkpoint\"\n", "print(\"checkpoint_path\",checkpoint_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Deserialize checkpoint to an rllib agent" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "import ray\n", "from ray import tune\n", "from ray.rllib.agents import dqn\n", "import warnings\n", "\n", "def get_agent(checkpoint_path, file_path= \"data/sample-data.csv\"):\n", " \"\"\"Deserialize a checkpoint into an in-memory agent.\"\"\"\n", " def register_env_creator(env_name):\n", " tune.register_env(env_name, lambda env_config: SimpleBattery(env_config))\n", "\n", " # Alternatively to register custom env and pass to trainer, DQNTrainer(config=config, env=env_class)\n", " # env_class = \"battery\"\n", " # register_env_creator(env_class)\n", "\n", " config = dqn.DEFAULT_CONFIG.copy()\n", " config[\"num_workers\"] = 1\n", " config[\"explore\"] = False\n", " config[\"evaluation_config\"] = {\"explore\": False}\n", "\n", " ray.shutdown()\n", " ray.init(local_mode=True)\n", "\n", " # Instantiate agent. Agent need env to be registered as it will be using tune behind the scene.\n", " # env: can pass in MyEnv(gym), or a registered environment (e.g. env_class)\n", " # region: HAHA begin HACK\n", " warnings.warn('HAHA: Hack to pass custom env_config to DQNTrainer')\n", " env_config = {\"MAX_STEPS_PER_EPISODE\": 168, \"LOCAL\": True, \"FILEPATH\": file_path}\n", " class _SimpleBattery(SimpleBattery):\n", " def __init__(self, *args, **kwargs):\n", " super().__init__(env_config=env_config)\n", " agent = dqn.DQNTrainer(config=config, env=_SimpleBattery)\n", " # endregion: HAHA end HACK\n", "\n", " # Load trained model\n", " agent.restore(checkpoint_path)\n", "\n", " return agent" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Policy evaluation and observation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "agent = get_agent(checkpoint_path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "np.random.seed(2)\n", "env_config = {\"MAX_STEPS_PER_EPISODE\": 168, \"LOCAL\": True, \"FILEPATH\": \"data/sample-data.csv\"}\n", "env = SimpleBattery(env_config)\n", "df_eval = evaluate_episode(agent, env)\n", "fig = plot_analysis(df_eval)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!mkdir -p ../data/streamlit_input\n", "df_eval.to_csv(\"../data/streamlit_input/result_dqn.csv\", index=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "conda_python3", "language": "python", "name": "conda_python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" }, "notice": "Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.", "pycharm": { "stem_cell": { "cell_type": "raw", "metadata": { "collapsed": false }, "source": [] } } }, "nbformat": 4, "nbformat_minor": 4 }