{ "cells": [ { "cell_type": "markdown", "id": "4a37d83e-7866-4006-b27b-06d82da2cbef", "metadata": {}, "source": [ "# Optional Example: Convert simulator backend to ONNX\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "f6fd83b2-d529-4ba2-b22f-268e2637ba78", "metadata": { "tags": [] }, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "%config InlineBackend.figure_format = \"retina\"\n", "\n", "# Require:\n", "#%pip install onnxruntime\n", "\n", "import my_nb_path # isort: split\n", "import a2rl as wi\n", "import numpy as np\n", "import torch\n", "from stable_baselines3 import A2C\n", "\n", "from flight_sales.flight_sales_gym import flight_sales_gym\n", "\n", "# Load artifacts\n", "m = torch.load(\"model-dyn-pricing/model.pt\")\n", "tokenizer = wi.utils.pickle_load(\"model-dyn-pricing/tokenizer.pt\")\n", "\n", "# From data frame, get an input sequence\n", "env = flight_sales_gym(f_reward=\"revenue_0_05\")\n", "env.reset()\n", "env.step(0.5)\n", "len_ar = len(tokenizer.df.actions) + len(tokenizer.df.rewards)\n", "field_tokenizer = tokenizer.field_tokenizer\n", "ctx = np.tile(\n", " field_tokenizer.transform(env.context(tail=2, fillna=True)).values.ravel()[:-len_ar],\n", " (1, 1),\n", ")\n", "ctx = torch.from_numpy(ctx)\n", "\n", "# Export as ONNX format.\n", "torch.onnx.export(\n", " m,\n", " ctx,\n", " \"model-dyn-pricing/model.onnx\",\n", " verbose=True,\n", " input_names=[\"input\"],\n", " output_names=[\"output\"],\n", " export_params=True,\n", " do_constant_folding=True,\n", ")\n", "\n", "# Optional next step: visualize with https://github.com/lutzroeder/netron." ] }, { "cell_type": "code", "execution_count": null, "id": "3662804a-485e-4d01-9632-c5b8e3d8e637", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 }