{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "arctic-admission", "metadata": {}, "outputs": [], "source": [ "# Install transformers, the library for Natural Language Understanding (NLU)\n", "pip install transformers" ] }, { "cell_type": "code", "execution_count": 25, "id": "psychological-institution", "metadata": {}, "outputs": [], "source": [ "import os\n", "import requests\n", "import json\n", "from transformers import AutoTokenizer\n", "from transformers import DistilBertForQuestionAnswering\n", "from transformers import DistilBertTokenizer\n", "from transformers import pipeline\n", "from transformers import DistilBertTokenizerFast\n", "from transformers import DistilBertConfig\n", "from transformers import Trainer\n", "from transformers import TrainingArguments" ] }, { "cell_type": "code", "execution_count": 8, "id": "illegal-sitting", "metadata": {}, "outputs": [], "source": [ "model_name = 'distilbert-base-uncased-distilled-squad'\n", "model = DistilBertForQuestionAnswering.from_pretrained(model_name)\n", "tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)" ] }, { "cell_type": "code", "execution_count": 9, "id": "overall-ideal", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'score': 0.56325364112854,\n", " 'start': 118,\n", " 'end': 132,\n", " 'answer': 'United Nations'}" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#test the model using nlp pipeline\n", "context = \"The Intergovernmental Panel on Climate Change (IPCC) is a scientific intergovernmental body under the auspices of the United Nations, set up at the request of member governments. It was first established in 1988 by two United Nations organizations, the World Meteorological Organization (WMO) and the United Nations Environment Programme (UNEP), and later endorsed by the United Nations General Assembly through Resolution 43/53. Membership of the IPCC is open to all members of the WMO and UNEP. The IPCC produces reports that support the United Nations Framework Convention on Climate Change (UNFCCC), which is the main international treaty on climate change. The ultimate objective of the UNFCCC is to \\\"stabilize greenhouse gas concentrations in the atmosphere at a level that would prevent dangerous anthropogenic [i.e., human-induced] interference with the climate system\\\". IPCC reports cover \\\"the scientific, technical and socio-economic information relevant to understanding the scientific basis of risk of human-induced climate change, its potential impacts and options for adaptation and mitigation.\\\"\"\n", "question = \"What organization is the IPCC a part of?\"\n", "\n", "nlp = pipeline('question-answering', model=model, tokenizer=tokenizer)\n", "\n", "nlp({\n", " 'question': question,\n", " 'context': context\n", "})" ] }, { "cell_type": "code", "execution_count": 3, "id": "unexpected-ensemble", "metadata": {}, "outputs": [], "source": [ "# download the dataset\n", "if not os.path.exists('squad'):\n", " os.mkdir('squad')\n", "\n", "url = 'https://rajpurkar.github.io/SQuAD-explorer/dataset/'\n", "res = requests.get(f'{url}train-v2.0.json')\n", "\n", "# loop through\n", "for file in ['train-v2.0.json', 'dev-v2.0.json']:\n", " # make the request to download data over HTTP\n", " res = requests.get(f'{url}{file}')\n", " # write to file\n", " with open(f'./squad/{file}', 'wb') as f:\n", " for chunk in res.iter_content(chunk_size=4):\n", " f.write(chunk)" ] }, { "cell_type": "code", "execution_count": 4, "id": "unnecessary-gazette", "metadata": {}, "outputs": [], "source": [ "def read_squad(path):\n", " # open JSON file and load intro dictionary\n", " with open(path, 'rb') as f:\n", " squad_dict = json.load(f)\n", "\n", " # initialize lists for contexts, questions, and answers\n", " contexts = []\n", " questions = []\n", " answers = []\n", " # iterate through all data in squad data\n", " for group in squad_dict['data']:\n", " for passage in group['paragraphs']:\n", " context = passage['context']\n", " for qa in passage['qas']:\n", " question = qa['question']\n", " # check if we need to be extracting from 'answers' or 'plausible_answers'\n", " if 'plausible_answers' in qa.keys():\n", " access = 'plausible_answers'\n", " else:\n", " access = 'answers'\n", " for answer in qa[access]:\n", " # append data to lists\n", " contexts.append(context)\n", " questions.append(question)\n", " answers.append(answer)\n", " # return formatted data lists\n", " return contexts, questions, answers" ] }, { "cell_type": "code", "execution_count": 5, "id": "secure-minimum", "metadata": {}, "outputs": [], "source": [ "# execute our read SQuAD function for training and validation sets\n", "train_contexts, train_questions, train_answers = read_squad('data/train-v2.0.json')\n", "val_contexts, val_questions, val_answers = read_squad('data/dev-v2.0.json')" ] }, { "cell_type": "code", "execution_count": 26, "id": "welsh-correspondence", "metadata": {}, "outputs": [], "source": [ "# get the character position at which the answer ends in the passage\n", "def add_end_idx(answers, contexts):\n", " # loop through each answer-context pair\n", " for answer, context in zip(answers, contexts):\n", " # gold_text refers to the answer we are expecting to find in context\n", " gold_text = answer['text']\n", " # we already know the start index\n", " start_idx = answer['answer_start']\n", " # and ideally this would be the end index...\n", " end_idx = start_idx + len(gold_text)\n", "\n", " # ...however, sometimes squad answers are off by a character or two\n", " if context[start_idx:end_idx] == gold_text:\n", " # if the answer is not off :)\n", " answer['answer_end'] = end_idx\n", " else:\n", " # this means the answer is off by 1-2 tokens\n", " for n in [1, 2]:\n", " if context[start_idx-n:end_idx-n] == gold_text:\n", " answer['answer_start'] = start_idx - n\n", " answer['answer_end'] = end_idx - n" ] }, { "cell_type": "code", "execution_count": 7, "id": "mediterranean-persian", "metadata": {}, "outputs": [], "source": [ "#get the character position at which the answer ends in the passage\n", "add_end_idx(train_answers, train_contexts)\n", "add_end_idx(val_answers, val_contexts)" ] }, { "cell_type": "code", "execution_count": 11, "id": "atomic-zimbabwe", "metadata": {}, "outputs": [], "source": [ "#tokenize our context/question pairs.\n", "train_encodings = tokenizer(train_contexts, train_questions, truncation=True, padding=True)\n", "val_encodings = tokenizer(val_contexts, val_questions, truncation=True, padding=True)" ] }, { "cell_type": "code", "execution_count": 12, "id": "lightweight-carroll", "metadata": {}, "outputs": [], "source": [ "# convert our character start/end positions to token start/end positions\n", "def add_token_positions(encodings, answers):\n", " start_positions = []\n", " end_positions = []\n", " for i in range(len(answers)):\n", " start_positions.append(encodings.char_to_token(i, answers[i]['answer_start']))\n", " end_positions.append(encodings.char_to_token(i, answers[i]['answer_end'] - 1))\n", "\n", " # if start position is None, the answer passage has been truncated\n", " if start_positions[-1] is None:\n", " start_positions[-1] = tokenizer.model_max_length\n", " if end_positions[-1] is None:\n", " end_positions[-1] = tokenizer.model_max_length\n", "\n", " encodings.update({'start_positions': start_positions, 'end_positions': end_positions})" ] }, { "cell_type": "code", "execution_count": 13, "id": "deadly-prediction", "metadata": {}, "outputs": [], "source": [ "# convert our character start/end positions to token start/end positions\n", "add_token_positions(train_encodings, train_answers)\n", "add_token_positions(val_encodings, val_answers)" ] }, { "cell_type": "code", "execution_count": 14, "id": "heard-halifax", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "class SquadDataset(torch.utils.data.Dataset):\n", " def __init__(self, encodings):\n", " self.encodings = encodings\n", "\n", " def __getitem__(self, idx):\n", " return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n", "\n", " def __len__(self):\n", " return len(self.encodings.input_ids)\n", "\n", "train_dataset = SquadDataset(train_encodings)\n", "val_dataset = SquadDataset(val_encodings)" ] }, { "cell_type": "code", "execution_count": 16, "id": "legendary-employment", "metadata": {}, "outputs": [], "source": [ "# The steps above prepared the datasets in the way that the trainer is expected. \n", "# Now all we need to do is create a model to fine-tune, define the TrainingArguments/TFTrainingArguments \n", "# and instantiate a Trainer/TFTrainer.\n", "training_args = TrainingArguments(\n", " output_dir='./results', # output directory\n", " num_train_epochs=1, # total number of training epochs\n", " per_device_train_batch_size=16, # batch size per device during training\n", " per_device_eval_batch_size=64, # batch size for evaluation\n", " warmup_steps=10, # number of warmup steps for learning rate scheduler\n", " weight_decay=0.01, # strength of weight decay\n", " logging_dir='./logs', # directory for storing logs\n", " logging_steps=10\n", ")\n", "\n", "trainer = Trainer(\n", " model=model, # the instantiated transformers model to be trained\n", " args=training_args, # training arguments, defined above\n", " train_dataset=train_dataset, # training dataset\n", " eval_dataset=val_dataset # evaluation dataset\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "entertaining-stuff", "metadata": {}, "outputs": [], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": 18, "id": "talented-screening", "metadata": {}, "outputs": [], "source": [ "# save the trained model locally \n", "torch.save(model.state_dict(), 'model.pth')" ] }, { "cell_type": "code", "execution_count": 19, "id": "fancy-sampling", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "model.pth\n" ] } ], "source": [ "# convert the model.pth to tar.gz file\n", "!tar -cvzf model.tar.gz model.pth" ] }, { "cell_type": "code", "execution_count": 20, "id": "demanding-construction", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "s3://gkrish-sagemaker/model/model.tar.gz\n" ] } ], "source": [ "# Upload model.tar.gz file to S3 bucket\n", "from sagemaker.s3 import S3Uploader\n", "model_artifact = S3Uploader.upload('model.tar.gz','s3://gkrish-sagemaker/model')\n", "print(model_artifact)" ] }, { "cell_type": "code", "execution_count": 22, "id": "worst-float", "metadata": {}, "outputs": [], "source": [ "# Create a pytorch model using the model file saved in S3\n", "import sagemaker\n", "from sagemaker.pytorch.model import PyTorchModel\n", "bertmodel = PyTorchModel(entry_point='inference.py', \n", " source_dir='scripts',\n", " model_data='s3://gkrish-sagemaker/model/model.tar.gz', \n", " role=sagemaker.get_execution_role(), \n", " framework_version='1.5', \n", " py_version='py3')" ] }, { "cell_type": "code", "execution_count": 23, "id": "advised-command", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-------------!" ] } ], "source": [ "# Deploy the model and create endpoint\n", "predictor = bertmodel.deploy(initial_instance_count=1, \n", " instance_type='ml.m5.xlarge')" ] } ], "metadata": { "kernelspec": { "display_name": "conda_pytorch_p36", "language": "python", "name": "conda_pytorch_p36" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" } }, "nbformat": 4, "nbformat_minor": 5 }