#!/bin/bash HOME_DIR=/test BIN_DIR=${HOME_DIR}/bin LOG_DIR=${HOME_DIR}/logs MXNETDIR=${HOME_DIR}/artifacts/mxnet TRAINING_LOG=${LOG_DIR}/mxnet_train_mnist.log VALIDATION_LOG=${LOG_DIR}/mxnet_validate_mnist.log set -e cd ${MXNETDIR} LD_LIBRARY_PATH=${MXNETDIR}/lib:$LD_LIBRARY_PATH echo "Training mnist using MXNet... This may take a few minutes. You can follow progress on the log file : $TRAINING_LOG" | tee -a $TRAINING_LOG set +e nvidia-smi RETURN_VAL=`echo $?` set -e # ====================================================================== # Temporariy fix for 503 error while downloading MNIST dataset. See https://github.com/pytorch/vision/issues/3549 FROM="http:\/\/yann\.lecun\.com\/exdb\/mnist\/" TO="https:\/\/dlinfra-mnist-dataset\.s3-us-west-2\.amazonaws\.com\/mnist\/" sed -i -e "s/${FROM}/${TO}/g" ./example/image-classification/train_mnist.py # ====================================================================== PYTHON_BIN=$1 set +e if [ ${RETURN_VAL} -eq 0 ]; then $PYTHON_BIN ./example/image-classification/train_mnist.py --gpus 0 2> ${TRAINING_LOG} else $PYTHON_BIN ./example/image-classification/train_mnist.py 2> ${TRAINING_LOG} fi RETURN_VAL=`echo $?` set -e if [ ${RETURN_VAL} -eq 0 ]; then echo "Training mnist Complete using MXNet." | tee -a $TRAINING_LOG for i in $(grep 'Validation-accuracy=' ${TRAINING_LOG} | sed 's/.*Validation-accuracy=//g'); do echo $i >${VALIDATION_LOG} 2>&1 done if [[ "$(cat ${VALIDATION_LOG})" < 0.97 ]]; then echo "Failed Validation Accuracy using MXNet: $(cat ${VALIDATION_LOG})" exit 1 fi ACCURACY=`cat ${VALIDATION_LOG}` ACCURACY=`python -c "print($ACCURACY*100)"` ACCURACY=${ACCURACY%.*} echo "Successful Validation Accuracy using MXNet: $ACCURACY" else echo "Training mnist Failed using MXNet. LOG:" cat $TRAINING_LOG exit 1 fi #Delete the downloaded mnist database set +e rm -rf mnist/ set -e exit 0