{ "cells": [ { "cell_type": "markdown", "id": "f933a167-9e20-465d-8714-52c4af80bf51", "metadata": {}, "source": [ "# (Optional) 0. TrOCR 훈련 텍스트 데이터 생성\n", "\n", "--- \n", "\n", "TrOCR 훈련 데이터를 만들기 위한 데이터 가공 과정입니다. 이미 ocr_dataset_poc.csv로 저장되어 있기 때문에 이 모듈을 실습하실 필요는 없지만, 데이터셋 가공 과정을 파악하거나 이에 대한 영감을 얻고 싶으신 분들은 코드를 한 줄씩 수행해 보세요. " ] }, { "cell_type": "code", "execution_count": null, "id": "6c1372df-c71e-4774-be58-8ef6affcb0e4", "metadata": {}, "outputs": [], "source": [ "!pip install -r requirements.txt" ] }, { "cell_type": "code", "execution_count": null, "id": "27e79bc8-945e-41d8-8f31-84056fc59f24", "metadata": {}, "outputs": [], "source": [ "import re\n", "import random\n", "import pandas as pd\n", "import multiprocessing\n", "from tqdm import tqdm\n", "from datasets import load_dataset, load_metric\n", "from IPython.display import display, HTML\n", "from collections.abc import Iterable\n", "from joblib import Parallel, delayed\n", "from kiwipiepy import Kiwi\n", "kiwi = Kiwi()\n", "num_cores = multiprocessing.cpu_count()" ] }, { "cell_type": "markdown", "id": "b2c869ea-7670-4468-8694-dd4c6db652de", "metadata": {}, "source": [ "
\n", "\n", "## 뉴스 데이터셋 가공\n", "---\n", "저자가 가공한 뉴스 데이터셋을 가공하여 훈련 데이터를 생성합니다. 데이터셋의 샘플 개수는 약 2만여건에 불과하지만, 각 레코드를 문장 분리하여 신규 데이터셋을 생성하면 샘플 개수가 증가합니다.\n", "- 데이터셋 출처: https://huggingface.co/datasets/daekeun-ml/naver-news-summarization-ko" ] }, { "cell_type": "code", "execution_count": null, "id": "60f33e53-df27-41d1-a94b-9df16ba1bd84", "metadata": {}, "outputs": [], "source": [ "news_datasets = load_dataset('daekeun-ml/naver-news-summarization-ko')\n", "news_df = pd.DataFrame(news_datasets[\"train\"]['document'], columns=['document'])" ] }, { "cell_type": "code", "execution_count": null, "id": "1579a7f4-8140-4a4a-81cc-8429777ec911", "metadata": {}, "outputs": [], "source": [ "def split_sentences(datasets,idx):\n", " document = datasets[idx]['document']\n", " splits = kiwi.split_into_sents(document, return_tokens=False)\n", " return [s.text for s in splits] \n", "\n", "def flatten(lis):\n", " for item in lis:\n", " if isinstance(item, Iterable) and not isinstance(item, str):\n", " for x in flatten(item):\n", " yield x\n", " else: \n", " yield item" ] }, { "cell_type": "markdown", "id": "2e91a941-76e0-43f0-bc1f-48a96b90ba45", "metadata": {}, "source": [ "### 문장 분리\n", "\n", "Kiwi 파이썬 래퍼 (https://github.com/bab2min/kiwipiepy) 를 사용하여 문장을 분리합니다. 문장 분리에 많은 시간이 소요되는데, 병렬 처리를 통해 처리 시간을 단축할 수 있습니다." ] }, { "cell_type": "code", "execution_count": null, "id": "93cca20a-0352-45f6-b5ef-6316b5f7742a", "metadata": {}, "outputs": [], "source": [ "num_samples = len(news_df)\n", "#num_samples = 200\n", "out = Parallel(n_jobs=num_cores, backend='threading')(\n", " delayed(split_sentences)(datasets=news_datasets['train'],idx=idx) for idx in tqdm(range(0, num_samples), miniters=1000)\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "57347484-e019-4e65-8492-1c3aa873693a", "metadata": {}, "outputs": [], "source": [ "def preprocessing_news(df):\n", " import re\n", " \n", " # Remove punctuations\n", " df[\"document\"] = df[\"document\"].apply(lambda s: ' '.join(re.sub(\"[.,!?:;-=...@#_]\", \" \", str(s)).split()))\n", " df[\"document\"] = df[\"document\"].apply(lambda s: ' '.join(re.sub(\"[ᅳㅡ]\", \"\", str(s)).split()))\n", " df[\"document\"] = df[\"document\"].apply(lambda s: ' '.join(re.sub(\"[--]\", \"\", str(s)).split()))\n", " df[\"document\"] = df[\"document\"].apply(lambda s: ' '.join(re.sub(\"[\\.\\,\\(\\)\\{\\}\\[\\]\\`\\'\\!\\?\\:\\;\\-\\=]\", \" \", str(s)).split()))\n", " df[\"document\"] = df[\"document\"].apply(lambda s: ' '.join(re.sub(\"[-=+,#/\\?:^$.@*\\\"※~&%ㆍ!』\\\\‘|\\(\\)\\[\\]\\<\\>`\\'…》]\", \"\", s).split()))\n", "\n", " # Remove links\n", " df[\"document\"] = df[\"document\"].apply(lambda s: ' '.join(re.sub(\"(w+://S+)\", \" \", s).split()))\n", " \n", " return df" ] }, { "cell_type": "code", "execution_count": null, "id": "18676849-6546-47e7-91cd-67dbef24bff2", "metadata": {}, "outputs": [], "source": [ "news_texts = list(flatten(out))\n", "news_df = pd.DataFrame(news_texts, columns=['document'])\n", "news_df = preprocessing_news(news_df)\n", "# news_df[\"document\"] = news_df[\"document\"].apply(lambda s: ' '.join(re.sub(\"[-=+,#/\\?:^$.@*\\\"※~&%ㆍ!』\\\\‘|\\(\\)\\[\\]\\<\\>`\\'…》]\", \"\", s).split()))\n", "# news_df[\"document\"] = news_df[\"document\"].apply(lambda s: ' '.join(re.sub(\"(w+://S+)\", \" \", s).split()))" ] }, { "cell_type": "code", "execution_count": null, "id": "775b8c38-e876-4597-b98b-5c57a28d4e26", "metadata": {}, "outputs": [], "source": [ "news_df.head()" ] }, { "cell_type": "markdown", "id": "8c16d4a3-d3f0-46a2-8fbb-041386f6b4ff", "metadata": {}, "source": [ "
\n", "\n", "## 네이버 영화 리뷰 데이터셋 가공 \n", "---\n", "\n", "네이버 영화 리뷰 데이터셋을 가공합니다. 더 많은 훈련 데이터를 확보하기 위해, 각 샘플의 문장 길이가 일정 이상일 때 문장 분리를 수행합니다.\n", "- 데이터셋 출처: https://github.com/e9t/nsmc" ] }, { "cell_type": "code", "execution_count": null, "id": "90d0a60c-3129-4186-b1e5-26bdfa27b787", "metadata": {}, "outputs": [], "source": [ "!curl -O https://raw.githubusercontent.com/e9t/nsmc/master/ratings_train.txt\n", "!curl -O https://raw.githubusercontent.com/e9t/nsmc/master/ratings_test.txt" ] }, { "cell_type": "code", "execution_count": null, "id": "8d6408a2-d41e-4133-8dd9-51af386d23b2", "metadata": {}, "outputs": [], "source": [ "import pandas as pda\n", "import numpy as np\n", "\n", "train_df = pd.read_csv('ratings_train.txt', header=0, delimiter='\\t')\n", "test_df = pd.read_csv('ratings_test.txt', header=0, delimiter='\\t')\n", "df = pd.concat([train_df, test_df], ignore_index=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "90ba9b71-7f2a-4378-a576-b5bd1570d5d0", "metadata": {}, "outputs": [], "source": [ "def preprocessing_nsmc(df):\n", " import re\n", " \n", " # Remove consonant & vowel for Korean language\n", " df[\"clean_document\"] = df[\"document\"].apply(lambda s: ' '.join(re.sub(\"([ㄱ-ㅎㅏ-ㅣ]+)\", \"\", str(s)).split()))\n", " \n", " # Remove punctuations\n", " df[\"clean_document\"] = df[\"clean_document\"].apply(lambda s: ' '.join(re.sub(\"[.,!?:;-=...@#_]\", \" \", str(s)).split()))\n", " df[\"clean_document\"] = df[\"clean_document\"].apply(lambda s: ' '.join(re.sub(\"[ᅳㅡ]\", \"\", str(s)).split()))\n", " df[\"clean_document\"] = df[\"clean_document\"].apply(lambda s: ' '.join(re.sub(\"[--]\", \"\", str(s)).split()))\n", " df[\"clean_document\"] = df[\"clean_document\"].apply(lambda s: ' '.join(re.sub(\"[\\.\\,\\(\\)\\{\\}\\[\\]\\`\\'\\!\\?\\:\\;\\-\\=]\", \" \", str(s)).split()))\n", " df[\"clean_document\"] = df[\"clean_document\"].apply(lambda s: ' '.join(re.sub(\"[-=+,#/\\?:^$.@*\\\"※~&%ㆍ!』\\\\‘|\\(\\)\\[\\]\\<\\>`\\'…》]\", \"\", s).split()))\n", "\n", " # Remove non-korean characters\n", " df[\"clean_document\"] = df[\"clean_document\"].apply(lambda s: ' '.join(re.sub(\"[^가-힣ㄱ-하-ㅣ\\\\s]\", \"\", str(s)).split()))\n", "\n", " # Remove links\n", " df[\"clean_document\"] = df[\"clean_document\"].apply(lambda s: ' '.join(re.sub(\"(w+://S+)\", \" \", s).split()))\n", " \n", " return df\n", "\n", "df = preprocessing_nsmc(df)" ] }, { "cell_type": "code", "execution_count": null, "id": "13a45101-d293-4045-9932-00d231ac08ef", "metadata": {}, "outputs": [], "source": [ "min_chars = 5\n", "max_chars = 32\n", "nsmc_short_df = df[(df[\"clean_document\"].str.len() >= min_chars) & (df[\"clean_document\"].str.len() < max_chars)]\n", "nsmc_long_df = df[df[\"clean_document\"].str.len() >= max_chars]" ] }, { "cell_type": "code", "execution_count": null, "id": "9bb2f079-1fb5-4521-8abb-63ea46b2422a", "metadata": {}, "outputs": [], "source": [ "def split_sentences(df,idx):\n", " document = df['document'].iloc[idx]\n", " splits = kiwi.split_into_sents(document, return_tokens=False)\n", " return [s.text for s in splits] \n", "\n", "num_samples = len(nsmc_long_df)\n", "out = Parallel(n_jobs=num_cores, backend='threading')(\n", " delayed(split_sentences)(df=nsmc_long_df,idx=idx) for idx in tqdm(range(0, num_samples), miniters=10000)\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "030477eb-ac49-4f34-aed1-064f7cbbc1f6", "metadata": {}, "outputs": [], "source": [ "nsmc_texts = list(flatten(out))\n", "nsmc_long_df = pd.DataFrame(nsmc_texts, columns=['document'])\n", "nsmc_long_df = preprocessing_nsmc(nsmc_long_df)" ] }, { "cell_type": "code", "execution_count": null, "id": "cc5df08d-99a2-443b-b776-dba48201c3e4", "metadata": {}, "outputs": [], "source": [ "nsmc_short_df = nsmc_short_df[\"clean_document\"].to_frame(name=\"document\")\n", "nsmc_long_df = nsmc_long_df[\"clean_document\"].to_frame(name=\"document\")" ] }, { "cell_type": "markdown", "id": "d4e4cd9d-ad7c-4b96-b92d-acbdf53c5543", "metadata": {}, "source": [ "
\n", "\n", "## 챗봇 데이터셋 가공\n", "---\n", "\n", "챗봇 데이터셋을 가공합니다.\n", "- 데이터셋 출처: https://github.com/songys/Chatbot_data" ] }, { "cell_type": "code", "execution_count": null, "id": "349f2f36-5ca7-4018-8079-71239cb3f630", "metadata": {}, "outputs": [], "source": [ "import urllib\n", "urllib.request.urlretrieve(\"https://raw.githubusercontent.com/songys/Chatbot_data/master/ChatbotData.csv\", \n", " filename=\"chatbot_train.csv\")\n", "chatbot_df = pd.read_csv('chatbot_train.csv')\n", "chatbot_df.head()" ] }, { "cell_type": "markdown", "id": "8650f856-3d43-4b71-b8b0-77534920066f", "metadata": {}, "source": [ "질문 문장과 응답 문장을 분리하여 개별 데이터프레임을 생성합니다." ] }, { "cell_type": "code", "execution_count": null, "id": "5e605a9a-8b0c-42e6-9446-30cbb5d651ed", "metadata": {}, "outputs": [], "source": [ "chatbot_q_df = chatbot_df['Q'].to_frame()\n", "chatbot_q_df.columns = ['document']\n", "chatbot_q_df = chatbot_q_df.drop_duplicates()\n", "\n", "chatbot_a_df = chatbot_df['A'].to_frame()\n", "chatbot_a_df.columns = ['document']\n", "chatbot_a_df = chatbot_a_df.drop_duplicates()" ] }, { "cell_type": "code", "execution_count": null, "id": "c85c6d94-c009-47c6-9e42-b26e88666659", "metadata": {}, "outputs": [], "source": [ "def preprocessing_chatbot(df, min_chars=4):\n", " df[\"document\"] = df[\"document\"].apply(lambda s: ' '.join(re.sub(\"[-=+,#/\\?:^$.@*\\\"※~&%ㆍ!』\\\\‘|\\(\\)\\[\\]\\<\\>`\\'…》]\", \"\", s).split()))\n", " # Remove rows if text has less than min characters\n", " df = df[df[\"document\"].str.len() >= min_chars]\n", "\n", " return df\n", "\n", "chatbot_q_df = preprocessing_chatbot(chatbot_q_df)\n", "chatbot_a_df = preprocessing_chatbot(chatbot_a_df)" ] }, { "cell_type": "markdown", "id": "76ec2ff8-1b19-41a7-a128-e9f623657f9c", "metadata": {}, "source": [ "
\n", "\n", "## 최종 데이터셋 취합\n", "---" ] }, { "cell_type": "code", "execution_count": null, "id": "9de74c99-ef88-47eb-bb34-7931647ce0af", "metadata": {}, "outputs": [], "source": [ "nsmc_short_df['category'] = 'nsmc'\n", "nsmc_long_df['category'] = 'nsmc'\n", "news_df['category'] = 'news'\n", "chatbot_q_df['category'] = 'chatbot'\n", "chatbot_a_df['category'] = 'chatbot'\n", "\n", "final_df = pd.concat(\n", " [nsmc_short_df, nsmc_long_df, news_df, chatbot_q_df, chatbot_a_df], \n", " ignore_index=True\n", ")\n", "final_df = final_df[final_df[\"document\"].str.len() >= 5]\n", "final_df['document'] = final_df['document'].str.strip()" ] }, { "cell_type": "code", "execution_count": null, "id": "931c82f5-09f9-4299-84e4-7611a558e67a", "metadata": {}, "outputs": [], "source": [ "final_df['category'].value_counts()" ] }, { "cell_type": "code", "execution_count": null, "id": "4048d75e-a809-4848-949d-3f888d18e3d8", "metadata": {}, "outputs": [], "source": [ "final_df.to_csv('ocr_dataset_poc.csv', index=False)" ] }, { "cell_type": "markdown", "id": "e60fa963-67a7-4533-873e-cfeefc584591", "metadata": {}, "source": [ "
\n", "\n", "## Clean up\n", "---" ] }, { "cell_type": "code", "execution_count": null, "id": "3b27db6a-e3a1-49a3-a7ce-23d219107597", "metadata": {}, "outputs": [], "source": [ "!rm ratings_train.txt ratings_test.txt chatbot_train.csv" ] } ], "metadata": { "kernelspec": { "display_name": "conda_pytorch_p38", "language": "python", "name": "conda_pytorch_p38" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 5 }