{ "cells": [ { "cell_type": "markdown", "id": "0a6e9ba6", "metadata": {}, "source": [ "## Introduction\n", "This notebook demonstrates how to test if an action masking model is working as expected. When masking is working properly, any actions that are assigned mask=0 will not be sampled by the RL trainer. Here we consider the action masking model developed for the portfolio optimization problem. This action vector here has three componenets. Each componenet is sampled from a discrete action space with 11 possible values. " ] }, { "cell_type": "markdown", "id": "c029dbc1", "metadata": {}, "source": [ "### Install packages\n" ] }, { "cell_type": "code", "execution_count": null, "id": "17851682", "metadata": {}, "outputs": [], "source": [ "!pip install ray==1.6.0\n", "!pip install gym==0.15.3\n", "!pip install dm_tree\n", "!pip install lz4" ] }, { "cell_type": "markdown", "id": "656159be", "metadata": {}, "source": [ "### Import libraries" ] }, { "cell_type": "code", "execution_count": 2, "id": "b5ec6b90", "metadata": {}, "outputs": [], "source": [ "import ray\n", "import ray.rllib.agents.ppo as ppo\n", "from ray.tune.registry import register_env\n", "from trading import mytradingenv\n", "from mask_model import register_actor_mask_model\n", "import numpy as np\n" ] }, { "cell_type": "markdown", "id": "367b46ef", "metadata": {}, "source": [ "### Register the action masking model " ] }, { "cell_type": "code", "execution_count": 3, "id": "6a0828ee", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ec2-user/anaconda3/envs/amazonei_tensorflow2_p36/lib/python3.6/site-packages/ray/_private/services.py:238: UserWarning: Not all Ray Dashboard dependencies were found. To use the dashboard please install Ray using `pip install ray[default]`. To disable this message, set RAY_DISABLE_IMPORT_WARNING env var to '1'.\n", " warnings.warn(warning_message)\n" ] } ], "source": [ "register_actor_mask_model()\n", "ray.shutdown()\n", "ray.init(ignore_reinit_error=True)\n", "\n", "env_config={}\n", "register_env(\"customtradingmodel\", lambda env_config:mytradingenv(env_config))\n", "\n" ] }, { "cell_type": "markdown", "id": "b59088a2", "metadata": {}, "source": [ "### Specify the environment config to include the action masking model" ] }, { "cell_type": "code", "execution_count": 4, "id": "6e923143", "metadata": {}, "outputs": [], "source": [ "\n", "TestEnvConfig = {\n", " \"log_level\":\"WARN\",\n", "\n", " \"model\": {\n", " \n", " \"custom_model\": \"trading_mask\" # Define the custom masking model in the config \n", " \n", " \n", " }\n", " }\n", "\n", "\n" ] }, { "cell_type": "markdown", "id": "7295f682", "metadata": {}, "source": [ "### Initialize a PPO trainer agent and the portfolio trading environment" ] }, { "cell_type": "code", "execution_count": 5, "id": "b46d5602", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-08-31 19:57:33,900\tINFO logger.py:180 -- pip install 'ray[tune]' to see TensorBoard files.\n", "2022-08-31 19:57:33,901\tWARNING logger.py:317 -- Could not instantiate TBXLogger: No module named 'tensorboardX'.\n", "2022-08-31 19:57:33,903\tINFO trainer.py:714 -- Tip: set framework=tfe or the --eager flag to enable TensorFlow eager execution\n", "2022-08-31 19:57:33,904\tINFO ppo.py:159 -- In multi-agent mode, policies will be optimized sequentially by the multi-GPU optimizer. Consider setting simple_optimizer=True if this doesn't work for you.\n", "2022-08-31 19:57:33,905\tINFO trainer.py:728 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.\n", "\u001b[2m\u001b[36m(pid=8885)\u001b[0m 2022-08-31 19:57:37,225\tWARNING deprecation.py:39 -- DeprecationWarning: `TFModelV2.register_variables` has been deprecated. This will raise an error in the future!\n", "2022-08-31 19:57:39,358\tWARNING deprecation.py:39 -- DeprecationWarning: `TFModelV2.register_variables` has been deprecated. This will raise an error in the future!\n", "2022-08-31 19:57:40,909\tWARNING util.py:55 -- Install gputil for GPU system monitoring.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[array([0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.], dtype=float32), array([0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.], dtype=float32), array([0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.], dtype=float32)]\n" ] } ], "source": [ "agent1 = ppo.PPOTrainer(config=TestEnvConfig,env=\"customtradingmodel\")\n", "env = agent1.env_creator('customtradingmodel')\n", "state=env.reset()\n", "print(state[\"action_mask\"])" ] }, { "cell_type": "markdown", "id": "14ce819f", "metadata": {}, "source": [ "### Update the masking values \n", "\n", "Here we mask all actions except action[0]=8, action[1]=5, and action[2]= 1 or 2." ] }, { "cell_type": "code", "execution_count": 6, "id": "07e698ed", "metadata": {}, "outputs": [], "source": [ "state[\"action_mask\"]=[np.zeros([11],dtype=np.float32) for _ in range(3)]\n", "state['action_mask'][0][8]=1\n", "state['action_mask'][1][5]=1\n", "state['action_mask'][2][1:3]=[1,1]" ] }, { "cell_type": "markdown", "id": "66b09d76", "metadata": {}, "source": [ "### Sample a new action after updating the masks" ] }, { "cell_type": "code", "execution_count": 8, "id": "1287339e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(8, 5, 1)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "agent1.compute_single_action(state)\n" ] }, { "cell_type": "markdown", "id": "6ba9045f", "metadata": {}, "source": [ "We see that only the unmasked actions are sampled by the agent, verifying that action masking is working as expected" ] }, { "cell_type": "code", "execution_count": null, "id": "0495a384", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "conda_amazonei_tensorflow2_p36", "language": "python", "name": "conda_amazonei_tensorflow2_p36" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" } }, "nbformat": 4, "nbformat_minor": 5 }