In [None]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"

import my_nb_path # isort: split
import a2rl
import numpy as np
import pandas as pd
import torch

from flight_sales.flight_sales_gym import flight_sales_gym

model = torch.load("results/model-profit/model.pt")
tokenizer = a2rl.utils.pickle_load("results/model-profit/tokenizer.pt")
simulator = a2rl.Simulator(tokenizer, model, max_steps=365)
env = flight_sales_gym()


def history_2_context(df):
 custom_context = df.values.ravel()[: -len(df.actions) - len(df.rewards)]
 return custom_context

In [None]:
env.reset()
env.step(0.5)

ctx = np.tile(history_2_context(env.context(tail=tokenizer.block_size_row, fillna=True)), (4, 1))
ctx.shape

In [None]:
simulator.sample(ctx, max_size=2, as_token=False)

In [None]:
display(
 # These are equivalent
 env.history.iloc[: env.day],
 env.context(),
)

In [None]:
env.reset()
env.step(0.5)
ctx = env.context(fillna=True)
display(
 (tokenizer.block_size_row, tokenizer.block_size),
 ctx,
 tokenizer.field_tokenizer.transform(ctx),
)

In [None]:
display(
 tokenizer.df.tail(tokenizer.block_size_row),
 tokenizer.df.iloc[-tokenizer.block_size_row :],
 history_2_context(tokenizer.df.iloc[-tokenizer.block_size_row :]),
 history_2_context(tokenizer.df.tail(tokenizer.block_size_row)),
)

In [None]:
# ctx = (s, a, r, ..., s), where len([first_s, ..., last_s]) == block_size
ctx = history_2_context(tokenizer.df.tail(tokenizer.block_size_row))
batch_ctx = np.asarray([ctx, ctx, ctx]) # A batch of 3 trajectories
display(
 simulator.sample(ctx, max_size=5),
 simulator.sample(batch_ctx, max_size=5),
 simulator.sample(batch_ctx, max_size=500).shape,
)

In [None]:
batch_ctx = np.asarray([ctx, ctx, ctx]) # A batch of 3 trajectories
trajectories_cnt = batch_ctx.shape[0]
batch_results = simulator.sample(batch_ctx, max_size=500)
display(batch_ctx.shape, batch_results.shape)

# Let's try a few different ways to cut the batch dataframe into per-trajectory objects.

#### 01: numpy manipulation ####
trajectory_results_a = batch_results.values.reshape(
 trajectories_cnt,
 -1,
 len(batch_results.actions) + len(batch_results.rewards),
)
display(
 trajectory_results_a.shape,
 # Verify first two rows and last two rows are the same.
 pd.concat([batch_results.head(2), batch_results.tail(2)]),
 [
 trajectory_results_a[0, 0:2, :],
 trajectory_results_a[-1, -2:, :],
 ],
)

#### 02: pandas manipulation ####
trajectory_results_df: list[a2rl.WiDataFrame] = np.array_split(batch_results, trajectories_cnt)
display(
 [tdf.shape for tdf in trajectory_results_df],
 # Verify first two rows and last two rows are the same.
 pd.concat([batch_results.head(2), batch_results.tail(2)]),
 pd.concat([trajectory_results_df[0].head(2), trajectory_results_df[-1].tail(2)]),
)