/KoGPT2-FineTuning

๐Ÿ”ฅ Korean GPT-2, KoGPT2 FineTuning cased. ํ•œ๊ตญ์–ด ๊ฐ€์‚ฌ ๋ฐ์ดํ„ฐ ํ•™์Šต ๐Ÿ”ฅ

Primary LanguagePythonApache License 2.0Apache-2.0

KoGPT2-FineTuning

Open In Colab license Apache-2.0 contributions welcome GitHub issues GitHub stars

SKT-AI์—์„œ ์•ฝ 20GB์˜ ํ•œ๊ตญ์–ด ๋ฐ์ดํ„ฐ๋ฅผ Pre-Training ์‹œํ‚จ KoGPT2๋ฅผ ์‚ฌ์šฉํ–ˆ์Šต๋‹ˆ๋‹ค. ์ฒซ ๋ฒˆ์งธ๋กœ ๊ฐ€์‚ฌ ์ž‘์‚ฌ๋ฅผ ์œ„ํ•ด์„œ, ์ €์ž‘๊ถŒ์ด ๋งŒ๋ฃŒ๋œ ์ •์ œ๋œ ๊ฐ€์‚ฌ ๋ฐ์ดํ„ฐ, ์†Œ์„ค, ๊ธฐ์‚ฌ ๋“ฑ์„ Data๋ณ„๋กœ weight๋ฅผ ๋‹ค๋ฅด๊ฒŒ ์ฃผ๋ฉฐ Fine-tuning ํ•˜์˜€์Šต๋‹ˆ๋‹ค. ๋˜ํ•œ ์žฅ๋ฅด๋„ ๋ฐ›์•„์„œ ์Œ์•… ์žฅ๋ฅด๋ณ„ ๊ฐ€์‚ฌ ํ•™์Šต ๊ฒฐ๊ณผ๋ฅผ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋˜ํ•œ Colab์—์„œ๋Š” ์›ํ™œํ•œ ํ•™์Šต์„ ์œ„ํ•ด์„œ Google Drive์™€ Dropbbox์„ ์—ฐ๋™ํ–ˆ์Šต๋‹ˆ๋‹ค. ํ•™์Šตํ•œ ์ค‘๊ฐ„ ๊ฒฐ๊ณผ๋ฅผ Google Drive์—์„œ Dropbbox๋กœ ์ด๋™์‹œํ‚จ ํ›„, Google Drive์—์„œ ํ•ด๋‹น ๊ฒฐ๊ณผ๋ฅผ ์‚ญ์ œํ•˜๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค. ์ด์™€ ๊ด€๋ จ๋œ Code

์Œ์•… ์žฅ๋ฅด๋ณ„๋กœ, CSV ํ˜•์‹์˜ Dataset์„ ๋ฐ›๋Š” ๋ฐ”๋€ Version 2์˜ Code๋กœ KoGPT2-FineTuning ์ž‘์—…์„ ํ•˜๊ธฐ ์–ด๋ ต๋‹ค๋ฉด, Version 1.1์„ ์ด์šฉํ•˜๊ธธ ๋ฐ”๋ž๋‹ˆ๋‹ค.

์•„๋ž˜์—์„œ, ๋‹ค์–‘ํ•œ ํ•œ๊ตญ์–ด ๊ฐ€์‚ฌ๋ฅผ ํ•™์Šตํ•œ ๊ฒฐ๊ณผ๋ฅผ ํ™•์ธ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์šฐ๋ฆฌ๋Š” ์ด์™ธ์—๋„ ๋‹ค์–‘ํ•œ ํ”„๋กœ์ ํŠธ๋ฅผ ์ง„ํ–‰ํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค.

Sample

Data structure

