{
"cells": [
{
"cell_type": "markdown",
"id": "f933a167-9e20-465d-8714-52c4af80bf51",
"metadata": {},
"source": [
"# (Optional) 0. TrOCR 훈련 텍스트 데이터 생성\n",
"\n",
"--- \n",
"\n",
"TrOCR 훈련 데이터를 만들기 위한 데이터 가공 과정입니다. 이미 ocr_dataset_poc.csv로 저장되어 있기 때문에 이 모듈을 실습하실 필요는 없지만, 데이터셋 가공 과정을 파악하거나 이에 대한 영감을 얻고 싶으신 분들은 코드를 한 줄씩 수행해 보세요. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6c1372df-c71e-4774-be58-8ef6affcb0e4",
"metadata": {},
"outputs": [],
"source": [
"!pip install -r requirements.txt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "27e79bc8-945e-41d8-8f31-84056fc59f24",
"metadata": {},
"outputs": [],
"source": [
"import re\n",
"import random\n",
"import pandas as pd\n",
"import multiprocessing\n",
"from tqdm import tqdm\n",
"from datasets import load_dataset, load_metric\n",
"from IPython.display import display, HTML\n",
"from collections.abc import Iterable\n",
"from joblib import Parallel, delayed\n",
"from kiwipiepy import Kiwi\n",
"kiwi = Kiwi()\n",
"num_cores = multiprocessing.cpu_count()"
]
},
{
"cell_type": "markdown",
"id": "b2c869ea-7670-4468-8694-dd4c6db652de",
"metadata": {},
"source": [
"
\n",
"\n",
"## 뉴스 데이터셋 가공\n",
"---\n",
"저자가 가공한 뉴스 데이터셋을 가공하여 훈련 데이터를 생성합니다. 데이터셋의 샘플 개수는 약 2만여건에 불과하지만, 각 레코드를 문장 분리하여 신규 데이터셋을 생성하면 샘플 개수가 증가합니다.\n",
"- 데이터셋 출처: https://huggingface.co/datasets/daekeun-ml/naver-news-summarization-ko"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "60f33e53-df27-41d1-a94b-9df16ba1bd84",
"metadata": {},
"outputs": [],
"source": [
"news_datasets = load_dataset('daekeun-ml/naver-news-summarization-ko')\n",
"news_df = pd.DataFrame(news_datasets[\"train\"]['document'], columns=['document'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1579a7f4-8140-4a4a-81cc-8429777ec911",
"metadata": {},
"outputs": [],
"source": [
"def split_sentences(datasets,idx):\n",
" document = datasets[idx]['document']\n",
" splits = kiwi.split_into_sents(document, return_tokens=False)\n",
" return [s.text for s in splits] \n",
"\n",
"def flatten(lis):\n",
" for item in lis:\n",
" if isinstance(item, Iterable) and not isinstance(item, str):\n",
" for x in flatten(item):\n",
" yield x\n",
" else: \n",
" yield item"
]
},