Sample meta training script for
sklearn clustering algorithms
# Overview of this examples
This is an example of a SageMaker training meta-script, where this one single
script can be used to call different sklearn clustering algorithms.
The `train.py` script supports two modes:
1. **Run a clustering algorithm with a specific number of components.**
The `train-kmeans.sh` quick-start example shows how to use this script to
in-turn call `sklearn.cluster.KMeans` with hyperparameters (i.e., kwargs)
`n_clusters`. Do review `train-kmeans.sh` and feel free to modify it with
additional hyperparameters (i.e., kwargs) supported by
`sklearn.cluster.KMeans`.
As a further exercise, try to modify this shell script to make it instruct
`train.py` to run GMM using `sklearn.mixture.GaussianMixture`. Tips: recall
that sklearn's GMM uses `n_components` kwargs to denote the number of
clusters.
2. **Run a clustering algorithm multiple times, each time with a different
number of clusters.**
See the `train-*sweep.sh` quick-start examples which demonstrates kmeans,
gmm, and agglomerative clustering. At the end of training, `train.py` will
also computes a bunch of clustering metrics and output it as .csv files.
Do note that `train.py` requires that the estimator must have
either `n_clusters` kwarg or `n_components` kwargs.
# Final note on the quick-start examples (i.e., `*.sh`)
These are provided so that you can quickly, directly run `train.py` in your own
Python virtual environments. As you review the `*.sh` files, you'll notice that
it loads sample input data, and saves the fitted estimators + evaluation
metrics to `/tmp`.
To run `train.py` on SageMaker containers (either as SageMaker *local* mode,
or as SageMaker training jobs), please refer to SageMaker SDK documentation and
examples.
# Sample invocation directly on a Python virtual environment
This section assumes you have created a Python virtual environment and installed
scikit-learn to it.
You'll need to activate your Python virtual environment, then run the quick
starts.
Once again, to run `train.py` on a SageMaker container (whether as the SageMaker
local mode, or as a SageMaker training jobs), please refer to the SageMaker
SDK documentation and examples.
## Quick-start `train-kmeans.sh`
```bash
$ ./train-kmeans.sh
level: 20, name: __main__, handlers: []
level: 20, name: root, handlers: [ (NOTSET)>]
[2021-07-03 22:49:14] [INFO] __main__ Entrypoint script that uses argparse to digest hyperparameters.
[2021-07-03 22:49:14] [INFO] __main__ Create model & output dirs prior to underlying function.
[2021-07-03 22:49:14] [INFO] __main__ cfg: {'model_dir': PosixPath('/tmp/kmeans/model'), 'output_data_dir': PosixPath('/tmp/kmeans/output/data/algo-1'), 'train': PosixPath('refdata'), 'test': PosixPath('data/test'), 'validation': PosixPath('data/validation'), 'algo': 'sklearn.cluster.KMeans', 'sweep': 0, 'sweep_start': 2, 'sweep_end': 4}
[2021-07-03 22:49:14] [INFO] __main__ hyperparams: ['--n_clusters', '3']
[2021-07-03 22:49:14] [INFO] __main__ Estimator class:
[2021-07-03 22:49:14] [INFO] __main__ estimator: KMeans(n_clusters=3)
$ tree /tmp/kmeans
/tmp/kmeans
├── model
│ └── model.joblib
└── output
└── data
└── algo-1
├── labels.csv
└── metrics.csv
4 directories, 3 files
% cat /tmp/kmeans/output/data/algo-1/metrics.csv
calinski_harabasz_score,davies_bouldin_score,silhouette_score,aic,bic
81.56261314942466,0.4026189432995613,0.657247586099344,,
$ cat /tmp/kmeans/output/data/algo-1/labels.csv
cluster_id,silhouette,id,f1,f2,f3,f4
0,0.6570752828694499,000,-0.7374963321337971,1.0840239414000163,1.6451554963402992,0.7761506975640019
0,0.6992985659368288,001,1.955620861696934,0.6019614154917433,2.134264266881833,2.5491831625152246
0,0.6764819315883501,002,2.30440629618236,0.1456371317487245,0.9669082309978988,0.329342760390853
0,0.7040898704358034,003,1.1547453489536674,0.9708258725908092,1.9356001590052525,1.3351244318155466
0,0.6866471227311532,004,1.177464368722279,0.1733676070545393,1.0269543726346029,0.6800324789482652
0,0.4424597686744442,005,5.403346854962121,5.615296764169435,4.433699147869191,4.057217331198596
0,0.45573721799988176,006,4.583628789281211,4.551442349431361,5.806121322733757,5.195322263844137
0,0.3497685977411139,007,6.307907999140955,3.476697309364682,6.064731317309211,5.801523803066705
0,0.5250760975795091,008,4.640353525876743,4.292542968233425,4.669704844461966,5.143879113105918
0,0.3402758177398864,009,5.753944605987746,3.7082092884236895,5.519315978296398,6.776496281720625
2,0.7472586091050486,010,8.256110511377253,10.259658030738004,8.475414141555964,10.093923501450144
2,0.7415264483782941,011,8.982919186532929,10.999366565589888,9.779032153138644,9.738187243522107
2,0.7862206729325794,012,9.724997769250589,8.642439095042015,9.332783861851016,9.780272307104973
2,0.5819851144454348,013,12.929282379380805,7.936651790361941,10.637334096190829,8.929354298627915
2,0.7683978060346004,014,10.513461771066424,9.944910874544462,9.47949852211006,10.293440663310554
1,0.8321402655386333,015,15.290848516062876,15.709579261886091,14.706704179531869,14.840269010487168
1,0.8015344910648762,016,15.86455082303423,14.893539692210656,15.230729389193948,16.491568626929325
1,0.8539840201453357,017,15.892095242344787,15.375064672515576,15.210671630287813,15.03413966697942
1,0.7749556776901024,018,15.069287808964468,13.843813503244991,15.737106065986024,15.078723399664812
1,0.720038343355556,019,16.35702992963304,16.173214788599783,14.81299265761452,12.774061390628177
```
## Quick-start `train-gmm-sweep.sh`
```bash
$ ./train-gmm-sweep.sh
level: 20, name: __main__, handlers: []
level: 20, name: root, handlers: [ (NOTSET)>]
[2021-07-03 22:56:04] [INFO] __main__ Entrypoint script that uses argparse to digest hyperparameters.
[2021-07-03 22:56:04] [INFO] __main__ Create model & output dirs prior to underlying function.
[2021-07-03 22:56:04] [INFO] __main__ cfg: {'model_dir': PosixPath('/tmp/gmm/model'), 'output_data_dir': PosixPath('/tmp/gmm/output/data/algo-1'), 'train': PosixPath('refdata'), 'test': PosixPath('data/test'), 'validation': PosixPath('data/validation'), 'algo': 'sklearn.mixture.GaussianMixture', 'sweep': 1, 'sweep_start': 2, 'sweep_end': 4}
[2021-07-03 22:56:04] [INFO] __main__ hyperparams: []
[2021-07-03 22:56:04] [INFO] __main__ Estimator class:
[2021-07-03 22:56:04] [INFO] __main__ estimator: GaussianMixture(n_components=2)
[2021-07-03 22:56:04] [INFO] __main__ estimator: GaussianMixture(n_components=3)
[2021-07-03 22:56:04] [INFO] __main__ estimator: GaussianMixture(n_components=4)
% tree /tmp/gmm
/tmp/gmm
├── model
│ ├── model-2.joblib
│ ├── model-3.joblib
│ └── model-4.joblib
└── output
└── data
└── algo-1
├── labels.csv
└── metrics.csv
4 directories, 5 files
$ cat /tmp/gmm/output/data/algo-1/metrics.csv
n_clusters,calinski_harabasz_score,davies_bouldin_score,silhouette_score,aic,bic
2,33.009505329541774,0.3903993523710882,0.5640846736711732,300.61849882644674,329.4947347595125
3,81.56261314942469,0.4026189432995613,0.657247586099344,288.56169639342045,332.37391642979605
4,211.27821263698516,0.3380563293030403,0.7335492300150311,236.1644615787536,294.91266571843903
% cat /tmp/gmm/output/data/algo-1/labels.csv
n_clusters,cluster_id,silhouette,id,f1,f2,f3,f4
2,0,0.6372392619721126,000,-0.7374963321337971,1.0840239414000163,1.6451554963402992,0.7761506975640019
2,0,0.6720210408622262,001,1.955620861696934,0.6019614154917433,2.134264266881833,2.5491831625152246
2,0,0.6492518981485068,002,2.30440629618236,0.1456371317487245,0.9669082309978988,0.329342760390853
2,0,0.6668478725509681,003,1.1547453489536674,0.9708258725908092,1.9356001590052525,1.3351244318155466
2,0,0.6514868184142037,004,1.177464368722279,0.1733676070545393,1.0269543726346029,0.6800324789482652
2,0,0.6552200653202191,005,5.403346854962121,5.615296764169435,4.433699147869191,4.057217331198596
2,0,0.6652296847658677,006,4.583628789281211,4.551442349431361,5.806121322733757,5.195322263844137
2,0,0.6430149121725359,007,6.307907999140955,3.476697309364682,6.064731317309211,5.801523803066705
2,0,0.6770137463150129,008,4.640353525876743,4.292542968233425,4.669704844461966,5.143879113105918
2,0,0.6421282168070243,009,5.753944605987746,3.7082092884236895,5.519315978296398,6.776496281720625
2,0,0.18556425138319263,010,8.256110511377253,10.259658030738004,8.475414141555964,10.093923501450144
2,0,0.024439822501673675,011,8.982919186532929,10.999366565589888,9.779032153138644,9.738187243522107
2,0,0.17731900631890743,012,9.724997769250589,8.642439095042015,9.332783861851016,9.780272307104973
2,0,-0.0687759304942684,013,12.929282379380805,7.936651790361941,10.637334096190829,8.929354298627915
2,0,-0.023634512312799446,014,10.513461771066424,9.944910874544462,9.47949852211006,10.293440663310554
2,1,0.9064400891190797,015,15.290848516062876,15.709579261886091,14.706704179531869,14.840269010487168
2,1,0.8852167003364897,016,15.86455082303423,14.893539692210656,15.230729389193948,16.491568626929325
2,1,0.9172096605398171,017,15.892095242344787,15.375064672515576,15.210671630287813,15.03413966697942
2,1,0.8759419759364964,018,15.069287808964468,13.843813503244991,15.737106065986024,15.078723399664812
2,1,0.8425188927661992,019,16.35702992963304,16.173214788599783,14.81299265761452,12.774061390628177
3,0,0.6570752828694499,000,-0.7374963321337971,1.0840239414000163,1.6451554963402992,0.7761506975640019
3,0,0.6992985659368288,001,1.955620861696934,0.6019614154917433,2.134264266881833,2.5491831625152246
3,0,0.6764819315883501,002,2.30440629618236,0.1456371317487245,0.9669082309978988,0.329342760390853
3,0,0.7040898704358034,003,1.1547453489536674,0.9708258725908092,1.9356001590052525,1.3351244318155466
3,0,0.6866471227311532,004,1.177464368722279,0.1733676070545393,1.0269543726346029,0.6800324789482652
3,0,0.4424597686744442,005,5.403346854962121,5.615296764169435,4.433699147869191,4.057217331198596
3,0,0.45573721799988176,006,4.583628789281211,4.551442349431361,5.806121322733757,5.195322263844137
3,0,0.3497685977411139,007,6.307907999140955,3.476697309364682,6.064731317309211,5.801523803066705
3,0,0.5250760975795091,008,4.640353525876743,4.292542968233425,4.669704844461966,5.143879113105918
3,0,0.3402758177398864,009,5.753944605987746,3.7082092884236895,5.519315978296398,6.776496281720625
3,1,0.7472586091050486,010,8.256110511377253,10.259658030738004,8.475414141555964,10.093923501450144
3,1,0.7415264483782941,011,8.982919186532929,10.999366565589888,9.779032153138644,9.738187243522107
3,1,0.7862206729325794,012,9.724997769250589,8.642439095042015,9.332783861851016,9.780272307104973
3,1,0.5819851144454348,013,12.929282379380805,7.936651790361941,10.637334096190829,8.929354298627915
3,1,0.7683978060346004,014,10.513461771066424,9.944910874544462,9.47949852211006,10.293440663310554
3,2,0.8321402655386333,015,15.290848516062876,15.709579261886091,14.706704179531869,14.840269010487168
3,2,0.8015344910648762,016,15.86455082303423,14.893539692210656,15.230729389193948,16.491568626929325
3,2,0.8539840201453357,017,15.892095242344787,15.375064672515576,15.210671630287813,15.03413966697942
3,2,0.7749556776901024,018,15.069287808964468,13.843813503244991,15.737106065986024,15.078723399664812
3,2,0.720038343355556,019,16.35702992963304,16.173214788599783,14.81299265761452,12.774061390628177
4,2,0.7053059195256683,000,-0.7374963321337971,1.0840239414000163,1.6451554963402992,0.7761506975640019
4,2,0.6387329267811005,001,1.955620861696934,0.6019614154917433,2.134264266881833,2.5491831625152246
4,2,0.7357478702125237,002,2.30440629618236,0.1456371317487245,0.9669082309978988,0.329342760390853
4,2,0.7749771396966622,003,1.1547453489536674,0.9708258725908092,1.9356001590052525,1.3351244318155466
4,2,0.7971016972687374,004,1.177464368722279,0.1733676070545393,1.0269543726346029,0.6800324789482652
4,0,0.6505926239252693,005,5.403346854962121,5.615296764169435,4.433699147869191,4.057217331198596
4,0,0.7582508392682082,006,4.583628789281211,4.551442349431361,5.806121322733757,5.195322263844137
4,0,0.7412670741156423,007,6.307907999140955,3.476697309364682,6.064731317309211,5.801523803066705
4,0,0.7353208973477902,008,4.640353525876743,4.292542968233425,4.669704844461966,5.143879113105918
4,0,0.7429806057864378,009,5.753944605987746,3.7082092884236895,5.519315978296398,6.776496281720625
4,3,0.6506048215491645,010,8.256110511377253,10.259658030738004,8.475414141555964,10.093923501450144
4,3,0.7155128458825588,011,8.982919186532929,10.999366565589888,9.779032153138644,9.738187243522107
4,3,0.7081763163851525,012,9.724997769250589,8.642439095042015,9.332783861851016,9.780272307104973
4,3,0.5739241442826469,013,12.929282379380805,7.936651790361941,10.637334096190829,8.929354298627915
4,3,0.7598360804785561,014,10.513461771066424,9.944910874544462,9.47949852211006,10.293440663310554
4,1,0.8321402655386333,015,15.290848516062876,15.709579261886091,14.706704179531869,14.840269010487168
4,1,0.8015344910648762,016,15.86455082303423,14.893539692210656,15.230729389193948,16.491568626929325
4,1,0.8539840201453357,017,15.892095242344787,15.375064672515576,15.210671630287813,15.03413966697942
4,1,0.7749556776901024,018,15.069287808964468,13.843813503244991,15.737106065986024,15.078723399664812
4,1,0.720038343355556,019,16.35702992963304,16.173214788599783,14.81299265761452,12.774061390628177
```