{ "cells": [ { "cell_type": "markdown", "id": "7f035fb5-8e89-47cc-ab5d-4c2c43ab279b", "metadata": {}, "source": [ "# Reward curve" ] }, { "cell_type": "code", "execution_count": null, "id": "f3db2c79-0a3f-462a-a314-b91301d6dec0", "metadata": { "tags": [] }, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "%config InlineBackend.figure_format = \"retina\"\n", "\n", "from __future__ import annotations\n", "\n", "import my_nb_path # isort: split\n", "import warnings\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from smallmatter.ds import SimpleMatrixPlotter\n", "\n", "from flight_sales.flight_sales_gym import reward_functions\n", "from flight_sales.plot import plot_contours\n", "\n", "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)" ] }, { "cell_type": "markdown", "id": "f8c18475-da97-412e-aced-d83967c22ce7", "metadata": {}, "source": [ "## Contour plots" ] }, { "cell_type": "code", "execution_count": null, "id": "ec6b885d-d328-4ad4-8dc1-b2c3e43e6723", "metadata": { "tags": [] }, "outputs": [], "source": [ "plot_contours()" ] }, { "cell_type": "markdown", "id": "cafdd24d-f8e2-46af-8574-d60e5d67e51b", "metadata": { "tags": [] }, "source": [ "## Plots by dimension\n", "\n", "Here're the helper functions to plot by prices, or by tickets sold." ] }, { "cell_type": "code", "execution_count": null, "id": "0f6dbd74-6136-455f-817a-03e2a5fb4f26", "metadata": { "tags": [] }, "outputs": [], "source": [ "def plot_hspace():\n", " fig, ax = plt.subplots()\n", " fig.set_visible(False)\n", " fig.set_figheight(1e-4) # figure height in inches\n", "\n", "\n", "def plot_prices(f):\n", " \"\"\"Evaluate the behavior of reward function ``f()``.\n", "\n", " Arguments:\n", " f: the reward function.\n", " \"\"\"\n", " max_fare = 20\n", " daily_seats_quota = 20\n", " prices = [round(i * max_fare, 1) for i in np.linspace(0.0, 1.0, 11)]\n", "\n", " def del_smp(s):\n", " s.fig.clf()\n", " plt.close(s.fig)\n", " del s.fig\n", " s.fig = None\n", "\n", " smp = SimpleMatrixPlotter(4, init_figcount=len(prices), figsize=(4, 2.5))\n", " for price in prices:\n", " rewards = [f(sold, price)[0] for sold in range(daily_seats_quota + 1)]\n", " # import ipdb; ipdb.set_trace()\n", " ax = smp.pop()\n", " ax.plot(rewards)\n", " ax.set_title(f\"{price=}\")\n", " ax.set_xlabel(\"Tickets sold\")\n", " ax.set_ylabel(\"Rewards\")\n", "\n", " smp.trim()\n", " smp.fig.suptitle(f\"Reward function: {f.__name__}()\", fontweight=\"bold\")\n", " smp.fig.tight_layout()\n", " plot_hspace()\n", "\n", "\n", "def plot_solds(f):\n", " \"\"\"Evaluate the behavior of reward function ``f()``.\n", "\n", " Arguments:\n", " f: the reward function.\n", " \"\"\"\n", " max_fare = 20\n", " daily_seats_quota = 20\n", " prices = [round(i * max_fare, 1) for i in np.linspace(0.0, 1.0, 11)]\n", "\n", " def del_smp(s):\n", " s.fig.clf()\n", " plt.close(s.fig)\n", " del s.fig\n", " s.fig = None\n", "\n", " smp = SimpleMatrixPlotter(4, init_figcount=(daily_seats_quota + 1), figsize=(4, 2))\n", " for sold in range(daily_seats_quota + 1):\n", " rewards = [f(sold, price)[0] for price in prices]\n", " ax = smp.pop()\n", " ax.plot(prices, rewards)\n", " ax.set_title(f\"{sold=}\")\n", " ax.set_xlabel(\"Price\")\n", " ax.set_ylabel(\"Rewards\")\n", "\n", " smp.trim()\n", " smp.fig.suptitle(f\"Reward function: {f.__name__}()\", fontweight=\"bold\")\n", " smp.fig.tight_layout()\n", " plot_hspace()" ] }, { "cell_type": "markdown", "id": "ea1606de-51e2-444b-8584-5431f71f1b53", "metadata": {}, "source": [ "### Without jitter" ] }, { "cell_type": "code", "execution_count": null, "id": "6d3f1189-81e7-4487-83cd-de33f39ac811", "metadata": { "tags": [] }, "outputs": [], "source": [ "plot_prices(reward_functions[\"revenue_0_20_no_jitter\"])\n", "plot_prices(reward_functions[\"revenue_0_05_no_jitter\"]);" ] }, { "cell_type": "code", "execution_count": null, "id": "1a86161d-98c1-4fca-b2e9-5d4b67970ab7", "metadata": { "tags": [] }, "outputs": [], "source": [ "plot_solds(reward_functions[\"revenue_0_20_no_jitter\"])\n", "plot_solds(reward_functions[\"revenue_0_05_no_jitter\"]);" ] }, { "cell_type": "code", "execution_count": null, "id": "546f12da-2425-4bdb-bc5b-ddd2adbecd9e", "metadata": { "tags": [] }, "outputs": [], "source": [ "plot_prices(reward_functions[\"profit_no_jitter\"])\n", "plot_solds(reward_functions[\"profit_no_jitter\"]);" ] }, { "cell_type": "markdown", "id": "a6336d87-1cae-44ba-bd08-dbad91f6defe", "metadata": {}, "source": [ "### With jitter" ] }, { "cell_type": "code", "execution_count": null, "id": "cb082e46-6e09-456e-83e8-4221ae8b35b1", "metadata": { "tags": [] }, "outputs": [], "source": [ "plot_prices(reward_functions[\"revenue_0_02\"])\n", "plot_solds(reward_functions[\"revenue_0_02\"]);" ] }, { "cell_type": "code", "execution_count": null, "id": "dc037552-48e3-40c8-9400-3b06035e8453", "metadata": { "tags": [] }, "outputs": [], "source": [ "plot_prices(reward_functions[\"profit\"])\n", "plot_solds(reward_functions[\"profit\"]);" ] }, { "cell_type": "code", "execution_count": null, "id": "dffaad91-e38e-49b1-a95f-a4c6dd0266a9", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" }, "toc-autonumbering": true }, "nbformat": 4, "nbformat_minor": 5 }