
{
"cell_type": "markdown",
"id": "2e91a941-76e0-43f0-bc1f-48a96b90ba45",
"metadata": {},
"source": [
"### 문장 분리\n",
"\n",
"Kiwi 파이썬 래퍼 (https://github.com/bab2min/kiwipiepy) 를 사용하여 문장을 분리합니다. 문장 분리에 많은 시간이 소요되는데, 병렬 처리를 통해 처리 시간을 단축할 수 있습니다."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "93cca20a-0352-45f6-b5ef-6316b5f7742a",
"metadata": {},
"outputs": [],
"source": [
"num_samples = len(news_df)\n",
"#num_samples = 200\n",
"out = Parallel(n_jobs=num_cores, backend='threading')(\n",
" delayed(split_sentences)(datasets=news_datasets['train'],idx=idx) for idx in tqdm(range(0, num_samples), miniters=1000)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "57347484-e019-4e65-8492-1c3aa873693a",
"metadata": {},
"outputs": [],
"source": [
"def preprocessing_news(df):\n",
" import re\n",
" \n",
" # Remove punctuations\n",
" df[\"document\"] = df[\"document\"].apply(lambda s: ' '.join(re.sub(\"[.,!?:;-=...@#_]\", \" \", str(s)).split()))\n",
" df[\"document\"] = df[\"document\"].apply(lambda s: ' '.join(re.sub(\"[ᅳㅡ]\", \"\", str(s)).split()))\n",
" df[\"document\"] = df[\"document\"].apply(lambda s: ' '.join(re.sub(\"[--]\", \"\", str(s)).split()))\n",
" df[\"document\"] = df[\"document\"].apply(lambda s: ' '.join(re.sub(\"[\\.\\,\\(\\)\\{\\}\\[\\]\\`\\'\\!\\?\\:\\;\\-\\=]\", \" \", str(s)).split()))\n",
" df[\"document\"] = df[\"document\"].apply(lambda s: ' '.join(re.sub(\"[-=+,#/\\?:^$.@*\\\"※~&%ㆍ!』\\\\‘|\\(\\)\\[\\]\\<\\>`\\'…》]\", \"\", s).split()))\n",
"\n",
" # Remove links\n",
" df[\"document\"] = df[\"document\"].apply(lambda s: ' '.join(re.sub(\"(w+://S+)\", \" \", s).split()))\n",
" \n",
" return df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18676849-6546-47e7-91cd-67dbef24bff2",
"metadata": {},
"outputs": [],
"source": [
"news_texts = list(flatten(out))\n",
"news_df = pd.DataFrame(news_texts, columns=['document'])\n",
"news_df = preprocessing_news(news_df)\n",
"# news_df[\"document\"] = news_df[\"document\"].apply(lambda s: ' '.join(re.sub(\"[-=+,#/\\?:^$.@*\\\"※~&%ㆍ!』\\\\‘|\\(\\)\\[\\]\\<\\>`\\'…》]\", \"\", s).split()))\n",
"# news_df[\"document\"] = news_df[\"document\"].apply(lambda s: ' '.join(re.sub(\"(w+://S+)\", \" \", s).split()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "775b8c38-e876-4597-b98b-5c57a28d4e26",
"metadata": {},
"outputs": [],
"source": [
"news_df.head()"
]
},