{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "beae7bb1-f4e8-4e9a-80ce-c22f16124638", "metadata": { "tags": [] }, "outputs": [], "source": [ "%pip install torch==1.13.1 torchdata" ] }, { "cell_type": "code", "execution_count": 1, "id": "8edb5aba-15bc-446d-874f-8795a436be6a", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "%pip install --disable-pip-version-check -q \\\n", " transformers==4.27.2 \\\n", " datasets==2.9.0 \\\n", " accelerate==0.17.0 \\\n", " evaluate==0.4.0 \\\n", " trl==0.4.1 \\\n", " rouge_score==0.1.2 \\\n", " loralib==0.1.1" ] }, { "cell_type": "code", "execution_count": 2, "id": "039fb142-c9a5-4ca6-aaf7-f8c4c65743b3", "metadata": { "tags": [] }, "outputs": [], "source": [ "#!pip install git+https://github.com/huggingface/peft.git" ] }, { "cell_type": "code", "execution_count": 3, "id": "442332d6-8048-4e38-aa97-d0ac35cb0f3e", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting git+https://github.com/lvwerra/trl.git\n", " Cloning https://github.com/lvwerra/trl.git to /tmp/pip-req-build-cpoujzci\n", " Running command git clone --filter=blob:none --quiet https://github.com/lvwerra/trl.git /tmp/pip-req-build-cpoujzci\n", " Resolved https://github.com/lvwerra/trl.git to commit ce37eadcfa22f2a3c25422411a586b8f593e3e6e\n", " Preparing metadata (setup.py) ... \u001b[?25ldone\n", "\u001b[?25hRequirement already satisfied: torch>=1.4.0 in /opt/conda/lib/python3.7/site-packages (from trl==0.4.2.dev0) (1.13.1)\n", "Requirement already satisfied: transformers>=4.18.0 in /opt/conda/lib/python3.7/site-packages (from trl==0.4.2.dev0) (4.27.2)\n", "Requirement already satisfied: numpy>=1.18.2 in /opt/conda/lib/python3.7/site-packages (from trl==0.4.2.dev0) (1.21.6)\n", "Requirement already satisfied: accelerate in /opt/conda/lib/python3.7/site-packages (from trl==0.4.2.dev0) (0.17.0)\n", "Requirement already satisfied: datasets in /opt/conda/lib/python3.7/site-packages (from trl==0.4.2.dev0) (2.9.0)\n", "Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.7/site-packages (from torch>=1.4.0->trl==0.4.2.dev0) (4.5.0)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /opt/conda/lib/python3.7/site-packages (from torch>=1.4.0->trl==0.4.2.dev0) (11.7.99)\n", "Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /opt/conda/lib/python3.7/site-packages (from torch>=1.4.0->trl==0.4.2.dev0) (11.10.3.66)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /opt/conda/lib/python3.7/site-packages (from torch>=1.4.0->trl==0.4.2.dev0) (11.7.99)\n", "Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /opt/conda/lib/python3.7/site-packages (from torch>=1.4.0->trl==0.4.2.dev0) (8.5.0.96)\n", "Requirement already satisfied: wheel in /opt/conda/lib/python3.7/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.4.0->trl==0.4.2.dev0) (0.40.0)\n", "Requirement already satisfied: setuptools in /opt/conda/lib/python3.7/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.4.0->trl==0.4.2.dev0) (59.3.0)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /opt/conda/lib/python3.7/site-packages (from transformers>=4.18.0->trl==0.4.2.dev0) (0.13.4)\n", "Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.7/site-packages (from transformers>=4.18.0->trl==0.4.2.dev0) (2022.10.31)\n", "Requirement already satisfied: importlib-metadata in /opt/conda/lib/python3.7/site-packages (from transformers>=4.18.0->trl==0.4.2.dev0) (6.3.0)\n", "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.7/site-packages (from transformers>=4.18.0->trl==0.4.2.dev0) (6.0)\n", "Requirement already satisfied: tqdm>=4.27 in /opt/conda/lib/python3.7/site-packages (from transformers>=4.18.0->trl==0.4.2.dev0) (4.65.0)\n", "Requirement already satisfied: requests in /opt/conda/lib/python3.7/site-packages (from transformers>=4.18.0->trl==0.4.2.dev0) (2.28.2)\n", "Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.7/site-packages (from transformers>=4.18.0->trl==0.4.2.dev0) (23.1)\n", "Requirement already satisfied: filelock in /opt/conda/lib/python3.7/site-packages (from transformers>=4.18.0->trl==0.4.2.dev0) (3.0.12)\n", "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /opt/conda/lib/python3.7/site-packages (from transformers>=4.18.0->trl==0.4.2.dev0) (0.13.3)\n", "Requirement already satisfied: psutil in /opt/conda/lib/python3.7/site-packages (from accelerate->trl==0.4.2.dev0) (5.6.7)\n", "Requirement already satisfied: pyarrow>=6.0.0 in /opt/conda/lib/python3.7/site-packages (from datasets->trl==0.4.2.dev0) (11.0.0)\n", "Requirement already satisfied: xxhash in /opt/conda/lib/python3.7/site-packages (from datasets->trl==0.4.2.dev0) (3.2.0)\n", "Requirement already satisfied: aiohttp in /opt/conda/lib/python3.7/site-packages (from datasets->trl==0.4.2.dev0) (3.8.4)\n", "Requirement already satisfied: pandas in /opt/conda/lib/python3.7/site-packages (from datasets->trl==0.4.2.dev0) (1.3.5)\n", "Requirement already satisfied: multiprocess in /opt/conda/lib/python3.7/site-packages (from datasets->trl==0.4.2.dev0) (0.70.14)\n", "Requirement already satisfied: responses<0.19 in /opt/conda/lib/python3.7/site-packages (from datasets->trl==0.4.2.dev0) (0.18.0)\n", "Requirement already satisfied: dill<0.3.7 in /opt/conda/lib/python3.7/site-packages (from datasets->trl==0.4.2.dev0) (0.3.6)\n", "Requirement already satisfied: fsspec[http]>=2021.11.1 in /opt/conda/lib/python3.7/site-packages (from datasets->trl==0.4.2.dev0) (2023.1.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.7/site-packages (from aiohttp->datasets->trl==0.4.2.dev0) (6.0.4)\n", "Requirement already satisfied: asynctest==0.13.0 in /opt/conda/lib/python3.7/site-packages (from aiohttp->datasets->trl==0.4.2.dev0) (0.13.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.7/site-packages (from aiohttp->datasets->trl==0.4.2.dev0) (1.3.3)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/lib/python3.7/site-packages (from aiohttp->datasets->trl==0.4.2.dev0) (1.8.2)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.7/site-packages (from aiohttp->datasets->trl==0.4.2.dev0) (1.3.1)\n", "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /opt/conda/lib/python3.7/site-packages (from aiohttp->datasets->trl==0.4.2.dev0) (2.0.4)\n", "Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.7/site-packages (from aiohttp->datasets->trl==0.4.2.dev0) (22.2.0)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /opt/conda/lib/python3.7/site-packages (from aiohttp->datasets->trl==0.4.2.dev0) (4.0.2)\n", "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests->transformers>=4.18.0->trl==0.4.2.dev0) (1.26.15)\n", "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.7/site-packages (from requests->transformers>=4.18.0->trl==0.4.2.dev0) (2.8)\n", "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests->transformers>=4.18.0->trl==0.4.2.dev0) (2022.12.7)\n", "Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata->transformers>=4.18.0->trl==0.4.2.dev0) (3.15.0)\n", "Requirement already satisfied: pytz>=2017.3 in /opt/conda/lib/python3.7/site-packages (from pandas->datasets->trl==0.4.2.dev0) (2019.3)\n", "Requirement already satisfied: python-dateutil>=2.7.3 in /opt/conda/lib/python3.7/site-packages (from pandas->datasets->trl==0.4.2.dev0) (2.8.2)\n", "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.7/site-packages (from python-dateutil>=2.7.3->pandas->datasets->trl==0.4.2.dev0) (1.14.0)\n", "Building wheels for collected packages: trl\n", " Building wheel for trl (setup.py) ... \u001b[?25ldone\n", "\u001b[?25h Created wheel for trl: filename=trl-0.4.2.dev0-py3-none-any.whl size=54216 sha256=01482a1b7edbf841545e9be34a23a25646b053744c4d2bf27bb772a41e059d2e\n", " Stored in directory: /tmp/pip-ephem-wheel-cache-w6prql5u/wheels/ca/6e/f4/b183ecbed483efdcd2041a8021ce7bcb9f7b09c74bff5bb00a\n", "Successfully built trl\n", "Installing collected packages: trl\n", " Attempting uninstall: trl\n", " Found existing installation: trl 0.4.1\n", " Uninstalling trl-0.4.1:\n", " Successfully uninstalled trl-0.4.1\n", "Successfully installed trl-0.4.2.dev0\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.1\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" ] } ], "source": [ "!pip install git+https://github.com/lvwerra/trl.git" ] }, { "cell_type": "code", "execution_count": 4, "id": "37a78f85-cb66-4059-b4fc-437db14e9684", "metadata": { "tags": [] }, "outputs": [], "source": [ "%store -r ranking_reward_model_custom_checkpoint" ] }, { "cell_type": "code", "execution_count": 5, "id": "948d777a-2ed7-4d9b-933e-49cd60e576f8", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "./ranking_reward_model_custom/\n" ] } ], "source": [ "print(ranking_reward_model_custom_checkpoint)" ] }, { "cell_type": "code", "execution_count": 6, "id": "107e192d-ec64-41a8-bf2f-674c4ebb816a", "metadata": { "tags": [] }, "outputs": [], "source": [ "# %store -r supervised_fine_tuned_model_path" ] }, { "cell_type": "code", "execution_count": 7, "id": "4347957b-85d3-4bdf-bba7-39512f08be55", "metadata": { "tags": [] }, "outputs": [], "source": [ "# print(supervised_fine_tuned_model_path)" ] }, { "cell_type": "markdown", "id": "4f421a5b-2720-4430-aae5-b25cbbef1a40", "metadata": {}, "source": [ "# Load dataset" ] }, { "cell_type": "code", "execution_count": 8, "id": "b696f76c-2aac-4c04-ad0f-82e8a16d3bad", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration data-summarization-46d1b2508a766ab7\n", "Found cached dataset csv (/root/.cache/huggingface/datasets/csv/data-summarization-46d1b2508a766ab7/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "94db57ad7a4546759c9f8ed2a7419797", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/1 [00:00