{ "cells": [ { "cell_type": "markdown", "id": "44a069b0", "metadata": {}, "source": [ "```\n", "Authors: Ehsan Kamalinejad (EK), Emily Webber\n", "Created: 2023-02-27\n", "```" ] }, { "cell_type": "markdown", "id": "757937bc", "metadata": {}, "source": [ "# RLHF\n", "\n", "Reinforcement learning with human feedback is an interesting technique designed to aggregate human prefences at scale. In particular we train a regressive large language model built on human-desginated ranks for each prompt. Then, we use this model to serve as the `reward signal`, fine-tuning another LLM. \n", "\n", "Here, we present a training pipeline to finetune a generative model to create IMDb reviews with positive sentiment according to the [OpenAI RLHF paper](https://arxiv.org/abs/1909.08593)(please see section 3).\n", "\n", "Please note, all of the training in this notebook currently happens locally, so you'll want to use an instance with enough accelerator memory. I'm running on a `ml.g4dn.xlarge`." ] }, { "cell_type": "markdown", "id": "b798a315-b134-4607-8a38-3af59329f2e7", "metadata": {}, "source": [ "---\n", "### 1. Install requirements" ] }, { "cell_type": "code", "execution_count": 16, "id": "1f023c01-279f-483e-8964-b0525bb33200", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting requirements.txt\n" ] } ], "source": [ "%%writefile requirements.txt\n", "tqdm\n", "omegaconf\n", "dataclasses\n", "torchtyping\n", "datasets\n", "transformers\n", "torch\n", "xformers" ] }, { "cell_type": "code", "execution_count": 32, "id": "21bc78ef-fdce-4f01-90d2-536ce6dd8df2", "metadata": { "tags": [] }, "outputs": [], "source": [ "!pip install -r requirements.txt" ] }, { "cell_type": "code", "execution_count": 4, "id": "806219f3", "metadata": {}, "outputs": [], "source": [ "import random\n", "import numpy as np\n", "from tqdm.notebook import tqdm\n", "from omegaconf import DictConfig\n", "from dataclasses import dataclass\n", "from typing import Optional, Tuple, Union\n", "from typing import Iterable, Sequence, List\n", "\n", "from torchtyping import TensorType\n", "\n", "import transformers\n", "from transformers import DataCollatorWithPadding\n", "from transformers import pipeline, AutoTokenizer\n", "\n", "from datasets import load_dataset\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.nn.utils.rnn import pad_sequence\n", "from torch.utils.data import DataLoader, Dataset\n", "from torch.optim.lr_scheduler import CosineAnnealingLR" ] }, { "cell_type": "markdown", "id": "eaf75c75-d19d-474c-aa0c-5292c5d7b530", "metadata": {}, "source": [ "### 2. Define model parameters" ] }, { "cell_type": "code", "execution_count": 5, "id": "f63f4904", "metadata": {}, "outputs": [], "source": [ "config = {\n", " 'train': {\n", " 'seed': 2023,\n", " 'seq_length': 1024,\n", " 'epochs': 50,\n", " 'total_steps': 5000,\n", " 'batch_size': 64,\n", " 'eval_interval': 100,\n", " 'model_device':'cuda:0',\n", " 'ref_model_device':'cpu',\n", " 'reward_model_device':'cpu'},\n", " 'model': {\n", " 'model_path': 'lvwerra/gpt2-imdb', #'edbeeching/gpt-neo-1.3B-imdb',\n", " 'tokenizer_path': 'lvwerra/gpt2-imdb', #'edbeeching/gpt-neo-1.3B-imdb',\n", " 'num_layers_unfrozen': 1},\n", " 'optimizer': {\n", " 'name': 'adamw',\n", " 'kwargs': {'lr': 0.0001,\n", " 'betas': [0.9, 0.95],\n", " 'eps': 1e-08,\n", " 'weight_decay': 1e-06}},\n", " 'scheduler': {\n", " 'name': 'cosine_annealing',\n", " 'kwargs': {\n", " 'T_max': 10000, 'eta_min': 0.0001}},\n", " 'method': {\n", " 'use_whitening': True,\n", " 'prompt_size': 10,\n", " 'num_rollouts': 128,\n", " 'chunk_size': 128,\n", " 'ppo_epochs': 4,\n", " 'kl_coef': 0.05,\n", " 'horizon': 10000,\n", " 'gamma': 1,\n", " 'lam': 0.95,\n", " 'cliprange': 0.2,\n", " 'cliprange_value': 0.2,\n", " 'vf_coef': 1,\n", " 'scale_reward': False,\n", " 'ref_mean': None,\n", " 'ref_std': None,\n", " 'cliprange_reward': 10,\n", " 'gen_kwargs': {\n", " 'max_new_tokens': 60,\n", " 'top_k': 0,\n", " 'top_p': 1.0,\n", " 'do_sample': True}}}" ] }, { "cell_type": "code", "execution_count": 6, "id": "cf8d10c7", "metadata": { "scrolled": true }, "outputs": [], "source": [ "config = DictConfig(config)" ] }, { "cell_type": "code", "execution_count": 7, "id": "a90ff010", "metadata": {}, "outputs": [], "source": [ "random.seed(config.train.seed)\n", "np.random.seed(config.train.seed)\n", "torch.manual_seed(config.train.seed)\n", "torch.cuda.manual_seed(config.train.seed)" ] }, { "cell_type": "markdown", "id": "4c99cfbc-acb7-4b82-abd9-cb9bfe56e7a6", "metadata": {}, "source": [ "---\n", "### 3. Define PyTorch objects\n", "Here, we'll show you how to implement the following reward modelling objects in PyTorch:\n", "1. Prompt Pipeline\n", "2. PPO RL Element\n", "3. PPO RL Batch\n", "4. PPO Rollout Storage, including a data loader" ] }, { "cell_type": "code", "execution_count": 8, "id": "2080a70b", "metadata": {}, "outputs": [], "source": [ "class PromptPipeline():\n", " def __init__(self, prompts: List[str], max_prompt_length: int, tokenizer):\n", " super().__init__()\n", "\n", " prompts = tokenizer(prompts).input_ids\n", "\n", " self.tokenizer = tokenizer\n", " self.prompts = [prompt[-max_prompt_length:] for prompt in prompts]\n", " self.prompts = [{\"input_ids\": prompt, \"attention_mask\": [1] * len(prompt)} for prompt in self.prompts]\n", "\n", " def __getitem__(self, ix: int):\n", " return self.prompts[ix]\n", "\n", " def __len__(self) -> int:\n", " return len(self.prompts)\n", "\n", " def create_loader(self, batch_size: int, shuffle=False) -> DataLoader:\n", " collate_fn = DataCollatorWithPadding(self.tokenizer)\n", " return DataLoader(self, batch_size=batch_size, collate_fn=collate_fn, shuffle=shuffle)" ] }, { "cell_type": "code", "execution_count": 9, "id": "e497c3d8", "metadata": {}, "outputs": [], "source": [ "@dataclass\n", "class PPORLElement:\n", " query_tensor: TensorType[\"query_size\"]\n", " response_tensor: TensorType[\"response_size\"]\n", " logprobs: TensorType[\"response_size\", \"vocab_size\"]\n", " values: TensorType[\"response_size\"]\n", " rewards: TensorType[\"response_size\"]\n", "\n", "\n", "@dataclass\n", "class PPORLBatch:\n", " query_tensors: TensorType[\"batch_size\", \"query_size\"]\n", " response_tensors: TensorType[\"batch_size\", \"response_size\"]\n", " logprobs: TensorType[\"batch_size\", \"response_size\", \"vocab_size\"]\n", " values: TensorType[\"batch_size\", \"response_size\"]\n", " rewards: TensorType[\"batch_size\", \"response_size\"]\n", "\n", "\n", "class PPORolloutStorage():\n", " def __init__(self, pad_token_id):\n", " super().__init__()\n", " self.pad_token_id = pad_token_id\n", " self.history: Iterable[PPORLElement] = [None]\n", "\n", " def push(self, exps: Iterable[PPORLElement]):\n", " self.history += exps\n", "\n", " def clear_history(self):\n", " self.history = []\n", "\n", " def __getitem__(self, index: int) -> PPORLElement:\n", " return self.history[index]\n", "\n", " def __len__(self) -> int:\n", " return len(self.history)\n", "\n", " def create_loader(self, batch_size: int, shuffle: bool) -> DataLoader:\n", " def collate_fn(elems: Iterable[PPORLElement]):\n", " return PPORLBatch(\n", " pad_sequence(\n", " [elem.query_tensor.flip(0) for elem in elems],\n", " padding_value=self.pad_token_id,\n", " batch_first=True,\n", " ).flip(1),\n", " pad_sequence(\n", " [elem.response_tensor for elem in elems],\n", " padding_value=self.pad_token_id,\n", " batch_first=True,\n", " ),\n", " pad_sequence(\n", " [elem.logprobs for elem in elems],\n", " padding_value=0.0,\n", " batch_first=True,\n", " ),\n", " pad_sequence(\n", " [elem.values for elem in elems],\n", " padding_value=0.0,\n", " batch_first=True\n", " ),\n", " pad_sequence(\n", " [elem.rewards for elem in elems],\n", " padding_value=0.0,\n", " batch_first=True,\n", " ),\n", " )\n", "\n", " return DataLoader(self, batch_size, shuffle=shuffle, collate_fn=collate_fn)" ] }, { "cell_type": "markdown", "id": "d37fedbc-d186-42fc-b20a-ab0947c49783", "metadata": {}, "source": [ "---\n", "### 4. Define more PyTorch functions\n", "Next, we'll implement more capabilities in PyTorch. This includes:\n", "1. A whiten utility\n", "2. A GAE utility\n", "3. The loss to update the PPO LLM you want to fine-tune\n", "4. An Actor class to take steps and experience rewards\n", "5. An Agent class to manipulate the Actor" ] }, { "cell_type": "code", "execution_count": 10, "id": "d28d570e", "metadata": {}, "outputs": [], "source": [ "def whiten(x):\n", " var, mean = torch.var_mean(x)\n", " return (x - mean) * torch.rsqrt(var + 1e-8)\n", "\n", "\n", "def gae(\n", " values,\n", " rewards,\n", "):\n", " advantages = torch.zeros_like(rewards, device=rewards.device)\n", " last_advantage = 0\n", " last_value = 0\n", " \n", " with torch.no_grad():\n", " for t in reversed(range(rewards.shape[1])):\n", " delta = rewards[:, t] + config.method.gamma * last_value - values[:, t]\n", " last_advantage = delta + config.method.gamma * config.method.lam * last_advantage\n", " advantages[:, t] = last_advantage\n", " last_value = values[:, t]\n", "\n", " returns = advantages + values\n", " \n", " if config.method.use_whitening:\n", " advantages = whiten(advantages)\n", " \n", " return advantages, returns" ] }, { "cell_type": "code", "execution_count": 11, "id": "5575874a", "metadata": {}, "outputs": [], "source": [ "def ppo_loss(\n", " logprobs, \n", " values, \n", " old_logprobs, \n", " old_values, \n", " advantages, \n", " returns, \n", " mask, \n", "):\n", "\n", " values_clipped = torch.clamp(\n", " values,\n", " old_values - config.method.cliprange_value,\n", " old_values + config.method.cliprange_value,\n", " )\n", " \n", " n = mask.sum()\n", " \n", " vf_loss1 = (values - returns) ** 2\n", " vf_loss2 = (values_clipped - returns) ** 2\n", " vf_loss = 0.5 * torch.sum(torch.max(vf_loss1, vf_loss2) * mask) / n\n", "\n", " log_ratio = (logprobs - old_logprobs) * mask\n", " ratio = torch.exp(log_ratio)\n", " pg_loss1 = -advantages * ratio\n", " pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - config.method.cliprange, 1.0 + config.method.cliprange)\n", " pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / n\n", " pg_clipfrac = torch.sum((pg_loss2 > pg_loss1).float() * mask) / n\n", "\n", " loss = pg_loss + config.method.vf_coef * vf_loss\n", " \n", " return loss" ] }, { "cell_type": "code", "execution_count": 12, "id": "811ddf6c", "metadata": {}, "outputs": [], "source": [ "def loss_fn(batch):\n", " model_device = next(model.parameters()).device\n", " query_tensors = batch.query_tensors.to(model_device)\n", " response_tensors = batch.response_tensors.to(model_device)\n", " old_logprobs = batch.logprobs.to(model_device)\n", " old_values = batch.values.to(model_device)\n", " old_rewards = batch.rewards.to(model_device)\n", " \n", " response_length = old_rewards.shape[1]\n", "\n", " advantages, returns = gae(old_values, old_rewards)\n", "\n", " tokens, attention_mask, position_ids = get_model_inputs(query_tensors, response_tensors, tokenizer.pad_token_id)\n", "\n", " logits, values_pred = model(tokens,\n", " attention_mask=attention_mask,\n", " position_ids=position_ids)\n", " values_pred = values_pred[:, :-1]\n", " logprobs = logprobs_from_logits(logits[:, :-1, :], tokens[:, 1:])\n", " attention_mask = attention_mask[:, :-1]\n", "\n", " start = query_tensors.shape[1] - 1\n", " end = start + response_length\n", " logprobs, values_pred, mask = (\n", " logprobs[:, start:end],\n", " values_pred[:, start:end],\n", " attention_mask[:, start:end],\n", " )\n", "\n", " loss = ppo_loss(\n", " logprobs=logprobs,\n", " values=values_pred,\n", " old_logprobs=old_logprobs,\n", " old_values=old_values,\n", " advantages=advantages,\n", " returns=returns,\n", " mask=mask,\n", " )\n", "\n", " return loss, old_rewards[:,-1].mean().item()" ] }, { "cell_type": "code", "execution_count": 13, "id": "5f8c337f", "metadata": {}, "outputs": [], "source": [ "class Actor():\n", "\n", " def __init__(\n", " self,\n", " prompt_pipeline,\n", " tokenizer,\n", " chunk_size = 128):\n", " \n", " self.prompt_pipeline = prompt_pipeline\n", " self.chunk_size = chunk_size\n", "\n", " self.prompt_pipeline_loader = self.prompt_pipeline.create_loader(self.chunk_size, shuffle=True)\n", " self.prompt_pipeline_iterator = iter(self.prompt_pipeline_loader)\n", "\n", " self.ref_model = Agent(config.model.model_path)\n", " self.ref_model_device = config.train.ref_model_device\n", " self.ref_model = self.ref_model.to(self.ref_model_device)\n", " \n", " self.tokenizer = tokenizer \n", " \n", "\n", " def make_experience(self, model, num_rollouts = 128):\n", " model_device = next(model.parameters()).device\n", " \n", " ppo_rl_elements = []\n", " while len(ppo_rl_elements) < num_rollouts:\n", " try:\n", " batch = next(self.prompt_pipeline_iterator)\n", " except StopIteration:\n", " self.pipeline_iterator = iter(self.prompt_pipeline_loader)\n", " batch = next(self.prompt_pipeline_iterator)\n", " \n", " trajectories = generate(model, self.tokenizer, **batch.to(model_device))\n", "\n", " query_tensors = batch.input_ids\n", " response_tensors = trajectories[:, query_tensors.shape[1] :]\n", "\n", " all_tokens, attention_mask, position_ids = get_model_inputs(\n", " query_tensors.to(response_tensors.device), response_tensors, self.tokenizer.pad_token_id)\n", " with torch.no_grad():\n", " logits, values = model(\n", " all_tokens, \n", " attention_mask=attention_mask, \n", " position_ids=position_ids)\n", " ref_logits, _ = self.ref_model(\n", " all_tokens.to(self.ref_model_device),\n", " attention_mask=attention_mask.to(self.ref_model_device),\n", " position_ids=position_ids.to(self.ref_model_device))\n", " \n", " all_tokens = all_tokens.cpu()\n", " logits = logits.cpu()\n", " ref_logits = ref_logits.cpu()\n", "\n", " logprobs = logprobs_from_logits(logits[:, :-1, :], all_tokens[:, 1:])\n", " ref_logprobs = logprobs_from_logits(ref_logits[:, :-1, :], all_tokens[:, 1:])\n", " \n", " n = trajectories.shape[0]\n", " values = values.cpu()[:, :-1]\n", " query_tensors = query_tensors.cpu()\n", " response_tensors = response_tensors.cpu()\n", " \n", " start = query_tensors.shape[1] - 1\n", " ends = start + attention_mask[:, start:].sum(1)\n", " all_values = [values[i, start : ends[i]] for i in range(n)]\n", " all_logprobs = [logprobs[i, start : ends[i]] for i in range(n)]\n", " \n", " texts = self.tokenizer.batch_decode(trajectories, skip_special_tokens=True)\n", " scores = torch.tensor(reward_fn(texts), device='cpu', dtype=torch.float)\n", "\n", " rewards = -config.method.kl_coef * (logprobs - ref_logprobs)\n", " all_rewards = [None] * n\n", " for i in range(n):\n", " rs = rewards[i][start : ends[i]]\n", " rs[-1] = scores[i]\n", " all_rewards[i] = rs\n", " \n", " new_ppo_rl_elements = [\n", " PPORLElement(\n", " query_tensor=query_tensors[i],\n", " response_tensor=response_tensors[i],\n", " logprobs=all_logprobs[i],\n", " values=all_values[i],\n", " rewards=all_rewards[i],\n", " )\n", " for i in range(n)\n", " ]\n", "\n", " ppo_rl_elements += new_ppo_rl_elements\n", "\n", " return ppo_rl_elements, scores.mean().item()" ] }, { "cell_type": "code", "execution_count": 14, "id": "4f50e1af", "metadata": {}, "outputs": [], "source": [ "def generate(model, tokenizer, input_ids, attention_mask=None, **kwargs):\n", " \n", " generate_kwargs = dict(\n", " config.method.gen_kwargs,\n", " eos_token_id=tokenizer.eos_token_id,\n", " pad_token_id=tokenizer.eos_token_id)\n", "\n", " kwargs = dict(generate_kwargs, **kwargs)\n", "\n", " with torch.no_grad():\n", " generated_results = model.generate(input_ids=input_ids, attention_mask=attention_mask, **kwargs)\n", "\n", " return generated_results\n", "\n", "\n", "def get_model_inputs(query_tensors, response_tensors, pad_token_id):\n", " tokens = torch.cat((query_tensors, response_tensors), dim=1)[:, -config.train.seq_length :]\n", " attention_mask = (tokens.not_equal(pad_token_id).long().to(tokens.device))\n", " position_ids = attention_mask.cumsum(-1) - 1\n", " position_ids.masked_fill_(attention_mask.eq(0), 0)\n", " return tokens, attention_mask, position_ids\n", "\n", "\n", "def logprobs_from_logits(logits, labels):\n", " logprobs = F.log_softmax(logits, dim=-1)\n", " logprobs_labels = torch.gather(logprobs, dim=-1, index=labels.unsqueeze(-1))\n", " return logprobs_labels.squeeze(-1)\n", "\n", "\n", "def freeze_bottom_causal_layers(model: nn.Module, num_layers_unfrozen: int = 0):\n", " hidden_layers = model.transformer.h\n", " if num_layers_unfrozen == 0:\n", " hidden_layers_to_freeze = list(hidden_layers)\n", " elif num_layers_unfrozen > 0:\n", " hidden_layers_to_freeze = list(hidden_layers)[:-num_layers_unfrozen]\n", " else:\n", " hidden_layers_to_freeze = []\n", " for layer in hidden_layers_to_freeze:\n", " layer.requires_grad_(False)\n", "\n", " \n", "class Agent(nn.Module):\n", " def __init__(self, model_path, num_layers_unfrozen=0):\n", " super().__init__()\n", "\n", " self.base_model = transformers.AutoModelForCausalLM.from_pretrained(model_path, cache_dir=\"./models\")\n", "\n", " self.logit_head = self.base_model.get_output_embeddings()\n", " \n", " n_embd = self.base_model.lm_head.in_features\n", " self.value_head = nn.Sequential(\n", " nn.Linear(n_embd, n_embd*2),\n", " nn.ReLU(),\n", " nn.Linear(n_embd*2, 1))\n", " \n", " freeze_bottom_causal_layers(self.base_model, num_layers_unfrozen)\n", " \n", " \n", " def generate(self, input_ids, **x):\n", " return self.base_model.generate(input_ids, **x)\n", "\n", " def forward(self, input_ids, attention_mask, position_ids):\n", "\n", " transformer_outputs = self.base_model.transformer(input_ids=input_ids,\n", " attention_mask=attention_mask,\n", " position_ids=position_ids)\n", " \n", " last_hidden_state = transformer_outputs.last_hidden_state\n", " lm_logits = self.logit_head(last_hidden_state)\n", " value = self.value_head(last_hidden_state).squeeze(-1)\n", " \n", " return lm_logits, value" ] }, { "cell_type": "markdown", "id": "ccb66aec-62a7-4e1b-9619-a09ecc804b8e", "metadata": {}, "source": [ "---\n", "### 5. Define the pipeline, download model and data artifacts" ] }, { "cell_type": "code", "execution_count": 18, "id": "79835e3a", "metadata": {}, "outputs": [], "source": [ "sentiment_fn = pipeline(\n", " model = \"lvwerra/distilbert-imdb\",\n", " top_k=2,\n", " batch_size=config.method.num_rollouts,\n", " device=config.train.reward_model_device,\n", ")" ] }, { "cell_type": "code", "execution_count": 19, "id": "e696db6a", "metadata": {}, "outputs": [], "source": [ "def get_positive_score(scores):\n", " return dict(map(lambda x: tuple(x.values()), scores))[\"POSITIVE\"]\n", "\n", "def reward_fn(samples: List[str]) -> List[float]:\n", " sentiments = list(map(get_positive_score, sentiment_fn(samples)))\n", " return sentiments" ] }, { "cell_type": "code", "execution_count": 20, "id": "180883d6", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f695ffa4091d43658c4af2e08674e0fb", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading builder script: 0.00B [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a95f80fa55014b1bb59f1b63b75150a8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading metadata: 0.00B [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5de1e6a5f9fc49e481e9673559e543f4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading readme: 0.00B [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Found cached dataset imdb (/root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)\n" ] } ], "source": [ "imdb = load_dataset(\"imdb\", split=\"train+test\")" ] }, { "cell_type": "code", "execution_count": 21, "id": "51ad4f7e", "metadata": {}, "outputs": [], "source": [ "prompts = [\" \".join(review.split()[:config.method.prompt_size]) for review in imdb[\"text\"]]" ] }, { "cell_type": "code", "execution_count": 22, "id": "e1554f68", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9c8e37f13aa04406bf940dbd61af004f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading (…)olve/main/vocab.json: 0.00B [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b737cb2ca66d47a693edacdc966aae45", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading (…)olve/main/merges.txt: 0.00B [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f232d3bf7a2d4903a6db9580f8335eb6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading (…)cial_tokens_map.json: 0%| | 0.00/90.0 [00:00