{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "afc684c7-a3ef-43ec-b1f4-729c8be719b5", "metadata": { "tags": [] }, "outputs": [], "source": [ "%pip install torch transformers pandas scikit-learn mlflow tensorboard s3fs ray[all]==2.0.0rc0" ] }, { "cell_type": "code", "execution_count": null, "id": "4784e62c-6205-47c9-929c-22b0e4522b40", "metadata": { "tags": [] }, "outputs": [], "source": [ "# restart kernel to pick up the pip installs above\n", "import IPython\n", "\n", "IPython.Application.instance().kernel.do_shutdown(True) #automatically restarts kernel" ] }, { "cell_type": "code", "execution_count": null, "id": "c154fead-7a6d-4546-b6a3-fc6f79e51ab0", "metadata": { "tags": [] }, "outputs": [], "source": [ "import ray\n", "from ray import serve\n", "import torch\n", "from ray.train.torch import TorchPredictor, TorchCheckpoint\n", "from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer\n", "import tempfile\n", "import s3fs\n", "\n", "@serve.deployment(route_prefix=\"/predict\", version=\"0.1.0\")\n", "class Predictor:\n", "\n", " def __init__(self):\n", " self.classes = [\n", " \"Negative\",\n", " \"Neutral\",\n", " \"Positive\",\n", " ]\n", "\n", " s3_file = s3fs.S3FileSystem()\n", "\n", " s3_path = \"s3://dsoaws/ray_output/TorchTrainer_2022-08-10_05-31-49/TorchTrainer_bba46_00000_0_2022-08-10_05-31-51/checkpoint_000000/\"\n", " model_path = tempfile.mkdtemp()\n", " s3_file.get(s3_path, model_path, recursive=True)\n", " print(model_path)\n", " num_labels = 3\n", " use_slow_tokenizer = False\n", "\n", " base_model_name_or_path = \"roberta-base\"\n", "\n", " self.config = AutoConfig.from_pretrained(\n", " base_model_name_or_path, num_labels=num_labels,\n", " )\n", " self.tokenizer = AutoTokenizer.from_pretrained(\n", " base_model_name_or_path, use_fast=not use_slow_tokenizer\n", " )\n", " self.base_model = AutoModelForSequenceClassification.from_pretrained(\n", " base_model_name_or_path,\n", " config=self.config,\n", " )\n", "\n", " self.model = TorchCheckpoint(local_path=model_path).get_model(self.base_model)\n", " print(self.model)\n", "\n", "\n", " def __call__(self, request):\n", " txt = request.query_params[\"txt\"]\n", " self.model.eval()\n", " with torch.no_grad():\n", " tokenized_txt = self.tokenizer.encode_plus(\n", " txt,\n", " padding='max_length',\n", " max_length=64,\n", " truncation=True,\n", " return_tensors=\"pt\"\n", " )\n", " input_ids = tokenized_txt[\"input_ids\"]\n", " pred = self.model(input_ids)\n", " predicted_class = self.classes[pred[0].argmax()]\n", " return predicted_class" ] }, { "cell_type": "code", "execution_count": null, "id": "b90dda05-3647-48b6-8265-c3267e600841", "metadata": { "tags": [] }, "outputs": [], "source": [ "import ray\n", "\n", "ray.shutdown()\n", "ray.init(address=\"ray://localhost:10001\", namespace=\"serve\",\n", " runtime_env={\"pip\": [\n", " \"torch\",\n", " \"scikit-learn\",\n", " \"transformers\",\n", " \"pandas\",\n", " \"datasets\",\n", " \"accelerate\",\n", " \"scikit-learn\",\n", " \"mlflow\",\n", " \"tensorboard\",\n", " \"s3fs\"\n", " ]\n", " })" ] }, { "cell_type": "code", "execution_count": null, "id": "9199cbf9-7c87-4de0-b3dd-db92b23afff4", "metadata": { "tags": [] }, "outputs": [], "source": [ "serve.start(detached=True, http_options={\"host\": \"0.0.0.0\"})\n", "\n", "Predictor.deploy()" ] }, { "cell_type": "code", "execution_count": null, "id": "be8fc2bf-b73a-41ef-9b65-9007ebde3535", "metadata": { "tags": [] }, "outputs": [], "source": [ "import requests\n", "\n", "review_text = (\n", " \"This product is great!\"\n", ")\n", "\n", "#response = requests.get(\"https://ray-demo.\"+os.environ[\"TF_VAR_eks_cluster_domain\"]+\"/serve/summarize?txt=\" + article_text).text\n", "response = requests.get(\"http://127.0.0.1:8000/predict?txt=\" + review_text).text" ] }, { "cell_type": "code", "execution_count": null, "id": "b3502902-8185-41db-9e87-316774613d3d", "metadata": {}, "outputs": [], "source": [ "print(response)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.7" } }, "nbformat": 4, "nbformat_minor": 5 }