{ "cells": [ { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "# Fine-Tune a Generative AI Model for Dialogue Summarization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook we will see how to fine tune an existing LLM from HuggingFace for enhanced dialogue summarization. We will be using the [Flan-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5) model as it provides a high quality instruction tuned model at various sizes. Flan-T5 can summarize text out of the box, but in this notebook we will see how fine-tuning on a high quality dataset can improve its performance for a specific task. Specifically, we will be using the [DialogSum](https://huggingface.co/datasets/knkarthick/dialogsum) dataset from HuggingFace which contains chunks of dialogue and associated summarizations of the dialogue." ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "## Setup\n", "\n", "First up, lets make sure we install some libraries which are needed for this notebook. After the installation, we will import the necessary packages for the notebook" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: The directory '/root/.cache/pip' or its parent directory is not owned or is not writable by the current user. The cache has been disabled. Check the permissions and owner of that directory. If executing pip with sudo, you should use sudo's -H flag.\u001b[0m\u001b[33m\n", "\u001b[0m\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "pytest-astropy 0.8.0 requires pytest-cov>=2.0, which is not installed.\n", "pytest-astropy 0.8.0 requires pytest-filter-subpackage>=0.1, which is not installed.\n", "sagemaker 2.144.0 requires importlib-metadata<5.0,>=1.4.0, but you have importlib-metadata 6.1.0 which is incompatible.\n", "sagemaker 2.144.0 requires PyYAML==5.4.1, but you have pyyaml 6.0 which is incompatible.\n", "docker-compose 1.29.2 requires PyYAML<6,>=3.10, but you have pyyaml 6.0 which is incompatible.\u001b[0m\u001b[31m\n", "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n", "\u001b[33mWARNING: The directory '/root/.cache/pip' or its parent directory is not owned or is not writable by the current user. The cache has been disabled. Check the permissions and owner of that directory. If executing pip with sudo, you should use sudo's -H flag.\u001b[0m\u001b[33m\n", "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n", "\u001b[33mWARNING: The directory '/root/.cache/pip' or its parent directory is not owned or is not writable by the current user. The cache has been disabled. Check the permissions and owner of that directory. If executing pip with sudo, you should use sudo's -H flag.\u001b[0m\u001b[33m\n", "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n", "\u001b[33mWARNING: The directory '/root/.cache/pip' or its parent directory is not owned or is not writable by the current user. The cache has been disabled. Check the permissions and owner of that directory. If executing pip with sudo, you should use sudo's -H flag.\u001b[0m\u001b[33m\n", "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n", "\u001b[33mWARNING: The directory '/root/.cache/pip' or its parent directory is not owned or is not writable by the current user. The cache has been disabled. Check the permissions and owner of that directory. If executing pip with sudo, you should use sudo's -H flag.\u001b[0m\u001b[33m\n", "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n", "\u001b[33mWARNING: The directory '/root/.cache/pip' or its parent directory is not owned or is not writable by the current user. The cache has been disabled. Check the permissions and owner of that directory. If executing pip with sudo, you should use sudo's -H flag.\u001b[0m\u001b[33m\n", "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n", "\u001b[33mWARNING: The directory '/root/.cache/pip' or its parent directory is not owned or is not writable by the current user. The cache has been disabled. Check the permissions and owner of that directory. If executing pip with sudo, you should use sudo's -H flag.\u001b[0m\u001b[33m\n", "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "%pip install transformers==4.27.2 --quiet\n", "%pip install torch==1.13.1 --quiet\n", "%pip install py7zr --quiet\n", "%pip install datasets --quiet\n", "%pip install sentencepiece --quiet\n", "%pip install evaluate --quiet\n", "%pip install rouge_score --quiet" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "tags": [] }, "outputs": [], "source": [ "from transformers import AutoTokenizer, T5ForConditionalGeneration, TrainingArguments, Trainer, GenerationConfig\n", "from datasets import load_dataset\n", "import datasets\n", "import torch\n", "import time\n", "import evaluate\n", "import numpy as np\n", "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "# Load Flan-T5 Model\n", "\n", "We can load the pre-trained Flan-T5 model directly from HuggingFace. Notice that we will be using the [base version](https://huggingface.co/google/flan-t5-base) of flan. This model version has ~247 million model parameters which makes it small compared to other LLMs. For higher quality results, we recommend looking into the larger versions of this model." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c671e91bb5ff42e78021315f86d5520a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading (…)okenizer_config.json: 0%| | 0.00/2.54k [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "97babcc3a18d48468a8a064f89ef949d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading spiece.model: 0%| | 0.00/792k [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "44405f2d472b492a8a5e533e80f844db", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading (…)/main/tokenizer.json: 0%| | 0.00/2.42M [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b8e7da5661094d54906618dc00ae3311", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading (…)cial_tokens_map.json: 0%| | 0.00/2.20k [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "62d4cccb4c8243e597bb45a3ecb76237", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading (…)lve/main/config.json: 0%| | 0.00/1.40k [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "cd0819cd2c6f4a4ba23972d969c761d0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading pytorch_model.bin: 0%| | 0.00/990M [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "aea7f5553f4c4912b4d989c586f8b77b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading (…)neration_config.json: 0%| | 0.00/147 [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "tokenizer = AutoTokenizer.from_pretrained(\"google/flan-t5-base\")\n", "model = T5ForConditionalGeneration.from_pretrained(\"google/flan-t5-base\")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total Number of Model Parameters: 247577856\n" ] } ], "source": [ "params = sum(p.numel() for p in model.parameters())\n", "print(f'Total Number of Model Parameters: {params}')" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "# Load Dataset\n", "\n", "The DialogSum dataset can also be loaded directly from HuggingFace. There are ~15k examples of dialogue in this dataset with associated human summarizations of these datasets" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d8155b5aa4f64fbeae82fe57508be3f3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading readme: 0%| | 0.00/4.58k [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading and preparing dataset csv/knkarthick--dialogsum to /root/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-6d41e9a7b96e340e/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "66474d04ea5d4f6abbc9467fa7fe18cf", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading data files: 0%| | 0/3 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "66079d80a74b4c07969ab9abe4fcb473", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading data: 0%| | 0.00/11.3M [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0262c5259f7d4e1c802355700e8669dd", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading data: 0%| | 0.00/1.35M [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "79ac8bd008c14f5c8acf9539bda6c2b9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading data: 0%| | 0.00/442k [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "45856718f7d94249801c8f290c161cb6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Extracting data files: 0%| | 0/3 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating train split: 0 examples [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating test split: 0 examples [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating validation split: 0 examples [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-6d41e9a7b96e340e/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1. Subsequent calls will reuse this data.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "152b145cafb74038b1893486c3c745ed", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "dataset = load_dataset(\"knkarthick/dialogsum\")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['id', 'dialogue', 'summary', 'topic'],\n", " num_rows: 12460\n", " })\n", " test: Dataset({\n", " features: ['id', 'dialogue', 'summary', 'topic'],\n", " num_rows: 1500\n", " })\n", " validation: Dataset({\n", " features: ['id', 'dialogue', 'summary', 'topic'],\n", " num_rows: 500\n", " })\n", "})" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "# Test the Model with Zero-Shot Prompts Before Tuning\n", "\n", "In the example below, we highlight how the summarization capability of the model is lacking compared to the baseline summary provided in the dataset. You can see that the model struggles to summarize the dialogue compared to the baseline summary, but it does pull out some important information from the text which indicates the model can be fine tuned to the task at hand." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prompt:\n", "--------------------------\n", "Summarize the following conversation.\n", "\n", "#Person1#: What time is it, Tom?\n", "#Person2#: Just a minute. It's ten to nine by my watch.\n", "#Person1#: Is it? I had no idea it was so late. I must be off now.\n", "#Person2#: What's the hurry?\n", "#Person1#: I must catch the nine-thirty train.\n", "#Person2#: You've plenty of time yet. The railway station is very close. It won't take more than twenty minutes to get there.\n", "\n", "Summary:\n", "--------------------------\n", "\n", "Original Response: The train is about to leave.\n", "Baseline Summary : #Person1# is in a hurry to catch a train. Tom tells #Person1# there is plenty of time.\n" ] } ], "source": [ "ind = 40\n", "diag = dataset['test'][ind]['dialogue']\n", "summary = dataset['test'][ind]['summary']\n", "\n", "prompt = f'Summarize the following conversation.\\n\\n{diag}\\n\\nSummary:'\n", "input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n", "\n", "original_outputs = model.to('cpu').generate(input_ids, GenerationConfig(max_new_tokens=200))\n", "original_text_output = tokenizer.decode(original_outputs[0], skip_special_tokens=True)\n", "\n", "print(f'Prompt:\\n--------------------------\\n{prompt}\\n--------------------------')\n", "print(f'\\nOriginal Response: {original_text_output}')\n", "print(f'Baseline Summary : {summary}')" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "# Preprocessing\n", "\n", "To preprocess the dataset, we need to append a useful prompt to the start and end of each dialogue set then tokenize the words with HuggingFace. The output dataset will be ready for fine tuning in the next step." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/12460 [00:00<?, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/1500 [00:00<?, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/500 [00:00<?, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def tokenize_function(example):\n", " prompt = 'Summarize the following conversation.\\n\\n'\n", " end_prompt = '\\n\\nSummary: '\n", " inp = [prompt + i + end_prompt for i in example[\"dialogue\"]]\n", " example['input_ids'] = tokenizer(inp, padding=\"max_length\", truncation=True, return_tensors=\"pt\").input_ids\n", " example['labels'] = tokenizer(example[\"summary\"], padding=\"max_length\", truncation=True, return_tensors=\"pt\").input_ids\n", " return example\n", "\n", "tokenized_datasets = dataset.map(tokenize_function, batched=True)\n", "tokenized_datasets = tokenized_datasets.remove_columns(['id', 'topic', 'dialogue', 'summary',])" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['input_ids', 'labels'],\n", " num_rows: 12460\n", " })\n", " test: Dataset({\n", " features: ['input_ids', 'labels'],\n", " num_rows: 1500\n", " })\n", " validation: Dataset({\n", " features: ['input_ids', 'labels'],\n", " num_rows: 500\n", " })\n", "})" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenized_datasets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Fine Tuning with HuggingFace Trainer\n", "\n", "Now that the dataset is preprocessed, we can utilize the built-in HuggingFace `Trainer` class to fine tune our model to the task at hand. Please note that training this full model takes a few hours on a GPU, so for the sake of time, a checkpoint for a model which has been trained on 10 epochs without downsampling has been provided. If you have time to experiment on fully training the model yourself, please see the inline comments for how to change up the code. If you are looking to train on a GPU machine, we have used a `ml.g5.xlarge` instance for the checkpoint provided as a place to start." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Filter: 0%| | 0/12460 [00:00<?, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Filter: 0%| | 0/1500 [00:00<?, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Filter: 0%| | 0/500 [00:00<?, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# for the sake of time in the lab, we will subsample our dataset\n", "# if you want to take the time to train a model fully, feel free to alter this subsampling to create a larger dataset\n", "# the line below can be completely removed to remove the subsampling\n", "tokenized_datasets = tokenized_datasets.filter(lambda example, indice: indice % 100 == 0, with_indices=True)\n", "\n", "output_dir = f'./diag-summary-training-{str(int(time.time()))}'\n", "training_args = TrainingArguments(\n", " output_dir=output_dir,\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " learning_rate=1e-5,\n", " num_train_epochs=1,\n", " # num_train_epochs=10, # Use a higher number of epochs when you are not in the lab and have more time to experiment\n", " per_device_train_batch_size=4,\n", " per_device_eval_batch_size=4,\n", " weight_decay=0.01,\n", ")\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=tokenized_datasets['train'],\n", " eval_dataset=tokenized_datasets['validation']\n", ")" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.7/site-packages/transformers/optimization.py:395: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " FutureWarning,\n" ] }, { "data": { "text/html": [ "\n", " <div>\n", " \n", " <progress value='2' max='32' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", " [ 2/32 : < :, Epoch 0.03/1]\n", " </div>\n", " <table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: left;\">\n", " <th>Epoch</th>\n", " <th>Training Loss</th>\n", " <th>Validation Loss</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " </tbody>\n", "</table><p>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" }, { "ename": "OutOfMemoryError", "evalue": "CUDA out of memory. Tried to allocate 252.00 MiB (GPU 0; 14.76 GiB total capacity; 13.16 GiB already allocated; 213.75 MiB free; 13.78 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mOutOfMemoryError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m<ipython-input-19-3435b262f1ae>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/transformers/trainer.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1635\u001b[0m \u001b[0mresume_from_checkpoint\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mresume_from_checkpoint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1636\u001b[0m \u001b[0mtrial\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrial\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1637\u001b[0;31m \u001b[0mignore_keys_for_eval\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mignore_keys_for_eval\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1638\u001b[0m )\n\u001b[1;32m 1639\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/transformers/trainer.py\u001b[0m in \u001b[0;36m_inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 1900\u001b[0m \u001b[0mtr_loss_step\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1901\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1902\u001b[0;31m \u001b[0mtr_loss_step\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1903\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1904\u001b[0m if (\n", "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/transformers/trainer.py\u001b[0m in \u001b[0;36mtraining_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m 2643\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2644\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_loss_context_manager\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2645\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2646\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2647\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_gpu\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/transformers/trainer.py\u001b[0m in \u001b[0;36mcompute_loss\u001b[0;34m(self, model, inputs, return_outputs)\u001b[0m\n\u001b[1;32m 2675\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2676\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2677\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2678\u001b[0m \u001b[0;31m# Save past state if it exists\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2679\u001b[0m \u001b[0;31m# TODO: this needs to be fixed and made cleaner later.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1195\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/transformers/models/t5/modeling_t5.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1735\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1736\u001b[0m \u001b[0mloss_fct\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mCrossEntropyLoss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mignore_index\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1737\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mloss_fct\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlm_logits\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlm_logits\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1738\u001b[0m \u001b[0;31m# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1739\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1195\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/modules/loss.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input, target)\u001b[0m\n\u001b[1;32m 1174\u001b[0m return F.cross_entropy(input, target, weight=self.weight,\n\u001b[1;32m 1175\u001b[0m \u001b[0mignore_index\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mignore_index\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduction\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreduction\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1176\u001b[0;31m label_smoothing=self.label_smoothing)\n\u001b[0m\u001b[1;32m 1177\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1178\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mcross_entropy\u001b[0;34m(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)\u001b[0m\n\u001b[1;32m 3024\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msize_average\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mreduce\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3025\u001b[0m \u001b[0mreduction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_Reduction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlegacy_get_string\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msize_average\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3026\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcross_entropy_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_Reduction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_enum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreduction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mignore_index\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel_smoothing\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3027\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3028\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 252.00 MiB (GPU 0; 14.76 GiB total capacity; 13.16 GiB already allocated; 213.75 MiB free; 13.78 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF" ] } ], "source": [ "trainer.train()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Load the Trained Model and Original Model\n", "\n", "Once the model has finished training, we will load both the original model from HuggingFace and the fine-tuned model to do some qualitative and quantitative comparisions." ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "download: s3://dsoaws/models/flan-dialogue-summary-checkpoint/config.json to flan-dialogue-summary-checkpoint/config.json\n", "download: s3://dsoaws/models/flan-dialogue-summary-checkpoint/generation_config.json to flan-dialogue-summary-checkpoint/generation_config.json\n", "download: s3://dsoaws/models/flan-dialogue-summary-checkpoint/trainer_state.json to flan-dialogue-summary-checkpoint/trainer_state.json\n", "download: s3://dsoaws/models/flan-dialogue-summary-checkpoint/rng_state.pth to flan-dialogue-summary-checkpoint/rng_state.pth\n", "download: s3://dsoaws/models/flan-dialogue-summary-checkpoint/training_args.bin to flan-dialogue-summary-checkpoint/training_args.bin\n", "download: s3://dsoaws/models/flan-dialogue-summary-checkpoint/scheduler.pt to flan-dialogue-summary-checkpoint/scheduler.pt\n", "download: s3://dsoaws/models/flan-dialogue-summary-checkpoint/pytorch_model.bin to flan-dialogue-summary-checkpoint/pytorch_model.bin\n", "download: s3://dsoaws/models/flan-dialogue-summary-checkpoint/optimizer.pt to flan-dialogue-summary-checkpoint/optimizer.pt\n" ] } ], "source": [ "!aws s3 cp --recursive s3://dsoaws/models/flan-dialogue-summary-checkpoint/ ./flan-dialogue-summary-checkpoint/" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "tags": [] }, "outputs": [], "source": [ "# if you have trained your own model and want to check it out compared to ours, change the line of code\n", "# below to contain your checkpoint directory\n", "\n", "tuned_model = T5ForConditionalGeneration.from_pretrained(\"./flan-dialogue-summary-checkpoint\")\n", "#tuned_model = T5ForConditionalGeneration.from_pretrained(f\"./{output_dir}/<put-your-checkpoint-dir-here>\")\n", "\n", "model = T5ForConditionalGeneration.from_pretrained(\"google/flan-t5-small\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Qualitative Results\n", "\n", "As with many GenAI applications, a qualitative approach where you ask yourself the question \"is my model behaving the way it is supposed to?\" is usually a good starting point. In the example below (the same one we started this notebook with), you can see how the fine-tuned model is able to create a reasonable summary of the dialogue compared to the original inability to understand what is being asked of the model." ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prompt:\n", "--------------------------\n", "Summarize the following conversation.\n", "\n", "#Person1#: What time is it, Tom?\n", "#Person2#: Just a minute. It's ten to nine by my watch.\n", "#Person1#: Is it? I had no idea it was so late. I must be off now.\n", "#Person2#: What's the hurry?\n", "#Person1#: I must catch the nine-thirty train.\n", "#Person2#: You've plenty of time yet. The railway station is very close. It won't take more than twenty minutes to get there.\n", "\n", "Summary:\n", "--------------------------\n", "Flan-T5 response: How long is it?\n", "Our instruct-tuned response (on top of Flan-T5): Tom tells #Person1# it's ten to nine and #Person1# must catch the nine-thirty train. #Person1# has plenty of time to catch the train.\n", "Baseline summary from original dataset: #Person1# is in a hurry to catch a train. Tom tells #Person1# there is plenty of time.\n" ] } ], "source": [ "ind = 40\n", "diag = dataset['test'][ind]['dialogue']\n", "summary = dataset['test'][ind]['summary']\n", "\n", "prompt = f'Summarize the following conversation.\\n\\n{diag}\\n\\nSummary:'\n", "input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n", "\n", "original_outputs = model.to('cpu').generate(\n", " input_ids,\n", " GenerationConfig(max_new_tokens=200, num_beams=1),\n", ")\n", "original_text_output = tokenizer.decode(original_outputs[0], skip_special_tokens=True)\n", "\n", "outputs = tuned_model.to('cpu').generate(\n", " input_ids,\n", " GenerationConfig(max_new_tokens=200, num_beams=1,),\n", ")\n", "text_output = tokenizer.decode(outputs[0], skip_special_tokens=True)\n", "\n", "print(f'Prompt:\\n--------------------------\\n{prompt}\\n--------------------------')\n", "print(f'Flan-T5 response: {original_text_output}')\n", "print(f'Our instruct-tuned response (on top of Flan-T5): {text_output}')\n", "print(f'Baseline summary from original dataset: {summary}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Quantitative Results with ROUGE Metric\n", "\n", "The [ROUGE metric](https://en.wikipedia.org/wiki/ROUGE_(metric)) helps quantify the validity of summarizations produced by models. It compares summarizations to a \"baseline\" summary which is usually created by a human. While not perfect, it does give an indication to the overall increase in summarization effectiveness that we have accomplished by fine-tuning." ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "tags": [] }, "outputs": [], "source": [ "rouge = evaluate.load('rouge')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluate a Subsection of Summaries" ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "tags": [] }, "outputs": [], "source": [ "# again, for the sake of time, we will only be generating 10 summarizations with each model\n", "# outside of the lab, a good exercise is to increase the number of validation summaries generated\n", "dialogues = dataset['test'][0:10]['dialogue']\n", "human_baseline_summaries = dataset['test'][0:10]['summary']\n", "\n", "original_model_summaries = []\n", "tuned_model_summaries = []\n", "\n", "for ind, diag in enumerate(dialogues):\n", " prompt = f'Summarize the following conversation.\\n\\nConversation:\\n{diag}\\n\\nSummary:'\n", " input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n", "\n", " original_outputs = model.generate(input_ids, GenerationConfig(max_new_tokens=200))\n", " original_text_output = tokenizer.decode(original_outputs[0], skip_special_tokens=True)\n", "\n", " outputs = tuned_model.generate(input_ids, GenerationConfig(max_new_tokens=200))\n", " text_output = tokenizer.decode(outputs[0], skip_special_tokens=True)\n", "\n", " original_model_summaries.append(original_text_output)\n", " tuned_model_summaries.append(text_output)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "tags": [] }, "outputs": [], "source": [ "original_results = rouge.compute(\n", " predictions=original_model_summaries,\n", " references=human_baseline_summaries[0:len(original_model_summaries)],\n", " use_aggregator=True,\n", " use_stemmer=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "tags": [] }, "outputs": [], "source": [ "tuned_results = rouge.compute(\n", " predictions=tuned_model_summaries,\n", " references=human_baseline_summaries[0:len(tuned_model_summaries)],\n", " use_aggregator=True,\n", " use_stemmer=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'rouge1': 0.07518181818181817,\n", " 'rouge2': 0.033229813664596264,\n", " 'rougeL': 0.06836363636363635,\n", " 'rougeLsum': 0.06836363636363635}" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "original_results" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'rouge1': 0.4015906463624618,\n", " 'rouge2': 0.17568542724181807,\n", " 'rougeL': 0.2874569966059625,\n", " 'rougeLsum': 0.2886327613084294}" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tuned_results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evalute the Full Dataset\n", "\n", "The file called \"diag-summary-training-results.csv\" contains a pre-populated list of all model results which we can use to evaluate on a larger section of data. The results show substantial improvement in all ROUGE metrics!" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "tags": [] }, "outputs": [], "source": [ "import pandas as pd\n", "results = pd.read_csv(\"diag-summary-training-results.csv\")\n", "original_model_summaries = results['original_model_summaries'].values\n", "tuned_model_summaries = results['tuned_model_summaries'].values\n", "human_baseline_summaries = results['human_baseline_summaries'].values" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "tags": [] }, "outputs": [], "source": [ "original_results = rouge.compute(\n", " predictions=original_model_summaries,\n", " references=human_baseline_summaries[0:len(original_model_summaries)],\n", " use_aggregator=True,\n", " use_stemmer=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "tags": [] }, "outputs": [], "source": [ "tuned_results = rouge.compute(\n", " predictions=tuned_model_summaries,\n", " references=human_baseline_summaries[0:len(tuned_model_summaries)],\n", " use_aggregator=True,\n", " use_stemmer=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'rouge1': 0.2334158581572823,\n", " 'rouge2': 0.07603964187010573,\n", " 'rougeL': 0.20145520923859048,\n", " 'rougeLsum': 0.20145899339006135}" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "original_results" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'rouge1': 0.42161291557556113,\n", " 'rouge2': 0.18035380596301792,\n", " 'rougeL': 0.3384439349963909,\n", " 'rougeLsum': 0.33835653595561666}" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tuned_results" ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "rouge1 absolute percentage improvement after instruct fine-tuning: 34.64%\n", "rouge2 absolute percentage improvement after instruct fine-tuning: 14.71%\n", "rougeL absolute percentage improvement after instruct fine-tuning: 27.01%\n", "rougeLsum absolute percentage improvement after instruct fine-tuning: 27.00%\n" ] } ], "source": [ "improvement = (np.array(list(tuned_results.values())) - np.array(list(original_results.values())))\n", "for key, value in zip(tuned_results.keys(), improvement):\n", " print(f'{key} absolute percentage improvement after instruct fine-tuning: {value*100:.2f}%')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 46, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting git+https://github.com/huggingface/peft.git\n", " Cloning https://github.com/huggingface/peft.git to /tmp/pip-req-build-zpmb2y2h\n", " Running command git clone --filter=blob:none --quiet https://github.com/huggingface/peft.git /tmp/pip-req-build-zpmb2y2h\n", " Resolved https://github.com/huggingface/peft.git to commit 202f1d7c3d3f8dc468a989f9fb757e319be73d6f\n", " Installing build dependencies ... \u001b[?25ldone\n", "\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n", "\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n", "\u001b[?25hRequirement already satisfied: torch>=1.13.0 in /opt/conda/lib/python3.7/site-packages (from peft==0.3.0.dev0) (1.13.1)\n", "Collecting accelerate\n", " Using cached accelerate-0.18.0-py3-none-any.whl (215 kB)\n", "Requirement already satisfied: transformers in /opt/conda/lib/python3.7/site-packages (from peft==0.3.0.dev0) (4.27.2)\n", "Requirement already satisfied: pyyaml in /opt/conda/lib/python3.7/site-packages (from peft==0.3.0.dev0) (6.0)\n", "Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.7/site-packages (from peft==0.3.0.dev0) (23.0)\n", "Requirement already satisfied: psutil in /opt/conda/lib/python3.7/site-packages (from peft==0.3.0.dev0) (5.6.7)\n", "Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.7/site-packages (from peft==0.3.0.dev0) (1.21.6)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /opt/conda/lib/python3.7/site-packages (from torch>=1.13.0->peft==0.3.0.dev0) (11.7.99)\n", "Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.7/site-packages (from torch>=1.13.0->peft==0.3.0.dev0) (4.5.0)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /opt/conda/lib/python3.7/site-packages (from torch>=1.13.0->peft==0.3.0.dev0) (11.7.99)\n", "Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /opt/conda/lib/python3.7/site-packages (from torch>=1.13.0->peft==0.3.0.dev0) (8.5.0.96)\n", "Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /opt/conda/lib/python3.7/site-packages (from torch>=1.13.0->peft==0.3.0.dev0) (11.10.3.66)\n", "Requirement already satisfied: wheel in /opt/conda/lib/python3.7/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.13.0->peft==0.3.0.dev0) (0.40.0)\n", "Requirement already satisfied: setuptools in /opt/conda/lib/python3.7/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.13.0->peft==0.3.0.dev0) (59.3.0)\n", "Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.7/site-packages (from transformers->peft==0.3.0.dev0) (2022.10.31)\n", "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /opt/conda/lib/python3.7/site-packages (from transformers->peft==0.3.0.dev0) (0.13.3)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /opt/conda/lib/python3.7/site-packages (from transformers->peft==0.3.0.dev0) (0.13.4)\n", "Requirement already satisfied: tqdm>=4.27 in /opt/conda/lib/python3.7/site-packages (from transformers->peft==0.3.0.dev0) (4.65.0)\n", "Requirement already satisfied: requests in /opt/conda/lib/python3.7/site-packages (from transformers->peft==0.3.0.dev0) (2.28.2)\n", "Requirement already satisfied: filelock in /opt/conda/lib/python3.7/site-packages (from transformers->peft==0.3.0.dev0) (3.0.12)\n", "Requirement already satisfied: importlib-metadata in /opt/conda/lib/python3.7/site-packages (from transformers->peft==0.3.0.dev0) (6.1.0)\n", "Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata->transformers->peft==0.3.0.dev0) (3.15.0)\n", "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests->transformers->peft==0.3.0.dev0) (1.26.15)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.7/site-packages (from requests->transformers->peft==0.3.0.dev0) (2.0.4)\n", "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.7/site-packages (from requests->transformers->peft==0.3.0.dev0) (2.8)\n", "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests->transformers->peft==0.3.0.dev0) (2022.12.7)\n", "Building wheels for collected packages: peft\n", " Building wheel for peft (pyproject.toml) ... \u001b[?25ldone\n", "\u001b[?25h Created wheel for peft: filename=peft-0.3.0.dev0-py3-none-any.whl size=49840 sha256=7ef844fb1349f407fa5051d183ab73a902d81690480b9ea8788ac08cc0209166\n", " Stored in directory: /tmp/pip-ephem-wheel-cache-9ayt0sc8/wheels/13/73/6d/ff27a3703d8bad21d7e0c24cbd9dde5d7ae78f756405707a0c\n", "Successfully built peft\n", "Installing collected packages: accelerate, peft\n", "Successfully installed accelerate-0.18.0 peft-0.3.0.dev0\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "!pip install git+https://github.com/huggingface/peft.git" ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: transformers==4.27.2 in /opt/conda/lib/python3.7/site-packages (4.27.2)\n", "Collecting datasets==2.9.0\n", " Using cached datasets-2.9.0-py3-none-any.whl (462 kB)\n", "Collecting accelerate==0.17.1\n", " Using cached accelerate-0.17.1-py3-none-any.whl (212 kB)\n", "Requirement already satisfied: evaluate==0.4.0 in /opt/conda/lib/python3.7/site-packages (0.4.0)\n", "Collecting bitsandbytes==0.37.1\n", " Using cached bitsandbytes-0.37.1-py3-none-any.whl (76.3 MB)\n", "Collecting loralib\n", " Using cached loralib-0.1.1-py3-none-any.whl (8.8 kB)\n", "Requirement already satisfied: filelock in /opt/conda/lib/python3.7/site-packages (from transformers==4.27.2) (3.0.12)\n", "Requirement already satisfied: tqdm>=4.27 in /opt/conda/lib/python3.7/site-packages (from transformers==4.27.2) (4.65.0)\n", "Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.7/site-packages (from transformers==4.27.2) (2022.10.31)\n", "Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.7/site-packages (from transformers==4.27.2) (23.0)\n", "Requirement already satisfied: importlib-metadata in /opt/conda/lib/python3.7/site-packages (from transformers==4.27.2) (6.1.0)\n", "Requirement already satisfied: requests in /opt/conda/lib/python3.7/site-packages (from transformers==4.27.2) (2.28.2)\n", "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /opt/conda/lib/python3.7/site-packages (from transformers==4.27.2) (0.13.3)\n", "Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.7/site-packages (from transformers==4.27.2) (1.21.6)\n", "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.7/site-packages (from transformers==4.27.2) (6.0)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /opt/conda/lib/python3.7/site-packages (from transformers==4.27.2) (0.13.4)\n", "Requirement already satisfied: pyarrow>=6.0.0 in /opt/conda/lib/python3.7/site-packages (from datasets==2.9.0) (11.0.0)\n", "Requirement already satisfied: fsspec[http]>=2021.11.1 in /opt/conda/lib/python3.7/site-packages (from datasets==2.9.0) (2023.1.0)\n", "Requirement already satisfied: dill<0.3.7 in /opt/conda/lib/python3.7/site-packages (from datasets==2.9.0) (0.3.6)\n", "Requirement already satisfied: responses<0.19 in /opt/conda/lib/python3.7/site-packages (from datasets==2.9.0) (0.18.0)\n", "Requirement already satisfied: xxhash in /opt/conda/lib/python3.7/site-packages (from datasets==2.9.0) (3.2.0)\n", "Requirement already satisfied: multiprocess in /opt/conda/lib/python3.7/site-packages (from datasets==2.9.0) (0.70.14)\n", "Requirement already satisfied: aiohttp in /opt/conda/lib/python3.7/site-packages (from datasets==2.9.0) (3.8.4)\n", "Requirement already satisfied: pandas in /opt/conda/lib/python3.7/site-packages (from datasets==2.9.0) (1.3.5)\n", "Requirement already satisfied: psutil in /opt/conda/lib/python3.7/site-packages (from accelerate==0.17.1) (5.6.7)\n", "Requirement already satisfied: torch>=1.4.0 in /opt/conda/lib/python3.7/site-packages (from accelerate==0.17.1) (1.13.1)\n", "Requirement already satisfied: typing-extensions>=3.7.4 in /opt/conda/lib/python3.7/site-packages (from aiohttp->datasets==2.9.0) (4.5.0)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/lib/python3.7/site-packages (from aiohttp->datasets==2.9.0) (1.8.2)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /opt/conda/lib/python3.7/site-packages (from aiohttp->datasets==2.9.0) (4.0.2)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.7/site-packages (from aiohttp->datasets==2.9.0) (6.0.4)\n", "Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.7/site-packages (from aiohttp->datasets==2.9.0) (22.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.7/site-packages (from aiohttp->datasets==2.9.0) (1.3.3)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.7/site-packages (from aiohttp->datasets==2.9.0) (1.3.1)\n", "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /opt/conda/lib/python3.7/site-packages (from aiohttp->datasets==2.9.0) (2.0.4)\n", "Requirement already satisfied: asynctest==0.13.0 in /opt/conda/lib/python3.7/site-packages (from aiohttp->datasets==2.9.0) (0.13.0)\n", "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.7/site-packages (from requests->transformers==4.27.2) (2.8)\n", "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests->transformers==4.27.2) (2022.12.7)\n", "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests->transformers==4.27.2) (1.26.15)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /opt/conda/lib/python3.7/site-packages (from torch>=1.4.0->accelerate==0.17.1) (11.7.99)\n", "Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /opt/conda/lib/python3.7/site-packages (from torch>=1.4.0->accelerate==0.17.1) (11.10.3.66)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /opt/conda/lib/python3.7/site-packages (from torch>=1.4.0->accelerate==0.17.1) (11.7.99)\n", "Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /opt/conda/lib/python3.7/site-packages (from torch>=1.4.0->accelerate==0.17.1) (8.5.0.96)\n", "Requirement already satisfied: wheel in /opt/conda/lib/python3.7/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.4.0->accelerate==0.17.1) (0.40.0)\n", "Requirement already satisfied: setuptools in /opt/conda/lib/python3.7/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.4.0->accelerate==0.17.1) (59.3.0)\n", "Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata->transformers==4.27.2) (3.15.0)\n", "Requirement already satisfied: python-dateutil>=2.7.3 in /opt/conda/lib/python3.7/site-packages (from pandas->datasets==2.9.0) (2.8.2)\n", "Requirement already satisfied: pytz>=2017.3 in /opt/conda/lib/python3.7/site-packages (from pandas->datasets==2.9.0) (2019.3)\n", "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.7/site-packages (from python-dateutil>=2.7.3->pandas->datasets==2.9.0) (1.14.0)\n", "Installing collected packages: bitsandbytes, loralib, datasets, accelerate\n", " Attempting uninstall: datasets\n", " Found existing installation: datasets 2.11.0\n", " Uninstalling datasets-2.11.0:\n", " Successfully uninstalled datasets-2.11.0\n", " Attempting uninstall: accelerate\n", " Found existing installation: accelerate 0.18.0\n", " Uninstalling accelerate-0.18.0:\n", " Successfully uninstalled accelerate-0.18.0\n", "Successfully installed accelerate-0.17.1 bitsandbytes-0.37.1 datasets-2.9.0 loralib-0.1.1\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "!pip install \"transformers==4.27.2\" \"datasets==2.9.0\" \"accelerate==0.17.1\" \"evaluate==0.4.0\" \"bitsandbytes==0.37.1\" loralib" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "91ad81d364be4aa0a0aedb414322447a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/12 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from transformers import AutoModelForSeq2SeqLM\n", "\n", "# huggingface hub model id\n", "model_id = \"philschmid/flan-t5-xxl-sharded-fp16\"\n", "\n", "# load model from the hub\n", "model = AutoModelForSeq2SeqLM.from_pretrained(model_id, \n", " #load_in_8bit=True, \n", " device_map=\"auto\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "trainable params: 18874368 || all params: 11154206720 || trainable%: 0.16921300163961817\n" ] } ], "source": [ "from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType\n", "\n", "# Define LoRA Config\n", "lora_config = LoraConfig(\n", " r=16,\n", " lora_alpha=32,\n", " target_modules=[\"q\", \"v\"],\n", " lora_dropout=0.05,\n", " bias=\"none\",\n", " task_type=TaskType.SEQ_2_SEQ_LM\n", ")\n", "# prepare int-8 model for training\n", "model = prepare_model_for_int8_training(model)\n", "\n", "# add LoRA adaptor\n", "model = get_peft_model(model, lora_config)\n", "model.print_trainable_parameters()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import DataCollatorForSeq2Seq\n", "\n", "# we want to ignore tokenizer pad token in the loss\n", "label_pad_token_id = -100\n", "# Data collator\n", "data_collator = DataCollatorForSeq2Seq(\n", " tokenizer,\n", " model=model,\n", " label_pad_token_id=label_pad_token_id,\n", " pad_to_multiple_of=8\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Release Resources" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "<p><b>Shutting down your kernel for this notebook to release resources.</b></p>\n", "<button class=\"sm-command-button\" data-commandlinker-command=\"kernelmenu:shutdown\" style=\"display:none;\">Shutdown Kernel</button>\n", " \n", "<script>\n", "try {\n", " els = document.getElementsByClassName(\"sm-command-button\");\n", " els[0].click();\n", "}\n", "catch(err) {\n", " // NoOp\n", "} \n", "</script>\n" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%%html\n", "\n", "<p><b>Shutting down your kernel for this notebook to release resources.</b></p>\n", "<button class=\"sm-command-button\" data-commandlinker-command=\"kernelmenu:shutdown\" style=\"display:none;\">Shutdown Kernel</button>\n", " \n", "<script>\n", "try {\n", " els = document.getElementsByClassName(\"sm-command-button\");\n", " els[0].click();\n", "}\n", "catch(err) {\n", " // NoOp\n", "} \n", "</script>" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "availableInstances": [ { "_defaultOrder": 0, "_isFastLaunch": true, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 4, "name": "ml.t3.medium", "vcpuNum": 2 }, { "_defaultOrder": 1, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 8, "name": "ml.t3.large", "vcpuNum": 2 }, { "_defaultOrder": 2, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.t3.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 3, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.t3.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 4, "_isFastLaunch": true, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 8, "name": "ml.m5.large", "vcpuNum": 2 }, { "_defaultOrder": 5, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.m5.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 6, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.m5.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 7, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 64, "name": "ml.m5.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 8, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 128, "name": "ml.m5.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 9, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 192, "name": "ml.m5.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 10, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 256, "name": "ml.m5.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 11, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 384, "name": "ml.m5.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 12, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 8, "name": "ml.m5d.large", "vcpuNum": 2 }, { "_defaultOrder": 13, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.m5d.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 14, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.m5d.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 15, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 64, "name": "ml.m5d.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 16, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 128, "name": "ml.m5d.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 17, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 192, "name": "ml.m5d.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 18, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 256, "name": "ml.m5d.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 19, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 384, "name": "ml.m5d.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 20, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": true, "memoryGiB": 0, "name": "ml.geospatial.interactive", "supportedImageNames": [ "sagemaker-geospatial-v1-0" ], "vcpuNum": 0 }, { "_defaultOrder": 21, "_isFastLaunch": true, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 4, "name": "ml.c5.large", "vcpuNum": 2 }, { "_defaultOrder": 22, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 8, "name": "ml.c5.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 23, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.c5.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 24, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.c5.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 25, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 72, "name": "ml.c5.9xlarge", "vcpuNum": 36 }, { "_defaultOrder": 26, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 96, "name": "ml.c5.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 27, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 144, "name": "ml.c5.18xlarge", "vcpuNum": 72 }, { "_defaultOrder": 28, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 192, "name": "ml.c5.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 29, "_isFastLaunch": true, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.g4dn.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 30, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.g4dn.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 31, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 64, "name": "ml.g4dn.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 32, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 128, "name": "ml.g4dn.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 33, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 4, "hideHardwareSpecs": false, "memoryGiB": 192, "name": "ml.g4dn.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 34, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 256, "name": "ml.g4dn.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 35, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 61, "name": "ml.p3.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 36, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 4, "hideHardwareSpecs": false, "memoryGiB": 244, "name": "ml.p3.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 37, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 8, "hideHardwareSpecs": false, "memoryGiB": 488, "name": "ml.p3.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 38, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 8, "hideHardwareSpecs": false, "memoryGiB": 768, "name": "ml.p3dn.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 39, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.r5.large", "vcpuNum": 2 }, { "_defaultOrder": 40, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.r5.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 41, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 64, "name": "ml.r5.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 42, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 128, "name": "ml.r5.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 43, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 256, "name": "ml.r5.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 44, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 384, "name": "ml.r5.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 45, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 512, "name": "ml.r5.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 46, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 768, "name": "ml.r5.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 47, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.g5.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 48, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.g5.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 49, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 64, "name": "ml.g5.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 50, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 128, "name": "ml.g5.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 51, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 256, "name": "ml.g5.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 52, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 4, "hideHardwareSpecs": false, "memoryGiB": 192, "name": "ml.g5.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 53, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 4, "hideHardwareSpecs": false, "memoryGiB": 384, "name": "ml.g5.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 54, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 8, "hideHardwareSpecs": false, "memoryGiB": 768, "name": "ml.g5.48xlarge", "vcpuNum": 192 } ], "colab": { "name": "Fine-tune a language model", "provenance": [] }, "instance_type": "ml.m5.2xlarge", "kernelspec": { "display_name": "Python 3 (Data Science)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/datascience-1.0" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" }, "vscode": { "interpreter": { "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" } } }, "nbformat": 4, "nbformat_minor": 4 }