{ "cells": [ { "cell_type": "markdown", "id": "f0b7d27a-0296-41c9-b53d-8ac2683e976b", "metadata": {}, "source": [ "# A closer inspection of `flight_sales/run_exp.py`" ] }, { "cell_type": "code", "execution_count": null, "id": "4ff02c87-f961-4365-b1c4-6b58dda1fa90", "metadata": { "tags": [] }, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "%config InlineBackend.figure_format = 'retina'\n", "\n", "from __future__ import annotations\n", "\n", "import my_nb_path # isort: split\n", "import math\n", "from pathlib import Path\n", "\n", "import a2rl as wi\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "from smallmatter.ds import SimpleMatrixPlotter\n", "from stable_baselines3 import A2C\n", "\n", "from flight_sales.flight_sales_gym import flight_sales_gym, fsigmoid, parameters\n", "from flight_sales.run_exp import (\n", " ExperimentRewards,\n", " plot_results,\n", " run_ppo_agent,\n", " run_random_agent,\n", " run_ucb_agent,\n", " run_whatif,\n", ")" ] }, { "cell_type": "markdown", "id": "00b431e2-7984-4f52-b643-b3f870d81d96", "metadata": {}, "source": [ "## Simple propensity model" ] }, { "cell_type": "markdown", "id": "2b9c47b9-519a-4148-8498-399acffd5653", "metadata": {}, "source": [ "### Parameters of logistic function\n", "\n", "The logistic function's parameters are seasonal." ] }, { "cell_type": "code", "execution_count": null, "id": "9e038d0e-4044-40af-8c56-5f1b6ff37351", "metadata": { "tags": [] }, "outputs": [], "source": [ "def plot_parameters(f):\n", " \"\"\"Evaluate the behavior of parameters function ``f``.\n", "\n", " Arguments:\n", " f: the batch-parameters function.\n", " \"\"\"\n", " days = np.arange(1, 366)\n", " season = 0.5 * (np.cos(days * (2.0 * np.pi / 365)) + 1)\n", "\n", " daily_params = [parameters(day) for day in days]\n", " smoothness, mid_price, conversion = zip(*daily_params)\n", "\n", " fig, axes = plt.subplots(1, 4, figsize=(10, 2))\n", " for ax, a, title in zip(\n", " axes.flatten(),\n", " (smoothness, mid_price, conversion, season),\n", " (\"k (smoothness)\", \"x0 (mid_price)\", \"L (conversion)\", 'State \"seasonality\"'),\n", " ):\n", " pd.DataFrame(a).plot(title=title, ax=ax, legend=False)\n", " ax.set_xlabel(\"Day-of-year\")\n", " ax.set_ylabel(\"Seasonal Factor\")\n", "\n", " fig.suptitle(f.__name__)\n", " plt.tight_layout()\n", " fig.subplots_adjust(top=0.75)\n", " plt.show()\n", "\n", "\n", "plot_parameters(parameters)" ] }, { "cell_type": "markdown", "id": "65164d60-20e6-4631-b283-d1ea3f24d717", "metadata": {}, "source": [ "### Propensity to buy" ] }, { "cell_type": "code", "execution_count": null, "id": "36608932-b8be-4063-a10c-e9303e0d04b4", "metadata": { "tags": [] }, "outputs": [], "source": [ "params_peak_season = np.array([0.5, 10, 0.4])\n", "params_off_season = np.array([0.3, 10, 0.2])\n", "prices = np.linspace(0, 20, 10)\n", "\n", "plt.plot(prices, [fsigmoid(price, day=2) for price in prices], label=\"peak season\")\n", "plt.plot(prices, [fsigmoid(price, day=180) for price in prices], label=\"off season\")\n", "\n", "plt.title(\"Conversion rate vs fare\")\n", "plt.legend()\n", "\n", "plt.ylabel(\"Conversion rate\")\n", "plt.xlabel(\"Fare\");" ] }, { "cell_type": "markdown", "id": "5f370437-d718-4d29-9012-77b13aadb868", "metadata": {}, "source": [ "## Generate offline data\n", "\n", "In RL parlence, offline data is historical data collected by another policy." ] }, { "cell_type": "code", "execution_count": null, "id": "81b3239e-1a63-43e7-a92d-f982c6fef022", "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "env = flight_sales_gym(f_reward=\"revenue_0_05\")\n", "print(f\"Reward function: {env.f_reward.__name__}()\")\n", "\n", "previous_model = A2C(policy=\"MlpPolicy\", env=env, verbose=False) # type: ignore[call-arg,arg-type]\n", "previous_model.learn(total_timesteps=1000)\n", "\n", "cap_env = wi.TransitionRecorder(env)\n", "previous_model.set_env(cap_env)\n", "previous_model.learn(total_timesteps=10000)\n", "\n", "wi_df = wi.WiDataFrame(\n", " cap_env.df.values,\n", " columns=[\"season\", \"freight_price\", \"ticket_price\", \"reward\"],\n", " states=[\"season\", \"freight_price\"],\n", " actions=[\"ticket_price\"],\n", " rewards=[\"reward\"],\n", ")\n", "wi_df.describe()" ] }, { "cell_type": "markdown", "id": "d5ea6d2f-2d9c-4ff3-9582-7ccd57bff17e", "metadata": { "tags": [] }, "source": [ "## Train an A2RL simulator backbone" ] }, { "cell_type": "code", "execution_count": null, "id": "1e837f0b-1c71-4437-9272-46d21121a45c", "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "model_dir = \"model-dyn-pricing\"\n", "tokenizer = wi.AutoTokenizer(wi_df, block_size_row=2)\n", "model_builder = wi.GPTBuilder(tokenizer, model_dir)\n", "model_builder.fit() # ~1.1m on MBP M1" ] }, { "cell_type": "markdown", "id": "f0a06b88-ff32-428d-9fd7-b77a9426e70d", "metadata": {}, "source": [ "## Evaluation\n", "\n", "Run each type of agents, then plot the results." ] }, { "cell_type": "markdown", "id": "8e677011-6df3-4b6d-875f-f96b86f53d03", "metadata": {}, "source": [ "### Run experiments" ] }, { "cell_type": "code", "execution_count": null, "id": "684ad64f-bde7-4d63-8807-9a215af7fcf2", "metadata": { "tags": [] }, "outputs": [], "source": [ "# MBP M1, sample_size = 500\n", "# ep=1 => 1m:10s\n", "# ep=5 => 4m:50s\n", "whatif_rewards = run_whatif(env, model_builder, ep=5, sample_size=500)" ] }, { "cell_type": "code", "execution_count": null, "id": "1395871c-71a9-469f-b5dc-08fbd53714dc", "metadata": { "tags": [] }, "outputs": [], "source": [ "random_agent_rewards = run_random_agent(env)" ] }, { "cell_type": "code", "execution_count": null, "id": "9b79699b-254b-4e47-be65-af401990d44e", "metadata": { "tags": [] }, "outputs": [], "source": [ "ucb_agent_rewards = run_ucb_agent(env)" ] }, { "cell_type": "code", "execution_count": null, "id": "5e26fcea-edb7-4e78-b196-cb8b0cccb973", "metadata": { "tags": [] }, "outputs": [], "source": [ "ppo_agent_rewards = run_ppo_agent(env, previous_model)" ] }, { "cell_type": "markdown", "id": "fc58b965-4677-45dd-9081-d0d03440fdb3", "metadata": {}, "source": [ "### Compile results" ] }, { "cell_type": "code", "execution_count": null, "id": "0c650c7a-e4c1-4e39-bc45-9758ae1c01c0", "metadata": { "tags": [] }, "outputs": [], "source": [ "rewards: dict[str, ExperimentRewards] = {\n", " \"whatif\": whatif_rewards,\n", " \"random_agent\": random_agent_rewards,\n", " \"ucb_agent\": ucb_agent_rewards,\n", " \"ppo_agent\": ppo_agent_rewards,\n", "}\n", "\n", "smp = plot_results(rewards, suptitle=f\"Reward function: {env.f_reward.__name__}()\")" ] }, { "cell_type": "code", "execution_count": null, "id": "f486ceef-f112-48b9-9dff-b997bc4cf289", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" }, "toc-autonumbering": true }, "nbformat": 4, "nbformat_minor": 5 }