{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Use Amazon SageMaker and Amazon OpenSearch Service to implement unified text and image search with a CLIP model " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook aims at building a prototyping machine learning (ML) powered search engine to retrieve and recommend products based on text or image queries. This is a step-by-step guide on how to create SageMaker Models with [Contrastive Language-Image Pre-Training (CLIP)](https://openai.com/blog/clip/), use the models to encode images and text into embeddings, ingest embeddings into [Amazon OpenSearch Service](https://aws.amazon.com/opensearch-service/) index, and query the index using OpenSearch Service [k-nearest neighbors (KNN) functionality](https://docs.aws.amazon.com/opensearch-service/latest/developerguide/knn.html)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data overview and preparation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It's recommended to execute the notebook in SageMaker Studio Notebooks `Python 3(Data Science)` Kernel with `ml.t3.medium` instance." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Warning** ⚠️ If you don't have an OpenSearch cluster running, you need to use the CloudFormation template to create one for you." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Install some python packages we are going to use in the POC." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install -qU aiobotocore \n", "%pip install -q jsonlines\n", "%pip install -q requests_aws4auth\n", "%pip install -q elasticsearch==7.13.4" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For this notebook, you are using the Amazon Berkeley Objects Dataset. The dataset is a collection of 147,702 product listings with multilingual metadata and 398,212 unique catalog images. 8,222 listings come with turntable photography. You will only make use of the item images and item name in US English (which is we consider as the product’s short description). For demo purposes you are going to use about 1,600 products for this practice.\n", "The README file of the dataset can be found in `s3://amazon-berkeley-objects/README.md`, you should be able to download the file to your SageMaker Studio environment by running the following command in a notebook with Python 3 (Data Science) kernel." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!mkdir -p data\n", "!aws s3 cp --no-sign-request s3://amazon-berkeley-objects/README.md data/README.md " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are 16 files include product description and metadata, named ` listings/metadata/listings_.json.gz ` . We use the first meta file in the implementation for the demo purposes.\n", "\n", "You can use pandas to load metadata, then select products which have titles in US English from the data frame. You will use a column called `main_image_id` to merge item name with item image later." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "from PIL.Image import Image as PilImage\n", "import textwrap, os\n", "import sagemaker\n", "from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig\n", "from sagemaker.serializers import JSONSerializer, IdentitySerializer\n", "from sagemaker.deserializers import JSONDeserializer\n", "\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "s3_bucket_name = sagemaker.session.Session().default_bucket()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "meta = pd.read_json(\"s3://amazon-berkeley-objects/listings/metadata/listings_0.json.gz\", lines=True)\n", "def func_(x):\n", " us_texts = [item[\"value\"] for item in x if item[\"language_tag\"] == \"en_US\"]\n", " return us_texts[0] if us_texts else None\n", "\n", "meta = meta.assign(item_name_in_en_us=meta.item_name.apply(func_))\n", "meta = meta[~meta.item_name_in_en_us.isna()][[\"item_id\", \"item_name_in_en_us\", \"main_image_id\"]]\n", "print(f\"#products with US English title: {len(meta)}\")\n", "meta.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You should be able to see over 1600 products in the data frame.\n", "Next, you can link the item names with item images. `images/metadata/images.csv.gz` contains Image metadata. This file is a gzip-compressed comma-separated value (CSV) file with the following columns: `image_id`, `height`, `width`, and `path`. You can read the meta data file and then merge it with item metadata." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "image_meta = pd.read_csv(\"s3://amazon-berkeley-objects/images/metadata/images.csv.gz\")\n", "dataset = meta.merge(image_meta, left_on=\"main_image_id\", right_on=\"image_id\")\n", "dataset.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can have a look at one sample image from the dataset by running the following code" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.s3 import S3Downloader as s3down\n", "from pathlib import Path\n", "from PIL import Image\n", "\n", "s3_data_root = \"s3://amazon-berkeley-objects/images/small/\"\n", "\n", "def get_image_from_item_id(item_id = \"B0896LJNLH\", return_image=True):\n", " \n", " item_idx = dataset.query(f\"item_id == '{item_id}'\").index[0]\n", " s3_path = dataset.iloc[item_idx].path\n", " local_data_root = f'./data/images'\n", " local_file_name = Path(s3_path).name\n", "\n", " s3down.download(f'{s3_data_root}{s3_path}', local_data_root)\n", "\n", " local_image_path = f\"{local_data_root}/{local_file_name}\"\n", " if return_image:\n", " img = Image.open(local_image_path)\n", " return img, dataset.iloc[item_idx].item_name_in_en_us\n", " else:\n", " return local_image_path, dataset.iloc[item_idx].item_name_in_en_us\n", "image, item_name = get_image_from_item_id()\n", "print(item_name)\n", "image" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model preparation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You are going to create a SageMaker model from a pretrained CLIP model. The first step is to download a pretrained model weighting files, put it into a model.tar.gz file and upload the file to S3. The path of pretrained model can be found in [CLIP repo](https://github.com/openai/CLIP/blob/main/clip/clip.py#L30). You're going to use a pretrained RN50 model in this POC." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%writefile build_model_tar.sh\n", "#!/bin/bash\n", "BUCKET_NAME=\"$1\"\n", "MODEL_NAME=RN50.pt\n", "MODEL_NAME_URL=https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt\n", "\n", "BUILD_ROOT=/tmp/model_path\n", "S3_PATH=s3://${BUCKET_NAME}/models/clip/model.tar.gz\n", "\n", "\n", "rm -rf $BUILD_ROOT\n", "mkdir $BUILD_ROOT\n", "cd $BUILD_ROOT && curl -o $BUILD_ROOT/$MODEL_NAME $MODEL_NAME_URL\n", "cd $BUILD_ROOT && tar -czvf model.tar.gz .\n", "aws s3 cp $BUILD_ROOT/model.tar.gz $S3_PATH" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The execution of following cell takes ~ 2 minutes to complete." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!bash build_model_tar.sh {s3_bucket_name}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You should be able to see the model tar file exists in Amazon S3." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!aws s3 ls s3://{s3_bucket_name}/models/clip/model.tar.gz" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then you need to provide inference code for the CLIP model. CLIP uses PyTorch for its implementation, so we are going to use [SageMaker PyTorch](https://docs.aws.amazon.com/sagemaker/latest/dg/pytorch.html) framework. More information related to deploy an PyTorch model with SageMaker can be found [here](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#deploy-pytorch-models). The inference code accepts two environment variables `MODEL_NAME` and `ENCODE_TYPE`. `MODEL_NAME` helps us to switch between different CLIP model easily. `ENCODE_TYPE` is to specify if you want to encode an image or a piece of text. Here, we implement the `model_fn`, `input_fn`, `predict_fn` and `output_fn` function to override the [default PyTorch inference handler](https://github.com/aws/sagemaker-pytorch-inference-toolkit/blob/master/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!mkdir -p code" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%writefile code/clip_inference.py\n", "\n", "import io\n", "import torch\n", "import clip\n", "from PIL import Image\n", "import json\n", "import logging\n", "import sys\n", "import os\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torchvision.transforms import ToTensor\n", "\n", "logger = logging.getLogger(__name__)\n", "logger.setLevel(logging.DEBUG)\n", "logger.addHandler(logging.StreamHandler(sys.stdout))\n", "\n", "MODEL_NAME = os.environ.get(\"MODEL_NAME\", \"RN50.pt\")\n", "# ENCODE_TYPE could be IMAGE or TEXT\n", "ENCODE_TYPE = os.environ.get(\"ENCODE_TYPE\", \"TEXT\")\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "\n", "# defining model and loading weights to it.\n", "def model_fn(model_dir):\n", " model, preprocess = clip.load(os.path.join(model_dir, MODEL_NAME), device=device)\n", " return {\"model_obj\": model, \"preprocess_fn\": preprocess}\n", "\n", "\n", "def load_from_bytearray(request_body):\n", " \n", " return image\n", "\n", "# data loading\n", "def input_fn(request_body, request_content_type):\n", " assert request_content_type in (\n", " \"application/json\",\n", " \"application/x-image\",\n", " ), f\"{request_content_type} is an unknown type.\"\n", " if request_content_type == \"application/json\":\n", " data = json.loads(request_body)[\"inputs\"]\n", " elif request_content_type == \"application/x-image\":\n", " image_as_bytes = io.BytesIO(request_body)\n", " data = Image.open(image_as_bytes)\n", " return data\n", "\n", "\n", "# inference\n", "def predict_fn(input_object, model):\n", " model_obj = model[\"model_obj\"]\n", " # for image preprocessing\n", " preprocess_fn = model[\"preprocess_fn\"]\n", " assert ENCODE_TYPE in (\"TEXT\", \"IMAGE\"), f\"{ENCODE_TYPE} is an unknown encode type.\"\n", "\n", " # preprocessing\n", " if ENCODE_TYPE == \"TEXT\":\n", " input_ = clip.tokenize(input_object).to(device)\n", " elif ENCODE_TYPE == \"IMAGE\":\n", " input_ = preprocess_fn(input_object).unsqueeze(0).to(device)\n", "\n", " # inference\n", " with torch.no_grad():\n", " if ENCODE_TYPE == \"TEXT\":\n", " prediction = model_obj.encode_text(input_)\n", " elif ENCODE_TYPE == \"IMAGE\":\n", " prediction = model_obj.encode_image(input_)\n", " return prediction\n", "\n", "\n", "# Serialize the prediction result into the desired response content type\n", "def output_fn(predictions, content_type):\n", " assert content_type == \"application/json\"\n", " res = predictions.cpu().numpy().tolist()\n", " return json.dumps(res)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In addition, you need to create a `requirements.txt` file which contains dependency for your inference code." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%writefile code/requirements.txt\n", "ftfy\n", "regex\n", "tqdm\n", "git+https://github.com/openai/CLIP.git" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You need create PyTorchModel objects from a saved model artifact and inference entry point. `clip_image_model` will be used to encode images of items and generate embedding through a batch transform job later. `clip_image_model` and `clip_text_model` can be deployed as serverless endpoints later for online inference." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.pytorch import PyTorchModel\n", "from sagemaker import get_execution_role, Session\n", "\n", "\n", "role = get_execution_role()\n", "shared_params = dict(\n", " entry_point=\"clip_inference.py\",\n", " source_dir=\"code\",\n", " role=role,\n", " model_data=f\"s3://{s3_bucket_name}/models/clip/model.tar.gz\",\n", " framework_version=\"1.9.0\",\n", " py_version=\"py38\",\n", ")\n", "\n", "clip_image_model = PyTorchModel(\n", " env={'MODEL_NAME': 'RN50.pt', \"ENCODE_TYPE\": \"IMAGE\"},\n", " name=\"clip-image-model\",\n", " **shared_params\n", ")\n", "\n", "clip_text_model = PyTorchModel(\n", " env={'MODEL_NAME': 'RN50.pt', \"ENCODE_TYPE\": \"TEXT\"},\n", " name=\"clip-text-model\",\n", " **shared_params\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Batch transform to generate embedding from item images" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "CLIP is able to project both images and text into the same latent space, so we only need to encode item images or texts into embedding space. In this practice, you are going to encode item images through a SageMaker Batch Transform job. Before creating the job, you need to copy item images from Amazon Berkeley Objects Dataset public S3 bucket to your own S3 Bucket. The operation needs take less than 10 mins." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from multiprocessing.pool import ThreadPool\n", "import boto3\n", "import tqdm.notebook as tq\n", "from urllib.parse import urlparse\n", "\n", "s3_sample_image_root = f\"s3://{s3_bucket_name}/sample-images\"\n", "\n", "client = boto3.client('s3')\n", "\n", "def upload_(args):\n", " client.copy_object(CopySource=args[\"source\"], Bucket=args[\"target_bucket\"], Key=args[\"target_key\"])\n", "\n", "arguments = []\n", "for idx, record in dataset.iterrows():\n", " argument = {}\n", " argument[\"source\"] = (s3_data_root + record.path)[5:]\n", " argument[\"target_bucket\"] = urlparse(s3_sample_image_root).netloc\n", " argument[\"target_key\"] = urlparse(s3_sample_image_root).path[1:] + '/' + record.path\n", " arguments.append(argument)\n", "\n", "with ThreadPool(4) as p:\n", " r = list(tq.tqdm(p.imap(upload_, arguments), total=len(dataset)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we perform inference on the item images in a batch manner. The execution of the job takes around 10 mins." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "batch_input = s3_sample_image_root + \"/\"\n", "output_path = f\"s3://{s3_bucket_name}/inference/clip-search/output/\"\n", "\n", "clip_image_transformer = clip_image_model.transformer(\n", " instance_count=1,\n", " instance_type=\"ml.c5.xlarge\",\n", " strategy=\"SingleRecord\",\n", " assemble_with=\"Line\",\n", " output_path=output_path,\n", ")\n", "\n", "clip_image_transformer.transform(\n", " batch_input, \n", " data_type=\"S3Prefix\",\n", " content_type=\"application/x-image\", \n", " wait=True,\n", " logs=False\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now you can go to the SageMaker Console --> Inference (left Side Bar) ---> Batch transform jobs to check the job execution status and logs. You can also set `logs=True` in the `transform` function, this will enable the log display in the notebook." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You are going to load embedding from S3 to a notebook variable, so can ingest the data into OpenSearch service later." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.s3 import S3Downloader\n", "import json\n", "\n", "embedding_root_path = \"./data/embedding\"\n", "s3down.download(output_path, embedding_root_path)\n", "\n", "embeddings = []\n", "for idx, record in dataset.iterrows():\n", " embedding_file = f\"{embedding_root_path}/{record.path}.out\"\n", " embeddings.append(json.load(open(embedding_file))[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load OpenSearch Service" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we will define the `opensearch_region` and `opensearch_url`, we will use our CloudFormation Stack Name to fetch the OpenSearch Domain URL. \n", "\n", "**Note1** If you didn't use the CloudFormation template provided in the repo to create the cluster. You can skip the following cell and define your own `opensearch_region` and `opensearch_region`. \n", "**Note2** If your SageMaker Execution role doesn't have permission to describe CloudFormation Stack. You can go the CloudFormation service in AWS Console, and use the output of `clip-poc-stack` to config these two parameters." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import boto3\n", "\n", "client = boto3.client('cloudformation')\n", "\n", "response = client.describe_stacks(\n", " StackName='clip-poc-stack',\n", ")\n", "\n", "my_session = boto3.session.Session()\n", "\n", "opensearch_region = my_session.region_name\n", "opensearch_url = [\n", " output_[\"OutputValue\"]\n", " for output_ in response[\"Stacks\"][0][\"Outputs\"]\n", " if output_[\"OutputKey\"] == \"DomainEndpoint\"\n", "][0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setting up the Amazon OpenSearch Service domain using KNN settings" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, set up an OpenSearch service cluster. For instructions, see [Creating and Managing Amazon OpenSearch Service Domains](https://docs.aws.amazon.com/opensearch-service/latest/developerguide/createupdatedomains.html). \n", "Once the OpenSearch cluster is set up, create an index to store the item meta data and the embeddings. The index settings must be configured beforehand to enable the KNN functionality using the following configuration:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "index_settings = {\n", " \"settings\": {\n", " \"index.knn\": True,\n", " \"index.knn.space_type\": \"cosinesimil\"\n", " },\n", " \"mappings\": {\n", " \"properties\": {\n", " \"embeddings\": {\n", " \"type\": \"knn_vector\",\n", " \"dimension\": 1024 #Make sure this is the size of the embeddings you generated, for RN50, it is 1024\n", " }\n", " }\n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This example uses the Python Elasticsearch client to communicate with the Elasticsearch cluster and create an index to host our data." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import boto3\n", "import json\n", "from requests_aws4auth import AWS4Auth\n", "from elasticsearch import Elasticsearch, RequestsHttpConnection\n", "\n", "\n", "index_name = \"clip-index\"\n", "\n", "def get_es_client(host=opensearch_url,\n", " port=443,\n", " region=opensearch_region,\n", " index_name=index_name\n", "):\n", "\n", " credentials = boto3.Session().get_credentials()\n", " awsauth = AWS4Auth(credentials.access_key,\n", " credentials.secret_key,\n", " region,\n", " 'es',\n", " session_token=credentials.token)\n", "\n", " headers = {\"Content-Type\": \"application/json\"}\n", "\n", " es = Elasticsearch(hosts=[{'host': host, 'port': port}],\n", " http_auth=awsauth,\n", " use_ssl=True,\n", " verify_certs=True,\n", " connection_class=RequestsHttpConnection,\n", " timeout=60 # for connection timeout errors\n", " )\n", " return es\n", "es = get_es_client()\n", "es.indices.create(index=index_name, body=json.dumps(index_settings))\n", "\n", "# You can check if the index is created within your es cluster\n", "print(es.indices.get_alias(\"*\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next you need to loop through your dataset and ingest items data into the cluster. A more robust and scalable solution for the embedding ingestion can be found in [Ingesting enriched data into Amazon ES](https://aws.amazon.com/blogs/industries/novartis-ag-uses-amazon-elasticsearch-k-nearest-neighbor-knn-and-amazon-sagemaker-to-power-search-and-recommendation/). The data ingestion for this POC should finish within 60 seconds. It also executes a simple query to verify the data have been ingested into the index." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import tqdm.notebook as tq\n", "\n", "# ingest_data_into_es\n", "\n", "for idx, record in tq.tqdm(dataset.iterrows(), total=len(dataset)):\n", " body = record[['item_name_in_en_us']].to_dict()\n", " body['embeddings'] = embeddings[idx]\n", " es.index(index=index_name, id=record.item_id, doc_type='_doc', body=body)\n", "\n", "#Check that data is indeed in ES\n", "res = es.search(\n", " index=index_name, body={\n", " \"query\": {\n", " \"match_all\": {}\n", " }},\n", " size=2)\n", "assert len(res[\"hits\"][\"hits\"]) > 0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate embeddings from the query" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that you have a working OpenSearch index to contain embeddings for your inventory, let's have a look at how you can generate embedding for new items. You need to create two Amazon SageMaker endpoint for extracting text features and image features. [Amazon SageMaker Serverless Inference](https://docs.aws.amazon.com/sagemaker/latest/dg/serverless-endpoints.html) is a purpose-built inference option that makes it easy for you to deploy and scale ML models. Serverless Inference is ideal for workloads which have idle periods between traffic spurts and can tolerate cold starts. Since you are creating a POC search engine, Serverless Inference is helpful to prevent extra cost. \n", "\n", "You can create two functions to use the endpoints to encode images and texts. For CLIP, you need to perform some `prompt engineering` to translate an item name to a sentence for item description. Here, `this is ` is a simple prompt engineering method." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "text_predictor = clip_text_model.deploy(\n", " instance_type='ml.c5.xlarge',\n", " initial_instance_count=1,\n", " serverless_inference_config=ServerlessInferenceConfig(memory_size_in_mb=6144),\n", " serializer=JSONSerializer(),\n", " deserializer=JSONDeserializer(),\n", " wait=True\n", ")\n", "\n", "image_predictor = clip_image_model.deploy(\n", " instance_type='ml.c5.xlarge',\n", " initial_instance_count=1,\n", " serverless_inference_config=ServerlessInferenceConfig(memory_size_in_mb=6144),\n", " serializer=IdentitySerializer(content_type=\"application/x-image\"),\n", " deserializer=JSONDeserializer(),\n", " wait=True\n", ")\n", "\n", "def encode_image(file_name=\"./data/images/0e9420c6.jpg\"): \n", " with open(file_name, \"rb\") as f:\n", " payload = f.read()\n", " payload = bytearray(payload)\n", " res = image_predictor.predict(payload)\n", " return res[0]\n", "\n", "def encode_name(item_name):\n", " res = text_predictor.predict({\"inputs\": [f\"this is a {item_name}\"]})\n", " return res[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can have a look at the picture you want to encode first, so it could be easier to evaluate the results later." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "item_image_path, item_name = get_image_from_item_id(item_id = \"B0896LJNLH\", return_image=False)\n", "feature_vector = encode_image(file_name=item_image_path)\n", "\n", "print(len(feature_vector))\n", "Image.open(item_image_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Make a k-nn based query" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let’s take a look at the results of a simple query. After retrieving results from the OpenSearch service, we get the item names and images from `dataset`. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "es = get_es_client()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def search_products(embedding, k = 3):\n", " body = {\n", " \"size\": k,\n", " \"_source\": {\n", " \"exclude\": [\"embeddings\"],\n", " },\n", " \"query\": {\n", " \"knn\": {\n", " \"embeddings\": {\n", " \"vector\": embedding,\n", " \"k\": k,\n", " }\n", " }\n", " },\n", " } \n", " res = es.search(index=index_name, body=body)\n", " images = []\n", " for hit in res[\"hits\"][\"hits\"]:\n", " id_ = hit[\"_id\"]\n", " image, item_name = get_image_from_item_id(id_)\n", " image.name_and_score = f'{hit[\"_score\"]}:{item_name}'\n", " images.append(image)\n", " return images\n", "\n", "def display_images(\n", " images: [PilImage], \n", " columns=2, width=20, height=8, max_images=15, \n", " label_wrap_length=50, label_font_size=8):\n", "\n", " if not images:\n", " print(\"No images to display.\")\n", " return \n", "\n", " if len(images) > max_images:\n", " print(f\"Showing {max_images} images of {len(images)}:\")\n", " images=images[0:max_images]\n", "\n", " height = max(height, int(len(images)/columns) * height)\n", " plt.figure(figsize=(width, height))\n", " for i, image in enumerate(images):\n", "\n", " plt.subplot(int(len(images) / columns + 1), columns, i + 1)\n", " plt.imshow(image)\n", "\n", " if hasattr(image, 'name_and_score'):\n", " plt.title(image.name_and_score, fontsize=label_font_size); \n", " \n", "images = search_products(feature_vector)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It will be more intuitive if you can show the results with `matplotlib`. As you can see, the first item which score is `1.0` is the same as input, that's the same as we expected, because it's the same as input. Other items are different, but somehow they are similar as the input. For example, the `Push Broom Kit` is black and tall." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "display_images(images)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can see that we found glass drinkware from our `dataset` without providing any textual information. That's what we want to achieve." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This time you can describe the item you want to retrieve in text." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "feature_vector = encode_name(\"drinkware glass\")\n", "images = search_products(feature_vector,)\n", "display_images(images)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now you can try other item description and images as well" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "feature_vector = encode_name(\"pizza\")\n", "images = search_products(feature_vector)\n", "display_images(images)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cleaning up" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With a pay-per-use model, Serverless Inference is a cost-effective option if you have an infrequent or unpredictable traffic pattern. So running endpoints will not charge you, but if you do not plan to use the endpoints, you should delete them to free up some computation resource.\n", "\n", "When you finish this exercise, remove your resources with the following steps:\n", "\n", "Delete SageMaker Studio user profile and domain. \n", "Optionally, delete registered models. \n", "Optionally, delete the SageMaker execution role. \n", "Optionally, empty and delete the S3 bucket, or keep whatever you want. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "es.indices.delete(index=index_name, ignore=[400, 404])\n", "!rm -rf data/\n", "text_predictor.delete_endpoint()\n", "image_predictor.delete_endpoint()\n", "clip_text_model.delete_model()\n", "clip_image_model.delete_model()" ] } ], "metadata": { "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "Python 3.9.10 64-bit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.10 (main, Jan 15 2022, 11:40:53) \n[Clang 13.0.0 (clang-1300.0.29.3)]" }, "vscode": { "interpreter": { "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e" } } }, "nbformat": 4, "nbformat_minor": 4 }