{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "ca6b600d-980f-4e4b-87c5-4f0b693d26ca", "metadata": { "tags": [] }, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "%config InlineBackend.figure_format = \"retina\"\n", "\n", "import my_nb_path # isort: split\n", "import a2rl\n", "import numpy as np\n", "import pandas as pd\n", "import torch\n", "\n", "from flight_sales.flight_sales_gym import flight_sales_gym\n", "\n", "model = torch.load(\"results/model-profit/model.pt\")\n", "tokenizer = a2rl.utils.pickle_load(\"results/model-profit/tokenizer.pt\")\n", "simulator = a2rl.Simulator(tokenizer, model, max_steps=365)\n", "env = flight_sales_gym()\n", "\n", "\n", "def history_2_context(df):\n", " custom_context = df.values.ravel()[: -len(df.actions) - len(df.rewards)]\n", " return custom_context" ] }, { "cell_type": "code", "execution_count": null, "id": "1255191d-d3b0-47d1-b032-a6e44c6ec8c6", "metadata": { "tags": [] }, "outputs": [], "source": [ "env.reset()\n", "env.step(0.5)\n", "\n", "ctx = np.tile(history_2_context(env.context(tail=tokenizer.block_size_row, fillna=True)), (4, 1))\n", "ctx.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "ec40b863-0a6c-4ca7-91ee-f8229086d155", "metadata": { "tags": [] }, "outputs": [], "source": [ "simulator.sample(ctx, max_size=2, as_token=False)" ] }, { "cell_type": "code", "execution_count": null, "id": "e7f490fd-ada6-4a99-abf2-24c506e805ac", "metadata": { "tags": [] }, "outputs": [], "source": [ "display(\n", " # These are equivalent\n", " env.history.iloc[: env.day],\n", " env.context(),\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "08bba524-f7c1-4d67-9ae6-786bd755a1f0", "metadata": { "tags": [] }, "outputs": [], "source": [ "env.reset()\n", "env.step(0.5)\n", "ctx = env.context(fillna=True)\n", "display(\n", " (tokenizer.block_size_row, tokenizer.block_size),\n", " ctx,\n", " tokenizer.field_tokenizer.transform(ctx),\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "f269ddc4-cd9b-4c72-b5c3-4547218815e2", "metadata": { "tags": [] }, "outputs": [], "source": [ "display(\n", " tokenizer.df.tail(tokenizer.block_size_row),\n", " tokenizer.df.iloc[-tokenizer.block_size_row :],\n", " history_2_context(tokenizer.df.iloc[-tokenizer.block_size_row :]),\n", " history_2_context(tokenizer.df.tail(tokenizer.block_size_row)),\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "9baa1891-d7b9-4670-8db3-5744d2059905", "metadata": { "tags": [] }, "outputs": [], "source": [ "# ctx = (s, a, r, ..., s), where len([first_s, ..., last_s]) == block_size\n", "ctx = history_2_context(tokenizer.df.tail(tokenizer.block_size_row))\n", "batch_ctx = np.asarray([ctx, ctx, ctx]) # A batch of 3 trajectories\n", "display(\n", " simulator.sample(ctx, max_size=5),\n", " simulator.sample(batch_ctx, max_size=5),\n", " simulator.sample(batch_ctx, max_size=500).shape,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "546bd50f-4592-4ef2-a97c-ce406cac151f", "metadata": { "tags": [] }, "outputs": [], "source": [ "batch_ctx = np.asarray([ctx, ctx, ctx]) # A batch of 3 trajectories\n", "trajectories_cnt = batch_ctx.shape[0]\n", "batch_results = simulator.sample(batch_ctx, max_size=500)\n", "display(batch_ctx.shape, batch_results.shape)\n", "\n", "# Let's try a few different ways to cut the batch dataframe into per-trajectory objects.\n", "\n", "#### 01: numpy manipulation ####\n", "trajectory_results_a = batch_results.values.reshape(\n", " trajectories_cnt,\n", " -1,\n", " len(batch_results.actions) + len(batch_results.rewards),\n", ")\n", "display(\n", " trajectory_results_a.shape,\n", " # Verify first two rows and last two rows are the same.\n", " pd.concat([batch_results.head(2), batch_results.tail(2)]),\n", " [\n", " trajectory_results_a[0, 0:2, :],\n", " trajectory_results_a[-1, -2:, :],\n", " ],\n", ")\n", "\n", "#### 02: pandas manipulation ####\n", "trajectory_results_df: list[a2rl.WiDataFrame] = np.array_split(batch_results, trajectories_cnt)\n", "display(\n", " [tdf.shape for tdf in trajectory_results_df],\n", " # Verify first two rows and last two rows are the same.\n", " pd.concat([batch_results.head(2), batch_results.tail(2)]),\n", " pd.concat([trajectory_results_df[0].head(2), trajectory_results_df[-1].tail(2)]),\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "9ce40a38-51ea-4e4c-847c-f5af4f44aa1f", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 }