{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import boto3\n", "import sagemaker\n", "print(boto3.__version__)\n", "print(sagemaker.__version__)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "session = sagemaker.Session()\n", "bucket = session.default_bucket()\n", "print(\"Default bucket is {}\".format(bucket))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "### REPLACE INPUT FROM PREVIOUS SAGEMAKER PROCESSING CONFIG OUTPUT #####\n", "prefix=\"customer_support_classification\"\n", "s3_training_path=\"https://sagemaker-us-east-1-123412341234.s3.amazonaws.com/sagemaker-scikit-learn-2020-11-16-19-27-42-281/output/training/training.txt\"\n", "s3_validation_path=\"https://sagemaker-us-east-1-123412341234.s3.amazonaws.com/sagemaker-scikit-learn-2020-11-16-19-27-42-281/output/validation/validation.txt\"\n", "s3_output_path=\"s3://{}/{}\".format(bucket, prefix)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker import image_uris\n", "region_name = boto3.Session().region_name\n", "print(\"Training the model in {} region\".format(region_name))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Estimator" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "container = image_uris.retrieve('blazingtext', region=region_name)\n", "print(\"The algo container is {}\".format(container))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "blazing_text = sagemaker.estimator.Estimator(\n", " container,\n", " role=sagemaker.get_execution_role(),\n", " instance_count=1,\n", " instance_type='ml.c4.4xlarge',\n", " output_path=s3_output_path\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# set the hyperparameters" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "blazing_text.set_hyperparameters(mode='supervised')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker import TrainingInput" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_data = TrainingInput(s3_training_path, \n", " distribution='FullyReplicated',\n", " content_type='text/plain',\n", " s3_data_type='S3Prefix'\n", " )\n", "validation_data = TrainingInput(\n", " s3_validation_path,\n", " distribution='FullyReplicated',\n", " content_type='text/plain',\n", " s3_data_type='S3Prefix'\n", ")\n", "\n", "s3_channels = {'train': train_data,\n", " 'validation': validation_data\n", " }\n", "blazing_text.fit(inputs=s3_channels)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "blazing_text.latest_training_job.job_name" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "blazing_text_predictor = blazing_text.deploy(\n", " initial_instance_count=1,\n", " instance_type='ml.t2.medium'\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "import nltk\n", "\n", "nltk.download('punkt')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "\"\"\"\n", "Hi my outlook app seems to misbehave a lot lately, I cannot sync my emails and it often crashes and asks for\n", "credentials. Could you help me out?\n", "\"\"\"\n", "sentences = [\"Hi my outlook app seems to misbehave a lot lately, I cannot sync my emails and it often crashes and asks for credentials.\", \"Could you help me out?\"]\n", "tokenized_sentences = [' '.join(nltk.word_tokenize(sent)) for sent in sentences]\n", "payload = {\"instances\" : tokenized_sentences,\n", " \"configuration\": {\"k\": 1}}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "payload" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.predictor import Predictor\n", "from sagemaker.serializers import JSONSerializer\n", "from sagemaker.deserializers import JSONDeserializer\n", "\n", "case_classifier = Predictor(\n", " endpoint_name=\"blazingtext-2020-11-18-15-13-52-229\", # Replace with sagemaker endpoint deployed in the previous step\n", " serializer=JSONSerializer()\n", ")\n", "response = case_classifier.predict(payload)\n", "\n", "predictions = json.loads(response)\n", "print(predictions)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predictions = sorted(predictions, key=lambda i: i['prob'], reverse=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(json.dumps(predictions[0], indent=2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "conda_python3", "language": "python", "name": "conda_python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.10" } }, "nbformat": 4, "nbformat_minor": 4 }