{ "cells": [ { "cell_type": "markdown", "id": "252de0de", "metadata": {}, "source": [ "### SageMaker Endpoint 部署ChatGLM\n", " \n", "[ChatGLM](https://github.com/THUDM/ChatGLM-6B): ChatGLM-6B 是一个开源的、支持中英双语的对话语言模型,基于 General Language Model (GLM) 架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4 量化级别下最低只需 6GB 显存)。 ChatGLM-6B 使用了和 ChatGPT 相似的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答。\n", "\n", "#### 准备\n", "1. 升级boto3, sagemaker python sdk \n", "2. 准备inference.py, requirements.txt" ] }, { "cell_type": "code", "execution_count": null, "id": "b8f2c403", "metadata": {}, "outputs": [], "source": [ "!pip install --upgrade boto3\n", "!pip install --upgrade sagemaker" ] }, { "cell_type": "code", "execution_count": null, "id": "a4a30f3a", "metadata": {}, "outputs": [], "source": [ "import boto3\n", "import sagemaker\n", "\n", "account_id = boto3.client('sts').get_caller_identity().get('Account')\n", "region_name = boto3.session.Session().region_name\n", "\n", "sagemaker_session = sagemaker.Session()\n", "bucket = sagemaker_session.default_bucket()\n", "role = sagemaker.get_execution_role()\n", "\n", "print(role)\n", "print(bucket)" ] }, { "cell_type": "code", "execution_count": null, "id": "2a2b0b22", "metadata": {}, "outputs": [], "source": [ "!touch dummy\n", "!tar czvf model.tar.gz dummy\n", "assets_dir = 's3://{0}/{1}/assets/'.format(bucket, 'chatglm')\n", "model_data = 's3://{0}/{1}/assets/model.tar.gz'.format(bucket, 'chatglm')\n", "!aws s3 cp model.tar.gz $assets_dir\n", "!rm -f dummy model.tar.gz" ] }, { "cell_type": "markdown", "id": "70cbcac2", "metadata": {}, "source": [ "### 设置模型推理参数\n", "1. model_name: Huggingface diffusers models (not support single check point format)\n", "2. model_args: diffuser StableDiffusionPipeline init arguments" ] }, { "cell_type": "code", "execution_count": null, "id": "8b8a09a3", "metadata": {}, "outputs": [], "source": [ "model_name = None\n", "entry_point = 'inference-chatglm.py'\n", "framework_version = '1.13.1'\n", "py_version = 'py39'\n", "model_environment = {\n", " 'SAGEMAKER_MODEL_SERVER_TIMEOUT':'600', \n", " 'SAGEMAKER_MODEL_SERVER_WORKERS': '1', \n", "}" ] }, { "cell_type": "code", "execution_count": null, "id": "b7d03288", "metadata": {}, "outputs": [], "source": [ "from sagemaker.pytorch.model import PyTorchModel\n", "\n", "model = PyTorchModel(\n", " name = model_name,\n", " model_data = model_data,\n", " entry_point = entry_point,\n", " source_dir = './code',\n", " role = role,\n", " framework_version = framework_version, \n", " py_version = py_version,\n", " env = model_environment\n", ")" ] }, { "cell_type": "markdown", "id": "e130efd9", "metadata": {}, "source": [ "### 部署 SageMaker Endpoint\n", "部署推理服务" ] }, { "cell_type": "code", "execution_count": null, "id": "4253e7d7", "metadata": {}, "outputs": [], "source": [ "endpoint_name = None\n", "instance_type = 'ml.g4dn.2xlarge'\n", "instance_count = 1\n", "\n", "from sagemaker.serializers import JSONSerializer\n", "from sagemaker.deserializers import JSONDeserializer\n", "predictor = model.deploy(\n", " endpoint_name = endpoint_name,\n", " instance_type = instance_type, \n", " initial_instance_count = instance_count,\n", " serializer = JSONSerializer(),\n", " deserializer = JSONDeserializer()\n", ")" ] }, { "cell_type": "markdown", "id": "f6f59e3f", "metadata": {}, "source": [ "### ChatGLM 测试\n" ] }, { "cell_type": "code", "execution_count": null, "id": "ada7d8b9", "metadata": {}, "outputs": [], "source": [ "#休眠2分钟,确保模型可以完全加载\n", "import time\n", "time.sleep(120)" ] }, { "cell_type": "code", "execution_count": null, "id": "01eb2968", "metadata": {}, "outputs": [], "source": [ "\n", "inputs= {\n", " \"ask\": \"你好!\"\n", "\n", "}\n", "\n", "response = predictor.predict(inputs)\n", "print(response[\"answer\"])\n", "\n", "inputs= {\n", " \"ask\": \"晚上睡不着应该怎么办\"\n", "\n", "}\n", "\n", "response = predictor.predict(inputs)\n", "print(response[\"answer\"])" ] }, { "cell_type": "code", "execution_count": null, "id": "90edc6ba", "metadata": {}, "outputs": [], "source": [ "inputs= {\n", " \"ask\": \"列出一些年夜饭好意头的菜肴以及其寓意!\"\n", "\n", "}\n", "\n", "response = predictor.predict(inputs)\n", "print(response[\"answer\"])\n", "\n", "inputs= {\n", " \"ask\": \"帮我写一篇人工智能课程的教案,1000字\"\n", "\n", "}\n", "\n", "response = predictor.predict(inputs)\n", "print(response[\"answer\"])" ] }, { "cell_type": "code", "execution_count": null, "id": "80082a13", "metadata": {}, "outputs": [], "source": [ "inputs= {\n", " \"ask\": \"怎么修改huggingface transformers的model cache位置\"\n", "\n", "}\n", "\n", "response = predictor.predict(inputs)\n", "print(response[\"answer\"])\n", "\n", "inputs= {\n", " \"ask\": \"用python3写出快速排序代码\"\n", "\n", "}\n", "\n", "response = predictor.predict(inputs)\n", "print(response[\"answer\"])" ] }, { "cell_type": "code", "execution_count": null, "id": "f3ddf260", "metadata": {}, "outputs": [], "source": [ "#我们来查看一下刚部署的这个ChatGLM模型对应的SageMaker inference endpoint\n", "sagemaker_endpoint_name = predictor.endpoint_name\n", "sagemaker_endpoint_name" ] }, { "cell_type": "markdown", "id": "fa257dda", "metadata": {}, "source": [ "利用已经在SageMaker real time inference endpoint部署的ChatGLM模型来模拟单轮对话和多轮对话。\n", "\n", "下面的代码建议在SageMaker Notebook上来运行。" ] }, { "cell_type": "code", "execution_count": null, "id": "078035cb", "metadata": {}, "outputs": [], "source": [ "import json\n", "import boto3\n", "\n", "client = boto3.client('runtime.sagemaker')\n", "\n", "def query_endpoint_with_json_payload(encoded_json):\n", " response = client.invoke_endpoint(EndpointName=sagemaker_endpoint_name, ContentType='application/json', Body=encoded_json)\n", " return response\n", "\n", "def parse_response_texts(query_response):\n", " model_predictions = json.loads(query_response['Body'].read())\n", " generated_text = model_predictions[\"answer\"]\n", " return generated_text\n" ] }, { "cell_type": "markdown", "id": "a7922fd8", "metadata": {}, "source": [ "先简单测试一下ChatGLM针对单个问题的回答" ] }, { "cell_type": "code", "execution_count": null, "id": "55a5b2ef", "metadata": {}, "outputs": [], "source": [ "payload = {\"ask\": \"信息抽取:\\n2022年世界杯的冠军是阿根廷队伍,梅西是MVP\\n问题:国家名,人名\\n答案:\"}\n", "query_response = query_endpoint_with_json_payload(json.dumps(payload).encode('utf-8'))\n", "generated_texts = parse_response_texts(query_response)\n", "print(generated_texts)" ] }, { "cell_type": "markdown", "id": "a81504f9", "metadata": {}, "source": [ "单轮对话:每个问题/query都是独立的,问题之间没有关联性。" ] }, { "cell_type": "code", "execution_count": null, "id": "82254c0f", "metadata": {}, "outputs": [], "source": [ "#1.首先需要一个简单的开场拍。\n", "print(\"用户:你好!\\nChatGLM:您好!我是ChatGLM。我可以回答您的问题、写文章、写作业、翻译,对于一些法律等领域的问题我也可以给你提供信息。\")\n", "\n", "#2.在同一个session中持续对话,但是每次对话之间没有什么关联。\n", "while True:\n", " command = input(\"用户:\")\n", " if command == 'quit':\n", " break;\n", " \n", " #print(command)\n", " payload = {\"ask\": \"\\n用户:\"+ command + \"\\n\"}\n", " #print(payload[\"ask\"])\n", " query_response = query_endpoint_with_json_payload(json.dumps(payload).encode('utf-8'))\n", " generated_texts = \"ChatGLM:\" + parse_response_texts(query_response)\n", " print(generated_texts)" ] }, { "cell_type": "markdown", "id": "326def35", "metadata": {}, "source": [ "多轮对话模拟:我们这里来测试一下ChatGLM的多轮对话能力。" ] }, { "cell_type": "code", "execution_count": null, "id": "b169581e", "metadata": {}, "outputs": [], "source": [ "#1.首先需要开场拍来引导ChatGLM,其实就是给它一个上下文来启动所谓的对话session。\n", "payload = {\"ask\": \"用户:你好!\\nChatGLM:您好!我是ChatGLM。我可以回答您的问题、写文章、写作业、翻译,对于一些法律等领域的问题我也可以给你提供信息。\"}\n", "print(payload[\"ask\"])\n", "generated_texts = \"\"\n", "\n", "#在这里简单模拟多轮对话时,发送给SageMaker endpoint的payload会越来越大,这里对payload大致做一个限制。\n", "session_len = 10 * 1024 * 1024 \n", "\n", "#2.在同一个session中持续对话,为了有多轮对话的效果,把每一轮的信息(问题-回答对)都带上来告诉ChatGLM之前的上下文。\n", "while True:\n", " command = input(\"用户:\")\n", " if command == 'quit':\n", " break;\n", " \n", " #print(command)\n", " whole_context = payload[\"ask\"] + generated_texts + \"\\n用户:\"+ command + \"\\n\"\n", " payload = {\"ask\": whole_context}\n", " if len(whole_context) > session_len:\n", " print(\"上下文信息太多了,当前对话session要退出了!\")\n", " break;\n", " #print(payload[\"ask\"])\n", " query_response = query_endpoint_with_json_payload(json.dumps(payload).encode('utf-8'))\n", " generated_texts = \"ChatGLM:\" + parse_response_texts(query_response)\n", " print(generated_texts)" ] }, { "cell_type": "markdown", "id": "1c2dcc6a", "metadata": {}, "source": [ "### 删除SageMaker Endpoint\n", "删除推理服务" ] }, { "cell_type": "code", "execution_count": null, "id": "c329e2d3", "metadata": {}, "outputs": [], "source": [ "predictor.delete_endpoint()" ] }, { "cell_type": "code", "execution_count": null, "id": "c512b41c", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "conda_python3", "language": "python", "name": "conda_python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" } }, "nbformat": 4, "nbformat_minor": 5 }