"""SageMaker Serving entrypoint"""
import json
from io import StringIO
import pandas as pd
from autogluon.tabular import TabularPredictor


def model_fn(model_dir):
    """loads model from previously saved artifact"""
    model = TabularPredictor.load(model_dir)

    globals()["raw_columns"] = pd.read_csv(f"{model_dir}/headers.csv").columns
    return model


def transform_fn(
    model, request_body, input_content_type, output_content_type="application/json"
):

    if input_content_type == "text/csv":
        buf = StringIO(request_body)
        data = pd.read_csv(buf, header=None)
        # ! raw_columns is defined in global scope (see `model_fn` function)
        data.columns = raw_columns  # pylint: disable=undefined-variable

    else:
        raise Exception(f"{input_content_type} content type not supported")

    pred = model.predict(data)
    pred_proba = model.predict_proba(data)
    prediction = pd.concat([pred, pred_proba], axis=1).values

    return json.dumps(prediction.tolist()), output_content_type