{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Step 0: モニター用のサンプルMLモデルを学習する" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "このノートブックを実行する時のヒント: \n", "- このノートブックは大容量のRawデータを読み込むため、メモリー8GB以上のインスタンスで実行してください\n", "- KernelはPython3(Data Science)で動作確認をしています。\n", "- デフォルトではSageMakerのデフォルトBucketを利用します。必要に応じて変更することも可能です。\n", "- 実際に動かさなくても出力を確認できるようにセルのアウトプットを残しています。きれいな状態から実行したい場合は、右クリックメニューから \"Clear All Outputs\"を選択して出力をクリアしてから始めてください。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 必要なライブラリーのインストールとインポート" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "!pip install xgboost==1.5.1 optuna scikit-learn==1.0.2" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import io\n", "import boto3\n", "import time\n", "from datetime import datetime\n", "\n", "import pandas as pd\n", "import numpy as np\n", "import sagemaker\n", "from xgboost import XGBRegressor\n", "from sklearn import metrics\n", "import optuna\n", "\n", "import model_utils" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# SageMaker Default Bucketをセット\n", "# 他のバケットを利用する場合はここを変更する\n", "bucket = sagemaker.Session().default_bucket()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### トレーニング用のデータをオープンデータのS3バケットから取得" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "RAWデータ読み込み時に、メモリー4GBのインスタンスではメモリー不足によりカーネルが終了します \n", "メモリーを8GB以上搭載したml.m5.largeなどのインスタンスを使用してください " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2019 7\n", "2019 8\n", "2019 9\n", "2019 10\n", "2019 11\n", "2019 12\n", "2020 1\n", "2020 2\n", "CPU times: user 2min 46s, sys: 32.7 s, total: 3min 19s\n", "Wall time: 4min 32s\n" ] } ], "source": [ "%%time\n", "\n", "start = '2019-07-01'\n", "end = '2020-02-29'\n", "sampling_rate = 20\n", "\n", "df_features = pd.DataFrame()\n", "for target in pd.date_range(start, end, freq='M'):\n", " print(target.year, target.month)\n", " \n", " # Get raw data with 20% sampling\n", " previous_year, previous_month = model_utils.get_previous_year_month(target.year, target.month)\n", " df_previous_month = model_utils.get_raw_data(previous_year, previous_month, sampling_rate)\n", " df_current_month = model_utils.get_raw_data(target.year, target.month, sampling_rate)\n", " df_data = pd.concat([df_previous_month, df_current_month]).reset_index(drop=True)\n", " del df_previous_month\n", " del df_current_month\n", "\n", " # Extract features\n", " df_features_current_month = model_utils.extract_features(df_data)\n", " df_features_current_month = model_utils.filter_current_month(df_features_current_month, target.year, target.month)\n", " \n", " df_features = pd.concat([df_features, df_features_current_month])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### データを期間で学習、検証、テストに分割し、ハイパーパラメーターチューニングを行う\n", "ハイパーパラメーターチューニングはSageMakerでのトレーニングに組み込むことも可能ですが、今回はモデル作成をシンプルに実行するためにノートブック内で動かしています" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Split data for training\n", "df_train = df_features[('2019-07-01'<=df_features.index) & (df_features.index <= '2019-12-31')].copy()\n", "df_validation = df_features[('2020-01-01'<=df_features.index) & (df_features.index <= '2020-01-31')].copy()\n", "df_test = df_features[('2020-02-01'<=df_features.index) & (df_features.index <= '2020-02-28')].copy()\n", "\n", "y_train = df_train['pickup_count'].values\n", "X_train = df_train[df_train.columns[1:]].values\n", "y_validation = df_validation['pickup_count'].values\n", "X_validation = df_validation[df_validation.columns[1:]].values\n", "y_test = df_test['pickup_count'].values\n", "X_test = df_test[df_test.columns[1:]].values" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m[I 2022-12-15 10:37:31,273]\u001b[0m A new study created in memory with name: no-name-df606950-889f-4535-a9f0-e4d1d7a52b19\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:37:49,796]\u001b[0m Trial 0 finished with value: 77.77208361524062 and parameters: {'max_depth': 5, 'alpha': 595.7300746704884, 'min_child_weight': 120, 'subsample': 0.7158474569125317, 'eta': 0.3801220778920714}. Best is trial 0 with value: 77.77208361524062.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:38:08,899]\u001b[0m Trial 1 finished with value: 77.02380578750139 and parameters: {'max_depth': 5, 'alpha': 975.5350306481461, 'min_child_weight': 70, 'subsample': 0.7601596156492616, 'eta': 0.2363846460940136}. Best is trial 1 with value: 77.02380578750139.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:38:25,763]\u001b[0m Trial 2 finished with value: 81.88596838991234 and parameters: {'max_depth': 5, 'alpha': 209.82960645759752, 'min_child_weight': 90, 'subsample': 0.5929469916694439, 'eta': 0.489450686239459}. Best is trial 1 with value: 77.02380578750139.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:38:49,537]\u001b[0m Trial 3 finished with value: 75.84568397248532 and parameters: {'max_depth': 8, 'alpha': 71.27087140564048, 'min_child_weight': 100, 'subsample': 0.5114553701914726, 'eta': 0.10532813405553175}. Best is trial 3 with value: 75.84568397248532.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:39:00,688]\u001b[0m Trial 4 finished with value: 79.24528251039771 and parameters: {'max_depth': 3, 'alpha': 569.4373647459664, 'min_child_weight': 89, 'subsample': 0.7049259419901959, 'eta': 0.1991279931446049}. Best is trial 3 with value: 75.84568397248532.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:39:27,348]\u001b[0m Trial 5 finished with value: 72.87626544251876 and parameters: {'max_depth': 7, 'alpha': 17.446370432440684, 'min_child_weight': 70, 'subsample': 0.8197475403827663, 'eta': 0.11950101013136272}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:40:03,156]\u001b[0m Trial 6 finished with value: 79.67777782906306 and parameters: {'max_depth': 8, 'alpha': 93.95774599445939, 'min_child_weight': 12, 'subsample': 0.8856977253181799, 'eta': 0.3677019348005638}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:40:10,388]\u001b[0m Trial 7 finished with value: 79.83207904514698 and parameters: {'max_depth': 2, 'alpha': 921.3524155540516, 'min_child_weight': 18, 'subsample': 0.6071649012352914, 'eta': 0.1428285881574027}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:40:32,526]\u001b[0m Trial 8 finished with value: 76.11333836402889 and parameters: {'max_depth': 6, 'alpha': 158.46003290587495, 'min_child_weight': 30, 'subsample': 0.6630682352765657, 'eta': 0.18087601086155877}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:41:04,670]\u001b[0m Trial 9 finished with value: 78.71398548817686 and parameters: {'max_depth': 8, 'alpha': 281.3364244988845, 'min_child_weight': 76, 'subsample': 0.8354604589739707, 'eta': 0.29429450448439476}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:41:46,078]\u001b[0m Trial 10 finished with value: 75.38634314536843 and parameters: {'max_depth': 9, 'alpha': 405.2634566881888, 'min_child_weight': 42, 'subsample': 0.9807998055993875, 'eta': 0.27250960263292057}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:42:28,247]\u001b[0m Trial 11 finished with value: 76.97887819901742 and parameters: {'max_depth': 9, 'alpha': 385.05684079071, 'min_child_weight': 44, 'subsample': 0.9902858603660764, 'eta': 0.2741003569734054}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:42:59,947]\u001b[0m Trial 12 finished with value: 79.00590977248038 and parameters: {'max_depth': 7, 'alpha': 771.693293432469, 'min_child_weight': 50, 'subsample': 0.9833736593112152, 'eta': 0.3324708964642379}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:43:34,551]\u001b[0m Trial 13 finished with value: 80.3679712952508 and parameters: {'max_depth': 9, 'alpha': 1.6240307004530337, 'min_child_weight': 59, 'subsample': 0.8878182263994833, 'eta': 0.4565976582960169}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:44:02,830]\u001b[0m Trial 14 finished with value: 75.2419466359473 and parameters: {'max_depth': 7, 'alpha': 414.05936668619256, 'min_child_weight': 36, 'subsample': 0.7956405970152522, 'eta': 0.10721894996527502}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:44:28,865]\u001b[0m Trial 15 finished with value: 75.9878905915093 and parameters: {'max_depth': 6, 'alpha': 720.0681429449196, 'min_child_weight': 0, 'subsample': 0.8142085064321046, 'eta': 0.10205035571523037}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:44:43,840]\u001b[0m Trial 16 finished with value: 75.02692274686075 and parameters: {'max_depth': 4, 'alpha': 335.039022052757, 'min_child_weight': 66, 'subsample': 0.7892541388459663, 'eta': 0.16651605093687322}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:44:55,206]\u001b[0m Trial 17 finished with value: 78.78332304426065 and parameters: {'max_depth': 3, 'alpha': 267.65419982803854, 'min_child_weight': 76, 'subsample': 0.8525717676305501, 'eta': 0.18074039077525406}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:44:59,604]\u001b[0m Trial 18 finished with value: 84.86921367458474 and parameters: {'max_depth': 1, 'alpha': 314.7698321231883, 'min_child_weight': 59, 'subsample': 0.9110096172534815, 'eta': 0.22652031476556614}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:45:14,707]\u001b[0m Trial 19 finished with value: 79.49312262346432 and parameters: {'max_depth': 4, 'alpha': 524.3202611896247, 'min_child_weight': 98, 'subsample': 0.764089704421335, 'eta': 0.16098542745193378}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:45:28,466]\u001b[0m Trial 20 finished with value: 79.57043176781805 and parameters: {'max_depth': 4, 'alpha': 7.992322395294423, 'min_child_weight': 116, 'subsample': 0.6725249120169093, 'eta': 0.2168084162276126}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:45:56,990]\u001b[0m Trial 21 finished with value: 75.77042364098418 and parameters: {'max_depth': 7, 'alpha': 426.7812655177545, 'min_child_weight': 35, 'subsample': 0.7914744501585645, 'eta': 0.1326514754763159}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:46:20,959]\u001b[0m Trial 22 finished with value: 74.71888785625812 and parameters: {'max_depth': 6, 'alpha': 682.7553432388506, 'min_child_weight': 67, 'subsample': 0.7942830128814948, 'eta': 0.12441495376871681}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:46:45,810]\u001b[0m Trial 23 finished with value: 75.9255318000287 and parameters: {'max_depth': 6, 'alpha': 671.4925002263429, 'min_child_weight': 69, 'subsample': 0.9229107105169898, 'eta': 0.14893757813182745}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:47:01,264]\u001b[0m Trial 24 finished with value: 78.65270214212134 and parameters: {'max_depth': 4, 'alpha': 801.4226733305587, 'min_child_weight': 81, 'subsample': 0.8466236452828909, 'eta': 0.13999119678237934}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:47:24,513]\u001b[0m Trial 25 finished with value: 76.44402467401008 and parameters: {'max_depth': 6, 'alpha': 868.1061629286465, 'min_child_weight': 63, 'subsample': 0.7235846042581046, 'eta': 0.18093512240514656}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:47:43,706]\u001b[0m Trial 26 finished with value: 76.81182653068953 and parameters: {'max_depth': 5, 'alpha': 658.4940270296295, 'min_child_weight': 53, 'subsample': 0.7710663436145959, 'eta': 0.12549252971711278}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:47:54,938]\u001b[0m Trial 27 finished with value: 79.14518838608227 and parameters: {'max_depth': 3, 'alpha': 495.04317574647087, 'min_child_weight': 85, 'subsample': 0.8231637577625006, 'eta': 0.2583378910534704}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:48:22,795]\u001b[0m Trial 28 finished with value: 75.76168225553381 and parameters: {'max_depth': 7, 'alpha': 185.46536334443476, 'min_child_weight': 102, 'subsample': 0.9337371439712899, 'eta': 0.16480662926287554}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:48:37,463]\u001b[0m Trial 29 finished with value: 76.35434082092893 and parameters: {'max_depth': 4, 'alpha': 615.7267187397515, 'min_child_weight': 67, 'subsample': 0.7260118646579098, 'eta': 0.20148799910311282}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:49:00,335]\u001b[0m Trial 30 finished with value: 78.4644613428213 and parameters: {'max_depth': 6, 'alpha': 494.53074327361975, 'min_child_weight': 52, 'subsample': 0.6848127335136737, 'eta': 0.33001440154189743}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:49:29,650]\u001b[0m Trial 31 finished with value: 74.62188337883865 and parameters: {'max_depth': 7, 'alpha': 377.90497014312723, 'min_child_weight': 32, 'subsample': 0.8048765831738828, 'eta': 0.10181569642365346}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:49:49,230]\u001b[0m Trial 32 finished with value: 75.79970228735438 and parameters: {'max_depth': 5, 'alpha': 327.88606371357383, 'min_child_weight': 75, 'subsample': 0.8662060524965635, 'eta': 0.12589347616640623}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:50:16,733]\u001b[0m Trial 33 finished with value: 77.08416914023665 and parameters: {'max_depth': 7, 'alpha': 119.79213019037157, 'min_child_weight': 28, 'subsample': 0.7549953548618707, 'eta': 0.12123628290578631}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:50:47,808]\u001b[0m Trial 34 finished with value: 73.8551575631542 and parameters: {'max_depth': 8, 'alpha': 206.15031206517529, 'min_child_weight': 60, 'subsample': 0.7914570912260337, 'eta': 0.1002366586383201}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:51:19,861]\u001b[0m Trial 35 finished with value: 75.09818160789118 and parameters: {'max_depth': 8, 'alpha': 232.51665731171983, 'min_child_weight': 44, 'subsample': 0.8082275508562556, 'eta': 0.10714056911328979}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:51:52,917]\u001b[0m Trial 36 finished with value: 82.75594469195075 and parameters: {'max_depth': 8, 'alpha': 51.83845271824586, 'min_child_weight': 13, 'subsample': 0.7457254081991068, 'eta': 0.44899115259255096}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:52:19,470]\u001b[0m Trial 37 finished with value: 75.2698072489552 and parameters: {'max_depth': 8, 'alpha': 139.59332781175615, 'min_child_weight': 94, 'subsample': 0.6173581926468027, 'eta': 0.1018788907763682}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:52:49,998]\u001b[0m Trial 38 finished with value: 77.25927940702049 and parameters: {'max_depth': 7, 'alpha': 217.71850543354248, 'min_child_weight': 23, 'subsample': 0.8621485786929561, 'eta': 0.20247284636847981}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:53:17,565]\u001b[0m Trial 39 finished with value: 77.78330077844033 and parameters: {'max_depth': 9, 'alpha': 70.82091170012853, 'min_child_weight': 58, 'subsample': 0.501853306757283, 'eta': 0.15084210738850345}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:53:41,713]\u001b[0m Trial 40 finished with value: 77.47778191457628 and parameters: {'max_depth': 6, 'alpha': 561.161276699925, 'min_child_weight': 108, 'subsample': 0.8895420652969035, 'eta': 0.24892863004586316}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:54:00,743]\u001b[0m Trial 41 finished with value: 75.83273990725189 and parameters: {'max_depth': 5, 'alpha': 346.4717007723859, 'min_child_weight': 68, 'subsample': 0.7937461524193514, 'eta': 0.17394834454685995}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:54:27,116]\u001b[0m Trial 42 finished with value: 74.82902125097327 and parameters: {'max_depth': 7, 'alpha': 447.7316670886733, 'min_child_weight': 83, 'subsample': 0.7739922428610406, 'eta': 0.12402984169678255}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:54:53,515]\u001b[0m Trial 43 finished with value: 75.82671540998987 and parameters: {'max_depth': 7, 'alpha': 993.6508635301852, 'min_child_weight': 85, 'subsample': 0.734868908704065, 'eta': 0.12122511226639646}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:55:24,657]\u001b[0m Trial 44 finished with value: 77.23259387459402 and parameters: {'max_depth': 8, 'alpha': 448.57017473721663, 'min_child_weight': 80, 'subsample': 0.8306406840416405, 'eta': 0.1424551313659464}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:55:47,389]\u001b[0m Trial 45 finished with value: 76.52796480472506 and parameters: {'max_depth': 6, 'alpha': 267.5576046327119, 'min_child_weight': 91, 'subsample': 0.7755836652680191, 'eta': 0.11970199388189023}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:56:13,168]\u001b[0m Trial 46 finished with value: 79.94465238072888 and parameters: {'max_depth': 7, 'alpha': 465.56276396640436, 'min_child_weight': 74, 'subsample': 0.6421041852233894, 'eta': 0.39729537142853155}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:56:45,747]\u001b[0m Trial 47 finished with value: 74.30726433007189 and parameters: {'max_depth': 8, 'alpha': 376.26908896101327, 'min_child_weight': 48, 'subsample': 0.8117033908399881, 'eta': 0.10112911773224174}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:57:17,663]\u001b[0m Trial 48 finished with value: 74.78163715525162 and parameters: {'max_depth': 8, 'alpha': 376.13860681138357, 'min_child_weight': 49, 'subsample': 0.812920710051217, 'eta': 0.10057496160804144}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n", "\u001b[32m[I 2022-12-15 10:57:49,184]\u001b[0m Trial 49 finished with value: 75.03167647867211 and parameters: {'max_depth': 9, 'alpha': 187.88131284150344, 'min_child_weight': 35, 'subsample': 0.5611342416569123, 'eta': 0.15294051734904007}. Best is trial 5 with value: 72.87626544251876.\u001b[0m\n" ] } ], "source": [ "def objective(trial):\n", " \n", " params = {\n", " 'objective': 'reg:squarederror',\n", " 'max_depth': trial.suggest_int('max_depth', 1, 9),\n", " 'alpha': trial.suggest_float('alpha', 0, 1000),\n", " 'min_child_weight': trial.suggest_int('min_child_weight', 0, 120),\n", " 'subsample': trial.suggest_float('subsample', 0.5, 1),\n", " 'eta': trial.suggest_float('eta', 0.1, 0.5),\n", " }\n", " \n", " # Fit model\n", " bst = XGBRegressor(**params)\n", " bst.fit(X_train, y_train)\n", " \n", " # make predictions\n", " preds = bst.predict(X_validation)\n", " \n", " # Return target metric\n", " return np.sqrt(metrics.mean_squared_error(y_validation, preds))\n", "\n", "# 最適なパラメーターを探すため、実行に20分ほどかかります。\n", "study = optuna.create_study()\n", "study.optimize(objective, n_trials=50)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Validation data\n", "RMSE: 72.87626544251876\n", "MAE: 46.227257232512194\n", "R2: 0.8969308192107929\n", "Test data\n", "RMSE: 50.24950549434309\n", "MAE: 35.37940113033567\n", "R2: 0.9549393007439243\n" ] } ], "source": [ "bst = XGBRegressor(**study.best_params)\n", "bst.fit(X_train, y_train)\n", "\n", "print('Validation data')\n", "validation_preds = bst.predict(X_validation)\n", "model_utils.calc_accuracy(y_validation, validation_preds)\n", "\n", "print('Test data')\n", "test_preds = bst.predict(X_test)\n", "model_utils.calc_accuracy(y_test, test_preds)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'max_depth': '7',\n", " 'alpha': '17.446370432440684',\n", " 'min_child_weight': '70',\n", " 'subsample': '0.8197475403827663',\n", " 'eta': '0.11950101013136272'}" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "best_params_str = {k: str(v) for k, v in study.best_params.items()}\n", "best_params_str" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 得られたハイパーパラメーターを利用してモデルのトレーニングを行う" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "model_data_prefix = 'model_monitor/data/nyctaxi/model_training'" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "s3://sagemaker-ap-northeast-1-370828233696/model_monitor/data/nyctaxi/model_training/data.train\n", "s3://sagemaker-ap-northeast-1-370828233696/model_monitor/data/nyctaxi/model_training/data.validate\n" ] } ], "source": [ "# Split data to train, validation and test, and save the data in S3\n", "data_path = f's3://{bucket}/{model_data_prefix}'\n", "train_data_file = 'data.train'\n", "validation_data_file = 'data.validate'\n", "\n", "# Upload train and validate data to s3\n", "df_train.to_csv(f'./{train_data_file}', index=False, header=False)\n", "resp = sagemaker.s3.S3Uploader.upload(f'./{train_data_file}', data_path)\n", "print(resp)\n", "\n", "df_validation.to_csv(f'./{validation_data_file}', index=False, header=False)\n", "resp = sagemaker.s3.S3Uploader.upload(f'./{validation_data_file}', data_path)\n", "print(resp)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training job nyctaxi-xgboost-regression-2022-12-15-10-58-19\n", "InProgress\n", "InProgress\n", "InProgress\n", "InProgress\n", "InProgress\n", "InProgress\n", "InProgress\n", "InProgress\n", "InProgress\n", "Completed\n", "arn:aws:sagemaker:ap-northeast-1:370828233696:model/nyctaxi-xgboost-regression-2022-12-15-10-58-19-model\n", "Model creation completed.\n", "\n", "Endpoint deployment started.\n", "Endpoint name: nyctaxi-testing-endpoint-2022-12-15-10-58-19\n", "Status: Creating\n", "Status: Creating\n", "Status: Creating\n", "Status: Creating\n", "Status: Creating\n", "Endpoint deployment completed.\n", "Endpoint arn: arn:aws:sagemaker:ap-northeast-1:370828233696:endpoint/nyctaxi-testing-endpoint-2022-12-15-10-58-19\n", "CPU times: user 265 ms, sys: 20.4 ms, total: 286 ms\n", "Wall time: 14min 3s\n" ] } ], "source": [ "%%time\n", "\n", "# モデルの学習とエンドポイントのデプロイをまとめて行うので、実行に15分程度かかります\n", "\n", "# Set environments for model training\n", "role = sagemaker.get_execution_role()\n", "region = boto3.Session().region_name\n", "# container = get_image_uri(region, 'xgboost', '1.0-1')\n", "container = sagemaker.image_uris.retrieve('xgboost', region, '1.5-1')\n", "\n", "training_start_time = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')\n", "\n", "job_name = 'nyctaxi-xgboost-regression-' + training_start_time\n", "model_prefix = 'model/nyctaxi-xgboost-regression'\n", "print(\"Training job\", job_name)\n", "\n", "# Define model training parameters\n", "create_training_params = \\\n", "{\n", " \"AlgorithmSpecification\": {\n", " \"TrainingImage\": container,\n", " \"TrainingInputMode\": \"File\"\n", " },\n", " \"RoleArn\": role,\n", " \"OutputDataConfig\": {\n", " \"S3OutputPath\": f's3://{bucket}/{model_prefix}'\n", " },\n", " \"ResourceConfig\": {\n", " \"InstanceCount\": 1,\n", " \"InstanceType\": \"ml.m5.large\",\n", " \"VolumeSizeInGB\": 5\n", " },\n", " \"TrainingJobName\": job_name,\n", " \"HyperParameters\": {\n", " **best_params_str,\n", " \"num_round\":\"1000\"\n", " },\n", " \"StoppingCondition\": {\n", " \"MaxRuntimeInSeconds\": 3600\n", " },\n", " \"InputDataConfig\": [\n", " {\n", " \"ChannelName\": \"train\",\n", " \"DataSource\": {\n", " \"S3DataSource\": {\n", " \"S3DataType\": \"S3Prefix\",\n", " \"S3Uri\": f'{data_path}/{train_data_file}',\n", " \"S3DataDistributionType\": \"FullyReplicated\"\n", " }\n", " },\n", " \"ContentType\": \"csv\",\n", " \"CompressionType\": \"None\"\n", " },\n", " {\n", " \"ChannelName\": \"validation\",\n", " \"DataSource\": {\n", " \"S3DataSource\": {\n", " \"S3DataType\": \"S3Prefix\",\n", " \"S3Uri\": f'{data_path}/{validation_data_file}',\n", " \"S3DataDistributionType\": \"FullyReplicated\"\n", " }\n", " },\n", " \"ContentType\": \"csv\",\n", " \"CompressionType\": \"None\"\n", " }\n", " ]\n", "}\n", "\n", "# Start model training\n", "client = boto3.client('sagemaker', region_name=region)\n", "client.create_training_job(**create_training_params)\n", "\n", "# Wait for model training completion\n", "status = client.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']\n", "print(status)\n", "while status !='Completed' and status!='Failed':\n", " time.sleep(60)\n", " status = client.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']\n", " print(status)\n", "\n", "# Create model\n", "model_name=job_name + '-model'\n", "info = client.describe_training_job(TrainingJobName=job_name)\n", "model_data = info['ModelArtifacts']['S3ModelArtifacts']\n", "primary_container = {\n", " 'Image': container,\n", " 'ModelDataUrl': model_data\n", "}\n", "\n", "create_model_response = client.create_model(\n", " ModelName = model_name,\n", " ExecutionRoleArn = role,\n", " PrimaryContainer = primary_container)\n", "\n", "print(create_model_response['ModelArn'])\n", "print('Model creation completed.')\n", "print('')\n", "\n", "# Deploy model to endpoint\n", "endpoint_name = 'nyctaxi-testing-endpoint-' + training_start_time\n", "endpoint_config_name = 'nyctaxi-testing-endpoint-config-' + training_start_time\n", "print('Endpoint deployment started.')\n", "\n", "create_endpoint_config_response = client.create_endpoint_config(\n", " EndpointConfigName = endpoint_config_name,\n", " ProductionVariants=[{\n", " 'InstanceType':'ml.t2.medium',\n", " 'InitialVariantWeight':1,\n", " 'InitialInstanceCount':1,\n", " 'ModelName':model_name,\n", " 'VariantName':'AllTraffic'}])\n", "\n", "print('Endpoint name:', endpoint_name)\n", "create_endpoint_response = client.create_endpoint(\n", " EndpointName=endpoint_name,\n", " EndpointConfigName=endpoint_config_name)\n", "\n", "# Check endpoint creation status\n", "resp = client.describe_endpoint(EndpointName=endpoint_name)\n", "status = resp['EndpointStatus']\n", "while status=='Creating':\n", " print(\"Status: \" + status)\n", " time.sleep(60)\n", " resp = client.describe_endpoint(EndpointName=endpoint_name)\n", " status = resp['EndpointStatus']\n", "\n", "print('Endpoint deployment completed.')\n", "print('Endpoint arn:', resp['EndpointArn'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "学習されたモデルの名前\n", "nyctaxi-xgboost-regression-2022-12-15-10-58-19-model\n" ] } ], "source": [ "print('学習されたモデルの名前')\n", "print(model_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- ここで作成したモデルの名前を、データ品質やモデル品質のモニタリング時に設定してください。\n", "- エンドポイントは学習時のテスト用なので、このノートブック以降では使用しません " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### デプロイした推論エンドポイントを使ってテストする" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | pickup_count | \n", "history_12slots | \n", "history_16slots | \n", "history_20slots | \n", "history_24slots | \n", "history_28slots | \n", "history_32slots | \n", "history_36slots | \n", "history_40slots | \n", "history_44slots | \n", "... | \n", "tolls_amount_mean_12slot | \n", "tolls_amount_mean_16slot | \n", "tolls_amount_mean_20slot | \n", "tolls_amount_mean_96slot | \n", "tolls_amount_mean_100slot | \n", "tolls_amount_mean_104slot | \n", "tolls_amount_mean_192slot | \n", "tolls_amount_mean_196slot | \n", "tolls_amount_mean_200slot | \n", "time_slot | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
timestamp | \n", "\n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " |
2020-02-01 00:00:00 | \n", "537 | \n", "758.0 | \n", "639.0 | \n", "773.0 | \n", "837.0 | \n", "652.0 | \n", "594.0 | \n", "689.0 | \n", "641.0 | \n", "569.0 | \n", "... | \n", "0.193127 | \n", "0.239437 | \n", "0.351979 | \n", "0.313423 | \n", "0.223082 | \n", "0.242169 | \n", "0.257931 | \n", "0.290571 | \n", "0.242643 | \n", "0 | \n", "
2020-02-01 00:15:00 | \n", "467 | \n", "686.0 | \n", "635.0 | \n", "773.0 | \n", "803.0 | \n", "712.0 | \n", "636.0 | \n", "664.0 | \n", "589.0 | \n", "592.0 | \n", "... | \n", "0.318950 | \n", "0.240945 | \n", "0.199884 | \n", "0.232559 | \n", "0.315537 | \n", "0.273710 | \n", "0.561222 | \n", "0.251507 | \n", "0.276241 | \n", "1 | \n", "
2020-02-01 00:30:00 | \n", "461 | \n", "613.0 | \n", "636.0 | \n", "733.0 | \n", "850.0 | \n", "720.0 | \n", "621.0 | \n", "618.0 | \n", "686.0 | \n", "611.0 | \n", "... | \n", "0.219462 | \n", "0.253349 | \n", "0.195689 | \n", "0.479963 | \n", "0.247932 | \n", "0.345783 | \n", "0.175694 | \n", "0.209829 | \n", "0.268718 | \n", "2 | \n", "
2020-02-01 00:45:00 | \n", "435 | \n", "653.0 | \n", "770.0 | \n", "656.0 | \n", "822.0 | \n", "827.0 | \n", "605.0 | \n", "628.0 | \n", "654.0 | \n", "614.0 | \n", "... | \n", "0.212894 | \n", "0.288740 | \n", "0.179558 | \n", "0.528936 | \n", "0.148938 | \n", "0.314701 | \n", "0.267750 | \n", "0.362395 | \n", "0.206045 | \n", "3 | \n", "
2020-02-01 01:00:00 | \n", "444 | \n", "608.0 | \n", "758.0 | \n", "639.0 | \n", "773.0 | \n", "837.0 | \n", "652.0 | \n", "594.0 | \n", "689.0 | \n", "641.0 | \n", "... | \n", "0.245378 | \n", "0.193127 | \n", "0.239437 | \n", "0.787042 | \n", "0.313423 | \n", "0.223082 | \n", "0.138045 | \n", "0.257931 | \n", "0.290571 | \n", "4 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
2020-02-28 22:45:00 | \n", "727 | \n", "789.0 | \n", "899.0 | \n", "879.0 | \n", "646.0 | \n", "650.0 | \n", "700.0 | \n", "661.0 | \n", "589.0 | \n", "626.0 | \n", "... | \n", "0.174195 | \n", "0.190456 | \n", "0.258225 | \n", "0.405913 | \n", "0.353292 | \n", "0.312892 | \n", "0.296960 | \n", "0.198254 | \n", "0.237768 | \n", "91 | \n", "
2020-02-28 23:00:00 | \n", "732 | \n", "677.0 | \n", "911.0 | \n", "857.0 | \n", "739.0 | \n", "708.0 | \n", "673.0 | \n", "657.0 | \n", "590.0 | \n", "610.0 | \n", "... | \n", "0.273648 | \n", "0.261844 | \n", "0.218833 | \n", "0.360368 | \n", "0.347987 | \n", "0.210449 | \n", "0.227228 | \n", "0.206366 | \n", "0.269820 | \n", "92 | \n", "
2020-02-28 23:15:00 | \n", "651 | \n", "629.0 | \n", "879.0 | \n", "882.0 | \n", "763.0 | \n", "636.0 | \n", "713.0 | \n", "603.0 | \n", "618.0 | \n", "571.0 | \n", "... | \n", "0.297361 | \n", "0.158965 | \n", "0.251508 | \n", "0.376248 | \n", "0.401356 | \n", "0.303953 | \n", "0.166909 | \n", "0.228785 | \n", "0.255742 | \n", "93 | \n", "
2020-02-28 23:30:00 | \n", "660 | \n", "582.0 | \n", "869.0 | \n", "889.0 | \n", "856.0 | \n", "648.0 | \n", "684.0 | \n", "737.0 | \n", "618.0 | \n", "540.0 | \n", "... | \n", "0.209467 | \n", "0.190713 | \n", "0.206524 | \n", "0.331530 | \n", "0.293599 | \n", "0.311977 | \n", "0.306903 | \n", "0.255863 | \n", "0.183760 | \n", "94 | \n", "
2020-02-28 23:45:00 | \n", "606 | \n", "631.0 | \n", "789.0 | \n", "899.0 | \n", "879.0 | \n", "646.0 | \n", "650.0 | \n", "700.0 | \n", "661.0 | \n", "589.0 | \n", "... | \n", "0.295404 | \n", "0.174195 | \n", "0.190456 | \n", "0.451774 | \n", "0.405913 | \n", "0.353292 | \n", "0.330726 | \n", "0.296960 | \n", "0.198254 | \n", "95 | \n", "
2688 rows × 143 columns
\n", "\n", " | VendorID | \n", "tpep_pickup_datetime | \n", "tpep_dropoff_datetime | \n", "passenger_count | \n", "trip_distance | \n", "RatecodeID | \n", "store_and_fwd_flag | \n", "PULocationID | \n", "DOLocationID | \n", "payment_type | \n", "... | \n", "extra | \n", "mta_tax | \n", "tip_amount | \n", "tolls_amount | \n", "improvement_surcharge | \n", "total_amount | \n", "congestion_surcharge | \n", "airport_fee | \n", "pickup_date_hour | \n", "hour | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
118021 | \n", "1 | \n", "2021-08-02 16:51:58 | \n", "2021-08-02 16:57:30 | \n", "1.0 | \n", "1.10 | \n", "1.0 | \n", "N | \n", "229 | \n", "137 | \n", "1 | \n", "... | \n", "3.5 | \n", "0.5 | \n", "2.06 | \n", "0.00 | \n", "0.3 | \n", "12.36 | \n", "2.5 | \n", "0.0 | \n", "2021-08-02 16:00:00 | \n", "16 | \n", "
1873676 | \n", "1 | \n", "2021-08-22 19:22:17 | \n", "2021-08-22 19:50:23 | \n", "1.0 | \n", "18.10 | \n", "2.0 | \n", "N | \n", "132 | \n", "230 | \n", "1 | \n", "... | \n", "2.5 | \n", "0.5 | \n", "12.35 | \n", "6.55 | \n", "0.3 | \n", "74.20 | \n", "2.5 | \n", "0.0 | \n", "2021-08-22 19:00:00 | \n", "19 | \n", "
948739 | \n", "1 | \n", "2021-08-12 04:27:49 | \n", "2021-08-12 04:36:36 | \n", "1.0 | \n", "3.20 | \n", "1.0 | \n", "N | \n", "74 | \n", "126 | \n", "2 | \n", "... | \n", "0.5 | \n", "0.5 | \n", "0.00 | \n", "0.00 | \n", "0.3 | \n", "12.30 | \n", "0.0 | \n", "0.0 | \n", "2021-08-12 04:00:00 | \n", "4 | \n", "
2060902 | \n", "2 | \n", "2021-08-25 09:16:46 | \n", "2021-08-25 09:42:36 | \n", "1.0 | \n", "5.33 | \n", "1.0 | \n", "N | \n", "151 | \n", "164 | \n", "1 | \n", "... | \n", "0.0 | \n", "0.5 | \n", "4.96 | \n", "0.00 | \n", "0.3 | \n", "29.76 | \n", "2.5 | \n", "0.0 | \n", "2021-08-25 09:00:00 | \n", "9 | \n", "
658824 | \n", "1 | \n", "2021-08-08 18:39:17 | \n", "2021-08-08 18:42:05 | \n", "2.0 | \n", "0.40 | \n", "1.0 | \n", "N | \n", "236 | \n", "236 | \n", "1 | \n", "... | \n", "2.5 | \n", "0.5 | \n", "1.45 | \n", "0.00 | \n", "0.3 | \n", "8.75 | \n", "2.5 | \n", "0.0 | \n", "2021-08-08 18:00:00 | \n", "18 | \n", "
5 rows × 21 columns
\n", "\n", " | VendorID | \n", "tpep_pickup_datetime | \n", "tpep_dropoff_datetime | \n", "passenger_count | \n", "trip_distance | \n", "RatecodeID | \n", "store_and_fwd_flag | \n", "PULocationID | \n", "DOLocationID | \n", "payment_type | \n", "... | \n", "extra | \n", "mta_tax | \n", "tip_amount | \n", "tolls_amount | \n", "improvement_surcharge | \n", "total_amount | \n", "congestion_surcharge | \n", "airport_fee | \n", "pickup_date_hour | \n", "hour | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
147820 | \n", "1 | \n", "2021-08-02 23:56:53 | \n", "2021-08-03 00:24:30 | \n", "1.0 | \n", "11.50 | \n", "1.0 | \n", "N | \n", "138 | \n", "48 | \n", "1 | \n", "... | \n", "4.25 | \n", "0.5 | \n", "1.00 | \n", "6.55 | \n", "0.3 | \n", "46.60 | \n", "2.5 | \n", "1.25 | \n", "2021-08-02 23:00:00 | \n", "23 | \n", "
2660769 | \n", "1 | \n", "2021-08-02 23:01:33 | \n", "2021-08-02 23:24:18 | \n", "NaN | \n", "15.10 | \n", "NaN | \n", "None | \n", "132 | \n", "145 | \n", "0 | \n", "... | \n", "1.75 | \n", "0.5 | \n", "8.71 | \n", "0.00 | \n", "0.3 | \n", "52.26 | \n", "NaN | \n", "NaN | \n", "2021-08-02 23:00:00 | \n", "23 | \n", "
2660798 | \n", "2 | \n", "2021-08-02 23:28:23 | \n", "2021-08-02 23:48:50 | \n", "NaN | \n", "5.49 | \n", "NaN | \n", "None | \n", "4 | \n", "265 | \n", "0 | \n", "... | \n", "0.00 | \n", "0.5 | \n", "5.68 | \n", "0.00 | \n", "0.3 | \n", "34.23 | \n", "NaN | \n", "NaN | \n", "2021-08-02 23:00:00 | \n", "23 | \n", "
147845 | \n", "2 | \n", "2021-08-02 23:49:04 | \n", "2021-08-02 23:53:29 | \n", "1.0 | \n", "1.42 | \n", "1.0 | \n", "N | \n", "68 | \n", "48 | \n", "1 | \n", "... | \n", "0.50 | \n", "0.5 | \n", "1.96 | \n", "0.00 | \n", "0.3 | \n", "11.76 | \n", "2.5 | \n", "0.00 | \n", "2021-08-02 23:00:00 | \n", "23 | \n", "
147486 | \n", "2 | \n", "2021-08-02 23:38:17 | \n", "2021-08-02 23:44:20 | \n", "1.0 | \n", "0.71 | \n", "1.0 | \n", "N | \n", "237 | \n", "140 | \n", "2 | \n", "... | \n", "0.50 | \n", "0.5 | \n", "0.00 | \n", "0.00 | \n", "0.3 | \n", "9.80 | \n", "2.5 | \n", "0.00 | \n", "2021-08-02 23:00:00 | \n", "23 | \n", "
5 rows × 21 columns
\n", "" ], "text/plain": [ " VendorID tpep_pickup_datetime tpep_dropoff_datetime passenger_count \\\n", "147820 1 2021-08-02 23:56:53 2021-08-03 00:24:30 1.0 \n", "2660769 1 2021-08-02 23:01:33 2021-08-02 23:24:18 NaN \n", "2660798 2 2021-08-02 23:28:23 2021-08-02 23:48:50 NaN \n", "147845 2 2021-08-02 23:49:04 2021-08-02 23:53:29 1.0 \n", "147486 2 2021-08-02 23:38:17 2021-08-02 23:44:20 1.0 \n", "\n", " trip_distance RatecodeID store_and_fwd_flag PULocationID \\\n", "147820 11.50 1.0 N 138 \n", "2660769 15.10 NaN None 132 \n", "2660798 5.49 NaN None 4 \n", "147845 1.42 1.0 N 68 \n", "147486 0.71 1.0 N 237 \n", "\n", " DOLocationID payment_type ... extra mta_tax tip_amount \\\n", "147820 48 1 ... 4.25 0.5 1.00 \n", "2660769 145 0 ... 1.75 0.5 8.71 \n", "2660798 265 0 ... 0.00 0.5 5.68 \n", "147845 48 1 ... 0.50 0.5 1.96 \n", "147486 140 2 ... 0.50 0.5 0.00 \n", "\n", " tolls_amount improvement_surcharge total_amount \\\n", "147820 6.55 0.3 46.60 \n", "2660769 0.00 0.3 52.26 \n", "2660798 0.00 0.3 34.23 \n", "147845 0.00 0.3 11.76 \n", "147486 0.00 0.3 9.80 \n", "\n", " congestion_surcharge airport_fee pickup_date_hour hour \n", "147820 2.5 1.25 2021-08-02 23:00:00 23 \n", "2660769 NaN NaN 2021-08-02 23:00:00 23 \n", "2660798 NaN NaN 2021-08-02 23:00:00 23 \n", "147845 2.5 0.00 2021-08-02 23:00:00 23 \n", "147486 2.5 0.00 2021-08-02 23:00:00 23 \n", "\n", "[5 rows x 21 columns]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_1day = df_raw[('2021-8-02' <= df_raw.tpep_pickup_datetime) & (df_raw.tpep_pickup_datetime < '2021-08-03')].sort_values(by='pickup_date_hour').copy()\n", "df_1day.tail(5)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " | VendorID | \n", "tpep_pickup_datetime | \n", "tpep_dropoff_datetime | \n", "passenger_count | \n", "trip_distance | \n", "RatecodeID | \n", "store_and_fwd_flag | \n", "PULocationID | \n", "DOLocationID | \n", "payment_type | \n", "... | \n", "extra | \n", "mta_tax | \n", "tip_amount | \n", "tolls_amount | \n", "improvement_surcharge | \n", "total_amount | \n", "congestion_surcharge | \n", "airport_fee | \n", "pickup_date_hour | \n", "hour | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
68575 | \n", "2 | \n", "2021-08-02 00:07:51 | \n", "2021-08-02 00:28:53 | \n", "1.0 | \n", "6.22 | \n", "1.0 | \n", "N | \n", "234 | \n", "49 | \n", "2 | \n", "... | \n", "0.5 | \n", "0.5 | \n", "0.00 | \n", "0.00 | \n", "0.3 | \n", "24.30 | \n", "2.5 | \n", "0.00 | \n", "2021-08-02 | \n", "0 | \n", "
70087 | \n", "2 | \n", "2021-08-02 00:09:40 | \n", "2021-08-02 00:42:29 | \n", "1.0 | \n", "20.50 | \n", "2.0 | \n", "N | \n", "132 | \n", "238 | \n", "1 | \n", "... | \n", "0.0 | \n", "0.5 | \n", "12.12 | \n", "6.55 | \n", "0.3 | \n", "72.72 | \n", "0.0 | \n", "1.25 | \n", "2021-08-02 | \n", "0 | \n", "
70270 | \n", "2 | \n", "2021-08-02 00:29:32 | \n", "2021-08-02 00:38:14 | \n", "1.0 | \n", "1.62 | \n", "1.0 | \n", "N | \n", "114 | \n", "246 | \n", "2 | \n", "... | \n", "0.5 | \n", "0.5 | \n", "0.00 | \n", "0.00 | \n", "0.3 | \n", "11.80 | \n", "2.5 | \n", "0.00 | \n", "2021-08-02 | \n", "0 | \n", "
69723 | \n", "2 | \n", "2021-08-02 00:09:47 | \n", "2021-08-02 00:16:51 | \n", "1.0 | \n", "2.99 | \n", "1.0 | \n", "N | \n", "170 | \n", "263 | \n", "1 | \n", "... | \n", "0.5 | \n", "0.5 | \n", "5.32 | \n", "0.00 | \n", "0.3 | \n", "18.62 | \n", "2.5 | \n", "0.00 | \n", "2021-08-02 | \n", "0 | \n", "
69539 | \n", "1 | \n", "2021-08-02 00:31:53 | \n", "2021-08-02 00:42:49 | \n", "1.0 | \n", "2.50 | \n", "1.0 | \n", "N | \n", "186 | \n", "143 | \n", "1 | \n", "... | \n", "3.0 | \n", "0.5 | \n", "2.85 | \n", "0.00 | \n", "0.3 | \n", "17.15 | \n", "2.5 | \n", "0.00 | \n", "2021-08-02 | \n", "0 | \n", "
5 rows × 21 columns
\n", "