{ "cells": [ { "cell_type": "markdown", "id": "b819e5dc", "metadata": {}, "source": [ "### This examples implements a use case to predict the taxi fare of chicago cabs and demonstrates continuous training using a train-eval-check recursive loop.\n", "### The main pipeline trains the initial model and then gradually trains the model\n", "### some more until the model evaluation metrics are good enough." ] }, { "cell_type": "code", "execution_count": null, "id": "ae37f600", "metadata": { "scrolled": true }, "outputs": [], "source": [ "#!pip install kfp==1.8.4" ] }, { "cell_type": "code", "execution_count": null, "id": "0b38e1a2", "metadata": {}, "outputs": [], "source": [ "import kfp\n", "from kfp import components" ] }, { "cell_type": "code", "execution_count": null, "id": "dea6f54b", "metadata": {}, "outputs": [], "source": [ "# Loads kubeflow pipelines components from URL\n", "chicago_taxi_dataset_op = components.load_component_from_url('https://raw.githubusercontent.com/kubeflow/pipelines/e3337b8bdcd63636934954e592d4b32c95b49129/components/datasets/Chicago%20Taxi/component.yaml')\n", "xgboost_train_on_csv_op = components.load_component_from_url('https://raw.githubusercontent.com/kubeflow/pipelines/567c04c51ff00a1ee525b3458425b17adbe3df61/components/XGBoost/Train/component.yaml')\n", "xgboost_predict_on_csv_op = components.load_component_from_url('https://raw.githubusercontent.com/kubeflow/pipelines/567c04c51ff00a1ee525b3458425b17adbe3df61/components/XGBoost/Predict/component.yaml')\n", "\n", "pandas_transform_csv_op = components.load_component_from_url('https://raw.githubusercontent.com/kubeflow/pipelines/6162d55998b176b50267d351241100bb0ee715bc/components/pandas/Transform_DataFrame/in_CSV_format/component.yaml')\n", "drop_header_op = kfp.components.load_component_from_url('https://raw.githubusercontent.com/kubeflow/pipelines/02c9638287468c849632cf9f7885b51de4c66f86/components/tables/Remove_header/component.yaml')\n", "calculate_regression_metrics_from_csv_op = kfp.components.load_component_from_url('https://raw.githubusercontent.com/kubeflow/pipelines/616542ac0f789914f4eb53438da713dd3004fba4/components/ml_metrics/Calculate_regression_metrics/from_CSV/component.yaml')\n" ] }, { "cell_type": "markdown", "id": "ea4e9d1f", "metadata": {}, "source": [ "### This recursive sub-pipeline trains a model, evaluates it, calculates the metrics and checks them.\n", "### If the model error is too high, then more training is performed until the model is good." ] }, { "cell_type": "code", "execution_count": null, "id": "90b210c4", "metadata": {}, "outputs": [], "source": [ "@kfp.dsl.graph_component\n", "def train_until_low_error(starting_model, training_data, true_values):\n", " # Training\n", " model = xgboost_train_on_csv_op(\n", " training_data=training_data,\n", " starting_model=starting_model,\n", " label_column=0,\n", " objective='reg:squarederror',\n", " num_iterations=50,\n", " ).outputs['model']\n", "\n", " # Predicting\n", " predictions = xgboost_predict_on_csv_op(\n", " data=training_data,\n", " model=model,\n", " label_column=0,\n", " ).output\n", "\n", " # Calculating the regression metrics \n", " metrics_task = calculate_regression_metrics_from_csv_op(\n", " true_values=true_values,\n", " predicted_values=predictions,\n", " )\n", "\n", " # Checking the metrics\n", " with kfp.dsl.Condition(metrics_task.outputs['mean_squared_error'] > 0.01):\n", " # Training some more\n", " train_until_low_error(\n", " starting_model=model,\n", " training_data=training_data,\n", " true_values=true_values,\n", " )" ] }, { "cell_type": "markdown", "id": "0f524031", "metadata": {}, "source": [ "### The main pipleine trains the initial model and then gradually trains the model some more until the model evaluation metrics are good enough.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1aaf016a", "metadata": {}, "outputs": [], "source": [ "# The main pipleine trains the initial model and then gradually trains the model some more until the model evaluation metrics are good enough.\n", "@kfp.dsl.pipeline()\n", "def train_until_good_pipeline():\n", " # Preparing the training data\n", " training_data = chicago_taxi_dataset_op(\n", " where='trip_start_timestamp >= \"2019-01-01\" AND trip_start_timestamp < \"2019-02-01\"',\n", " select='tips,trip_seconds,trip_miles,pickup_community_area,dropoff_community_area,fare,tolls,extras,trip_total',\n", " limit=10000,\n", " ).output\n", "\n", " # Preparing the true values\n", " true_values_table = pandas_transform_csv_op(\n", " table=training_data,\n", " transform_code='df = df[[\"tips\"]]',\n", " ).output\n", " \n", " true_values = drop_header_op(true_values_table).output\n", "\n", " # Initial model training\n", " first_model = xgboost_train_on_csv_op(\n", " training_data=training_data,\n", " label_column=0,\n", " objective='reg:squarederror',\n", " num_iterations=100,\n", " ).outputs['model']\n", "\n", " # Recursively training until the error becomes low\n", " train_until_low_error(\n", " starting_model=first_model,\n", " training_data=training_data,\n", " true_values=true_values,\n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "04bb072d", "metadata": {}, "outputs": [], "source": [ "kfp.compiler.Compiler().compile(train_until_good_pipeline, \"train_until_good_pipeline.zip\")" ] }, { "cell_type": "code", "execution_count": null, "id": "3e53d67f", "metadata": { "tags": [] }, "outputs": [], "source": [ "client = kfp.Client()\n", "\n", "experiment = client.create_experiment(name=\"kubeflow\")\n", "\n", "my_run = client.run_pipeline(experiment.id, \"train_until_good_pipeline\", \"train_until_good_pipeline.zip\")" ] }, { "cell_type": "code", "execution_count": null, "id": "997c6094-07a6-430a-a0d3-d54566a2a2f7", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "8a8b3775-435a-4302-be61-d6c03d531745", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" } }, "nbformat": 4, "nbformat_minor": 5 }