Sample meta training script for sklearn clustering algorithms
# Overview of this examples This is an example of a SageMaker training meta-script, where this one single script can be used to call different sklearn clustering algorithms. The `train.py` script supports two modes: 1. **Run a clustering algorithm with a specific number of components.** The `train-kmeans.sh` quick-start example shows how to use this script to in-turn call `sklearn.cluster.KMeans` with hyperparameters (i.e., kwargs) `n_clusters`. Do review `train-kmeans.sh` and feel free to modify it with additional hyperparameters (i.e., kwargs) supported by `sklearn.cluster.KMeans`. As a further exercise, try to modify this shell script to make it instruct `train.py` to run GMM using `sklearn.mixture.GaussianMixture`. Tips: recall that sklearn's GMM uses `n_components` kwargs to denote the number of clusters. 2. **Run a clustering algorithm multiple times, each time with a different number of clusters.** See the `train-*sweep.sh` quick-start examples which demonstrates kmeans, gmm, and agglomerative clustering. At the end of training, `train.py` will also computes a bunch of clustering metrics and output it as .csv files. Do note that `train.py` requires that the estimator must have either `n_clusters` kwarg or `n_components` kwargs. # Final note on the quick-start examples (i.e., `*.sh`) These are provided so that you can quickly, directly run `train.py` in your own Python virtual environments. As you review the `*.sh` files, you'll notice that it loads sample input data, and saves the fitted estimators + evaluation metrics to `/tmp`. To run `train.py` on SageMaker containers (either as SageMaker *local* mode, or as SageMaker training jobs), please refer to SageMaker SDK documentation and examples. # Sample invocation directly on a Python virtual environment This section assumes you have created a Python virtual environment and installed scikit-learn to it. You'll need to activate your Python virtual environment, then run the quick starts. Once again, to run `train.py` on a SageMaker container (whether as the SageMaker local mode, or as a SageMaker training jobs), please refer to the SageMaker SDK documentation and examples. ## Quick-start `train-kmeans.sh` ```bash $ ./train-kmeans.sh level: 20, name: __main__, handlers: [] level: 20, name: root, handlers: [ (NOTSET)>] [2021-07-03 22:49:14] [INFO] __main__ Entrypoint script that uses argparse to digest hyperparameters. [2021-07-03 22:49:14] [INFO] __main__ Create model & output dirs prior to underlying function. [2021-07-03 22:49:14] [INFO] __main__ cfg: {'model_dir': PosixPath('/tmp/kmeans/model'), 'output_data_dir': PosixPath('/tmp/kmeans/output/data/algo-1'), 'train': PosixPath('refdata'), 'test': PosixPath('data/test'), 'validation': PosixPath('data/validation'), 'algo': 'sklearn.cluster.KMeans', 'sweep': 0, 'sweep_start': 2, 'sweep_end': 4} [2021-07-03 22:49:14] [INFO] __main__ hyperparams: ['--n_clusters', '3'] [2021-07-03 22:49:14] [INFO] __main__ Estimator class: [2021-07-03 22:49:14] [INFO] __main__ estimator: KMeans(n_clusters=3) $ tree /tmp/kmeans /tmp/kmeans ├── model │ └── model.joblib └── output └── data └── algo-1 ├── labels.csv └── metrics.csv 4 directories, 3 files % cat /tmp/kmeans/output/data/algo-1/metrics.csv calinski_harabasz_score,davies_bouldin_score,silhouette_score,aic,bic 81.56261314942466,0.4026189432995613,0.657247586099344,, $ cat /tmp/kmeans/output/data/algo-1/labels.csv cluster_id,silhouette,id,f1,f2,f3,f4 0,0.6570752828694499,000,-0.7374963321337971,1.0840239414000163,1.6451554963402992,0.7761506975640019 0,0.6992985659368288,001,1.955620861696934,0.6019614154917433,2.134264266881833,2.5491831625152246 0,0.6764819315883501,002,2.30440629618236,0.1456371317487245,0.9669082309978988,0.329342760390853 0,0.7040898704358034,003,1.1547453489536674,0.9708258725908092,1.9356001590052525,1.3351244318155466 0,0.6866471227311532,004,1.177464368722279,0.1733676070545393,1.0269543726346029,0.6800324789482652 0,0.4424597686744442,005,5.403346854962121,5.615296764169435,4.433699147869191,4.057217331198596 0,0.45573721799988176,006,4.583628789281211,4.551442349431361,5.806121322733757,5.195322263844137 0,0.3497685977411139,007,6.307907999140955,3.476697309364682,6.064731317309211,5.801523803066705 0,0.5250760975795091,008,4.640353525876743,4.292542968233425,4.669704844461966,5.143879113105918 0,0.3402758177398864,009,5.753944605987746,3.7082092884236895,5.519315978296398,6.776496281720625 2,0.7472586091050486,010,8.256110511377253,10.259658030738004,8.475414141555964,10.093923501450144 2,0.7415264483782941,011,8.982919186532929,10.999366565589888,9.779032153138644,9.738187243522107 2,0.7862206729325794,012,9.724997769250589,8.642439095042015,9.332783861851016,9.780272307104973 2,0.5819851144454348,013,12.929282379380805,7.936651790361941,10.637334096190829,8.929354298627915 2,0.7683978060346004,014,10.513461771066424,9.944910874544462,9.47949852211006,10.293440663310554 1,0.8321402655386333,015,15.290848516062876,15.709579261886091,14.706704179531869,14.840269010487168 1,0.8015344910648762,016,15.86455082303423,14.893539692210656,15.230729389193948,16.491568626929325 1,0.8539840201453357,017,15.892095242344787,15.375064672515576,15.210671630287813,15.03413966697942 1,0.7749556776901024,018,15.069287808964468,13.843813503244991,15.737106065986024,15.078723399664812 1,0.720038343355556,019,16.35702992963304,16.173214788599783,14.81299265761452,12.774061390628177 ``` ## Quick-start `train-gmm-sweep.sh` ```bash $ ./train-gmm-sweep.sh level: 20, name: __main__, handlers: [] level: 20, name: root, handlers: [ (NOTSET)>] [2021-07-03 22:56:04] [INFO] __main__ Entrypoint script that uses argparse to digest hyperparameters. [2021-07-03 22:56:04] [INFO] __main__ Create model & output dirs prior to underlying function. [2021-07-03 22:56:04] [INFO] __main__ cfg: {'model_dir': PosixPath('/tmp/gmm/model'), 'output_data_dir': PosixPath('/tmp/gmm/output/data/algo-1'), 'train': PosixPath('refdata'), 'test': PosixPath('data/test'), 'validation': PosixPath('data/validation'), 'algo': 'sklearn.mixture.GaussianMixture', 'sweep': 1, 'sweep_start': 2, 'sweep_end': 4} [2021-07-03 22:56:04] [INFO] __main__ hyperparams: [] [2021-07-03 22:56:04] [INFO] __main__ Estimator class: [2021-07-03 22:56:04] [INFO] __main__ estimator: GaussianMixture(n_components=2) [2021-07-03 22:56:04] [INFO] __main__ estimator: GaussianMixture(n_components=3) [2021-07-03 22:56:04] [INFO] __main__ estimator: GaussianMixture(n_components=4) % tree /tmp/gmm /tmp/gmm ├── model │ ├── model-2.joblib │ ├── model-3.joblib │ └── model-4.joblib └── output └── data └── algo-1 ├── labels.csv └── metrics.csv 4 directories, 5 files $ cat /tmp/gmm/output/data/algo-1/metrics.csv n_clusters,calinski_harabasz_score,davies_bouldin_score,silhouette_score,aic,bic 2,33.009505329541774,0.3903993523710882,0.5640846736711732,300.61849882644674,329.4947347595125 3,81.56261314942469,0.4026189432995613,0.657247586099344,288.56169639342045,332.37391642979605 4,211.27821263698516,0.3380563293030403,0.7335492300150311,236.1644615787536,294.91266571843903 % cat /tmp/gmm/output/data/algo-1/labels.csv n_clusters,cluster_id,silhouette,id,f1,f2,f3,f4 2,0,0.6372392619721126,000,-0.7374963321337971,1.0840239414000163,1.6451554963402992,0.7761506975640019 2,0,0.6720210408622262,001,1.955620861696934,0.6019614154917433,2.134264266881833,2.5491831625152246 2,0,0.6492518981485068,002,2.30440629618236,0.1456371317487245,0.9669082309978988,0.329342760390853 2,0,0.6668478725509681,003,1.1547453489536674,0.9708258725908092,1.9356001590052525,1.3351244318155466 2,0,0.6514868184142037,004,1.177464368722279,0.1733676070545393,1.0269543726346029,0.6800324789482652 2,0,0.6552200653202191,005,5.403346854962121,5.615296764169435,4.433699147869191,4.057217331198596 2,0,0.6652296847658677,006,4.583628789281211,4.551442349431361,5.806121322733757,5.195322263844137 2,0,0.6430149121725359,007,6.307907999140955,3.476697309364682,6.064731317309211,5.801523803066705 2,0,0.6770137463150129,008,4.640353525876743,4.292542968233425,4.669704844461966,5.143879113105918 2,0,0.6421282168070243,009,5.753944605987746,3.7082092884236895,5.519315978296398,6.776496281720625 2,0,0.18556425138319263,010,8.256110511377253,10.259658030738004,8.475414141555964,10.093923501450144 2,0,0.024439822501673675,011,8.982919186532929,10.999366565589888,9.779032153138644,9.738187243522107 2,0,0.17731900631890743,012,9.724997769250589,8.642439095042015,9.332783861851016,9.780272307104973 2,0,-0.0687759304942684,013,12.929282379380805,7.936651790361941,10.637334096190829,8.929354298627915 2,0,-0.023634512312799446,014,10.513461771066424,9.944910874544462,9.47949852211006,10.293440663310554 2,1,0.9064400891190797,015,15.290848516062876,15.709579261886091,14.706704179531869,14.840269010487168 2,1,0.8852167003364897,016,15.86455082303423,14.893539692210656,15.230729389193948,16.491568626929325 2,1,0.9172096605398171,017,15.892095242344787,15.375064672515576,15.210671630287813,15.03413966697942 2,1,0.8759419759364964,018,15.069287808964468,13.843813503244991,15.737106065986024,15.078723399664812 2,1,0.8425188927661992,019,16.35702992963304,16.173214788599783,14.81299265761452,12.774061390628177 3,0,0.6570752828694499,000,-0.7374963321337971,1.0840239414000163,1.6451554963402992,0.7761506975640019 3,0,0.6992985659368288,001,1.955620861696934,0.6019614154917433,2.134264266881833,2.5491831625152246 3,0,0.6764819315883501,002,2.30440629618236,0.1456371317487245,0.9669082309978988,0.329342760390853 3,0,0.7040898704358034,003,1.1547453489536674,0.9708258725908092,1.9356001590052525,1.3351244318155466 3,0,0.6866471227311532,004,1.177464368722279,0.1733676070545393,1.0269543726346029,0.6800324789482652 3,0,0.4424597686744442,005,5.403346854962121,5.615296764169435,4.433699147869191,4.057217331198596 3,0,0.45573721799988176,006,4.583628789281211,4.551442349431361,5.806121322733757,5.195322263844137 3,0,0.3497685977411139,007,6.307907999140955,3.476697309364682,6.064731317309211,5.801523803066705 3,0,0.5250760975795091,008,4.640353525876743,4.292542968233425,4.669704844461966,5.143879113105918 3,0,0.3402758177398864,009,5.753944605987746,3.7082092884236895,5.519315978296398,6.776496281720625 3,1,0.7472586091050486,010,8.256110511377253,10.259658030738004,8.475414141555964,10.093923501450144 3,1,0.7415264483782941,011,8.982919186532929,10.999366565589888,9.779032153138644,9.738187243522107 3,1,0.7862206729325794,012,9.724997769250589,8.642439095042015,9.332783861851016,9.780272307104973 3,1,0.5819851144454348,013,12.929282379380805,7.936651790361941,10.637334096190829,8.929354298627915 3,1,0.7683978060346004,014,10.513461771066424,9.944910874544462,9.47949852211006,10.293440663310554 3,2,0.8321402655386333,015,15.290848516062876,15.709579261886091,14.706704179531869,14.840269010487168 3,2,0.8015344910648762,016,15.86455082303423,14.893539692210656,15.230729389193948,16.491568626929325 3,2,0.8539840201453357,017,15.892095242344787,15.375064672515576,15.210671630287813,15.03413966697942 3,2,0.7749556776901024,018,15.069287808964468,13.843813503244991,15.737106065986024,15.078723399664812 3,2,0.720038343355556,019,16.35702992963304,16.173214788599783,14.81299265761452,12.774061390628177 4,2,0.7053059195256683,000,-0.7374963321337971,1.0840239414000163,1.6451554963402992,0.7761506975640019 4,2,0.6387329267811005,001,1.955620861696934,0.6019614154917433,2.134264266881833,2.5491831625152246 4,2,0.7357478702125237,002,2.30440629618236,0.1456371317487245,0.9669082309978988,0.329342760390853 4,2,0.7749771396966622,003,1.1547453489536674,0.9708258725908092,1.9356001590052525,1.3351244318155466 4,2,0.7971016972687374,004,1.177464368722279,0.1733676070545393,1.0269543726346029,0.6800324789482652 4,0,0.6505926239252693,005,5.403346854962121,5.615296764169435,4.433699147869191,4.057217331198596 4,0,0.7582508392682082,006,4.583628789281211,4.551442349431361,5.806121322733757,5.195322263844137 4,0,0.7412670741156423,007,6.307907999140955,3.476697309364682,6.064731317309211,5.801523803066705 4,0,0.7353208973477902,008,4.640353525876743,4.292542968233425,4.669704844461966,5.143879113105918 4,0,0.7429806057864378,009,5.753944605987746,3.7082092884236895,5.519315978296398,6.776496281720625 4,3,0.6506048215491645,010,8.256110511377253,10.259658030738004,8.475414141555964,10.093923501450144 4,3,0.7155128458825588,011,8.982919186532929,10.999366565589888,9.779032153138644,9.738187243522107 4,3,0.7081763163851525,012,9.724997769250589,8.642439095042015,9.332783861851016,9.780272307104973 4,3,0.5739241442826469,013,12.929282379380805,7.936651790361941,10.637334096190829,8.929354298627915 4,3,0.7598360804785561,014,10.513461771066424,9.944910874544462,9.47949852211006,10.293440663310554 4,1,0.8321402655386333,015,15.290848516062876,15.709579261886091,14.706704179531869,14.840269010487168 4,1,0.8015344910648762,016,15.86455082303423,14.893539692210656,15.230729389193948,16.491568626929325 4,1,0.8539840201453357,017,15.892095242344787,15.375064672515576,15.210671630287813,15.03413966697942 4,1,0.7749556776901024,018,15.069287808964468,13.843813503244991,15.737106065986024,15.078723399664812 4,1,0.720038343355556,019,16.35702992963304,16.173214788599783,14.81299265761452,12.774061390628177 ```