{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a7eeb41e",
   "metadata": {},
   "source": [
    "# mRNA renal cell carcinoma prediction\n",
    "Note: cleaned data from preprocessing script must be input here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4b5fd799",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "scipy version:  1.5.4\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import numpy as np\n",
    "import boto3\n",
    "import pandas as pd\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn import preprocessing\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "from cox_functions import categorize_expression_levels, cox_ph_pipeline, normalize_gene_expression\n",
    "from copy import deepcopy\n",
    "\n",
    "from numpy.random import seed\n",
    "seed(1)\n",
    "\n",
    "import scipy\n",
    "print('scipy version: ', scipy.__version__)\n",
    "\n",
    "\n",
    "# !pip install lifelines\n",
    "\n",
    "from lifelines.utils import concordance_index\n",
    "from lifelines import CoxPHFitter\n",
    "\n",
    "import pandas as pd\n",
    "import json\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "from xgboost import XGBClassifier\n",
    "\n",
    "# from XGBoostPipeline import XGBoostPipeline\n",
    "from XGBoostPipelineLatest import XGBoostPipeline\n",
    "from sklearn.metrics import classification_report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c2d9ea16",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_data(filename):\n",
    "#     df = pd.read_csv('../data/KIRC_TCGA_microRNA_expression_and_clinical.csv')\n",
    "    df = pd.read_csv(filename)\n",
    "    if 'Unnamed: 0' in df.columns:\n",
    "        df.drop(columns=['Unnamed: 0'], inplace=True)\n",
    "    # DO NOT SET THE INDEX.  \n",
    "    # You will run into an issue where every CoxPH model will return a \n",
    "    # concordance of 0.5 no matter what subset of features you use if you set the index.\n",
    "\n",
    "    # Dropping columns that are not the target or related to microRNA expression\n",
    "\n",
    "    label = \"AJCC_PATHOLOGIC_TUMOR_STAGE\"\n",
    "\n",
    "    cols_to_drop = [\n",
    "        'PATIENT_ID',\n",
    "        \"SEX_male_female\",\n",
    "        \"RACE\", \n",
    "        \"ETHNICITY\",\n",
    "        \"DFS_STATUS\",\n",
    "        \"DFS_MONTHS\",\n",
    "    ]\n",
    "\n",
    "    df.drop(columns=[ci for ci in cols_to_drop if ci in df.columns], inplace=True)\n",
    "\n",
    "    # From dataset joining, the NaN values for tumor stage have been verified to be normal patients\n",
    "    df[label].fillna('normal',inplace=True)\n",
    "#     dfph=dfph.drop(index=np.where(dfph[\"OS_MONTHS\"]==0)[0])\n",
    "    # Map tumor stage categories to numeric categories\n",
    "    tumor_stage_map = {\n",
    "        \"normal\" : 0,\n",
    "        \"STAGE I\" : 1,\n",
    "        \"STAGE II\" : 1,\n",
    "        \"STAGE III\" : 2,\n",
    "        \"STAGE IV\" : 2\n",
    "    }\n",
    "\n",
    "    df[label] = df[label].map(tumor_stage_map)\n",
    "    \n",
    "    \n",
    "    train, test = train_test_split(df, test_size=0.3, random_state=42, stratify=df[\"AJCC_PATHOLOGIC_TUMOR_STAGE\"].values)\n",
    "\n",
    "    train, val = train_test_split(train, test_size=0.2, random_state=32, stratify=train[\"AJCC_PATHOLOGIC_TUMOR_STAGE\"].values)\n",
    "    \n",
    "    dfph = deepcopy(train)\n",
    "    dfphtest = deepcopy(test)\n",
    "    dfphval = deepcopy(val)\n",
    "    \n",
    "    dfph.loc[dfph['AJCC_PATHOLOGIC_TUMOR_STAGE']==0,\"OS_STATUS\"]=\"0:LIVING\"\n",
    "    dfph.loc[dfph['AJCC_PATHOLOGIC_TUMOR_STAGE']==0,\"OS_MONTHS\"] = dfph['OS_MONTHS'].median()\n",
    "    \n",
    "    dfphval.loc[dfphval['AJCC_PATHOLOGIC_TUMOR_STAGE']==0,\"OS_STATUS\"]=\"0:LIVING\"\n",
    "    dfphval.loc[dfphval['AJCC_PATHOLOGIC_TUMOR_STAGE']==0,\"OS_MONTHS\"] = dfph['OS_MONTHS'].median()\n",
    "\n",
    "    dfphtest.loc[dfphtest['AJCC_PATHOLOGIC_TUMOR_STAGE']==0,\"OS_STATUS\"]=\"0:LIVING\"\n",
    "    dfphtest.loc[dfphtest['AJCC_PATHOLOGIC_TUMOR_STAGE']==0,\"OS_MONTHS\"] = dfph['OS_MONTHS'].median()\n",
    "\n",
    "    dfph['OS_STATUS'] = dfph['OS_STATUS'].astype('category').cat.codes\n",
    "    dfphval['OS_STATUS'] = dfphval['OS_STATUS'].astype('category').cat.codes\n",
    "    dfphtest['OS_STATUS'] = dfphtest['OS_STATUS'].astype('category').cat.codes\n",
    "    \n",
    "    if 'hsa-mir-4296' in dfph.columns:\n",
    "        dfph.drop(columns=['hsa-mir-4296'], inplace=True)\n",
    "        dfphval.drop(columns=['hsa-mir-4296'], inplace=True)\n",
    "        dfphtest.drop(columns=['hsa-mir-4296'], inplace=True)\n",
    "    \n",
    "    empty_features = []\n",
    "    for col in dfph.columns:\n",
    "        if (dfph[col].sum() == 0):\n",
    "            empty_features.append(col)\n",
    "    \n",
    "    dfph.drop(columns=empty_features,inplace=True)\n",
    "    dfphval.drop(columns=empty_features,inplace=True)\n",
    "    dfphtest.drop(columns=empty_features,inplace=True)\n",
    "\n",
    "    gc = list(dfph.columns)\n",
    "    genes = gc[3:]\n",
    "\n",
    "#     return train, val, test, gc\n",
    "    return dfph, dfphval, dfphtest, genes, label\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "42b089d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# def normalize_expression(dfph, dfphval, dfphtest, genes):\n",
    "    \n",
    "#     dfph = normalize_gene_expression(dfph, genes);   \n",
    "#     dfphval = normalize_gene_expression(dfphval, genes);\n",
    "#     dfphtest = normalize_gene_expression(dfphtest, genes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8dee3dd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_cox(dfph, dfphval, dfphtest, genes, pvalue=.05, dataset_name=\"final_microRNA_plus_normal_trainvaltest\"):\n",
    "    \n",
    "    dfph = normalize_gene_expression(dfph, genes);   \n",
    "    dfphval = normalize_gene_expression(dfphval, genes);\n",
    "    dfphtest = normalize_gene_expression(dfphtest, genes)\n",
    "    \n",
    "    dfph = categorize_expression_levels(dfph, genes)\n",
    "    dfphval = categorize_expression_levels(dfphval, genes)    \n",
    "    dfphtest = categorize_expression_levels(dfphtest, genes)\n",
    "    \n",
    "    print('Running Cox PH on train set\\n')\n",
    "    info_map, significant_genes = cox_ph_pipeline(\n",
    "                dfph, \n",
    "                genes, \n",
    "                dataset_name=dataset_name, \n",
    "                duration=\"OS_MONTHS\", \n",
    "                event=\"OS_STATUS\",\n",
    "                pvalue=pvalue\n",
    "    )\n",
    "    print('number of significant genes: ',len(significant_genes))\n",
    "    return dfph, dfphval, dfphtest, info_map, significant_genes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b6531e49",
   "metadata": {},
   "outputs": [],
   "source": [
    "def data_format_xgbpipeline(significant_genes, label = \"AJCC_PATHOLOGIC_TUMOR_STAGE\"):\n",
    "    input_df = deepcopy(dfph[[label] + significant_genes])\n",
    "    input_df_val = deepcopy(dfphval[[label] + significant_genes])\n",
    "    input_df_test = deepcopy(dfphtest[[label] + significant_genes])\n",
    "\n",
    "    all_columns = input_df.columns # Creates list of all column headers\n",
    "    input_df[all_columns[1:]] = input_df[all_columns[1:]].astype('float')\n",
    "    input_df_val[all_columns[1:]] = input_df_val[all_columns[1:]].astype('float')\n",
    "    input_df_test[all_columns[1:]] = input_df_test[all_columns[1:]].astype('float')\n",
    "    \n",
    "    return input_df, input_df_val, input_df_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "db190d1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_xgbpipeline(input_df_test, pipeline):\n",
    "    xgb_model = XGBClassifier()\n",
    "    xgb_model.load_model(pipeline.model_filepath)\n",
    "    hr_pred = xgb_model.predict(input_df_test[input_df_test.columns[1:]])\n",
    "    hr_pred_proba = xgb_model.predict_proba(input_df_test[input_df_test.columns[1:]])#,validate_features=True)#, input_df_test[input_df.columns[0]].values-1)\n",
    "    hr_pred_proba_norm = hr_pred_proba[:,np.shape(hr_pred_proba)[1]//2:].T\n",
    "    roc_auc_score_ovo = roc_auc_score(\n",
    "                input_df_test[input_df_test.columns[0]].values,\n",
    "                hr_pred_proba_norm,\n",
    "                multi_class='ovo'\n",
    "            )\n",
    "\n",
    "    roc_auc_score_ovr = roc_auc_score(\n",
    "                input_df_test[input_df_test.columns[0]].values,\n",
    "                hr_pred_proba_norm,\n",
    "                multi_class='ovr'\n",
    "            )\n",
    "    print(classification_report(hr_pred, input_df_test[input_df_test.columns[0]]))\n",
    "    print('roc_auc_ovo: ', roc_auc_score_ovo, 'roc_auc_ovr: ',roc_auc_score_ovr)\n",
    "    return roc_auc_score_ovo, roc_auc_score_ovr\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b46a193f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def save_importance_matrix(model, dataset_name='final_microRNA_james'):\n",
    "\n",
    "    print(\"Creating importance matrix\")\n",
    "\n",
    "    feature_importance_dict = {}\n",
    "    importance_types = ['gain', 'cover', 'weight', 'total_gain', 'total_cover']\n",
    "\n",
    "    for metric in importance_types:\n",
    "        feature_importance_dict[metric] = model.get_booster().get_score(importance_type=metric)\n",
    "\n",
    "    importance_matrix = pd.DataFrame(feature_importance_dict)\n",
    "    importance_matrix_filepath = \"../final_results/XGBoost/{}_xgboost_feature_importance_latest.csv\".format(dataset_name)\n",
    "\n",
    "#         print(\"Saving importance matrix to:\", importance_matrix_filepath)\n",
    "#         importance_matrix.to_csv(importance_matrix_filepath)\n",
    "\n",
    "    return importance_matrix"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c782391",
   "metadata": {},
   "source": [
    "# Run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "2dc94e5b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/IPython/core/interactiveshell.py:3263: DtypeWarning: Columns (1,2,3,4,5,7) have mixed types.Specify dtype option on import or set low_memory=False.\n",
      "  if (await self.run_code(code, result,  async_=asy)):\n"
     ]
    }
   ],
   "source": [
    "filename = \"../data/KIRC_TCGA_GTEX_mRNA_expression_and_clinical_normalized.csv\"\n",
    "dfph, dfphval, dfphtest, genes, label = load_data(filename)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d4f64302",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5569"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# cph_map = {} #k: gene_name, v: {info about model}\n",
    "filepath = '../final_results/final_mRNA_plus_normal_trainvaltest_individual_cox_results.json'\n",
    "f = open(filepath, \"r\")\n",
    "cph_map = json.load(f)\n",
    "f.close()\n",
    "\n",
    "all_genes_df = pd.DataFrame(cph_map).T\n",
    "\n",
    "# significant_genes_05 = significant_genes.copy()\n",
    "\n",
    "significant_genes = all_genes_df.loc[all_genes_df.sort_values('p-value')['p-value'].values<.05].index.values\n",
    "\n",
    "significant_genes = [si for si in significant_genes]\n",
    "\n",
    "len([si for si in significant_genes])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7003c580",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "input_df, input_df_val, input_df_test = data_format_xgbpipeline(significant_genes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "04eb0f84",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initializing pipeline:\n",
      "Initializing test and train data:\n",
      "Running XGBoost pipeline\n",
      "Beginning Bayesian Optimization:\n",
      "\n",
      "|   iter    |  target   |    eta    |   gamma   | max_de... | max_depth |\n",
      "-------------------------------------------------------------------------\n",
      "| \u001b[0m 1       \u001b[0m | \u001b[0m 0.8453  \u001b[0m | \u001b[0m 0.3079  \u001b[0m | \u001b[0m 0.1869  \u001b[0m | \u001b[0m 8.756   \u001b[0m | \u001b[0m 7.66    \u001b[0m |\n",
      "| \u001b[0m 2       \u001b[0m | \u001b[0m 0.8408  \u001b[0m | \u001b[0m 0.5713  \u001b[0m | \u001b[0m 0.3983  \u001b[0m | \u001b[0m 10.11   \u001b[0m | \u001b[0m 3.074   \u001b[0m |\n",
      "| \u001b[95m 3       \u001b[0m | \u001b[95m 0.8454  \u001b[0m | \u001b[95m 0.1787  \u001b[0m | \u001b[95m 0.1234  \u001b[0m | \u001b[95m 17.62   \u001b[0m | \u001b[95m 9.121   \u001b[0m |\n",
      "| \u001b[95m 4       \u001b[0m | \u001b[95m 0.8458  \u001b[0m | \u001b[95m 0.277   \u001b[0m | \u001b[95m 0.1405  \u001b[0m | \u001b[95m 17.65   \u001b[0m | \u001b[95m 9.128   \u001b[0m |\n",
      "| \u001b[95m 5       \u001b[0m | \u001b[95m 0.8487  \u001b[0m | \u001b[95m 0.7922  \u001b[0m | \u001b[95m 0.2426  \u001b[0m | \u001b[95m 17.85   \u001b[0m | \u001b[95m 9.166   \u001b[0m |\n",
      "=========================================================================\n",
      "Best AUC: 0.848742724867725\n",
      "Best parameters: {'eta': 0.7922203053182407, 'gamma': 0.24264067575499107, 'max_delta_step': 17.853711554895643, 'max_depth': 9.166231294401388}\n",
      "{'eta': 0.7922203053182407, 'gamma': 0.24264067575499107, 'max_delta_step': 17.853711554895643, 'max_depth': 9, 'eval_metric': 'auc', 'objective': 'multi:softprob', 'num_class': 3, 'min_child_weight': 1, 'subsample': 1, 'colsample_bytree': 1}\n",
      "Training XGBoost model\n",
      "[0]\tTest-auc:0.76933\n",
      "[1]\tTest-auc:0.79571\n",
      "[2]\tTest-auc:0.80572\n",
      "[3]\tTest-auc:0.82041\n",
      "[4]\tTest-auc:0.82220\n",
      "[5]\tTest-auc:0.82459\n",
      "[6]\tTest-auc:0.82832\n",
      "[7]\tTest-auc:0.83424\n",
      "[8]\tTest-auc:0.83204\n",
      "[9]\tTest-auc:0.83136\n",
      "[10]\tTest-auc:0.83187\n",
      "[11]\tTest-auc:0.83187\n",
      "[12]\tTest-auc:0.83187\n",
      "[13]\tTest-auc:0.83187\n",
      "[14]\tTest-auc:0.83187\n",
      "[15]\tTest-auc:0.83187\n",
      "[16]\tTest-auc:0.83187\n",
      "[17]\tTest-auc:0.83187\n",
      "[18]\tTest-auc:0.83187\n",
      "[19]\tTest-auc:0.83187\n",
      "[20]\tTest-auc:0.83187\n",
      "[21]\tTest-auc:0.83187\n",
      "[22]\tTest-auc:0.83187\n",
      "[23]\tTest-auc:0.83187\n",
      "[24]\tTest-auc:0.83187\n",
      "[25]\tTest-auc:0.83187\n",
      "[26]\tTest-auc:0.83187\n",
      "[27]\tTest-auc:0.83187\n",
      "[28]\tTest-auc:0.83187\n",
      "[29]\tTest-auc:0.83187\n",
      "[30]\tTest-auc:0.83187\n",
      "[31]\tTest-auc:0.83187\n",
      "[32]\tTest-auc:0.83187\n",
      "[33]\tTest-auc:0.83187\n",
      "[34]\tTest-auc:0.83187\n",
      "[35]\tTest-auc:0.83187\n",
      "[36]\tTest-auc:0.83187\n",
      "[37]\tTest-auc:0.83187\n",
      "[38]\tTest-auc:0.83187\n",
      "[39]\tTest-auc:0.83187\n",
      "[40]\tTest-auc:0.83187\n",
      "[41]\tTest-auc:0.83187\n",
      "[42]\tTest-auc:0.83187\n",
      "[43]\tTest-auc:0.83187\n",
      "[44]\tTest-auc:0.83187\n",
      "[45]\tTest-auc:0.83187\n",
      "[46]\tTest-auc:0.83187\n",
      "[47]\tTest-auc:0.83187\n",
      "[48]\tTest-auc:0.83187\n",
      "[49]\tTest-auc:0.83187\n",
      "[50]\tTest-auc:0.83187\n",
      "[51]\tTest-auc:0.83187\n",
      "[52]\tTest-auc:0.83187\n",
      "[53]\tTest-auc:0.83187\n",
      "[54]\tTest-auc:0.83187\n",
      "[55]\tTest-auc:0.83187\n",
      "[56]\tTest-auc:0.83187\n",
      "[57]\tTest-auc:0.83187\n",
      "[58]\tTest-auc:0.83187\n",
      "[59]\tTest-auc:0.83187\n",
      "[60]\tTest-auc:0.83187\n",
      "[61]\tTest-auc:0.83187\n",
      "[62]\tTest-auc:0.83187\n",
      "[63]\tTest-auc:0.83187\n",
      "[64]\tTest-auc:0.83187\n",
      "[65]\tTest-auc:0.83187\n",
      "[66]\tTest-auc:0.83187\n",
      "[67]\tTest-auc:0.83187\n",
      "[68]\tTest-auc:0.83187\n",
      "[69]\tTest-auc:0.83187\n",
      "[70]\tTest-auc:0.83187\n",
      "[71]\tTest-auc:0.83187\n",
      "[72]\tTest-auc:0.83187\n",
      "[73]\tTest-auc:0.83187\n",
      "[74]\tTest-auc:0.83187\n",
      "[75]\tTest-auc:0.83187\n",
      "[76]\tTest-auc:0.83187\n",
      "[77]\tTest-auc:0.83187\n",
      "[78]\tTest-auc:0.83187\n",
      "[79]\tTest-auc:0.83187\n",
      "[80]\tTest-auc:0.83187\n",
      "[81]\tTest-auc:0.83187\n",
      "[82]\tTest-auc:0.83187\n",
      "[83]\tTest-auc:0.83187\n",
      "[84]\tTest-auc:0.83187\n",
      "[85]\tTest-auc:0.83187\n",
      "[86]\tTest-auc:0.83187\n",
      "[87]\tTest-auc:0.83187\n",
      "[88]\tTest-auc:0.83187\n",
      "[89]\tTest-auc:0.83187\n",
      "[90]\tTest-auc:0.83187\n",
      "[91]\tTest-auc:0.83187\n",
      "[92]\tTest-auc:0.83187\n",
      "[93]\tTest-auc:0.83187\n",
      "[94]\tTest-auc:0.83187\n",
      "[95]\tTest-auc:0.83187\n",
      "[96]\tTest-auc:0.83187\n",
      "[97]\tTest-auc:0.83187\n",
      "[98]\tTest-auc:0.83187\n",
      "[99]\tTest-auc:0.83187\n",
      "[100]\tTest-auc:0.83187\n",
      "[101]\tTest-auc:0.83187\n",
      "[102]\tTest-auc:0.83187\n",
      "[103]\tTest-auc:0.83187\n",
      "[104]\tTest-auc:0.83187\n",
      "[105]\tTest-auc:0.83187\n",
      "[106]\tTest-auc:0.83187\n",
      "[107]\tTest-auc:0.83187\n",
      "Saving model to:  ../final_results/XGBoost/final_mRNA_multiclass_post_cox_bayes_opt_xgboost_best_james_ttv2.json\n",
      "Saving XGBoost JSON results to: ../final_results/XGBoost/final_mRNA_multiclass_post_cox_bayes_opt_xgboost_best_output_james_ttv2.json\n",
      "Creating importance matrix\n",
      "Saving importance matrix to: ../final_results/XGBoost/final_mRNA_xgboost_feature_importance.csv\n",
      "Results summary:\n",
      "Parameter bounds: {'max_depth': (3, 10), 'eta': (0.01, 1), 'gamma': (0.0, 1), 'max_delta_step': (1, 25)}\n",
      "Number of boosting rounds: 1000\n",
      "Early stopping rounds: 100\n",
      "Accuracy = 0.691358024691358\n",
      "ROC AUC OVO = 0.8749703228869895\n",
      "ROC AUC OVR = 0.8497237276477043\n",
      "Model filepath = ../final_results/XGBoost/final_mRNA_multiclass_post_cox_bayes_opt_xgboost_best_james_ttv2.json\n",
      "Importance matrix filepath = ../final_results/XGBoost/final_mRNA_xgboost_feature_importance.csv\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<xgboost.core.Booster at 0x7f539a99afd0>"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "pipeline = XGBoostPipeline(\n",
    "    pd.concat([input_df,input_df_val]),\n",
    "    random_state=60, \n",
    "    label_column=label,\n",
    "    num_classes=3, \n",
    "    weighted=False,\n",
    "    n_iter=50,\n",
    "    model_name=\"final_mRNA_multiclass_post_cox_bayes_opt_xgboost_best_james_ttv2.json\",\n",
    "    json_filepath=\"final_mRNA_multiclass_post_cox_bayes_opt_xgboost_best_output_james_ttv2.json\",\n",
    "    dataset_name=\"final_mRNA\",\n",
    "#     X_train=input_df[input_df.columns[1:]],\n",
    "#     y_train=input_df[input_df.columns[0]],\n",
    "#     X_test=input_df_val[input_df.columns[1:]],\n",
    "#     y_test=input_df_val[input_df.columns[0]]\n",
    ")\n",
    "\n",
    "pipeline.run_workflow()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fa9cb0a",
   "metadata": {},
   "source": [
    "# Final fit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d8cc13a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_df_final = pd.concat([input_df,input_df_val])\n",
    "\n",
    "xgb_model_final = XGBClassifier(pipeline.best_params,objective='multi:softprob')\n",
    "\n",
    "xgb_model_final.fit(X=input_df_final[input_df_final.columns[1:]],y=input_df_final[input_df_final.columns[0]])#,eval_metric='multi:softprob')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "f3488155",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "from xgboost import plot_importance\n",
    "\n",
    "# Get the booster from the xgbmodel\n",
    "booster = xgb_model_final.get_booster()\n",
    "\n",
    "# Get the importance dictionary (by gain) from the booster\n",
    "importance = booster.get_score(importance_type=\"gain\")\n",
    "\n",
    "# make your changes\n",
    "for key in importance.keys():\n",
    "    importance[key] = round(importance[key],2)\n",
    "\n",
    "# provide the importance dictionary to the plotting function\n",
    "ax = plot_importance(importance, max_num_features=20, importance_type='gain', show_values=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "820971b9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "81d78d55",
   "metadata": {},
   "source": [
    "# Autogluon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "aa8cff7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from autogluon.tabular import TabularPredictor as task\n",
    "from sklearn.model_selection import train_test_split\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "cc1b93fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_df_final = pd.concat([input_df,input_df_val])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "6c20a62e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Beginning AutoGluon training ... Time limit = 180s\n",
      "AutoGluon will save models to \"../final_results/AutoGluon/models/mRNA/\"\n",
      "AutoGluon Version:  0.3.1\n",
      "Train Data Rows:    405\n",
      "Train Data Columns: 5569\n",
      "Preprocessing data ...\n",
      "AutoGluon infers your prediction problem is: 'multiclass' (because dtype of label-column == int, but few unique label-values observed).\n",
      "\t3 unique label values:  [0, 1, 2]\n",
      "\tIf 'multiclass' is not the correct problem_type, please manually specify the problem_type argument in fit() (You may specify problem_type as one of: ['binary', 'multiclass', 'regression'])\n",
      "Train Data Class Count: 3\n",
      "Using Feature Generators to preprocess the data ...\n",
      "Fitting AutoMLPipelineFeatureGenerator...\n",
      "\tAvailable Memory:                    126346.32 MB\n",
      "\tTrain Data (Original)  Memory Usage: 18.04 MB (0.0% of available memory)\n",
      "\tInferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.\n",
      "\tStage 1 Generators:\n",
      "\t\tFitting AsTypeFeatureGenerator...\n",
      "\t\t\tNote: Converting 11 features to boolean dtype as they only contain 2 unique values.\n",
      "\tStage 2 Generators:\n",
      "\t\tFitting FillNaFeatureGenerator...\n",
      "\tStage 3 Generators:\n",
      "\t\tFitting IdentityFeatureGenerator...\n",
      "\tStage 4 Generators:\n",
      "\t\tFitting DropUniqueFeatureGenerator...\n",
      "\tTypes of features in original data (raw dtype, special dtypes):\n",
      "\t\t('float', []) : 5569 | ['BLOC1S4', 'TBC1D17', 'CARD8', 'KIF12', 'SPECC1L', ...]\n",
      "\tTypes of features in processed data (raw dtype, special dtypes):\n",
      "\t\t('float', [])     : 5558 | ['BLOC1S4', 'TBC1D17', 'CARD8', 'KIF12', 'SPECC1L', ...]\n",
      "\t\t('int', ['bool']) :   11 | ['DUX4L3', 'CT47A7', 'GAGE13', 'DUX4L6', 'CT47A8', ...]\n",
      "\t8.7s = Fit runtime\n",
      "\t5569 features in original data used to generate 5569 features in processed data.\n",
      "\tTrain Data (Processed) Memory Usage: 18.01 MB (0.0% of available memory)\n",
      "Data preprocessing and feature engineering runtime = 9.39s ...\n",
      "AutoGluon will gauge predictive performance using evaluation metric: 'roc_auc_ovo_macro'\n",
      "\tThis metric expects predicted probabilities rather than predicted class labels, so you'll need to use predict_proba() instead of predict()\n",
      "\tTo change this, specify the eval_metric argument of fit()\n",
      "Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 324, Val Rows: 81\n",
      "Fitting 13 L1 models ...\n",
      "Fitting model: KNeighborsUnif ... Training model for up to 170.61s of the 170.43s of remaining time.\n",
      "\t0.7639\t = Validation score   (roc_auc_ovo_macro)\n",
      "\t1.44s\t = Training   runtime\n",
      "\t0.12s\t = Validation runtime\n",
      "Fitting model: KNeighborsDist ... Training model for up to 168.85s of the 168.66s of remaining time.\n",
      "\t0.773\t = Validation score   (roc_auc_ovo_macro)\n",
      "\t1.43s\t = Training   runtime\n",
      "\t0.12s\t = Validation runtime\n",
      "Fitting model: NeuralNetFastAI ... Training model for up to 167.09s of the 166.9s of remaining time.\n",
      "\tWarning: Exception caused NeuralNetFastAI to fail during training... Skipping this model.\n",
      "\t\tfuture feature annotations is not defined (dispatch.py, line 4)\n",
      "Detailed Traceback:\n",
      "Traceback (most recent call last):\n",
      "  File \"/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/autogluon/tabular/trainer/abstract_trainer.py\", line 962, in _train_and_save\n",
      "    model = self._train_single(X, y, model, X_val, y_val, **model_fit_kwargs)\n",
      "  File \"/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/autogluon/tabular/trainer/abstract_trainer.py\", line 934, in _train_single\n",
      "    model = model.fit(X=X, y=y, X_val=X_val, y_val=y_val, **model_fit_kwargs)\n",
      "  File \"/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/autogluon/core/models/abstract/abstract_model.py\", line 522, in fit\n",
      "    out = self._fit(**kwargs)\n",
      "  File \"/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/autogluon/tabular/models/fastainn/tabular_nn_fastai.py\", line 163, in _fit\n",
      "    try_import_fastai()\n",
      "  File \"/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/autogluon/core/utils/try_import.py\", line 107, in try_import_fastai\n",
      "    import autogluon.tabular.models.fastainn.imports_helper\n",
      "  File \"/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/autogluon/tabular/models/fastainn/imports_helper.py\", line 1, in <module>\n",
      "    from fastai.tabular.all import *\n",
      "  File \"/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/fastai/tabular/all.py\", line 1, in <module>\n",
      "    from ..basics import *\n",
      "  File \"/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/fastai/basics.py\", line 1, in <module>\n",
      "    from .data.all import *\n",
      "  File \"/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/fastai/data/all.py\", line 1, in <module>\n",
      "    from ..torch_basics import *\n",
      "  File \"/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/fastai/torch_basics.py\", line 9, in <module>\n",
      "    from .imports import *\n",
      "  File \"/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/fastai/imports.py\", line 30, in <module>\n",
      "    from fastcore.all import *\n",
      "  File \"/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/fastcore/all.py\", line 3, in <module>\n",
      "    from .dispatch import *\n",
      "  File \"/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/fastcore/dispatch.py\", line 4\n",
      "    from __future__ import annotations\n",
      "                                     ^\n",
      "SyntaxError: future feature annotations is not defined\n",
      "Fitting model: LightGBMXT ... Training model for up to 164.99s of the 164.81s of remaining time.\n",
      "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/lightgbm/engine.py:239: UserWarning: 'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. Pass 'log_evaluation()' callback via 'callbacks' argument instead.\n",
      "  _log_warning(\"'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. \"\n",
      "\t0.8521\t = Validation score   (roc_auc_ovo_macro)\n",
      "\t9.55s\t = Training   runtime\n",
      "\t0.03s\t = Validation runtime\n",
      "Fitting model: LightGBM ... Training model for up to 155.22s of the 155.04s of remaining time.\n",
      "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/lightgbm/engine.py:239: UserWarning: 'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. Pass 'log_evaluation()' callback via 'callbacks' argument instead.\n",
      "  _log_warning(\"'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. \"\n",
      "\t0.8472\t = Validation score   (roc_auc_ovo_macro)\n",
      "\t10.64s\t = Training   runtime\n",
      "\t0.03s\t = Validation runtime\n",
      "Fitting model: RandomForestGini ... Training model for up to 144.37s of the 144.19s of remaining time.\n",
      "\t0.8311\t = Validation score   (roc_auc_ovo_macro)\n",
      "\t3.46s\t = Training   runtime\n",
      "\t0.11s\t = Validation runtime\n",
      "Fitting model: RandomForestEntr ... Training model for up to 140.61s of the 140.43s of remaining time.\n",
      "\t0.844\t = Validation score   (roc_auc_ovo_macro)\n",
      "\t3.79s\t = Training   runtime\n",
      "\t0.11s\t = Validation runtime\n",
      "Fitting model: CatBoost ... Training model for up to 136.52s of the 136.34s of remaining time.\n",
      "Metric roc_auc_ovo_macro is not supported by this model - using AUC:type=Mu instead\n",
      "\tMany features detected (5569), dynamically setting 'colsample_bylevel' to 0.1795654516071108 to speed up training (Default = 1).\n",
      "\tTo disable this functionality, explicitly specify 'colsample_bylevel' in the model hyperparameters.\n",
      "\t0.8562\t = Validation score   (roc_auc_ovo_macro)\n",
      "\t55.52s\t = Training   runtime\n",
      "\t0.08s\t = Validation runtime\n",
      "Fitting model: ExtraTreesGini ... Training model for up to 80.72s of the 80.54s of remaining time.\n",
      "\t0.818\t = Validation score   (roc_auc_ovo_macro)\n",
      "\t3.45s\t = Training   runtime\n",
      "\t0.11s\t = Validation runtime\n",
      "Fitting model: ExtraTreesEntr ... Training model for up to 76.97s of the 76.79s of remaining time.\n",
      "\t0.8356\t = Validation score   (roc_auc_ovo_macro)\n",
      "\t3.54s\t = Training   runtime\n",
      "\t0.11s\t = Validation runtime\n",
      "Fitting model: XGBoost ... Training model for up to 73.13s of the 72.95s of remaining time.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[20:12:48] WARNING: ../src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'multi:softprob' was changed from 'merror' to 'mlogloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\t0.8291\t = Validation score   (roc_auc_ovo_macro)\n",
      "\t46.16s\t = Training   runtime\n",
      "\t0.04s\t = Validation runtime\n",
      "Fitting model: NeuralNetMXNet ... Training model for up to 26.72s of the 26.54s of remaining time.\n",
      "\t0.8911\t = Validation score   (roc_auc_ovo_macro)\n",
      "\t13.25s\t = Training   runtime\n",
      "\t2.69s\t = Validation runtime\n",
      "Fitting model: LightGBMLarge ... Training model for up to 10.56s of the 10.38s of remaining time.\n",
      "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/lightgbm/engine.py:239: UserWarning: 'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. Pass 'log_evaluation()' callback via 'callbacks' argument instead.\n",
      "  _log_warning(\"'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. \"\n",
      "\tRan out of time, early stopping on iteration 45. Best iteration is:\n",
      "\t[21]\ttrain_set's multi_logloss: 0.390576\ttrain_set's roc_auc_ovo_macro: 1\tvalid_set's multi_logloss: 0.809444\tvalid_set's roc_auc_ovo_macro: 0.805259\n",
      "\t0.8053\t = Validation score   (roc_auc_ovo_macro)\n",
      "\t12.15s\t = Training   runtime\n",
      "\t0.03s\t = Validation runtime\n",
      "Fitting model: WeightedEnsemble_L2 ... Training model for up to 170.61s of the -6.05s of remaining time.\n",
      "\t0.8962\t = Validation score   (roc_auc_ovo_macro)\n",
      "\t3.39s\t = Training   runtime\n",
      "\t0.0s\t = Validation runtime\n",
      "AutoGluon training complete, total runtime = 189.52s ...\n",
      "TabularPredictor saved. To load, use: predictor = TabularPredictor.load(\"../final_results/AutoGluon/models/mRNA/\")\n"
     ]
    }
   ],
   "source": [
    "time_limit = 180*1\n",
    "metric = 'roc_auc_ovo_macro' \n",
    "save_path = '../final_results/AutoGluon/models/mRNA'  # specifies folder to store trained models\n",
    "predictor = task(\n",
    "    label=label, \n",
    "    path=save_path, \n",
    "    eval_metric=metric).fit(input_df_final, time_limit=time_limit)#,                           \n",
    "#         hyperparameters=hyperparameters, hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,)#, presets='best_quality')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "09398471",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: roc_auc_ovo_macro on test data: 0.8759131103421761\n",
      "Evaluations on test data:\n",
      "{\n",
      "    \"roc_auc_ovo_macro\": 0.8759131103421761,\n",
      "    \"accuracy\": 0.735632183908046,\n",
      "    \"balanced_accuracy\": 0.7572755417956657,\n",
      "    \"mcc\": 0.6009272578354864\n",
      "}\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'roc_auc_ovo_macro': 0.8759131103421761,\n",
       " 'accuracy': 0.735632183908046,\n",
       " 'balanced_accuracy': 0.7572755417956657,\n",
       " 'mcc': 0.6009272578354864}"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictor.evaluate(input_df_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "511d2656",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>model</th>\n",
       "      <th>score_val</th>\n",
       "      <th>pred_time_val</th>\n",
       "      <th>fit_time</th>\n",
       "      <th>pred_time_val_marginal</th>\n",
       "      <th>fit_time_marginal</th>\n",
       "      <th>stack_level</th>\n",
       "      <th>can_infer</th>\n",
       "      <th>fit_order</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>WeightedEnsemble_L2</td>\n",
       "      <td>0.896249</td>\n",
       "      <td>2.724961</td>\n",
       "      <td>26.194055</td>\n",
       "      <td>0.003094</td>\n",
       "      <td>3.394044</td>\n",
       "      <td>2</td>\n",
       "      <td>True</td>\n",
       "      <td>13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>NeuralNetMXNet</td>\n",
       "      <td>0.891115</td>\n",
       "      <td>2.692266</td>\n",
       "      <td>13.250540</td>\n",
       "      <td>2.692266</td>\n",
       "      <td>13.250540</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>CatBoost</td>\n",
       "      <td>0.856244</td>\n",
       "      <td>0.079102</td>\n",
       "      <td>55.515267</td>\n",
       "      <td>0.079102</td>\n",
       "      <td>55.515267</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>LightGBMXT</td>\n",
       "      <td>0.852060</td>\n",
       "      <td>0.029601</td>\n",
       "      <td>9.549471</td>\n",
       "      <td>0.029601</td>\n",
       "      <td>9.549471</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>LightGBM</td>\n",
       "      <td>0.847163</td>\n",
       "      <td>0.028573</td>\n",
       "      <td>10.635069</td>\n",
       "      <td>0.028573</td>\n",
       "      <td>10.635069</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>RandomForestEntr</td>\n",
       "      <td>0.844047</td>\n",
       "      <td>0.108880</td>\n",
       "      <td>3.789447</td>\n",
       "      <td>0.108880</td>\n",
       "      <td>3.789447</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>ExtraTreesEntr</td>\n",
       "      <td>0.835574</td>\n",
       "      <td>0.108996</td>\n",
       "      <td>3.537303</td>\n",
       "      <td>0.108996</td>\n",
       "      <td>3.537303</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>RandomForestGini</td>\n",
       "      <td>0.831093</td>\n",
       "      <td>0.109040</td>\n",
       "      <td>3.459631</td>\n",
       "      <td>0.109040</td>\n",
       "      <td>3.459631</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>XGBoost</td>\n",
       "      <td>0.829060</td>\n",
       "      <td>0.039434</td>\n",
       "      <td>46.158945</td>\n",
       "      <td>0.039434</td>\n",
       "      <td>46.158945</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>ExtraTreesGini</td>\n",
       "      <td>0.818020</td>\n",
       "      <td>0.108950</td>\n",
       "      <td>3.449311</td>\n",
       "      <td>0.108950</td>\n",
       "      <td>3.449311</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>LightGBMLarge</td>\n",
       "      <td>0.805259</td>\n",
       "      <td>0.028821</td>\n",
       "      <td>12.145864</td>\n",
       "      <td>0.028821</td>\n",
       "      <td>12.145864</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>KNeighborsDist</td>\n",
       "      <td>0.772985</td>\n",
       "      <td>0.116131</td>\n",
       "      <td>1.433261</td>\n",
       "      <td>0.116131</td>\n",
       "      <td>1.433261</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>KNeighborsUnif</td>\n",
       "      <td>0.763874</td>\n",
       "      <td>0.118167</td>\n",
       "      <td>1.438675</td>\n",
       "      <td>0.118167</td>\n",
       "      <td>1.438675</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  model  score_val  pred_time_val   fit_time  \\\n",
       "0   WeightedEnsemble_L2   0.896249       2.724961  26.194055   \n",
       "1        NeuralNetMXNet   0.891115       2.692266  13.250540   \n",
       "2              CatBoost   0.856244       0.079102  55.515267   \n",
       "3            LightGBMXT   0.852060       0.029601   9.549471   \n",
       "4              LightGBM   0.847163       0.028573  10.635069   \n",
       "5      RandomForestEntr   0.844047       0.108880   3.789447   \n",
       "6        ExtraTreesEntr   0.835574       0.108996   3.537303   \n",
       "7      RandomForestGini   0.831093       0.109040   3.459631   \n",
       "8               XGBoost   0.829060       0.039434  46.158945   \n",
       "9        ExtraTreesGini   0.818020       0.108950   3.449311   \n",
       "10        LightGBMLarge   0.805259       0.028821  12.145864   \n",
       "11       KNeighborsDist   0.772985       0.116131   1.433261   \n",
       "12       KNeighborsUnif   0.763874       0.118167   1.438675   \n",
       "\n",
       "    pred_time_val_marginal  fit_time_marginal  stack_level  can_infer  \\\n",
       "0                 0.003094           3.394044            2       True   \n",
       "1                 2.692266          13.250540            1       True   \n",
       "2                 0.079102          55.515267            1       True   \n",
       "3                 0.029601           9.549471            1       True   \n",
       "4                 0.028573          10.635069            1       True   \n",
       "5                 0.108880           3.789447            1       True   \n",
       "6                 0.108996           3.537303            1       True   \n",
       "7                 0.109040           3.459631            1       True   \n",
       "8                 0.039434          46.158945            1       True   \n",
       "9                 0.108950           3.449311            1       True   \n",
       "10                0.028821          12.145864            1       True   \n",
       "11                0.116131           1.433261            1       True   \n",
       "12                0.118167           1.438675            1       True   \n",
       "\n",
       "    fit_order  \n",
       "0          13  \n",
       "1          11  \n",
       "2           7  \n",
       "3           3  \n",
       "4           4  \n",
       "5           6  \n",
       "6           9  \n",
       "7           5  \n",
       "8          10  \n",
       "9           8  \n",
       "10         12  \n",
       "11          2  \n",
       "12          1  "
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictor.leaderboard(silent=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a883e033",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Computing feature importance via permutation shuffling for 5569 features using 174 rows with 3 shuffle sets...\n",
      "\t50177.81s\t= Expected runtime (16725.94s per shuffle set)\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "## Feature importance by autogluon\n",
    "\n",
    "ag_feature_importance_test = predictor.feature_importance(input_df_test)\n",
    "ag_feature_importance_test.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "f25001d7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>importance</th>\n",
       "      <th>stddev</th>\n",
       "      <th>p_value</th>\n",
       "      <th>n</th>\n",
       "      <th>p99_high</th>\n",
       "      <th>p99_low</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>GNG7</th>\n",
       "      <td>0.002113</td>\n",
       "      <td>0.000841</td>\n",
       "      <td>0.024476</td>\n",
       "      <td>3</td>\n",
       "      <td>0.006931</td>\n",
       "      <td>-0.002705</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>BCR</th>\n",
       "      <td>0.002016</td>\n",
       "      <td>0.001607</td>\n",
       "      <td>0.080939</td>\n",
       "      <td>3</td>\n",
       "      <td>0.011225</td>\n",
       "      <td>-0.007192</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>B3GALT1</th>\n",
       "      <td>0.001641</td>\n",
       "      <td>0.000733</td>\n",
       "      <td>0.030286</td>\n",
       "      <td>3</td>\n",
       "      <td>0.005842</td>\n",
       "      <td>-0.002560</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>SPTSSB</th>\n",
       "      <td>0.001590</td>\n",
       "      <td>0.000669</td>\n",
       "      <td>0.027155</td>\n",
       "      <td>3</td>\n",
       "      <td>0.005426</td>\n",
       "      <td>-0.002246</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>GLS2</th>\n",
       "      <td>0.001577</td>\n",
       "      <td>0.000893</td>\n",
       "      <td>0.046195</td>\n",
       "      <td>3</td>\n",
       "      <td>0.006695</td>\n",
       "      <td>-0.003541</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         importance    stddev   p_value  n  p99_high   p99_low\n",
       "GNG7       0.002113  0.000841  0.024476  3  0.006931 -0.002705\n",
       "BCR        0.002016  0.001607  0.080939  3  0.011225 -0.007192\n",
       "B3GALT1    0.001641  0.000733  0.030286  3  0.005842 -0.002560\n",
       "SPTSSB     0.001590  0.000669  0.027155  3  0.005426 -0.002246\n",
       "GLS2       0.001577  0.000893  0.046195  3  0.006695 -0.003541"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ag_feature_importance_test.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f260929b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Computing feature importance via permutation shuffling for 5569 features using 405 rows with 3 shuffle sets...\n",
      "\t55531.05s\t= Expected runtime (18510.35s per shuffle set)\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "## Feature importance by autogluon\n",
    "\n",
    "ag_feature_importance = predictor.feature_importance(input_df_final)\n",
    "ag_feature_importance.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "aedfaccd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>importance</th>\n",
       "      <th>stddev</th>\n",
       "      <th>p_value</th>\n",
       "      <th>n</th>\n",
       "      <th>p99_high</th>\n",
       "      <th>p99_low</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>TMCO2</th>\n",
       "      <td>0.001268</td>\n",
       "      <td>0.000169</td>\n",
       "      <td>0.002920</td>\n",
       "      <td>3</td>\n",
       "      <td>0.002234</td>\n",
       "      <td>0.000302</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>NKX2-3</th>\n",
       "      <td>0.000731</td>\n",
       "      <td>0.000261</td>\n",
       "      <td>0.020054</td>\n",
       "      <td>3</td>\n",
       "      <td>0.002229</td>\n",
       "      <td>-0.000767</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>SNAP91</th>\n",
       "      <td>0.000665</td>\n",
       "      <td>0.000097</td>\n",
       "      <td>0.003536</td>\n",
       "      <td>3</td>\n",
       "      <td>0.001224</td>\n",
       "      <td>0.000107</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>CYP1A2</th>\n",
       "      <td>0.000642</td>\n",
       "      <td>0.000243</td>\n",
       "      <td>0.022298</td>\n",
       "      <td>3</td>\n",
       "      <td>0.002036</td>\n",
       "      <td>-0.000751</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>CTAG1A</th>\n",
       "      <td>0.000542</td>\n",
       "      <td>0.000219</td>\n",
       "      <td>0.025103</td>\n",
       "      <td>3</td>\n",
       "      <td>0.001795</td>\n",
       "      <td>-0.000711</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        importance    stddev   p_value  n  p99_high   p99_low\n",
       "TMCO2     0.001268  0.000169  0.002920  3  0.002234  0.000302\n",
       "NKX2-3    0.000731  0.000261  0.020054  3  0.002229 -0.000767\n",
       "SNAP91    0.000665  0.000097  0.003536  3  0.001224  0.000107\n",
       "CYP1A2    0.000642  0.000243  0.022298  3  0.002036 -0.000751\n",
       "CTAG1A    0.000542  0.000219  0.025103  3  0.001795 -0.000711"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ag_feature_importance.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fed8babd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ag_feature_importance.to_csv('/home/ec2-user/SageMaker/daniel/mRNA_train_james.csv')\n",
    "\n",
    "# ag_feature_importance.to_csv('/home/ec2-user/SageMaker/daniel/final_results/AutoGluon/models/mRNA/mRNA_james.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "574b6f8e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>importance</th>\n",
       "      <th>stddev</th>\n",
       "      <th>p_value</th>\n",
       "      <th>n</th>\n",
       "      <th>p99_high</th>\n",
       "      <th>p99_low</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>TMCO2</th>\n",
       "      <td>0.001268</td>\n",
       "      <td>0.000169</td>\n",
       "      <td>0.002920</td>\n",
       "      <td>3</td>\n",
       "      <td>0.002234</td>\n",
       "      <td>0.000302</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>NKX2-3</th>\n",
       "      <td>0.000731</td>\n",
       "      <td>0.000261</td>\n",
       "      <td>0.020054</td>\n",
       "      <td>3</td>\n",
       "      <td>0.002229</td>\n",
       "      <td>-0.000767</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>SNAP91</th>\n",
       "      <td>0.000665</td>\n",
       "      <td>0.000097</td>\n",
       "      <td>0.003536</td>\n",
       "      <td>3</td>\n",
       "      <td>0.001224</td>\n",
       "      <td>0.000107</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>CYP1A2</th>\n",
       "      <td>0.000642</td>\n",
       "      <td>0.000243</td>\n",
       "      <td>0.022298</td>\n",
       "      <td>3</td>\n",
       "      <td>0.002036</td>\n",
       "      <td>-0.000751</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>CTAG1A</th>\n",
       "      <td>0.000542</td>\n",
       "      <td>0.000219</td>\n",
       "      <td>0.025103</td>\n",
       "      <td>3</td>\n",
       "      <td>0.001795</td>\n",
       "      <td>-0.000711</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>LHFPL5</th>\n",
       "      <td>0.000512</td>\n",
       "      <td>0.000438</td>\n",
       "      <td>0.090087</td>\n",
       "      <td>3</td>\n",
       "      <td>0.003020</td>\n",
       "      <td>-0.001996</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RNASE13</th>\n",
       "      <td>0.000505</td>\n",
       "      <td>0.000180</td>\n",
       "      <td>0.019996</td>\n",
       "      <td>3</td>\n",
       "      <td>0.001539</td>\n",
       "      <td>-0.000529</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>WIF1</th>\n",
       "      <td>0.000505</td>\n",
       "      <td>0.000452</td>\n",
       "      <td>0.096271</td>\n",
       "      <td>3</td>\n",
       "      <td>0.003095</td>\n",
       "      <td>-0.002085</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>SPACA5B</th>\n",
       "      <td>0.000500</td>\n",
       "      <td>0.000113</td>\n",
       "      <td>0.008351</td>\n",
       "      <td>3</td>\n",
       "      <td>0.001149</td>\n",
       "      <td>-0.000149</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RXFP4</th>\n",
       "      <td>0.000494</td>\n",
       "      <td>0.000058</td>\n",
       "      <td>0.002314</td>\n",
       "      <td>3</td>\n",
       "      <td>0.000829</td>\n",
       "      <td>0.000159</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TTLL2</th>\n",
       "      <td>0.000477</td>\n",
       "      <td>0.000257</td>\n",
       "      <td>0.042331</td>\n",
       "      <td>3</td>\n",
       "      <td>0.001949</td>\n",
       "      <td>-0.000995</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AMELY</th>\n",
       "      <td>0.000461</td>\n",
       "      <td>0.000103</td>\n",
       "      <td>0.008057</td>\n",
       "      <td>3</td>\n",
       "      <td>0.001050</td>\n",
       "      <td>-0.000127</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RAD21L1</th>\n",
       "      <td>0.000459</td>\n",
       "      <td>0.000085</td>\n",
       "      <td>0.005630</td>\n",
       "      <td>3</td>\n",
       "      <td>0.000947</td>\n",
       "      <td>-0.000029</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AC007919.2</th>\n",
       "      <td>0.000443</td>\n",
       "      <td>0.000144</td>\n",
       "      <td>0.016700</td>\n",
       "      <td>3</td>\n",
       "      <td>0.001269</td>\n",
       "      <td>-0.000382</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MC4R</th>\n",
       "      <td>0.000441</td>\n",
       "      <td>0.000070</td>\n",
       "      <td>0.004101</td>\n",
       "      <td>3</td>\n",
       "      <td>0.000840</td>\n",
       "      <td>0.000042</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>GOLGA8G</th>\n",
       "      <td>0.000432</td>\n",
       "      <td>0.000048</td>\n",
       "      <td>0.002017</td>\n",
       "      <td>3</td>\n",
       "      <td>0.000705</td>\n",
       "      <td>0.000159</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>SMIM18</th>\n",
       "      <td>0.000432</td>\n",
       "      <td>0.000264</td>\n",
       "      <td>0.052641</td>\n",
       "      <td>3</td>\n",
       "      <td>0.001945</td>\n",
       "      <td>-0.001081</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>LSMEM2</th>\n",
       "      <td>0.000429</td>\n",
       "      <td>0.000114</td>\n",
       "      <td>0.011357</td>\n",
       "      <td>3</td>\n",
       "      <td>0.001082</td>\n",
       "      <td>-0.000224</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RBP2</th>\n",
       "      <td>0.000421</td>\n",
       "      <td>0.000240</td>\n",
       "      <td>0.046576</td>\n",
       "      <td>3</td>\n",
       "      <td>0.001794</td>\n",
       "      <td>-0.000952</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MYOD1</th>\n",
       "      <td>0.000419</td>\n",
       "      <td>0.000308</td>\n",
       "      <td>0.071522</td>\n",
       "      <td>3</td>\n",
       "      <td>0.002185</td>\n",
       "      <td>-0.001348</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>HIST1H2AB</th>\n",
       "      <td>0.000416</td>\n",
       "      <td>0.000121</td>\n",
       "      <td>0.013549</td>\n",
       "      <td>3</td>\n",
       "      <td>0.001111</td>\n",
       "      <td>-0.000278</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>CCDC166</th>\n",
       "      <td>0.000415</td>\n",
       "      <td>0.000075</td>\n",
       "      <td>0.005370</td>\n",
       "      <td>3</td>\n",
       "      <td>0.000845</td>\n",
       "      <td>-0.000015</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>XAGE1B</th>\n",
       "      <td>0.000411</td>\n",
       "      <td>0.000128</td>\n",
       "      <td>0.015470</td>\n",
       "      <td>3</td>\n",
       "      <td>0.001145</td>\n",
       "      <td>-0.000324</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>EFCAB8</th>\n",
       "      <td>0.000410</td>\n",
       "      <td>0.000192</td>\n",
       "      <td>0.032787</td>\n",
       "      <td>3</td>\n",
       "      <td>0.001508</td>\n",
       "      <td>-0.000687</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>NLRP4</th>\n",
       "      <td>0.000403</td>\n",
       "      <td>0.000142</td>\n",
       "      <td>0.019443</td>\n",
       "      <td>3</td>\n",
       "      <td>0.001216</td>\n",
       "      <td>-0.000410</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AC104057.1</th>\n",
       "      <td>0.000403</td>\n",
       "      <td>0.000165</td>\n",
       "      <td>0.025844</td>\n",
       "      <td>3</td>\n",
       "      <td>0.001348</td>\n",
       "      <td>-0.000543</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>FABP2</th>\n",
       "      <td>0.000397</td>\n",
       "      <td>0.000135</td>\n",
       "      <td>0.018113</td>\n",
       "      <td>3</td>\n",
       "      <td>0.001169</td>\n",
       "      <td>-0.000374</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>PLK5</th>\n",
       "      <td>0.000391</td>\n",
       "      <td>0.000203</td>\n",
       "      <td>0.039790</td>\n",
       "      <td>3</td>\n",
       "      <td>0.001557</td>\n",
       "      <td>-0.000775</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>LCT</th>\n",
       "      <td>0.000384</td>\n",
       "      <td>0.000248</td>\n",
       "      <td>0.057471</td>\n",
       "      <td>3</td>\n",
       "      <td>0.001803</td>\n",
       "      <td>-0.001035</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>INSM2</th>\n",
       "      <td>0.000377</td>\n",
       "      <td>0.000255</td>\n",
       "      <td>0.062174</td>\n",
       "      <td>3</td>\n",
       "      <td>0.001839</td>\n",
       "      <td>-0.001084</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RP11-293I14.2</th>\n",
       "      <td>0.000375</td>\n",
       "      <td>0.000257</td>\n",
       "      <td>0.063754</td>\n",
       "      <td>3</td>\n",
       "      <td>0.001851</td>\n",
       "      <td>-0.001100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>S100G</th>\n",
       "      <td>0.000375</td>\n",
       "      <td>0.000055</td>\n",
       "      <td>0.003504</td>\n",
       "      <td>3</td>\n",
       "      <td>0.000689</td>\n",
       "      <td>0.000062</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AL117190.3</th>\n",
       "      <td>0.000374</td>\n",
       "      <td>0.000043</td>\n",
       "      <td>0.002161</td>\n",
       "      <td>3</td>\n",
       "      <td>0.000619</td>\n",
       "      <td>0.000129</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>CT45A3</th>\n",
       "      <td>0.000373</td>\n",
       "      <td>0.000073</td>\n",
       "      <td>0.006187</td>\n",
       "      <td>3</td>\n",
       "      <td>0.000790</td>\n",
       "      <td>-0.000043</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>GJA8</th>\n",
       "      <td>0.000372</td>\n",
       "      <td>0.000078</td>\n",
       "      <td>0.007125</td>\n",
       "      <td>3</td>\n",
       "      <td>0.000818</td>\n",
       "      <td>-0.000074</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>SLC39A12</th>\n",
       "      <td>0.000370</td>\n",
       "      <td>0.000206</td>\n",
       "      <td>0.044851</td>\n",
       "      <td>3</td>\n",
       "      <td>0.001552</td>\n",
       "      <td>-0.000811</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>OR4F17</th>\n",
       "      <td>0.000368</td>\n",
       "      <td>0.000142</td>\n",
       "      <td>0.023006</td>\n",
       "      <td>3</td>\n",
       "      <td>0.001180</td>\n",
       "      <td>-0.000444</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TBATA</th>\n",
       "      <td>0.000363</td>\n",
       "      <td>0.000036</td>\n",
       "      <td>0.001606</td>\n",
       "      <td>3</td>\n",
       "      <td>0.000568</td>\n",
       "      <td>0.000158</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>CDH19</th>\n",
       "      <td>0.000360</td>\n",
       "      <td>0.000382</td>\n",
       "      <td>0.122104</td>\n",
       "      <td>3</td>\n",
       "      <td>0.002552</td>\n",
       "      <td>-0.001831</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TCEAL6</th>\n",
       "      <td>0.000360</td>\n",
       "      <td>0.000167</td>\n",
       "      <td>0.032492</td>\n",
       "      <td>3</td>\n",
       "      <td>0.001317</td>\n",
       "      <td>-0.000598</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "               importance    stddev   p_value  n  p99_high   p99_low\n",
       "TMCO2            0.001268  0.000169  0.002920  3  0.002234  0.000302\n",
       "NKX2-3           0.000731  0.000261  0.020054  3  0.002229 -0.000767\n",
       "SNAP91           0.000665  0.000097  0.003536  3  0.001224  0.000107\n",
       "CYP1A2           0.000642  0.000243  0.022298  3  0.002036 -0.000751\n",
       "CTAG1A           0.000542  0.000219  0.025103  3  0.001795 -0.000711\n",
       "LHFPL5           0.000512  0.000438  0.090087  3  0.003020 -0.001996\n",
       "RNASE13          0.000505  0.000180  0.019996  3  0.001539 -0.000529\n",
       "WIF1             0.000505  0.000452  0.096271  3  0.003095 -0.002085\n",
       "SPACA5B          0.000500  0.000113  0.008351  3  0.001149 -0.000149\n",
       "RXFP4            0.000494  0.000058  0.002314  3  0.000829  0.000159\n",
       "TTLL2            0.000477  0.000257  0.042331  3  0.001949 -0.000995\n",
       "AMELY            0.000461  0.000103  0.008057  3  0.001050 -0.000127\n",
       "RAD21L1          0.000459  0.000085  0.005630  3  0.000947 -0.000029\n",
       "AC007919.2       0.000443  0.000144  0.016700  3  0.001269 -0.000382\n",
       "MC4R             0.000441  0.000070  0.004101  3  0.000840  0.000042\n",
       "GOLGA8G          0.000432  0.000048  0.002017  3  0.000705  0.000159\n",
       "SMIM18           0.000432  0.000264  0.052641  3  0.001945 -0.001081\n",
       "LSMEM2           0.000429  0.000114  0.011357  3  0.001082 -0.000224\n",
       "RBP2             0.000421  0.000240  0.046576  3  0.001794 -0.000952\n",
       "MYOD1            0.000419  0.000308  0.071522  3  0.002185 -0.001348\n",
       "HIST1H2AB        0.000416  0.000121  0.013549  3  0.001111 -0.000278\n",
       "CCDC166          0.000415  0.000075  0.005370  3  0.000845 -0.000015\n",
       "XAGE1B           0.000411  0.000128  0.015470  3  0.001145 -0.000324\n",
       "EFCAB8           0.000410  0.000192  0.032787  3  0.001508 -0.000687\n",
       "NLRP4            0.000403  0.000142  0.019443  3  0.001216 -0.000410\n",
       "AC104057.1       0.000403  0.000165  0.025844  3  0.001348 -0.000543\n",
       "FABP2            0.000397  0.000135  0.018113  3  0.001169 -0.000374\n",
       "PLK5             0.000391  0.000203  0.039790  3  0.001557 -0.000775\n",
       "LCT              0.000384  0.000248  0.057471  3  0.001803 -0.001035\n",
       "INSM2            0.000377  0.000255  0.062174  3  0.001839 -0.001084\n",
       "RP11-293I14.2    0.000375  0.000257  0.063754  3  0.001851 -0.001100\n",
       "S100G            0.000375  0.000055  0.003504  3  0.000689  0.000062\n",
       "AL117190.3       0.000374  0.000043  0.002161  3  0.000619  0.000129\n",
       "CT45A3           0.000373  0.000073  0.006187  3  0.000790 -0.000043\n",
       "GJA8             0.000372  0.000078  0.007125  3  0.000818 -0.000074\n",
       "SLC39A12         0.000370  0.000206  0.044851  3  0.001552 -0.000811\n",
       "OR4F17           0.000368  0.000142  0.023006  3  0.001180 -0.000444\n",
       "TBATA            0.000363  0.000036  0.001606  3  0.000568  0.000158\n",
       "CDH19            0.000360  0.000382  0.122104  3  0.002552 -0.001831\n",
       "TCEAL6           0.000360  0.000167  0.032492  3  0.001317 -0.000598"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ag_feature_importance.sort_values(\"importance\")[::-1][:40]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a986a9a5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}