{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Develop your model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "from six.moves import xrange # pylint: disable=redefined-builtin\n", "import tensorflow as tf\n", "from tensorflow.examples.tutorials.mnist import input_data\n", "from tensorflow.examples.tutorials.mnist import mnist\n", "\n", "\n", "INPUT_DATA_DIR = '/tmp/tensorflow/mnist/input_data/'\n", "MAX_STEPS = 1000\n", "BATCH_SIZE = 100\n", "LEARNING_RATE = 0.3\n", "HIDDEN_1 = 128\n", "HIDDEN_2 = 32\n", "\n", "# HACK: Ideally we would want to have a unique subpath for each instance of the job, but since we can't\n", "# we are instead appending HOSTNAME to the logdir\n", "LOG_DIR = os.path.join(os.getenv('TEST_TMPDIR', '/tmp'),\n", " 'tensorflow/mnist/logs/fully_connected_feed/', os.getenv('HOSTNAME', ''))\n", "\n", "class TensorflowModel():\n", " def train(self, **kwargs):\n", " tf.logging.set_verbosity(tf.logging.ERROR)\n", " self.data_sets = input_data.read_data_sets(INPUT_DATA_DIR)\n", " self.images_placeholder = tf.placeholder(\n", " tf.float32, shape=(BATCH_SIZE, mnist.IMAGE_PIXELS))\n", " self.labels_placeholder = tf.placeholder(tf.int32, shape=(BATCH_SIZE))\n", "\n", " logits = mnist.inference(self.images_placeholder,\n", " HIDDEN_1,\n", " HIDDEN_2)\n", "\n", " self.loss = mnist.loss(logits, self.labels_placeholder)\n", " self.train_op = mnist.training(self.loss, LEARNING_RATE)\n", " self.summary = tf.summary.merge_all()\n", " init = tf.global_variables_initializer()\n", " self.sess = tf.Session()\n", " self.summary_writer = tf.summary.FileWriter(LOG_DIR, self.sess.graph)\n", " self.sess.run(init)\n", "\n", " data_set = self.data_sets.train\n", " for step in xrange(MAX_STEPS):\n", " images_feed, labels_feed = data_set.next_batch(BATCH_SIZE, False)\n", " feed_dict = {\n", " self.images_placeholder: images_feed,\n", " self.labels_placeholder: labels_feed,\n", " }\n", "\n", " _, loss_value = self.sess.run([self.train_op, self.loss],\n", " feed_dict=feed_dict)\n", " if step % 100 == 0:\n", " print(\"At step {}, loss = {}\".format(step, loss_value))\n", " summary_str = self.sess.run(self.summary, feed_dict=feed_dict)\n", " self.summary_writer.add_summary(summary_str, step)\n", " self.summary_writer.flush()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train your model inside notebook" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "local_train = TensorflowModel()\n", "local_train.train()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Remote train" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Note: `__file__` won't work inside notebook, please don't execute follow block. Instead, we create a python file `distributed_training.py` to run.\n", "\n", "\n", "```py\n", "if __name__ == '__main__':\n", " if os.getenv('FAIRING_RUNTIME', None) is None:\n", " from kubeflow import fairing\n", " AWS_ACCOUNT_ID=fairing.cloud.aws.guess_account_id()\n", " AWS_REGION='us-west-2'\n", " DOCKER_REGISTRY = '{}.dkr.ecr.{}.amazonaws.com'.format(AWS_ACCOUNT_ID, AWS_REGION)\n", "\n", " fairing.config.set_preprocessor('python', input_files=[__file__])\n", " fairing.config.set_builder(name='append', registry=DOCKER_REGISTRY,\n", " base_image='tensorflow/tensorflow:1.14.0-py3')\n", " fairing.config.set_deployer(\n", " name='tfjob', worker_count=1, ps_count=1)\n", " fairing.config.run()\n", " else:\n", " remote_train = TensorflowModel()\n", " remote_train.train()\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run this comamnd to create a Tensoflow Flow Job (TFJob)\n", "\n", "You can check file `distributed_traing.py` to see details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!python ./distributed_training.py" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.7" } }, "nbformat": 4, "nbformat_minor": 2 }