weight Genre lyrics
1100.0 ๋ฐœ๋ผ๋“œ '๋‚ด ๋ง˜์„ ์•Œ์ž–์•„์š”\n\n\n๋ฐ”๋กœ์ฒ˜๋Ÿผ ๋ฉํ•˜๋‹ˆ ์„œ ์žˆ๋Š” ๋ชจ์Šต๋งŒ\n\n\n๋ฐ”๋ผ๋ณด๋‹ค\n\n\nํฌ๊ธฐํ•  ์ˆ˜ ๋ฐ–์— ์—†์–ด์„œ...'
...
3x200000

Fine Tuning

python main.py --epoch=200 --data_file_path=./dataset/lyrics_dataset.csv --save_path=./checkpoint/ --load_path=./checkpoint/genre/KoGPT2_checkpoint_296000.tar --batch_size=1

parser

parser.add_argument('--epoch', type=int, default=200,
					help="epoch ๋ฅผ ํ†ตํ•ด์„œ ํ•™์Šต ๋ฒ”์œ„๋ฅผ ์กฐ์ ˆํ•ฉ๋‹ˆ๋‹ค.")
parser.add_argument('--save_path', type=str, default='./checkpoint/',
					help="ํ•™์Šต ๊ฒฐ๊ณผ๋ฅผ ์ €์žฅํ•˜๋Š” ๊ฒฝ๋กœ์ž…๋‹ˆ๋‹ค.")
parser.add_argument('--load_path', type=str, default='./checkpoint/Alls/KoGPT2_checkpoint_296000.tar', 
					help="ํ•™์Šต๋œ ๊ฒฐ๊ณผ๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๋Š” ๊ฒฝ๋กœ์ž…๋‹ˆ๋‹ค.")
parser.add_argument('--samples', type=str, default="samples/",
					help="์ƒ์„ฑ ๊ฒฐ๊ณผ๋ฅผ ์ €์žฅํ•  ๊ฒฝ๋กœ์ž…๋‹ˆ๋‹ค.")
parser.add_argument('--data_file_path', type=str, default='dataset/lyrics_dataset.txt',
					help="ํ•™์Šตํ•  ๋ฐ์ดํ„ฐ๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๋Š” ๊ฒฝ๋กœ์ž…๋‹ˆ๋‹ค.")
parser.add_argument('--batch_size', type=int, default=8,
					help="batch_size ๋ฅผ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค.")

Use Colab

Open In Colab

Colab์„ ์ด์šฉํ•ด์„œ Fine-tuning Code๋ฅผ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Runtime Disconnection Prevention

function ClickConnect() {
    // ๋ฐฑ์—”๋“œ๋ฅผ ํ• ๋‹นํ•˜์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค.
    // GPU์ด(๊ฐ€) ์žˆ๋Š” ๋ฐฑ์—”๋“œ๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ๊ฐ€์†๊ธฐ๊ฐ€ ์—†๋Š” ๋Ÿฐํƒ€์ž„์„ ์‚ฌ์šฉํ•˜์‹œ๊ฒ ์Šต๋‹ˆ๊นŒ?
    // ์ทจ์†Œ ๋ฒ„ํŠผ์„ ์ฐพ์•„์„œ ํด๋ฆญ
    var buttons = document.querySelectorAll("colab-dialog.yes-no-dialog paper-button#cancel"); 
    buttons.forEach(function(btn) {
		btn.click();
    });
    console.log("1๋ถ„ ๋งˆ๋‹ค ๋‹ค์‹œ ์—ฐ๊ฒฐ");
    document.querySelector("#top-toolbar > colab-connect-button").click();
}
setInterval(ClickConnect,1000*60);

Clear the screen every 10 minutes

function CleanCurrentOutput(){ 
	var btn = document.querySelector(".output-icon.clear_outputs_enabled.output-icon-selected[title$='ํ˜„์žฌ ์‹คํ–‰ ์ค‘...'] iron-icon[command=clear-focused-or-selected-outputs]");
	if(btn) {
		console.log("10๋ถ„ ๋งˆ๋‹ค ์ถœ๋ ฅ ์ง€์šฐ๊ธฐ");
		btn.click();
	}
} 
setInterval(CleanCurrentOutput,1000*60*10);

