# Notebook: Forecast Visualization

This notebook gives an example visualization of a demand dataset and related timeseries data.

In [None]:
# Set the start and end dates of the forecast
start_date = '2014-09-01' # YYYY-MM-DD
end_date = '2015-02-01' # YYYY-MM-DD

# provide the full CSV name uploaded to the /train folder in S3
demand_dataset_name = 'DATASET_NAME.csv'
related_dataset_name = 'DATASET_NAME.related.csv'

# provide the exports to show from the export/ folder in S3 (these are created by Amazon Forecast)
forecast_exports = [
 {
 'path': 'export_date/export_name.csv',
 'name': 'Forecast 1 - Name'
 }, 
 {
 'path': 'export_date/export_name.csv',
 'name': 'Forecast 2 - Name'
 }
 
]

In [None]:
import os 
import pandas as pd

forecast_bucket = os.getenv('FORECAST_BUCKET')

def download_data(path, date_column, header=None):
 data = pd.read_csv(f"s3://{forecast_bucket}/{path}", header=header, parse_dates=[date_column])
 
 # filter the data to the dates specified 
 flt = (data[date_column] >= start_date) & (data[date_column] <= end_date)
 data = data.loc[flt]
 
 return data

def get_exports(exports_list, date_column):
 for export in exports_list:
 export['data'] = download_data(f"exports/{export.get('path')}", date_column, header=0)


demand = download_data(f"train/{demand_dataset_name}", date_column=1)
relate = download_data(f"train/{related_dataset_name}", date_column=1)
get_exports(forecast_exports, date_column='date') 

In [None]:
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import re 

years = mdates.YearLocator()
months = mdates.MonthLocator()
years_fmt = mdates.DateFormatter("%Y")


# there will be a subplot for every export, and an extra for the related timeseries
fig, axes = plt.subplots(len(forecast_exports) + 1, sharex=True)
fig.set_size_inches(18.5, 10.5)

axes[0].xaxis.set_major_locator(years)
axes[0].xaxis.set_major_formatter(years_fmt)
axes[0].xaxis.set_minor_locator(months)
axes[0].format_xdata = mdates.DateFormatter("%Y-%m-%d")

for idx, ax in enumerate(axes): 
 ax.grid(True)
 
 if idx >= len(forecast_exports):
 continue
 
 title = forecast_exports[idx].get('name')
 data = forecast_exports[idx].get('data')
 
 ax.set_title(title)
 ax.set_xlabel("Date")
 ax.set_ylabel("Demand")
 
 # plot the demand
 ax.plot(demand[1], demand[2], linestyle='solid', color='DodgerBlue', label="Demand")
 
 # plot a dashed line from the end of the data to the start of the forecast
 xs = [demand[1].iloc[-1], data.date.iloc[0]]
 ys = [demand[2].iloc[-1], data.p50.iloc[0]]
 ax.plot(xs, ys, linestyle='dashed', color='DodgerBlue')

 # plot each forecast
 ax.plot(data.date, data.p50, linestyle='dashed', color='DodgerBlue', label='P50') 

 # plot the quantiles
 colors = ['LightBlue', 'LightSteelBlue']
 for pnn in [pnn for pnn in data.columns if re.match('^p\d+$', pnn) and pnn != 'p50']:
 color = colors.pop()
 colors.insert(0, color)
 
 ax.fill_between(data.date, data.p50, data[pnn], color=color, label=pnn)

 ax.legend(loc="lower left")

 
# this will show a related timeseries on the same plot using the same x axis 
axes[-1].set_title("Item Price")
axes[-1].set_xlabel("Date")
axes[-1].set_ylabel("Price ($)")
axes[-1].plot(relate[1], relate[2])

fig.autofmt_xdate()
plt.show()