{ "cells": [ { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "# Fine-Tune a Generative AI Model for Dialogue Summarization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook we will see how to fine tune an existing LLM from HuggingFace for enhanced dialogue summarization. We will be using the [Flan-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5) model as it provides a high quality instruction tuned model at various sizes. Flan-T5 can summarize text out of the box, but in this notebook we will see how fine-tuning on a high quality dataset can improve its performance for a specific task. Specifically, we will be using the [DialogSum](https://huggingface.co/datasets/knkarthick/dialogsum) dataset from HuggingFace which contains chunks of dialogue and associated summarizations of the dialogue." ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "## Setup\n", "\n", "First up, lets make sure we install some libraries which are needed for this notebook. After the installation, we will import the necessary packages for the notebook" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: torch==1.13.1 in /opt/conda/lib/python3.7/site-packages (1.13.1)\n", "Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.7/site-packages (from torch==1.13.1) (4.5.0)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /opt/conda/lib/python3.7/site-packages (from torch==1.13.1) (11.7.99)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /opt/conda/lib/python3.7/site-packages (from torch==1.13.1) (11.7.99)\n", "Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /opt/conda/lib/python3.7/site-packages (from torch==1.13.1) (8.5.0.96)\n", "Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /opt/conda/lib/python3.7/site-packages (from torch==1.13.1) (11.10.3.66)\n", "Requirement already satisfied: setuptools in /opt/conda/lib/python3.7/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch==1.13.1) (59.3.0)\n", "Requirement already satisfied: wheel in /opt/conda/lib/python3.7/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch==1.13.1) (0.40.0)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n", "Collecting torchdata\n", " Using cached torchdata-0.5.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.6 MB)\n", "Requirement already satisfied: torch==1.13.1 in /opt/conda/lib/python3.7/site-packages (from torchdata) (1.13.1)\n", "Requirement already satisfied: requests in /opt/conda/lib/python3.7/site-packages (from torchdata) (2.28.2)\n", "Collecting portalocker>=2.0.0\n", " Using cached portalocker-2.7.0-py2.py3-none-any.whl (15 kB)\n", "Requirement already satisfied: urllib3>=1.25 in /opt/conda/lib/python3.7/site-packages (from torchdata) (1.26.15)\n", "Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /opt/conda/lib/python3.7/site-packages (from torch==1.13.1->torchdata) (11.10.3.66)\n", "Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /opt/conda/lib/python3.7/site-packages (from torch==1.13.1->torchdata) (8.5.0.96)\n", "Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.7/site-packages (from torch==1.13.1->torchdata) (4.5.0)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /opt/conda/lib/python3.7/site-packages (from torch==1.13.1->torchdata) (11.7.99)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /opt/conda/lib/python3.7/site-packages (from torch==1.13.1->torchdata) (11.7.99)\n", "Requirement already satisfied: setuptools in /opt/conda/lib/python3.7/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch==1.13.1->torchdata) (59.3.0)\n", "Requirement already satisfied: wheel in /opt/conda/lib/python3.7/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch==1.13.1->torchdata) (0.40.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.7/site-packages (from requests->torchdata) (2.0.4)\n", "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.7/site-packages (from requests->torchdata) (2.8)\n", "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests->torchdata) (2022.12.7)\n", "Installing collected packages: portalocker, torchdata\n", "Successfully installed portalocker-2.7.0 torchdata-0.5.1\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "%pip install torch==1.13.1\n", "%pip install torchdata\n", "%pip install transformers==4.27.2 --quiet\n", "%pip install torch==1.13.1 --quiet\n", "%pip install py7zr==0.20.4 --quiet\n", "%pip install datasets==2.9.0 --quiet\n", "%pip install sentencepiece==0.1.97 --quiet\n", "%pip install evaluate==0.4.0 --quiet\n", "%pip install accelerate==0.17.0\n", "%pip install rouge_score==0.1.2 --quiet\n", "%pip install loralib==0.1.1 --quiet" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "tags": [] }, "outputs": [], "source": [ "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments, Trainer, GenerationConfig\n", "from datasets import load_dataset\n", "import datasets\n", "import torch\n", "import time\n", "import evaluate\n", "import numpy as np\n", "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "# Load Flan-T5 Model\n", "\n", "We can load the pre-trained Flan-T5 model directly from HuggingFace. Notice that we will be using the [small version](https://huggingface.co/google/flan-t5-small) of flan." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained(\"google/flan-t5-base\")\n", "original_model = AutoModelForSeq2SeqLM.from_pretrained(\"google/flan-t5-base\", device_map=\"auto\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total number of model parameters: 247577856\n" ] } ], "source": [ "params = sum(p.numel() for p in original_model.parameters())\n", "print(f'Total number of model parameters: {params}')" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "# Load Dataset\n", "\n", "The DialogSum dataset can also be loaded directly from HuggingFace. There are ~15k examples of dialogue in this dataset with associated human summarizations of these datasets" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration knkarthick--dialogsum-6d41e9a7b96e340e\n", "Found cached dataset csv (/root/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-6d41e9a7b96e340e/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "85bf17721e8c410f999844a8f113e3de", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00\")\n", "\n", "instruct_model = AutoModelForSeq2SeqLM.from_pretrained(\"./flan-dialogue-summary-checkpoint\", device_map=\"auto\")\n", "\n", "original_model = AutoModelForSeq2SeqLM.from_pretrained(\"google/flan-t5-small\", device_map=\"auto\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Qualitative Results\n", "\n", "As with many GenAI applications, a qualitative approach where you ask yourself the question \"is my model behaving the way it is supposed to?\" is usually a good starting point. In the example below (the same one we started this notebook with), you can see how the fine-tuned model is able to create a reasonable summary of the dialogue compared to the original inability to understand what is being asked of the model." ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prompt:\n", "--------------------------\n", "Summarize the following conversation.\n", "\n", "#Person1#: What's wrong with you? Why are you scratching so much?\n", "#Person2#: I feel itchy! I can't stand it anymore! I think I may be coming down with something. I feel lightheaded and weak.\n", "#Person1#: Let me have a look. Whoa! Get away from me!\n", "#Person2#: What's wrong?\n", "#Person1#: I think you have chicken pox! You are contagious! Get away! Don't breathe on me!\n", "#Person2#: Maybe it's just a rash or an allergy! We can't be sure until I see a doctor.\n", "#Person1#: Well in the meantime you are a biohazard! I didn't get it when I was a kid and I've heard that you can even die if you get it as an adult!\n", "#Person2#: Are you serious? You always blow things out of proportion. In any case, I think I'll go take an oatmeal bath.\n", "\n", "Summary:\n", "--------------------------\n", "Baseline human summary from original dataset: #Person1# thinks #Person2# has chicken pox and warns #Person2# about the possible hazards but #Person2# thinks it will be fine.\n", "Original Flan-T5 summary: Talk to your doctor.\n", "Instruct model summary: #Person2# feels itchy and feels lightheaded. #Person1# thinks #Person2# has chicken pox and #Person2# is a biohazard. #Person1# thinks #Person2# is a biohazard and #Person2# will go take an oatmeal bath.\n" ] } ], "source": [ "idx = 20\n", "diag = dataset['test'][idx]['dialogue']\n", "baseline_human_summary = dataset['test'][idx]['summary']\n", "\n", "prompt = f'Summarize the following conversation.\\n\\n{diag}\\n\\nSummary:'\n", "input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n", "\n", "original_model_outputs = original_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))\n", "original_model_text_output = tokenizer.decode(original_model_outputs[0], skip_special_tokens=True)\n", "\n", "instruct_model_outputs = instruct_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))\n", "instruct_model_text_output = tokenizer.decode(instruct_model_outputs[0], skip_special_tokens=True)\n", "\n", "print(f'Prompt:\\n--------------------------\\n{prompt}\\n--------------------------')\n", "print(f'Baseline human summary from original dataset: {baseline_human_summary}')\n", "print(f'Original Flan-T5 summary: {original_model_text_output}')\n", "print(f'Instruct model summary: {instruct_model_text_output}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Quantitative Results with ROUGE Metric\n", "\n", "The [ROUGE metric](https://en.wikipedia.org/wiki/ROUGE_(metric)) helps quantify the validity of summarizations produced by models. It compares summarizations to a \"baseline\" summary which is usually created by a human. While not perfect, it does give an indication to the overall increase in summarization effectiveness that we have accomplished by fine-tuning." ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "tags": [] }, "outputs": [], "source": [ "rouge = evaluate.load('rouge')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluate a Subsection of Summaries" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "tags": [] }, "outputs": [], "source": [ "# again, for the sake of time, we will only be generating 10 summarizations with each model\n", "# outside of the lab, a good exercise is to increase the number of validation summaries generated\n", "dialogues = dataset['test'][0:10]['dialogue']\n", "human_baseline_summaries = dataset['test'][0:10]['summary']\n", "\n", "original_model_summaries = []\n", "instruct_model_summaries = []\n", "\n", "for ind, diag in enumerate(dialogues):\n", " prompt = f'Summarize the following conversation.\\n\\nConversation:\\n{diag}\\n\\nSummary:'\n", " input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n", "\n", " original_model_outputs = original_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200))\n", " original_model_text_output = tokenizer.decode(original_model_outputs[0], skip_special_tokens=True)\n", " original_model_summaries.append(original_model_text_output)\n", "\n", " instruct_model_outputs = instruct_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200))\n", " instruct_model_text_output = tokenizer.decode(instruct_model_outputs[0], skip_special_tokens=True)\n", " instruct_model_summaries.append(instruct_model_text_output)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "tags": [] }, "outputs": [], "source": [ "original_model_results = rouge.compute(\n", " predictions=original_model_summaries,\n", " references=human_baseline_summaries[0:len(original_model_summaries)],\n", " use_aggregator=True,\n", " use_stemmer=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "tags": [] }, "outputs": [], "source": [ "instruct_model_results = rouge.compute(\n", " predictions=instruct_model_summaries,\n", " references=human_baseline_summaries[0:len(instruct_model_summaries)],\n", " use_aggregator=True,\n", " use_stemmer=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'rouge1': 0.07518181818181817,\n", " 'rouge2': 0.033229813664596264,\n", " 'rougeL': 0.06836363636363635,\n", " 'rougeLsum': 0.06836363636363635}" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "original_model_results" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'rouge1': 0.4015906463624618,\n", " 'rouge2': 0.17568542724181807,\n", " 'rougeL': 0.2874569966059625,\n", " 'rougeLsum': 0.2886327613084294}" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "instruct_model_results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evalute the Full Dataset\n", "\n", "The file called \"diag-summary-training-results.csv\" contains a pre-populated list of all model results which we can use to evaluate on a larger section of data. The results show substantial improvement in all ROUGE metrics!" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "tags": [] }, "outputs": [], "source": [ "import pandas as pd\n", "results = pd.read_csv(\"diag-summary-training-results-with-peft.csv\")\n", "\n", "human_baseline_summaries = results['human_baseline_summaries'].values\n", "original_model_summaries = results['original_model_summaries'].values\n", "instruct_model_summaries = results['instruct_model_summaries'].values\n", "#peft_model_summaries = results['peft_model_summaries'].values" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "tags": [] }, "outputs": [], "source": [ "original_model_results = rouge.compute(\n", " predictions=original_model_summaries,\n", " references=human_baseline_summaries[0:len(original_model_summaries)],\n", " use_aggregator=True,\n", " use_stemmer=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "tags": [] }, "outputs": [], "source": [ "instruct_model_results = rouge.compute(\n", " predictions=instruct_model_summaries,\n", " references=human_baseline_summaries[0:len(instruct_model_summaries)],\n", " use_aggregator=True,\n", " use_stemmer=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'rouge1': 0.2334158581572823,\n", " 'rouge2': 0.07603964187010573,\n", " 'rougeL': 0.20145520923859048,\n", " 'rougeLsum': 0.20145899339006135}" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "original_model_results" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'rouge1': 0.42161291557556113,\n", " 'rouge2': 0.18035380596301792,\n", " 'rougeL': 0.3384439349963909,\n", " 'rougeLsum': 0.33835653595561666}" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "instruct_model_results" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "rouge1 absolute percentage improvement of instruct model over human baseline: 18.82%\n", "rouge2 absolute percentage improvement of instruct model over human baseline: 10.43%\n", "rougeL absolute percentage improvement of instruct model over human baseline: 13.70%\n", "rougeLsum absolute percentage improvement of instruct model over human baseline: 13.69%\n" ] } ], "source": [ "improvement = (np.array(list(instruct_model_results.values())) - np.array(list(original_model_results.values())))\n", "for key, value in zip(instruct_model_results.keys(), improvement):\n", " print(f'{key} absolute percentage improvement of instruct model over human baseline: {value*100:.2f}%')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# PEFT" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting git+https://github.com/huggingface/peft.git\n", " Cloning https://github.com/huggingface/peft.git to /tmp/pip-req-build-onzurk0v\n", " Running command git clone --filter=blob:none --quiet https://github.com/huggingface/peft.git /tmp/pip-req-build-onzurk0v\n", " Resolved https://github.com/huggingface/peft.git to commit cc82b674b5db38b9a393463d38afe66e8f48ac1c\n", " Installing build dependencies ... \u001b[?25ldone\n", "\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n", "\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n", "\u001b[?25hRequirement already satisfied: transformers in /opt/conda/lib/python3.7/site-packages (from peft==0.3.0.dev0) (4.27.2)\n", "Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.7/site-packages (from peft==0.3.0.dev0) (23.0)\n", "Requirement already satisfied: torch>=1.13.0 in /opt/conda/lib/python3.7/site-packages (from peft==0.3.0.dev0) (1.13.1)\n", "Requirement already satisfied: pyyaml in /opt/conda/lib/python3.7/site-packages (from peft==0.3.0.dev0) (6.0)\n", "Requirement already satisfied: psutil in /opt/conda/lib/python3.7/site-packages (from peft==0.3.0.dev0) (5.6.7)\n", "Requirement already satisfied: accelerate in /opt/conda/lib/python3.7/site-packages (from peft==0.3.0.dev0) (0.17.0)\n", "Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.7/site-packages (from peft==0.3.0.dev0) (1.21.6)\n", "Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /opt/conda/lib/python3.7/site-packages (from torch>=1.13.0->peft==0.3.0.dev0) (8.5.0.96)\n", "Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.7/site-packages (from torch>=1.13.0->peft==0.3.0.dev0) (4.5.0)\n", "Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /opt/conda/lib/python3.7/site-packages (from torch>=1.13.0->peft==0.3.0.dev0) (11.10.3.66)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /opt/conda/lib/python3.7/site-packages (from torch>=1.13.0->peft==0.3.0.dev0) (11.7.99)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /opt/conda/lib/python3.7/site-packages (from torch>=1.13.0->peft==0.3.0.dev0) (11.7.99)\n", "Requirement already satisfied: setuptools in /opt/conda/lib/python3.7/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.13.0->peft==0.3.0.dev0) (59.3.0)\n", "Requirement already satisfied: wheel in /opt/conda/lib/python3.7/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.13.0->peft==0.3.0.dev0) (0.40.0)\n", "Requirement already satisfied: tqdm>=4.27 in /opt/conda/lib/python3.7/site-packages (from transformers->peft==0.3.0.dev0) (4.65.0)\n", "Requirement already satisfied: importlib-metadata in /opt/conda/lib/python3.7/site-packages (from transformers->peft==0.3.0.dev0) (6.1.0)\n", "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /opt/conda/lib/python3.7/site-packages (from transformers->peft==0.3.0.dev0) (0.13.3)\n", "Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.7/site-packages (from transformers->peft==0.3.0.dev0) (2022.10.31)\n", "Requirement already satisfied: requests in /opt/conda/lib/python3.7/site-packages (from transformers->peft==0.3.0.dev0) (2.28.2)\n", "Requirement already satisfied: filelock in /opt/conda/lib/python3.7/site-packages (from transformers->peft==0.3.0.dev0) (3.0.12)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /opt/conda/lib/python3.7/site-packages (from transformers->peft==0.3.0.dev0) (0.13.4)\n", "Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata->transformers->peft==0.3.0.dev0) (3.15.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.7/site-packages (from requests->transformers->peft==0.3.0.dev0) (2.0.4)\n", "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests->transformers->peft==0.3.0.dev0) (1.26.15)\n", "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.7/site-packages (from requests->transformers->peft==0.3.0.dev0) (2.8)\n", "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests->transformers->peft==0.3.0.dev0) (2022.12.7)\n", "Building wheels for collected packages: peft\n", " Building wheel for peft (pyproject.toml) ... \u001b[?25ldone\n", "\u001b[?25h Created wheel for peft: filename=peft-0.3.0.dev0-py3-none-any.whl size=49836 sha256=3382c9f07fbb4a4f5501d9eb9efe9fbabb265330ce34b610d6cba99ffcb16312\n", " Stored in directory: /tmp/pip-ephem-wheel-cache-eych99wd/wheels/13/73/6d/ff27a3703d8bad21d7e0c24cbd9dde5d7ae78f756405707a0c\n", "Successfully built peft\n", "Installing collected packages: peft\n", "Successfully installed peft-0.3.0.dev0\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "%pip install git+https://github.com/huggingface/peft.git" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "tags": [] }, "outputs": [], "source": [ "# re-importing as the rest of this notebook will likely move to a new notebook. you're welcome!\n", "\n", "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments, Trainer, GenerationConfig\n", "from datasets import load_dataset\n", "import datasets\n", "import torch\n", "import time\n", "import evaluate\n", "import numpy as np\n", "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "# Load base Flan-T5 model and tokenizer\n", "\n", "We can load the pre-trained Flan-T5 model directly from HuggingFace. Notice that we will be using the [base version](https://huggingface.co/google/flan-t5-base) of flan to create the PEFT adapter. We will compare this to an fully-fine tuned instruct model." ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "tags": [] }, "outputs": [], "source": [ "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "\n", "original_model = AutoModelForSeq2SeqLM.from_pretrained(\"google/flan-t5-base\", device_map=\"auto\")\n", "tokenizer = AutoTokenizer.from_pretrained(\"google/flan-t5-base\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Add PEFT layer/parameter adapters\n", "Note the rank (`r`) hyper-parameter below." ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "trainable params: 3538944 || all params: 251116800 || trainable%: 1.4092820552029972\n" ] } ], "source": [ "from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType\n", "\n", "# Define LoRA Config\n", "lora_config = LoraConfig(\n", " r=32, # rank\n", " lora_alpha=32,\n", " target_modules=[\"q\", \"v\"],\n", " lora_dropout=0.05,\n", " bias=\"none\",\n", " task_type=TaskType.SEQ_2_SEQ_LM\n", ")\n", "\n", "# Add LoRA adapter layers/parameters \n", "peft_model = get_peft_model(original_model, lora_config)\n", "peft_model.print_trainable_parameters()" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "from transformers import DataCollatorForSeq2Seq\n", "from transformers import AutoTokenizer\n", "\n", "# we want to ignore tokenizer pad token in the loss\n", "label_pad_token_id = -100\n", "# Data collator\n", "data_collator = DataCollatorForSeq2Seq(\n", " tokenizer,\n", " model=peft_model,\n", " label_pad_token_id=label_pad_token_id,\n", " pad_to_multiple_of=8\n", ")" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "# Load Dataset\n", "\n", "The DialogSum dataset can also be loaded directly from HuggingFace. There are ~15k examples of dialogue in this dataset with associated human summarizations of these datasets" ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:datasets.builder:Using custom data configuration knkarthick--dialogsum-6d41e9a7b96e340e\n", "WARNING:datasets.builder:Found cached dataset csv (/root/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-6d41e9a7b96e340e/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "59d64ef3561d4b2d8e4c8bba288d6d17", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
human_baseline_summariesoriginal_model_summariesinstruct_model_summariespeft_model_summaries
0Ms. Dawson helps #Person1# to write a memo to ...The memo is to be distributed to all employees...#Person1# asks Ms. Dawson to take a dictation ...#Person1# asks Ms. Dawson to take a dictation ...
1In order to prevent employees from wasting tim...The memo is to be distributed to all employees...#Person1# asks Ms. Dawson to take a dictation ...#Person1# asks Ms. Dawson to take a dictation ...
2Ms. Dawson takes a dictation for #Person1# abo...The memo is to be distributed to all employees...#Person1# asks Ms. Dawson to take a dictation ...#Person1# asks Ms. Dawson to take a dictation ...
3#Person2# arrives late because of traffic jam....The traffic jam at the Carrefour intersection ...#Person2# got stuck in traffic again. #Person1...#Person2# got stuck in traffic and got stuck i...
4#Person2# decides to follow #Person1#'s sugges...The traffic jam at the Carrefour intersection ...#Person2# got stuck in traffic again. #Person1...#Person2# got stuck in traffic and got stuck i...
5#Person2# complains to #Person1# about the tra...The traffic jam at the Carrefour intersection ...#Person2# got stuck in traffic again. #Person1...#Person2# got stuck in traffic and got stuck i...
6#Person1# tells Kate that Masha and Hero get d...Masha and Hero are getting divorced.Masha and Hero are getting divorced. Kate can'...Kate tells #Person2# Masha and Hero are gettin...
7#Person1# tells Kate that Masha and Hero are g...Masha and Hero are getting divorced.Masha and Hero are getting divorced. Kate can'...Kate tells #Person2# Masha and Hero are gettin...
8#Person1# and Kate talk about the divorce betw...Masha and Hero are getting divorced.Masha and Hero are getting divorced. Kate can'...Kate tells #Person2# Masha and Hero are gettin...
9#Person1# and Brian are at the birthday party ...Brian's birthday is coming up.Brian's birthday is coming. #Person1# invites ...Brian remembers his birthday and invites #Pers...
\n", "" ], "text/plain": [ " human_baseline_summaries \\\n", "0 Ms. Dawson helps #Person1# to write a memo to ... \n", "1 In order to prevent employees from wasting tim... \n", "2 Ms. Dawson takes a dictation for #Person1# abo... \n", "3 #Person2# arrives late because of traffic jam.... \n", "4 #Person2# decides to follow #Person1#'s sugges... \n", "5 #Person2# complains to #Person1# about the tra... \n", "6 #Person1# tells Kate that Masha and Hero get d... \n", "7 #Person1# tells Kate that Masha and Hero are g... \n", "8 #Person1# and Kate talk about the divorce betw... \n", "9 #Person1# and Brian are at the birthday party ... \n", "\n", " original_model_summaries \\\n", "0 The memo is to be distributed to all employees... \n", "1 The memo is to be distributed to all employees... \n", "2 The memo is to be distributed to all employees... \n", "3 The traffic jam at the Carrefour intersection ... \n", "4 The traffic jam at the Carrefour intersection ... \n", "5 The traffic jam at the Carrefour intersection ... \n", "6 Masha and Hero are getting divorced. \n", "7 Masha and Hero are getting divorced. \n", "8 Masha and Hero are getting divorced. \n", "9 Brian's birthday is coming up. \n", "\n", " instruct_model_summaries \\\n", "0 #Person1# asks Ms. Dawson to take a dictation ... \n", "1 #Person1# asks Ms. Dawson to take a dictation ... \n", "2 #Person1# asks Ms. Dawson to take a dictation ... \n", "3 #Person2# got stuck in traffic again. #Person1... \n", "4 #Person2# got stuck in traffic again. #Person1... \n", "5 #Person2# got stuck in traffic again. #Person1... \n", "6 Masha and Hero are getting divorced. Kate can'... \n", "7 Masha and Hero are getting divorced. Kate can'... \n", "8 Masha and Hero are getting divorced. Kate can'... \n", "9 Brian's birthday is coming. #Person1# invites ... \n", "\n", " peft_model_summaries \n", "0 #Person1# asks Ms. Dawson to take a dictation ... \n", "1 #Person1# asks Ms. Dawson to take a dictation ... \n", "2 #Person1# asks Ms. Dawson to take a dictation ... \n", "3 #Person2# got stuck in traffic and got stuck i... \n", "4 #Person2# got stuck in traffic and got stuck i... \n", "5 #Person2# got stuck in traffic and got stuck i... \n", "6 Kate tells #Person2# Masha and Hero are gettin... \n", "7 Kate tells #Person2# Masha and Hero are gettin... \n", "8 Kate tells #Person2# Masha and Hero are gettin... \n", "9 Brian remembers his birthday and invites #Pers... " ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# import pandas as pd\n", " \n", "# zipped_summaries = list(zip(human_baseline_summaries, original_model_summaries, instruct_model_summaries, peft_model_summaries))\n", " \n", "# df = pd.DataFrame(zipped_summaries, columns = ['human_baseline_summaries', 'original_model_summaries', 'instruct_model_summaries', 'peft_model_summaries'])\n", "# df.to_csv(\"diag-summary-training-results-with-peft.csv\")\n", "# df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Compute ROUGE score for subset of data" ] }, { "cell_type": "code", "execution_count": 72, "metadata": { "tags": [] }, "outputs": [], "source": [ "rouge = evaluate.load('rouge')" ] }, { "cell_type": "code", "execution_count": 73, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'rouge1': 0.2641800976800977,\n", " 'rouge2': 0.10762431771127423,\n", " 'rougeL': 0.22565120065120065,\n", " 'rougeLsum': 0.2302477614977615}" ] }, "execution_count": 73, "metadata": {}, "output_type": "execute_result" } ], "source": [ "original_model_results = rouge.compute(\n", " predictions=original_model_summaries,\n", " references=human_baseline_summaries[0:len(original_model_summaries)],\n", " use_aggregator=True,\n", " use_stemmer=True,\n", ")\n", "original_model_results" ] }, { "cell_type": "code", "execution_count": 74, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'rouge1': 0.4015906463624618,\n", " 'rouge2': 0.17568542724181807,\n", " 'rougeL': 0.2874569966059625,\n", " 'rougeLsum': 0.2886327613084294}" ] }, "execution_count": 74, "metadata": {}, "output_type": "execute_result" } ], "source": [ "instruct_model_results = rouge.compute(\n", " predictions=instruct_model_summaries,\n", " references=human_baseline_summaries[0:len(instruct_model_summaries)],\n", " use_aggregator=True,\n", " use_stemmer=True,\n", ")\n", "instruct_model_results" ] }, { "cell_type": "code", "execution_count": 75, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'rouge1': 0.3672148825926539,\n", " 'rouge2': 0.13054070743902557,\n", " 'rougeL': 0.2742648318851858,\n", " 'rougeLsum': 0.2740649551203669}" ] }, "execution_count": 75, "metadata": {}, "output_type": "execute_result" } ], "source": [ "peft_model_results = rouge.compute(\n", " predictions=peft_model_summaries,\n", " references=human_baseline_summaries[0:len(peft_model_summaries)],\n", " use_aggregator=True,\n", " use_stemmer=True,\n", ")\n", "peft_model_results" ] }, { "cell_type": "code", "execution_count": 86, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'rouge1': 0.2334158581572823,\n", " 'rouge2': 0.07603964187010573,\n", " 'rougeL': 0.20145520923859048,\n", " 'rougeLsum': 0.20145899339006135}" ] }, "execution_count": 86, "metadata": {}, "output_type": "execute_result" } ], "source": [ "flan_t5_base" ] }, { "cell_type": "code", "execution_count": 90, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'rouge1': 0.42161291557556113,\n", " 'rouge2': 0.18035380596301792,\n", " 'rougeL': 0.3384439349963909,\n", " 'rougeLsum': 0.33835653595561666}" ] }, "execution_count": 90, "metadata": {}, "output_type": "execute_result" } ], "source": [ "flan_t5_base_instruct_full" ] }, { "cell_type": "code", "execution_count": 88, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'rouge1': 0.40810631575616746,\n", " 'rouge2': 0.1633255794568712,\n", " 'rougeL': 0.32507074586565354,\n", " 'rougeLsum': 0.3248950182867091}" ] }, "execution_count": 88, "metadata": {}, "output_type": "execute_result" } ], "source": [ "flan_t5_base_instruct_peft" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evalute the Full Dataset\n", "\n", "The file called \"diag-summary-training-results-with-peft.csv\" contains a pre-populated list of all model results which we can use to evaluate on a larger section of data. The results show substantial improvement in all ROUGE metrics!" ] }, { "cell_type": "code", "execution_count": 76, "metadata": { "tags": [] }, "outputs": [], "source": [ "import pandas as pd\n", "results = pd.read_csv(\"diag-summary-training-results-with-peft.csv\")\n", "human_baseline_summaries = results['human_baseline_summaries'].values\n", "original_model_summaries = results['original_model_summaries'].values\n", "instruct_model_summaries = results['instruct_model_summaries'].values\n", "peft_model_summaries = results['peft_model_summaries'].values" ] }, { "cell_type": "code", "execution_count": 77, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'rouge1': 0.2334158581572823,\n", " 'rouge2': 0.07603964187010573,\n", " 'rougeL': 0.20145520923859048,\n", " 'rougeLsum': 0.20145899339006135}" ] }, "execution_count": 77, "metadata": {}, "output_type": "execute_result" } ], "source": [ "original_model_results = rouge.compute(\n", " predictions=original_model_summaries,\n", " references=human_baseline_summaries[0:len(original_model_summaries)],\n", " use_aggregator=True,\n", " use_stemmer=True,\n", ")\n", "original_model_results" ] }, { "cell_type": "code", "execution_count": 78, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'rouge1': 0.42161291557556113,\n", " 'rouge2': 0.18035380596301792,\n", " 'rougeL': 0.3384439349963909,\n", " 'rougeLsum': 0.33835653595561666}" ] }, "execution_count": 78, "metadata": {}, "output_type": "execute_result" } ], "source": [ "instruct_model_results = rouge.compute(\n", " predictions=instruct_model_summaries,\n", " references=human_baseline_summaries[0:len(instruct_model_summaries)],\n", " use_aggregator=True,\n", " use_stemmer=True,\n", ")\n", "instruct_model_results" ] }, { "cell_type": "code", "execution_count": 79, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'rouge1': 0.40810631575616746,\n", " 'rouge2': 0.1633255794568712,\n", " 'rougeL': 0.32507074586565354,\n", " 'rougeLsum': 0.3248950182867091}" ] }, "execution_count": 79, "metadata": {}, "output_type": "execute_result" } ], "source": [ "peft_model_results = rouge.compute(\n", " predictions=peft_model_summaries,\n", " references=human_baseline_summaries[0:len(peft_model_summaries)],\n", " use_aggregator=True,\n", " use_stemmer=True,\n", ")\n", "peft_model_results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Calculate improvement of PEFT over original" ] }, { "cell_type": "code", "execution_count": 80, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "rouge1 absolute percentage improvement of peft model over original model: 17.47%\n", "rouge2 absolute percentage improvement of peft model over original model: 8.73%\n", "rougeL absolute percentage improvement of peft model over original model: 12.36%\n", "rougeLsum absolute percentage improvement of peft model over original model: 12.34%\n" ] } ], "source": [ "improvement = (np.array(list(peft_model_results.values())) - np.array(list(original_model_results.values())))\n", "for key, value in zip(peft_model_results.keys(), improvement):\n", " print(f'{key} absolute percentage improvement of peft model over original model: {value*100:.2f}%')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Calculate improvement of PEFT over Instruct" ] }, { "cell_type": "code", "execution_count": 81, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "rouge1 absolute percentage improvement of peft model over instruct model: -1.35%\n", "rouge2 absolute percentage improvement of peft model over instruct model: -1.70%\n", "rougeL absolute percentage improvement of peft model over instruct model: -1.34%\n", "rougeLsum absolute percentage improvement of peft model over instruct model: -1.35%\n" ] } ], "source": [ "improvement = (np.array(list(peft_model_results.values())) - np.array(list(instruct_model_results.values())))\n", "for key, value in zip(peft_model_results.keys(), improvement):\n", " print(f'{key} absolute percentage improvement of peft model over instruct model: {value*100:.2f}%')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Release Resources" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%html\n", "\n", "

Shutting down your kernel for this notebook to release resources.

\n", "\n", " \n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "availableInstances": [ { "_defaultOrder": 0, "_isFastLaunch": true, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 4, "name": "ml.t3.medium", "vcpuNum": 2 }, { "_defaultOrder": 1, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 8, "name": "ml.t3.large", "vcpuNum": 2 }, { "_defaultOrder": 2, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.t3.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 3, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.t3.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 4, "_isFastLaunch": true, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 8, "name": "ml.m5.large", "vcpuNum": 2 }, { "_defaultOrder": 5, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.m5.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 6, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.m5.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 7, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 64, "name": "ml.m5.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 8, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 128, "name": "ml.m5.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 9, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 192, "name": "ml.m5.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 10, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 256, "name": "ml.m5.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 11, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 384, "name": "ml.m5.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 12, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 8, "name": "ml.m5d.large", "vcpuNum": 2 }, { "_defaultOrder": 13, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.m5d.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 14, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.m5d.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 15, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 64, "name": "ml.m5d.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 16, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 128, "name": "ml.m5d.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 17, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 192, "name": "ml.m5d.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 18, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 256, "name": "ml.m5d.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 19, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 384, "name": "ml.m5d.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 20, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": true, "memoryGiB": 0, "name": "ml.geospatial.interactive", "supportedImageNames": [ "sagemaker-geospatial-v1-0" ], "vcpuNum": 0 }, { "_defaultOrder": 21, "_isFastLaunch": true, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 4, "name": "ml.c5.large", "vcpuNum": 2 }, { "_defaultOrder": 22, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 8, "name": "ml.c5.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 23, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.c5.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 24, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.c5.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 25, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 72, "name": "ml.c5.9xlarge", "vcpuNum": 36 }, { "_defaultOrder": 26, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 96, "name": "ml.c5.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 27, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 144, "name": "ml.c5.18xlarge", "vcpuNum": 72 }, { "_defaultOrder": 28, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 192, "name": "ml.c5.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 29, "_isFastLaunch": true, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.g4dn.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 30, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.g4dn.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 31, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 64, "name": "ml.g4dn.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 32, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 128, "name": "ml.g4dn.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 33, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 4, "hideHardwareSpecs": false, "memoryGiB": 192, "name": "ml.g4dn.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 34, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 256, "name": "ml.g4dn.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 35, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 61, "name": "ml.p3.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 36, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 4, "hideHardwareSpecs": false, "memoryGiB": 244, "name": "ml.p3.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 37, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 8, "hideHardwareSpecs": false, "memoryGiB": 488, "name": "ml.p3.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 38, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 8, "hideHardwareSpecs": false, "memoryGiB": 768, "name": "ml.p3dn.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 39, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.r5.large", "vcpuNum": 2 }, { "_defaultOrder": 40, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.r5.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 41, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 64, "name": "ml.r5.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 42, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 128, "name": "ml.r5.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 43, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 256, "name": "ml.r5.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 44, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 384, "name": "ml.r5.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 45, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 512, "name": "ml.r5.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 46, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 768, "name": "ml.r5.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 47, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.g5.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 48, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.g5.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 49, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 64, "name": "ml.g5.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 50, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 128, "name": "ml.g5.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 51, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 256, "name": "ml.g5.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 52, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 4, "hideHardwareSpecs": false, "memoryGiB": 192, "name": "ml.g5.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 53, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 4, "hideHardwareSpecs": false, "memoryGiB": 384, "name": "ml.g5.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 54, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 8, "hideHardwareSpecs": false, "memoryGiB": 768, "name": "ml.g5.48xlarge", "vcpuNum": 192 } ], "colab": { "name": "Fine-tune a language model", "provenance": [] }, "instance_type": "ml.g5.xlarge", "kernelspec": { "display_name": "Python 3 (Data Science)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/datascience-1.0" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" }, "vscode": { "interpreter": { "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" } } }, "nbformat": 4, "nbformat_minor": 4 }