{ "cells": [ { "cell_type": "code", "execution_count": 5, "id": "afc684c7-a3ef-43ec-b1f4-729c8be719b5", "metadata": { "execution": { "iopub.execute_input": "2022-08-04T22:42:36.971016Z", "iopub.status.busy": "2022-08-04T22:42:36.970782Z", "iopub.status.idle": "2022-08-04T22:42:38.903722Z", "shell.execute_reply": "2022-08-04T22:42:38.903081Z", "shell.execute_reply.started": "2022-08-04T22:42:36.970995Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Defaulting to user installation because normal site-packages is not writeable\n", "Requirement already up-to-date: torch in /home/emr-notebook/.local/lib/python3.7/site-packages (1.12.0)\n", "Requirement already satisfied, skipping upgrade: typing-extensions in /usr/local/lib/python3.7/site-packages (from torch) (4.3.0)\n" ] } ], "source": [ "!pip3 install -U torch" ] }, { "cell_type": "code", "execution_count": 6, "id": "4784e62c-6205-47c9-929c-22b0e4522b40", "metadata": { "execution": { "iopub.execute_input": "2022-08-04T22:42:40.136785Z", "iopub.status.busy": "2022-08-04T22:42:40.136537Z", "iopub.status.idle": "2022-08-04T22:42:40.146559Z", "shell.execute_reply": "2022-08-04T22:42:40.145993Z", "shell.execute_reply.started": "2022-08-04T22:42:40.136757Z" } }, "outputs": [ { "data": { "text/plain": [ "{'status': 'ok', 'restart': True}" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# restart kernel to pick up the pip installs above\n", "import IPython\n", "\n", "IPython.Application.instance().kernel.do_shutdown(True) #automatically restarts kernel" ] }, { "cell_type": "code", "execution_count": 26, "id": "1d8da87e-188a-4112-8ff6-4d3c8a3d0abb", "metadata": { "execution": { "iopub.execute_input": "2022-08-06T01:55:15.487357Z", "iopub.status.busy": "2022-08-06T01:55:15.487121Z", "iopub.status.idle": "2022-08-06T01:55:16.288821Z", "shell.execute_reply": "2022-08-06T01:55:16.288145Z", "shell.execute_reply.started": "2022-08-06T01:55:15.487329Z" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "download: s3://dsoaws/ray_output/TorchTrainer_2022-08-06_01-49-34/TorchTrainer_05a2a_00000_0_2022-08-06_01-49-34/checkpoint_000000/.is_checkpoint to home/emr-notebook/TorchTrainer_2022-08-06_01-49-34/TorchTrainer_05a2a_00000_0_2022-08-06_01-49-34/checkpoint_000000/.is_checkpoint\n", "download: s3://dsoaws/ray_output/TorchTrainer_2022-08-06_01-49-34/TorchTrainer_05a2a_00000_0_2022-08-06_01-49-34/events.out.tfevents.1659750574.ip-172-31-20-207 to home/emr-notebook/TorchTrainer_2022-08-06_01-49-34/TorchTrainer_05a2a_00000_0_2022-08-06_01-49-34/events.out.tfevents.1659750574.ip-172-31-20-207\n", "download: s3://dsoaws/ray_output/TorchTrainer_2022-08-06_01-49-34/TorchTrainer_05a2a_00000_0_2022-08-06_01-49-34/checkpoint_000000/.tune_metadata to home/emr-notebook/TorchTrainer_2022-08-06_01-49-34/TorchTrainer_05a2a_00000_0_2022-08-06_01-49-34/checkpoint_000000/.tune_metadata\n", "download: s3://dsoaws/ray_output/TorchTrainer_2022-08-06_01-49-34/basic-variant-state-2022-08-06_01-49-34.json to home/emr-notebook/TorchTrainer_2022-08-06_01-49-34/basic-variant-state-2022-08-06_01-49-34.json\n", "download: s3://dsoaws/ray_output/TorchTrainer_2022-08-06_01-49-34/TorchTrainer_05a2a_00000_0_2022-08-06_01-49-34/params.json to home/emr-notebook/TorchTrainer_2022-08-06_01-49-34/TorchTrainer_05a2a_00000_0_2022-08-06_01-49-34/params.json\n", "download: s3://dsoaws/ray_output/TorchTrainer_2022-08-06_01-49-34/TorchTrainer_05a2a_00000_0_2022-08-06_01-49-34/result.json to home/emr-notebook/TorchTrainer_2022-08-06_01-49-34/TorchTrainer_05a2a_00000_0_2022-08-06_01-49-34/result.json\n", "download: s3://dsoaws/ray_output/TorchTrainer_2022-08-06_01-49-34/experiment_state-2022-08-06_01-49-34.json to home/emr-notebook/TorchTrainer_2022-08-06_01-49-34/experiment_state-2022-08-06_01-49-34.json\n", "download: s3://dsoaws/ray_output/TorchTrainer_2022-08-06_01-49-34/trainable.pkl to home/emr-notebook/TorchTrainer_2022-08-06_01-49-34/trainable.pkl\n", "download: s3://dsoaws/ray_output/TorchTrainer_2022-08-06_01-49-34/TorchTrainer_05a2a_00000_0_2022-08-06_01-49-34/params.pkl to home/emr-notebook/TorchTrainer_2022-08-06_01-49-34/TorchTrainer_05a2a_00000_0_2022-08-06_01-49-34/params.pkl\n", "download: s3://dsoaws/ray_output/TorchTrainer_2022-08-06_01-49-34/tuner.pkl to home/emr-notebook/TorchTrainer_2022-08-06_01-49-34/tuner.pkl\n", "download: s3://dsoaws/ray_output/TorchTrainer_2022-08-06_01-49-34/TorchTrainer_05a2a_00000_0_2022-08-06_01-49-34/checkpoint_000000/dict_checkpoint.pkl to home/emr-notebook/TorchTrainer_2022-08-06_01-49-34/TorchTrainer_05a2a_00000_0_2022-08-06_01-49-34/checkpoint_000000/dict_checkpoint.pkl\n" ] } ], "source": [ "!aws s3 cp --recursive s3://dsoaws/ray_output/ ~/" ] }, { "cell_type": "code", "execution_count": 27, "id": "75cdc301-a4b4-4d79-955d-a4fa7ba905e5", "metadata": { "execution": { "iopub.execute_input": "2022-08-06T01:55:41.379551Z", "iopub.status.busy": "2022-08-06T01:55:41.379300Z", "iopub.status.idle": "2022-08-06T01:55:41.382690Z", "shell.execute_reply": "2022-08-06T01:55:41.382119Z", "shell.execute_reply.started": "2022-08-06T01:55:41.379520Z" }, "tags": [] }, "outputs": [], "source": [ "model_path = '/home/emr-notebook/TorchTrainer_2022-08-06_01-49-34/TorchTrainer_05a2a_00000_0_2022-08-06_01-49-34/checkpoint_000000'" ] }, { "cell_type": "code", "execution_count": 28, "id": "54976207-c0a4-405f-b400-4d79236197e7", "metadata": { "execution": { "iopub.execute_input": "2022-08-06T01:55:42.303386Z", "iopub.status.busy": "2022-08-06T01:55:42.303148Z", "iopub.status.idle": "2022-08-06T01:55:42.308335Z", "shell.execute_reply": "2022-08-06T01:55:42.307736Z", "shell.execute_reply.started": "2022-08-06T01:55:42.303363Z" }, "tags": [] }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "\n", "# Define model\n", "class NeuralNetwork(nn.Module):\n", " def __init__(self):\n", " super(NeuralNetwork, self).__init__()\n", " self.flatten = nn.Flatten()\n", " self.linear_relu_stack = nn.Sequential(\n", " nn.Linear(28 * 28, 512),\n", " nn.ReLU(),\n", " nn.Linear(512, 512),\n", " nn.ReLU(),\n", " nn.Linear(512, 10),\n", " nn.ReLU(),\n", " )\n", "\n", " def forward(self, x):\n", " x = self.flatten(x)\n", " logits = self.linear_relu_stack(x)\n", " return logits\n" ] }, { "cell_type": "code", "execution_count": 29, "id": "2edb3c11-4a7a-4a39-9aa5-335db76a016a", "metadata": { "execution": { "iopub.execute_input": "2022-08-06T01:55:44.803543Z", "iopub.status.busy": "2022-08-06T01:55:44.803310Z", "iopub.status.idle": "2022-08-06T01:55:44.815244Z", "shell.execute_reply": "2022-08-06T01:55:44.814618Z", "shell.execute_reply.started": "2022-08-06T01:55:44.803518Z" }, "tags": [] }, "outputs": [], "source": [ "from torchvision import datasets\n", "from torchvision.transforms import ToTensor\n", "\n", "# Download test data from open datasets.\n", "test_data = datasets.FashionMNIST(\n", " root=\"~/data\",\n", " train=False,\n", " download=True,\n", " transform=ToTensor(),\n", ")" ] }, { "cell_type": "code", "execution_count": 30, "id": "d6c5dab6-1fb4-42a9-b003-9fe7898806c9", "metadata": { "execution": { "iopub.execute_input": "2022-08-06T01:55:47.050798Z", "iopub.status.busy": "2022-08-06T01:55:47.050561Z", "iopub.status.idle": "2022-08-06T01:55:47.066303Z", "shell.execute_reply": "2022-08-06T01:55:47.065725Z", "shell.execute_reply.started": "2022-08-06T01:55:47.050774Z" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted: \"Ankle boot\", Actual: \"Ankle boot\"\n" ] } ], "source": [ "import torch\n", "from torch import nn\n", "\n", "from ray.train.torch import TorchCheckpoint\n", "\n", "def predict_from_model(model):\n", " classes = [\n", " \"T-shirt/top\",\n", " \"Trouser\",\n", " \"Pullover\",\n", " \"Dress\",\n", " \"Coat\",\n", " \"Sandal\",\n", " \"Shirt\",\n", " \"Sneaker\",\n", " \"Bag\",\n", " \"Ankle boot\",\n", " ]\n", "\n", " model.eval()\n", " x, y = test_data[0][0], test_data[0][1]\n", " with torch.no_grad():\n", " pred = model(x)\n", " predicted, actual = classes[pred[0].argmax(0)], classes[y]\n", " print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')\n", " \n", "model_loaded = TorchCheckpoint(local_path=model_path).get_model(NeuralNetwork())\n", "\n", "predict_from_model(model_loaded)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 5 }