GPU Memory Check

nvidia-smi.exe

generator

python generator.py --temperature=1.0 --text_size=1000 --tmp_sent=""

ํ‘œ์ ˆ ์—†์Œ

python generator.py --temperature=5.0 --text_size=500 --tmp_sent=""

parser

parser.add_argument('--temperature', type=float, default=0.7,
					help="temperature ๋ฅผ ํ†ตํ•ด์„œ ๊ธ€์˜ ์ฐฝ์˜์„ฑ์„ ์กฐ์ ˆํ•ฉ๋‹ˆ๋‹ค.")
parser.add_argument('--top_p', type=float, default=0.9,
					help="top_p ๋ฅผ ํ†ตํ•ด์„œ ๊ธ€์˜ ํ‘œํ˜„ ๋ฒ”์œ„๋ฅผ ์กฐ์ ˆํ•ฉ๋‹ˆ๋‹ค.")
parser.add_argument('--top_k', type=int, default=40,
					help="top_k ๋ฅผ ํ†ตํ•ด์„œ ๊ธ€์˜ ํ‘œํ˜„ ๋ฒ”์œ„๋ฅผ ์กฐ์ ˆํ•ฉ๋‹ˆ๋‹ค.")
parser.add_argument('--text_size', type=int, default=250,
					help="๊ฒฐ๊ณผ๋ฌผ์˜ ๊ธธ์ด๋ฅผ ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค.")
parser.add_argument('--loops', type=int, default=-1,
					help="๊ธ€์„ ๋ช‡ ๋ฒˆ ๋ฐ˜๋ณตํ• ์ง€ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค. -1์€ ๋ฌดํ•œ๋ฐ˜๋ณต์ž…๋‹ˆ๋‹ค.")
parser.add_argument('--tmp_sent', type=str, default="์‚ฌ๋ž‘",
					help="๊ธ€์˜ ์‹œ์ž‘ ๋ฌธ์žฅ์ž…๋‹ˆ๋‹ค.")
parser.add_argument('--load_path', type=str, default="./checkpoint/Alls/KoGPT2_checkpoint_296000.tar",
					help="ํ•™์Šต๋œ ๊ฒฐ๊ณผ๋ฌผ์„ ์ €์žฅํ•˜๋Š” ๊ฒฝ๋กœ์ž…๋‹ˆ๋‹ค.")

Use Colab

Open In Colab

Colab์„ ์ด์šฉํ•ด์„œ generator๋ฅผ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

tensorboard

ํ•™์Šต์— ๋”ฐ๋ฅธ ๋ณ€ํ™”๋ฅผ ํ™•์ธํ•˜๊ธฐ ์œ„ํ•ด์„œ, tensorboard๋กœ ์ ‘๊ทผํ•˜์—ฌ loss์™€ text๋ฅผ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค.

tensorboard --logdir=runs

loss

text

Citation

@misc{KoGPT2-FineTuning,
  author = {gyung},
  title = {KoGPT2-FineTuning},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/gyunggyung/KoGPT2-FineTuning}},
}

Output

์ž์„ธํ•œ ๊ฒฐ๊ณผ๋ฌผ์€ samples์—์„œ ํ™•์ธ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ•™์Šต์— ๋Œ€ํ•ด์„œ๋Š” ๊ด€๋ จ ํฌ์ŠคํŒ…์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Reference

https://github.com/openai/gpt-2
https://github.com/nshepperd/gpt-2
https://github.com/SKT-AI/KoGPT2
https://github.com/asyml/texar-pytorch/tree/master/examples/gpt-2
https://github.com/graykode/gpt-2-Pytorch
https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
https://github.com/shbictai/narrativeKoGPT2
https://github.com/ssut/py-hanspell
https://github.com/likejazz/korean-sentence-splitter