
{
"cell_type": "markdown",
"id": "8c16d4a3-d3f0-46a2-8fbb-041386f6b4ff",
"metadata": {},
"source": [
"
\n",
"\n",
"## 네이버 영화 리뷰 데이터셋 가공 \n",
"---\n",
"\n",
"네이버 영화 리뷰 데이터셋을 가공합니다. 더 많은 훈련 데이터를 확보하기 위해, 각 샘플의 문장 길이가 일정 이상일 때 문장 분리를 수행합니다.\n",
"- 데이터셋 출처: https://github.com/e9t/nsmc"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "90d0a60c-3129-4186-b1e5-26bdfa27b787",
"metadata": {},
"outputs": [],
"source": [
"!curl -O https://raw.githubusercontent.com/e9t/nsmc/master/ratings_train.txt\n",
"!curl -O https://raw.githubusercontent.com/e9t/nsmc/master/ratings_test.txt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8d6408a2-d41e-4133-8dd9-51af386d23b2",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pda\n",
"import numpy as np\n",
"\n",
"train_df = pd.read_csv('ratings_train.txt', header=0, delimiter='\\t')\n",
"test_df = pd.read_csv('ratings_test.txt', header=0, delimiter='\\t')\n",
"df = pd.concat([train_df, test_df], ignore_index=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "90ba9b71-7f2a-4378-a576-b5bd1570d5d0",
"metadata": {},
"outputs": [],
"source": [
"def preprocessing_nsmc(df):\n",
" import re\n",
" \n",
" # Remove consonant & vowel for Korean language\n",
" df[\"clean_document\"] = df[\"document\"].apply(lambda s: ' '.join(re.sub(\"([ㄱ-ㅎㅏ-ㅣ]+)\", \"\", str(s)).split()))\n",
" \n",
" # Remove punctuations\n",
" df[\"clean_document\"] = df[\"clean_document\"].apply(lambda s: ' '.join(re.sub(\"[.,!?:;-=...@#_]\", \" \", str(s)).split()))\n",
" df[\"clean_document\"] = df[\"clean_document\"].apply(lambda s: ' '.join(re.sub(\"[ᅳㅡ]\", \"\", str(s)).split()))\n",
" df[\"clean_document\"] = df[\"clean_document\"].apply(lambda s: ' '.join(re.sub(\"[--]\", \"\", str(s)).split()))\n",
" df[\"clean_document\"] = df[\"clean_document\"].apply(lambda s: ' '.join(re.sub(\"[\\.\\,\\(\\)\\{\\}\\[\\]\\`\\'\\!\\?\\:\\;\\-\\=]\", \" \", str(s)).split()))\n",
" df[\"clean_document\"] = df[\"clean_document\"].apply(lambda s: ' '.join(re.sub(\"[-=+,#/\\?:^$.@*\\\"※~&%ㆍ!』\\\\‘|\\(\\)\\[\\]\\<\\>`\\'…》]\", \"\", s).split()))\n",
"\n",
" # Remove non-korean characters\n",
" df[\"clean_document\"] = df[\"clean_document\"].apply(lambda s: ' '.join(re.sub(\"[^가-힣ㄱ-하-ㅣ\\\\s]\", \"\", str(s)).split()))\n",
"\n",
" # Remove links\n",
" df[\"clean_document\"] = df[\"clean_document\"].apply(lambda s: ' '.join(re.sub(\"(w+://S+)\", \" \", s).split()))\n",
" \n",
" return df\n",
"\n",
"df = preprocessing_nsmc(df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "13a45101-d293-4045-9932-00d231ac08ef",
"metadata": {},
"outputs": [],
"source": [
"min_chars = 5\n",
"max_chars = 32\n",
"nsmc_short_df = df[(df[\"clean_document\"].str.len() >= min_chars) & (df[\"clean_document\"].str.len() < max_chars)]\n",
"nsmc_long_df = df[df[\"clean_document\"].str.len() >= max_chars]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9bb2f079-1fb5-4521-8abb-63ea46b2422a",
"metadata": {},
"outputs": [],
"source": [
"def split_sentences(df,idx):\n",
" document = df['document'].iloc[idx]\n",
" splits = kiwi.split_into_sents(document, return_tokens=False)\n",
" return [s.text for s in splits] \n",
"\n",
"num_samples = len(nsmc_long_df)\n",
"out = Parallel(n_jobs=num_cores, backend='threading')(\n",
" delayed(split_sentences)(df=nsmc_long_df,idx=idx) for idx in tqdm(range(0, num_samples), miniters=10000)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "030477eb-ac49-4f34-aed1-064f7cbbc1f6",
"metadata": {},
"outputs": [],
"source": [
"nsmc_texts = list(flatten(out))\n",
"nsmc_long_df = pd.DataFrame(nsmc_texts, columns=['document'])\n",
"nsmc_long_df = preprocessing_nsmc(nsmc_long_df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cc5df08d-99a2-443b-b776-dba48201c3e4",
"metadata": {},
"outputs": [],
"source": [
"nsmc_short_df = nsmc_short_df[\"clean_document\"].to_frame(name=\"document\")\n",
"nsmc_long_df = nsmc_long_df[\"clean_document\"].to_frame(name=\"document\")"
]
},