โ์ง์ ์ํโ๋ฅผ ์ถ์ ํ๋ ๋ฅ๋ฌ๋ ๋ฐฉ๋ฒ๋ก ์ธ DKT(Deep Knowledge Tracing)๋ฅผ ํตํด ์ฌ์ฉ์๊ฐ ํ์ด์๋ ๊ณผ๋ชฉ์ ๋ํด ์ผ๋งํผ ์ดํดํ๊ณ ์๋์ง๋ฅผ ์ธก์ ํ์ฌ ์์ง ํ์ง ์์ ๋ฏธ๋์ ๋ฌธ์ ์ ๋ํด ๋ง์์ง ํ๋ฆด์ง ์์ธกํ๋ ๋ชจ๋ธ ๊ฐ๋ฐํฉ๋๋ค.
์ฃผ์ด์ง ๋ฌธ์ ๋ ํ์ ๊ฐ๊ฐ์ธ์ ์ดํด๋์ธ ์ง์ ์ํ๋ฅผ ์์ธก๋ฟ๋ง ์๋๋ผ ๋ฌธ์ ๋ฅผ ๋ง์ถ ์ฌ๋ถ์ ๋ํ ํ๋ฅ ๋ก๋ ์ ๊ทผํ ์ ์์ต๋๋ค. ๋๋ฌธ์ ์ฃผ์ด์ง ์ฌ๋ฌ ์ถ์ฒ ์์คํ ๋ชจ๋ธ ๋ฟ๋ง ์๋๋ผ ML ๊ณ์ด์ ๋ชจ๋ธ๊ณผ๋ ๋น๊ตํด๋ณด๊ณ , ์ด๋ฅผ ํตํด Sequential data์ static data ๊ฐ๊ฐ์ ๊ดํ ์ถ์ฒ ์์คํ ์ ์ฐจ์ด๋ฅผ ์ดํดํฉ๋๋ค. ๋ํ ๋ํ์์ ์ฃผ์ด์ง๋ ๋ฐ์ดํฐ ๋ถ์์ ํตํด ์ต์ ์ ๋ชจ๋ธ๊ณผ ์ถ์ฒ ๋ฐฉ์์ ์ฐพ๊ณ , ์ด๋ฅผ ๋ฐํ์ผ๋ก ๋ชจ๋ธ์ ์ฑ๋ฅ๊ณผ auroc, acc๋ฅผ ๋์ด๋ ๊ฒ์ ๋ชฉํ๋ก ํฉ๋๋ค.
train_data.csv
USER_ID | assessmentItemID | testID | answerCode | Timestamp | KnowledgeTag |
---|---|---|---|---|---|
์ฌ์ฉ์์ ๊ณ ์ ๋ฒํธ | ๋ฌธํญ์ ๊ณ ์ ๋ฒํธ | ์ฌํ์ง์ ๊ณ ์ ๋ฒํธ | ์ฌ์ฉ์๊ฐ ํด๋น ๋ฌธํญ์ ๋ง์ท๋์ง ์ฌ๋ถ(์ ๋ต 1, ์ค๋ต 0 | ์ฌ์ฉ์๊ฐ ํด๋น ๋ฌธํญ์ ํ๊ธฐ ์์ํ ์์ | ๋ฌธํญ ๋น ํ๋์ฉ ๋ฐฐ์ ๋๋ ํ๊ทธ(์ค๋ถ๋ฅ) |
- 2,266,588๊ฐ์ ํ์ผ๋ก ๊ตฌ์ฑ๋์์ต๋๋ค.
- ์ด 9,454๊ฐ์ ๊ณ ์ ๋ฌธํญ, 1,537๊ฐ์ ์ํ์ง, 912๊ฐ์ ํ๊ทธ(์ค๋ถ๋ฅ)๊ฐ ์กด์ฌํฉ๋๋ค.
- DKT๋ฅผ ํ์ฉํ๋ฉด ํ์ ๊ฐ๊ฐ์ธ์๊ฒ ๋ฌธ์ ์ ๋ํ ์ดํด๋์ ์ทจ์ฝํ ๋ถ๋ถ์ ๊ทน๋ณตํ๊ธฐ ์ํด ์ด๋ค ๋ฌธ์ ๋ค์ ํ๋ฉด ์ข์์ง ์ถ์ฒ์ด ๊ฐ๋ฅํฉ๋๋ค.
- ํ์ ๊ฐ๊ฐ์ธ์ ์ดํด๋๋ฅผ ๊ฐ๋ฆฌํค๋ ์ง์ ์ํ๋ฅผ ์์ธกํ๋ ์ผ๋ณด๋ค๋, ์ฃผ์ด์ง ๋ฌธ์ ๋ฅผ ๋ง์ถ์ง ํ๋ฆด์ง ์์ธกํ๋ ๊ฒ์ ์ง์คํ์ฌ ํด๊ฒฐํ ์ ์์ต๋๋ค.
์ฌ์ฉ์์ ๋ฌธํญ์ ๊ณ ์ ๋ฒํธ, ์ํ์ง์ ๊ณ ์ ๋ฒํธ, ๋ฌธ์ ๋ฅผ ํ๊ธฐ ์์ํ ์์ ๊ณผ ํ๊ทธ ๋ฑ์ side-information์ ํ์ฉํ์ฌ ๊ฐ ํ์์ด ํผ ๋ฌธ์ ๋ฆฌ์คํธ์ ์ ๋ต ์ฌ๋ถ๊ฐ ๋ด๊ธด ๋ฐ์ดํฐ๋ฅผ ๋ฐ์ ์ต์ข
๋ฌธ์ ๋ฅผ ๋ง์ถ์ง ํ๋ฆด์ง ์์ธกํฉ๋๋ค.
DKT๋ ์ฃผ์ด์ง ๋ง์ง๋ง ๋ฌธ์ ๋ฅผ ๋ง์๋์ง ํ๋ ธ๋์ง๋ก ๋ถ๋ฅํ๋ ์ด์ง ๋ถ๋ฅ ๋ฌธ์ ์ด๊ธฐ์, AUROC(Area Under the ROC curve)์ Accuracy๋ฅผ ์ฌ์ฉํฉ๋๋ค
//input/data
โโ sample_submission.csv # ๐ output sample
โโ train_data.csv # ๐ train data set
โโ test_data.csv # ๐ test data set
/code
โโโ README.md
โโโ CatBoost
โ โโโ config.py
โ โโโ dataset.py
โ โโโ train_inference.py
โโโ LGBM
โ โโโ args.py
โ โโโ asset
โ โโโ inference.py
โ โโโ lgbm
โ โ โโโ model.py
โ โ โโโ preprocess.py
โ โ โโโ trainer.py
โ โโโ model
โ โโโ output
โ โโโ save_pic
โ โ โโโ lgbm_feature_importance.png
โ โโโ train.py
โโโ LSTMAttn_with_LGCN
โ โโโ criterion.py
โ โโโ dataloader.py
โ โโโ lightgcn
โ โ โโโ config.py
โ โ โโโ id2index.pickle
โ โ โโโ inference.py
โ โ โโโ install.sh
โ โ โโโ lightgcn
โ โ โ โโโ datasets.py
โ โ โ โโโ models.py
โ โ โ โโโ optimizer.py
โ โ โ โโโ scheduler.py
โ โ โ โโโ utils.py
โ โ โโโ train.py
โ โ โโโ weight
โ โ โโโ best_auc_model2.pt
โ โโโ lightgcn_for_tag
โ โ โโโ config.py
โ โ โโโ inference.py
โ โ โโโ install.sh
โ โ โโโ lightgcn_for_tag
โ โ โ โโโ datasets.py
โ โ โ โโโ models.py
โ โ โ โโโ optimizer.py
โ โ โ โโโ scheduler.py
โ โ โ โโโ utils.py
โ โ โโโ output
โ โ โ โโโ best_auc_submission.csv
โ โ โโโ run.log
โ โ โโโ tag2index.pickle
โ โ โโโ temp.ipynb
โ โ โโโ train.py
โ โ โโโ weight
โ โ โโโ best_auc_model_tag.pt
โ โโโ metric.py
โ โโโ model.py
โ โโโ optimizer.py
โ โโโ scheduler.py
โ โโโ trainer.py
โ โโโ utils.py
โโโ README.md
โโโ args.py
โโโ hyper_run.py
โโโ id2index.pickle
โโโ inference.py
โโโ requirements.txt
โโโ tag2index.pickle
โโโ train.py
- ์ด๋ฒ ํ๋ก์ ํธ์์ ์คํํ ๋จ์ผ ๋ชจ๋ธ ์ข
๋ฅ๋ Boosting ๊ธฐ๋ฒ์ ์ฌ์ฉํ CatBoost, LightGBM, XGBoost ๋ชจ๋ธ, RNN ๊ณ์ด์ Attention Mechanism์ ์ฌ์ฉํ LSTM ๋ชจ๋ธ, GNN ๊ณ์ด์ LightGCN ๋ชจ๋ธ์ด ์์ผ๋ฉฐ, ๋จ์ผ ๋ชจ๋ธ ์ฑ๋ฅ์ ๋น๊ตํ์ฌ ์ต์ข
๊ฒฐ๊ณผ์ ์ฌ์ฉํ ๋ชจ๋ธ์ CatBoost, LightGBM, LSTM Attention(LSTM with attention mechanism)์
๋๋ค.
- ๋จ์ผ ๋ชจ๋ธ ๊ฐ ์ฑ๋ฅ ์คํ ์ CatBoost, LightGBM, LSTM with attention mechanism ์์ผ๋ก ๋์์ต๋๋ค.
- Public AUC ๊ธฐ์ค CatBoost 0.8139, LightGBM 0.7723, LSTM Attention 0.751
- ๊ธฐ์กด XGBoost(AUC ๊ธฐ์ค 0.6171)๋ณด๋ค LightGBM์ ์ฑ๋ฅ์ด ๋ ๋์๋ ๊ฒ์ LightGCN์์ ์ฌ์ฉํ๋ Gradient-based One Side Sampling ๊ธฐ๋ฒ์ผ๋ก ๊ฒฐ๊ณผ์ ์ผ๋ก ์๋์ ์ผ๋ก ์์ Gradient๋ณด๋ค ํฐ Gradient๋ฅผ ์ง๋๋ Instance์ ์ด์ ์ ๋์ด Underfitting์ ๋ง๋ ํจ๊ณผ์ธ ๊ฒ์ผ๋ก ๋ณด์ ๋๋ค.
- LightGCN์ ๊ทธ๋ํ์์ ํ์ตํ ์ ์ ์ ํ๊ทธ์ ๋ํ ์๋ฒ ๋ฉ์ LSTM Attention์ Feature๋ก ์ถ๊ฐํ์ฌ ์ต์ข
์ ์ผ๋ก CatBoost์ LightGBM์ ์์๋ธ ํ ๋ ์ ์๋ฅผ ํฅ์์ํฌ ์ ์์์ต๋๋ค.
- Public AUC ๊ธฐ์ค 0.8168 โ 0.8173
- ๋จ์ผ ๋ชจ๋ธ ๊ฐ ์ฑ๋ฅ ์คํ ์ CatBoost, LightGBM, LSTM with attention mechanism ์์ผ๋ก ๋์์ต๋๋ค.
- CatBoost, LightGBM, LSTM Attention ์ธ ๋ชจ๋ธ์ ์์๋ธํ๊ธฐ ์ํด Hard Voting์ ์งํํ์ต๋๋ค.
- ์ฑ๋ฅ์ด ๊ฐ์ฅ ์ข์๋ ์์๋ธ ๋ชจ๋ธ ๊ฒฐ๊ณผ์ธ CatBoost์ ๊ฐ์ฅ ๋ง์ ๊ฐ์ค์น๋ฅผ ๋ถ์ฌํ ๊ฒ์ด Public AUC ๊ธฐ์ค์ผ๋ก ๊ฐ์ฅ ์ข์ ์ฑ๋ฅ์ ๋ณด์์ต๋๋ค.
- CatBoost : LightGBM : LSTM Attention = 3.5 : 1.5 : 1 (๊ฐ์ค์น)
- LightGBM์์ ์์ธกํ ๊ฒฐ๊ณผ๋ฅผ ๊ฐ์ง๊ณ OOF(Out Of Fold) Stacking ๊ธฐ๋ฒ์ ์ ์ฉํ์ ๋ Public AUC ๊ธฐ์ค 0.7881๋ก ๋ฏธ์ ์ฉํ ์คํ๋ณด๋ค ์๋์ ์ผ๋ก ๋ฎ๊ฒ ๋์์ง๋ง, Private์์๋ 0.8369๋ก ์ฑ๋ฅ์ด ์ฌ๋์ต๋๋ค.
- OOF Stacking ์ ์ฉ ์ ๋ฉํ ๋ชจ๋ธ๋ก๋ XGBoost์ LightGBM์ ์ฌ์ฉํ์ต๋๋ค.
- OOF Stacking ๊ธฐ๋ฒ์์ inferenceํ ํ๋ฅ ์ด 0 ๋๋ 1์ ๊ฐ๊น๊ฒ ๊ทน๋จ์ ์ผ๋ก ๋์์ Overfitting์ด ๋์๋ค๊ณ ์์ํด์ ์ต์ข ์ ์ถ ํ์ผ์ ํฌํจ์ํค์ง ์์์ต๋๋ค.
- ์ฑ๋ฅ์ด ๊ฐ์ฅ ์ข์๋ ์์๋ธ ๋ชจ๋ธ ๊ฒฐ๊ณผ์ธ CatBoost์ ๊ฐ์ฅ ๋ง์ ๊ฐ์ค์น๋ฅผ ๋ถ์ฌํ ๊ฒ์ด Public AUC ๊ธฐ์ค์ผ๋ก ๊ฐ์ฅ ์ข์ ์ฑ๋ฅ์ ๋ณด์์ต๋๋ค.
- ์ฝ๋ ๊ณต์
- GitHub
- ๊ฐ๋ฐ ํ๊ฒฝ
- JupyterLab, VS Code, PyCharm
- ๋ชจ๋ธ ์ฑ๋ฅ ๋ถ์
- Wandb
์ง์ํ | ๋ฐ์ ๊ท | ๊น์์ | ์ด์ ํธ | ์ด์ํฌ |
Truth | Juke | Sunny | Glaneyes | Brill |
ํ์ | ์ญํ |
---|---|
๊น์์ | ML๊ณ์ด ๊ธฐ๋ฐ ๋ชจ๋ธ ์คํ๊ณผ feature engineering์ ํตํ LGBM ์ฑ๋ฅ ํฅ์, LightGCN validation set ๋ณ๊ฒฝ๊ณผ ํ๋์ ํตํ ์ฑ๋ฅ ํฅ์ |
๋ฐ์ ๊ท | EDA๋ฅผ ํตํด ํ์ ๊ณ ๋ฏผํด๋ณผ ๋งํ feature ์ ์, CatBoost๋ฅผ ํ์ฉํด feature๋ฅผ ํตํ Ensemble ์คํ, FE ์คํ |
์ด์ํฌ | EDA์ ํ์๋ถ์์ ๊ธฐ๋ฐํ feature engineering / LSTM attention๋ฑ ์ํ์ค ๊ธฐ๋ฐ ๋ชจ๋ธ ์คํ. LGCN์ผ๋ก ์ฌ์ ํ์ตํ ์๋ฒ ๋ฉ์ lstm์ ๋ฐ์ํ์ฌ ์ฑ๋ฅ ๊ฐ์ |
์ด์ ํธ | ์ต๊ทผ ์ฌ์ฉ์๋ณ ๋ฌธ์ ํ์ด ์ด๋ ฅ, ํ๊ทธ ์ฐ์ ์ถํ ํ์ ๋ฑ FE๋ฅผ ํตํด CatBoost์ LightGBM ์ฑ๋ฅ ํฅ์. OOF Stacking ๊ธฐ๋ฒ ์คํ๊ณผ ์์๋ธ ๊ตฌํ |
์ง์ํ | ๋ฒ ์ด์ค๋ผ์ธ ๋ชจ๋ธ ์ฑ๋ฅ ๊ฐ์ , ML๊ณ์ด ๋ชจ๋ธ ์คํ์ ํตํด CatBoost ๋ชจ๋ธ ์ ์ฉ. ํ์ดํผ ํ๋ผ๋ฏธํฐ ํ๋๊ณผ FE๋ฅผ ํตํด ๋จ์ผ ๋ชจ๋ธ ์ฑ๋ฅ ๊ฐ์ |