- ๐ Naver Boost camp AI tech 2nd , Team CLUE
- ๐ Wrap-up report , ๐๏ธ presentation slide
โ KLUE MRC(Machine Reading Comprehension) Dataset์ผ๋ก ์ฃผ์ด์ง ์ง๋ฌธ์ ๋ํ ๋ฌธ์ ๊ฒ์ ํ ๋ต๋ณ ์ถ์ถํ๋ Task.
โ Retriver ๋ฅผ ํตํด wikipedia์์ Top-k ๋ฌธ์๋ฅผ ๋ถ๋ฌ์ค๊ณ , Reader๋ฅผ ํตํด ๋ฌธ์ ๋ด ๋ต๋ณ์ ์ถ์ถํ๋ ๋ชจ๋ธ์ ๊ตฌ์ถ, ์คํ ํ์ฌ ์ฃผ์ด์ง ์ง๋ฌธ์ ์ ํํ ๋ต๋ณ์ ์ฐพ์๋ด๋ ๋ชจ๋ธ์ ๋ง๋๋ ๊ฒ.
โ 1์ผ ํ ์ ์ถํ์๋ 10ํ๋ก ์ ํ๋์์ต๋๋ค.
๐ dataset ๋ค์ด๋ก๋
# data (51.2 MB)
tar -xzf data.tar.gz
๐ ํด๋น ๋ ํฌ ๋ค์ด๋ก๋
git clone https://github.com/boostcampaitech2/mrc-level2-nlp-13.git
๐ Poetry๋ฅผ ํตํ ํจํค์ง ๋ฒ์ ๊ด๋ฆฌ
# curl ์ค์น
apt-get install curl #7.58.0
# poetry ์ค์น
curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py | python
# poetry ํญ์์ฑ ํ์ฑํ
~/.bashrc๋ฅผ ์์ ํ์ฌ poetry๋ฅผ shell์์ ์ฌ์ฉ ํ ์ ์๋๋ก ๊ฐ์ํ๊ฒฝ์ ์ถ๊ฐ
poetry use [์ฌ์ฉํ๋ ๊ฐ์ํ๊ฒฝ์ `python path` | ๊ฐ์ํ๊ฒฝ์ด ์คํ์ค์ด๋ผ๋ฉด `python`]
# repo download ํ ๋ฒ์ ์ ์ฉ (poetry.toml์ ๋ฐ๋ผ ์ ์ฉ)
poetry install
mrc-level2-nlp-13
โโโ configs
โ โโโ example.json
โโโ model
โ โโโ Reader
โ โ โโโ RobertaCnn.py
โ โ โโโ trainer_qa.py
โ โโโ Retrieval
โ โโโ retrieval.py
โโโ inference.py
โโโ notebook
โ โโโ post_preprocessing.ipynb
โโโ ensemble
โ โโโ hard_vote.ipynb
โโโ augmentation
โ โโโ quesiton_generate.py
โโโ images
โ โโโ dataset.png
โโโ poetry.lock
โโโ pyproject.toml
โโโ readme.md
โโโ License.md
โโโ dense_retrieval_train.py
โโโ train_reader.py
โโโ utils
โโโ arguments.py
โโโ dense_utils
โ โโโ retrieval_dataset.py
โ โโโ utils.py
โโโ logger.py
โโโ utils_qa.py
์๋๋ ์ ๊ณตํ๋ ๋ฐ์ดํฐ์ ์ ๋ถํฌ๋ฅผ ๋ณด์ฌ์ค๋๋ค.
๋ฐ์ดํฐ์ ์ ํธ์์ฑ์ ์ํด Huggingface ์์ ์ ๊ณตํ๋ datasets๋ฅผ ์ด์ฉํ์ฌ pyarrow ํ์์ ๋ฐ์ดํฐ๋ก ์ ์ฅ๋์ด์์ต๋๋ค. ๋ค์์ ๋ฐ์ดํฐ์ ์ ๊ตฌ์ฑ์ ๋๋ค.
./data/ # ์ ์ฒด ๋ฐ์ดํฐ
./train_dataset/ # ํ์ต์ ์ฌ์ฉํ ๋ฐ์ดํฐ์
. train ๊ณผ validation ์ผ๋ก ๊ตฌ์ฑ
./test_dataset/ # ์ ์ถ์ ์ฌ์ฉ๋ ๋ฐ์ดํฐ์
. validation ์ผ๋ก ๊ตฌ์ฑ
./wikipedia_documents.json # ์ํคํผ๋์ ๋ฌธ์ ์งํฉ. retrieval์ ์ํด ์ฐ์ด๋ corpus.
๋ง์ฝ ๋ฐ์ดํฐ ์ฆ๊ฐ์ ํตํ dataset์ ์ฌ์ฉํ์ ๋ค๋ฉด, ์ด ๋๋ ํ ๋ฆฌ์ ์ถ๊ฐํด์ฃผ์๊ณ config ๋ด "data_args" ๋ฅผ ๋ณ๊ฒฝํด์ฃผ์๋ฉด ๋ฉ๋๋ค.
roberta ๋ชจ๋ธ์ ์ฌ์ฉํ ๊ฒฝ์ฐ, token type ids๋ฅผ ์ฌ์ฉ์ํ๋ฏ๋ก tokenizer ์ฌ์ฉ์ ์๋ ํจ์์ ์ต์ ์ ์์ ํด์ผํฉ๋๋ค. ๋ฒ ์ด์ค๋ผ์ธ์ klue/bert-base๋ก ์งํ๋๋ ์ด ๋ถ๋ถ์ ์ฃผ์์ ํด์ ํ์ฌ ์ฌ์ฉํด์ฃผ์ธ์ ! tokenizer๋ train, validation (train.py), test(inference.py) ์ ์ฒ๋ฆฌ๋ฅผ ์ํด ํธ์ถ๋์ด ์ฌ์ฉ๋ฉ๋๋ค. (tokenizer์ return_token_type_ids=False๋ก ์ค์ ํด์ฃผ์ด์ผ ํจ)
- ํ์ต์ ํ์ํ ํ๋ผ๋ฏธํฐ๋ฅผ configs directory ๋ฐ์ .json ํ์ผ๋ก ์์ฑํ์ฌ ์คํ์ ์งํํฉ๋๋ค.
- ํ์ต๋ ๋ชจ๋ธ์ tuned_models/"model_name" directory์ bin file์ ํํ๋ก ์ ์ฅ๋ฉ๋๋ค.
# train_reader.py
def prepare_train_features(examples):
# truncation๊ณผ padding(length๊ฐ ์งง์๋๋ง)์ ํตํด toknization์ ์งํํ๋ฉฐ, stride๋ฅผ ์ด์ฉํ์ฌ overflow๋ฅผ ์ ์งํฉ๋๋ค.
# ๊ฐ example๋ค์ ์ด์ ์ context์ ์กฐ๊ธ์ฉ ๊ฒน์น๊ฒ๋ฉ๋๋ค.
tokenized_examples = tokenizer(
... ...
#return_token_type_ids=False, # roberta๋ชจ๋ธ์ ์ฌ์ฉํ ๊ฒฝ์ฐ False, bert๋ฅผ ์ฌ์ฉํ ๊ฒฝ์ฐ True๋ก ํ๊ธฐํด์ผํฉ๋๋ค.
padding="max_length" if data_args.pad_to_max_length else False,
)
# train_reader argparser
-c, --config_file_path : train config ์ ๋ณด๊ฐ ๋ค์ด์๋ json file์ ์ด๋ฆ
-l ,--log_file_path : train logging์ ํ ํ์ผ ์ด๋ฆ
-n ,--model_name : ๋ชจ๋ธ์ด ์ ์ฅ๋ ๋๋ ํ ๋ฆฌ ์ด๋ฆ
--do_train : Reader๋ชจ๋ธ train flag
--do_eval : Reader๋ชจ๋ธ validation flag
- reader ํ์ต ์์
python train_reader.py -c ./configs/exp1.json -l exp1.log -n experiments1 --do_train
- dense retriver ํ์ต ์์
python train_reader.py -c ./configs/dense_exp1.json -l dense_exp1.log -n dense_experiment1 --do_train
MRC ๋ชจ๋ธ์ ์ฑ๋ฅ ํ๊ฐ(๊ฒ์ฆ)๋ (--do_eval
) ํ๋ ๊ทธ๋ฅผ ๋ฐ๋ก ์ค์ ํด์ผ ํฉ๋๋ค. ์ ํ์ต ์์์ ๋จ์ํ --do_eval
์ ์ถ๊ฐ๋ก ์
๋ ฅํด์ ํ๋ จ ๋ฐ ํ๊ฐ๋ฅผ ๋์์ ์งํํ ์๋ ์์ต๋๋ค.
# mrc ๋ชจ๋ธ ํ๊ฐ (train/validation ์ฌ์ฉ)
python train_reader.py -c ./configs/exp1.json -l exp1.log -n experiments1 --do_train --do_eval
retrieval ๊ณผ mrc ๋ชจ๋ธ์ ํ์ต์ด ์๋ฃ๋๋ฉด inference.py
๋ฅผ ์ด์ฉํด odqa ๋ฅผ ์งํํ ์ ์์ต๋๋ค.
-
ํ์ตํ ๋ชจ๋ธ์ test_dataset์ ๋ํ ๊ฒฐ๊ณผ๋ฅผ ์ ์ถํ๊ธฐ ์ํด์ ์ถ๋ก (
--do_predict
)๋ง ์งํํ๋ฉด ๋ฉ๋๋ค. -
ํ์ตํ ๋ชจ๋ธ์ด train_dataset ๋ํด์ ODQA ์ฑ๋ฅ์ด ์ด๋ป๊ฒ ๋์ค๋์ง ์๊ณ ์ถ๋ค๋ฉด ํ๊ฐ(--do_eval)๋ฅผ ์งํํ๋ฉด ๋ฉ๋๋ค.
# ODQA ์คํ (test_dataset ์ฌ์ฉ)
# wandb ๊ฐ ๋ก๊ทธ์ธ ๋์ด์๋ค๋ฉด ์๋์ผ๋ก ๊ฒฐ๊ณผ๊ฐ wandb ์ ์ ์ฅ๋ฉ๋๋ค. ์๋๋ฉด ๋จ์ํ ์ถ๋ ฅ๋ฉ๋๋ค
# inference argparser
-c, --config_file_path : inference config ์ ๋ณด๊ฐ ๋ค์ด์๋ json file์ ์ด๋ฆ
-l ,--log_file_path : inference logging์ ํ ํ์ผ ์ด๋ฆ
-n ,--inference_name : inference ๊ฒฐ๊ณผ๊ฐ ์ ์ฅ๋ ๋๋ ํ ๋ฆฌ ์ด๋ฆ
-m , --model_name_or_path : inference์ ์ฌ์ฉํ ๋ชจ๋ธ ๋๋ ํ ๋ฆฌ์ ์ด๋ฆ
python inference.py -c infer1.json -l infer1.log --n infer1_result -m ./tuned_models/train_dataset/ --do_predict
inference.py
ํ์ผ์ ์ ์์์ฒ๋ผ --do_predict
์ผ๋ก ์คํํ๋ฉด --inference_name
์์น์ predictions.json
์ด๋ผ๋ ํ์ผ์ด ์์ฑ๋ฉ๋๋ค. ํด๋น ํ์ผ์ ์ ์ถํด์ฃผ์๋ฉด ๋ฉ๋๋ค.
๋ค์์ MRC ๋ชจ๋ธ์ public & private datset์ ๋ํ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฌ์ค๋๋ค.
-
inference.py
์์ TF-IDF score์ ๊ฒฝ์ฐ sparse embedding ์ ํ๋ จํ๊ณ ์ ์ฅํ๋ ๊ณผ์ ์ ์๊ฐ์ด ์ค๋ ๊ฑธ๋ฆฌ์ง ์์ ๋ฐ๋ก argument ์ default ๊ฐ True๋ก ์ค์ ๋์ด ์์ต๋๋ค. ์คํ ํ sparse_embedding.bin ๊ณผ tfidfv.bin ์ด ์ ์ฅ์ด ๋ฉ๋๋ค. ๋ง์ฝ sparse retrieval ๊ด๋ จ ์ฝ๋๋ฅผ ์์ ํ๋ค๋ฉด, ๊ผญ ๋ ํ์ผ์ ์ง์ฐ๊ณ ๋ค์ ์คํํด์ฃผ์ธ์! ์๊ทธ๋ฌ๋ฉด ์กด์ฌํ๋ ํ์ผ์ด load ๋ฉ๋๋ค. -
๋ชจ๋ธ์ ๊ฒฝ์ฐ
--overwrite_cache
๋ฅผ ์ถ๊ฐํ์ง ์์ผ๋ฉด ๊ฐ์ ํด๋์ ์ ์ฅ๋์ง ์์ต๋๋ค. -
./predictions/ ํด๋ ๋ํ
--overwrite_output_dir
์ ์ถ๊ฐํ์ง ์์ผ๋ฉด ๊ฐ์ ํด๋์ ์ ์ฅ๋์ง ์์ต๋๋ค.
This work is licensed under a Creative Commons Attribution-ShareAlike 4.0 International License.