
{
"cell_type": "markdown",
"id": "d4e4cd9d-ad7c-4b96-b92d-acbdf53c5543",
"metadata": {},
"source": [
"
\n",
"\n",
"## 챗봇 데이터셋 가공\n",
"---\n",
"\n",
"챗봇 데이터셋을 가공합니다.\n",
"- 데이터셋 출처: https://github.com/songys/Chatbot_data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "349f2f36-5ca7-4018-8079-71239cb3f630",
"metadata": {},
"outputs": [],
"source": [
"import urllib\n",
"urllib.request.urlretrieve(\"https://raw.githubusercontent.com/songys/Chatbot_data/master/ChatbotData.csv\", \n",
" filename=\"chatbot_train.csv\")\n",
"chatbot_df = pd.read_csv('chatbot_train.csv')\n",
"chatbot_df.head()"
]
},
{
"cell_type": "markdown",
"id": "8650f856-3d43-4b71-b8b0-77534920066f",
"metadata": {},
"source": [
"질문 문장과 응답 문장을 분리하여 개별 데이터프레임을 생성합니다."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5e605a9a-8b0c-42e6-9446-30cbb5d651ed",
"metadata": {},
"outputs": [],
"source": [
"chatbot_q_df = chatbot_df['Q'].to_frame()\n",
"chatbot_q_df.columns = ['document']\n",
"chatbot_q_df = chatbot_q_df.drop_duplicates()\n",
"\n",
"chatbot_a_df = chatbot_df['A'].to_frame()\n",
"chatbot_a_df.columns = ['document']\n",
"chatbot_a_df = chatbot_a_df.drop_duplicates()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c85c6d94-c009-47c6-9e42-b26e88666659",
"metadata": {},
"outputs": [],
"source": [
"def preprocessing_chatbot(df, min_chars=4):\n",
" df[\"document\"] = df[\"document\"].apply(lambda s: ' '.join(re.sub(\"[-=+,#/\\?:^$.@*\\\"※~&%ㆍ!』\\\\‘|\\(\\)\\[\\]\\<\\>`\\'…》]\", \"\", s).split()))\n",
" # Remove rows if text has less than min characters\n",
" df = df[df[\"document\"].str.len() >= min_chars]\n",
"\n",
" return df\n",
"\n",
"chatbot_q_df = preprocessing_chatbot(chatbot_q_df)\n",
"chatbot_a_df = preprocessing_chatbot(chatbot_a_df)"
]
},
{
"cell_type": "markdown",
"id": "76ec2ff8-1b19-41a7-a128-e9f623657f9c",
"metadata": {},
"source": [
"
\n",
"\n",
"## 최종 데이터셋 취합\n",
"---"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9de74c99-ef88-47eb-bb34-7931647ce0af",
"metadata": {},
"outputs": [],
"source": [
"nsmc_short_df['category'] = 'nsmc'\n",
"nsmc_long_df['category'] = 'nsmc'\n",
"news_df['category'] = 'news'\n",
"chatbot_q_df['category'] = 'chatbot'\n",
"chatbot_a_df['category'] = 'chatbot'\n",
"\n",
"final_df = pd.concat(\n",
" [nsmc_short_df, nsmc_long_df, news_df, chatbot_q_df, chatbot_a_df], \n",
" ignore_index=True\n",
")\n",
"final_df = final_df[final_df[\"document\"].str.len() >= 5]\n",
"final_df['document'] = final_df['document'].str.strip()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "931c82f5-09f9-4299-84e4-7611a558e67a",
"metadata": {},
"outputs": [],
"source": [
"final_df['category'].value_counts()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4048d75e-a809-4848-949d-3f888d18e3d8",
"metadata": {},
"outputs": [],
"source": [
"final_df.to_csv('ocr_dataset_poc.csv', index=False)"
]
},
{
"cell_type": "markdown",
"id": "e60fa963-67a7-4533-873e-cfeefc584591",
"metadata": {},
"source": [
"
\n",
"\n",
"## Clean up\n",
"---"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3b27db6a-e3a1-49a3-a7ce-23d219107597",
"metadata": {},
"outputs": [],
"source": [
"!rm ratings_train.txt ratings_test.txt chatbot_train.csv"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "conda_pytorch_p38",
"language": "python",
"name": "conda_pytorch_p38"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}