{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Training and Deploying ML Models using JAX on SageMaker\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. \n", "\n", "![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-2/advanced_functionality|jax_bring_your_own|train_deploy_jax.ipynb)\n", "\n", "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Amazon SageMaker provides you the flexibility to train models using our pre-built machine learning containers or your own bespoke container. We'll refer to these strategies as Bring-Your-Own-Script **(BYOS)** and Bring-Your-Own-Container **(BYOC)** in this tutorial. \n", "\n", "### Bring Your Own JAX Script\n", "\n", "In this notebook, we'll show how to extend our optimized TensorFlow containers to train machine learning models using the increasingly popular [JAX library](https://github.com/google/jax). We'll train a fashion MNIST classification model using vanilla JAX, another using `jax.experimental.stax`, and a final model using the [higher level Trax library](https://github.com/google/trax).\n", "\n", "For all three patterns, we'll show how the JAX models can be serialized as standard TensorFlow [SavedModel format](https://www.tensorflow.org/guide/saved_model). This enables us to seamlessly deploy the models using the managed and optimized SageMaker TensorFlow inference containers.\n", "\n", "\n", "### Bring Your Own JAX Container\n", "\n", "We've included a dockerfile in this repo directory to show how you can build your own bespoke JAX container with support for GPUs on SageMaker. Unfortunately, the NVIDIA/CUDA Dockerhub containers have a [deletion policy](https://gitlab.com/nvidia/container-images/cuda/blob/master/doc/support-policy.md), so we're unable to assert that the container can be built through time. Nonetheless, you can trivially adapt a newer version of the container if your workload requires a custom container. For more information on running BYOC on SageMaker see the [documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/adapt-training-container.html).\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install --upgrade sagemaker" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sagemaker\n", "from sagemaker import get_execution_role\n", "from sagemaker.tensorflow import TensorFlow\n", "\n", "role = get_execution_role()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Installing JAX in SageMaker TensorFlow Containers\n", "\n", "When using BYOS with managed SageMaker containers, you can trivially install extra dependencies by providing a `requirements.txt` within the `source_dir` that contains your training scripts. At runtime these dependencies will be installed prior to executing the training script, so we can utilize our optimized TensorFlow GPU container to utilize JAX with CUDA support.\n", "\n", "To be specific, any container that has the [sagemaker-training-toolkit](https://github.com/aws/sagemaker-training-toolkit) supports installing additional dependencies from `requirements.txt`\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Serializing models as SavedModel format\n", "In the upcoming training jobs we'll be training a vanilla JAX model, a Stax model, and a Trax model on the [fashion MNIST dataset](https://github.com/zalandoresearch/fashion-mnist).\n", "The full details of the model can be seen in the `training_scripts/` directory, but it is worth calling out the methods for serialization.\n", "\n", "The JAX/Stax models utilize the new jax2tf converter: https://github.com/google/jax/tree/master/jax/experimental/jax2tf\n", "\n", "```python\n", "def save_model_tf(prediction_function, params_to_save):\n", " tf_fun = jax2tf.convert(prediction_function, enable_xla=False)\n", " param_vars = tf.nest.map_structure(lambda param: tf.Variable(param), params_to_save)\n", "\n", " tf_graph = tf.function(\n", " lambda inputs: tf_fun(param_vars, inputs),\n", " autograph=False,\n", " jit_compile=False,\n", " )\n", "\n", "```\n", "\n", "\n", "The Trax model utilizes the new trax2keras functionality: https://github.com/google/trax/blob/master/trax/trax2keras.py\n", "\n", "```python\n", "def save_model_tf(model_to_save):\n", " \"\"\"\n", " Serialize a TensorFlow graph from trained Trax Model\n", " :param model_to_save: Trax Model\n", " \"\"\"\n", " keras_layer = trax.AsKeras(model_to_save, batch_size=1)\n", " inputs = tf.keras.Input(shape=(28, 28, 1))\n", " hidden = keras_layer(inputs)\n", "\n", " keras_model = tf.keras.Model(inputs=inputs, outputs=hidden)\n", " keras_model.save(\"/opt/ml/model/1\", save_format=\"tf\")\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train using Vanilla JAX\n", "\n", "Note: Our `source_dir` directory contains a `requirements.txt` that will install JAX with CUDA support" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vanilla_jax_estimator = TensorFlow(\n", " role=role,\n", " instance_count=1,\n", " base_job_name=\"jax\",\n", " framework_version=\"2.10\",\n", " py_version=\"py39\",\n", " source_dir=\"training_scripts\",\n", " entry_point=\"train_jax.py\",\n", " instance_type=\"ml.p3.2xlarge\",\n", " hyperparameters={\"num_epochs\": 3},\n", ")\n", "vanilla_jax_estimator.fit(logs=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train Using JAX Medium-level API Stax" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "stax_estimator = TensorFlow(\n", " role=role,\n", " instance_count=1,\n", " base_job_name=\"stax\",\n", " framework_version=\"2.10\",\n", " py_version=\"py39\",\n", " source_dir=\"training_scripts\",\n", " entry_point=\"train_stax.py\",\n", " instance_type=\"ml.p3.2xlarge\",\n", " hyperparameters={\"num_epochs\": 3},\n", ")\n", "\n", "stax_estimator.fit(logs=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train Using JAX High-level API Trax" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trax_estimator = TensorFlow(\n", " role=role,\n", " instance_count=1,\n", " base_job_name=\"trax\",\n", " framework_version=\"2.10\",\n", " py_version=\"py39\",\n", " source_dir=\"training_scripts\",\n", " entry_point=\"train_trax.py\",\n", " instance_type=\"ml.p3.2xlarge\",\n", " hyperparameters={\"train_steps\": 1000},\n", ")\n", "\n", "\n", "trax_estimator.fit(logs=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Deploy Models to managed TF Containers\n", "Since we've serialized the models as TensorFlow SavedModel format, deploying these models as endpoints is just a trivial call to the `estimator.deploy()` method" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vanilla_jax_predictor = vanilla_jax_estimator.deploy(\n", " initial_instance_count=1, instance_type=\"ml.m4.xlarge\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trax_predictor = trax_estimator.deploy(initial_instance_count=1, instance_type=\"ml.m4.xlarge\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "stax_predictor = stax_estimator.deploy(initial_instance_count=1, instance_type=\"ml.m4.xlarge\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test Inference Endpoints\n", "This requires TF to be installed on your notebook's kernel as it is used to load testing data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "import numpy as np\n", "from matplotlib import pyplot as plt\n", "\n", "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def test_image(predictor, test_images, test_labels, image_number):\n", " np_img = np.expand_dims(np.expand_dims(test_images[image_number], axis=-1), axis=0)\n", "\n", " result = predictor.predict(np_img)\n", " pred_y = np.argmax(result[\"predictions\"])\n", "\n", " print(\"True Label:\", test_labels[image_number])\n", " print(\"Predicted Label:\", pred_y)\n", " plt.imshow(test_images[image_number])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_image(vanilla_jax_predictor, x_test, y_test, 0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_image(stax_predictor, x_test, y_test, 0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_image(trax_predictor, x_test, y_test, 0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Optional: Delete the running endpoints" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# Clean-Up\n", "vanilla_jax_predictor.delete_endpoint()\n", "stax_predictor.delete_endpoint()\n", "trax_predictor.delete_endpoint()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Notebook CI Test Results\n", "\n", "This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n", "\n", "![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-1/advanced_functionality|jax_bring_your_own|train_deploy_jax.ipynb)\n", "\n", "![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-2/advanced_functionality|jax_bring_your_own|train_deploy_jax.ipynb)\n", "\n", "![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-1/advanced_functionality|jax_bring_your_own|train_deploy_jax.ipynb)\n", "\n", "![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ca-central-1/advanced_functionality|jax_bring_your_own|train_deploy_jax.ipynb)\n", "\n", "![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/sa-east-1/advanced_functionality|jax_bring_your_own|train_deploy_jax.ipynb)\n", "\n", "![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-1/advanced_functionality|jax_bring_your_own|train_deploy_jax.ipynb)\n", "\n", "![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-2/advanced_functionality|jax_bring_your_own|train_deploy_jax.ipynb)\n", "\n", "![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-3/advanced_functionality|jax_bring_your_own|train_deploy_jax.ipynb)\n", "\n", "![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-central-1/advanced_functionality|jax_bring_your_own|train_deploy_jax.ipynb)\n", "\n", "![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-north-1/advanced_functionality|jax_bring_your_own|train_deploy_jax.ipynb)\n", "\n", "![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-1/advanced_functionality|jax_bring_your_own|train_deploy_jax.ipynb)\n", "\n", "![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-2/advanced_functionality|jax_bring_your_own|train_deploy_jax.ipynb)\n", "\n", "![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-1/advanced_functionality|jax_bring_your_own|train_deploy_jax.ipynb)\n", "\n", "![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-2/advanced_functionality|jax_bring_your_own|train_deploy_jax.ipynb)\n", "\n", "![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-south-1/advanced_functionality|jax_bring_your_own|train_deploy_jax.ipynb)\n" ] } ], "metadata": { "kernelspec": { "display_name": "conda_tensorflow2_p38", "language": "python", "name": "conda_tensorflow2_p38" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 4 }