/level2-dkt-level2-recsys-14

level2-dkt-level2-recsys-14 created by GitHub Classroom

Primary LanguageJupyter Notebook

[Recsys-14] Deep Knowledge Tracing


๐Ÿ“š ํ”„๋กœ์ ํŠธ ๊ฐœ์š”

๐Ÿ“‹ ํ”„๋กœ์ ํŠธ ์ฃผ์ œ

โ€˜์ง€์‹ ์ƒํƒœโ€™๋ฅผ ์ถ”์ ํ•˜๋Š” ๋”ฅ๋Ÿฌ๋‹ ๋ฐฉ๋ฒ•๋ก ์ธ DKT(Deep Knowledge Tracing)๋ฅผ ํ†ตํ•ด ์‚ฌ์šฉ์ž๊ฐ€ ํ’€์–ด์™”๋˜ ๊ณผ๋ชฉ์— ๋Œ€ํ•ด ์–ผ๋งŒํผ ์ดํ•ดํ•˜๊ณ  ์žˆ๋Š”์ง€๋ฅผ ์ธก์ •ํ•˜์—ฌ ์•„์ง ํ’€์ง€ ์•Š์€ ๋ฏธ๋ž˜์˜ ๋ฌธ์ œ์— ๋Œ€ํ•ด ๋งž์„์ง€ ํ‹€๋ฆด์ง€ ์˜ˆ์ธกํ•˜๋Š” ๋ชจ๋ธ ๊ฐœ๋ฐœํ•ฉ๋‹ˆ๋‹ค.



๐Ÿ“ˆ ํ”„๋กœ์ ํŠธ ๋ชฉํ‘œ

์ฃผ์–ด์ง„ ๋ฌธ์ œ๋Š” ํ•™์ƒ ๊ฐœ๊ฐœ์ธ์˜ ์ดํ•ด๋„์ธ ์ง€์‹ ์ƒํƒœ๋ฅผ ์˜ˆ์ธก๋ฟ๋งŒ ์•„๋‹ˆ๋ผ ๋ฌธ์ œ๋ฅผ ๋งž์ถœ ์—ฌ๋ถ€์— ๋Œ€ํ•œ ํ™•๋ฅ ๋กœ๋„ ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋•Œ๋ฌธ์— ์ฃผ์–ด์ง„ ์—ฌ๋Ÿฌ ์ถ”์ฒœ ์‹œ์Šคํ…œ ๋ชจ๋ธ ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ ML ๊ณ„์—ด์˜ ๋ชจ๋ธ๊ณผ๋„ ๋น„๊ตํ•ด๋ณด๊ณ , ์ด๋ฅผ ํ†ตํ•ด Sequential data์™€ static data ๊ฐ๊ฐ์— ๊ด€ํ•œ ์ถ”์ฒœ ์‹œ์Šคํ…œ์˜ ์ฐจ์ด๋ฅผ ์ดํ•ดํ•ฉ๋‹ˆ๋‹ค. ๋˜ํ•œ ๋Œ€ํšŒ์—์„œ ์ฃผ์–ด์ง€๋Š” ๋ฐ์ดํ„ฐ ๋ถ„์„์„ ํ†ตํ•ด ์ตœ์ ์˜ ๋ชจ๋ธ๊ณผ ์ถ”์ฒœ ๋ฐฉ์‹์„ ์ฐพ๊ณ , ์ด๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ๊ณผ auroc, acc๋ฅผ ๋†’์ด๋Š” ๊ฒƒ์„ ๋ชฉํ‘œ๋กœ ํ•ฉ๋‹ˆ๋‹ค.



๐Ÿ—ž ์‚ฌ์šฉ ๋ฐ์ดํ„ฐ

train_data.csv

USER_ID assessmentItemID testID answerCode Timestamp KnowledgeTag
์‚ฌ์šฉ์ž์˜ ๊ณ ์œ  ๋ฒˆํ˜ธ ๋ฌธํ•ญ์˜ ๊ณ ์œ ๋ฒˆํ˜ธ ์‚ฌํ—˜์ง€์˜ ๊ณ ์œ ๋ฒˆํ˜ธ ์‚ฌ์šฉ์ž๊ฐ€ ํ•ด๋‹น ๋ฌธํ•ญ์„ ๋งž์ท„๋Š”์ง€ ์—ฌ๋ถ€(์ •๋‹ต 1, ์˜ค๋‹ต 0 ์‚ฌ์šฉ์ž๊ฐ€ ํ•ด๋‹น ๋ฌธํ•ญ์„ ํ’€๊ธฐ ์‹œ์ž‘ํ•œ ์‹œ์  ๋ฌธํ•ญ ๋‹น ํ•˜๋‚˜์”ฉ ๋ฐฐ์ •๋˜๋Š” ํƒœ๊ทธ(์ค‘๋ถ„๋ฅ˜)
  • 2,266,588๊ฐœ์˜ ํ–‰์œผ๋กœ ๊ตฌ์„ฑ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.
  • ์ด 9,454๊ฐœ์˜ ๊ณ ์œ  ๋ฌธํ•ญ, 1,537๊ฐœ์˜ ์‹œํ—˜์ง€, 912๊ฐœ์˜ ํƒœ๊ทธ(์ค‘๋ถ„๋ฅ˜)๊ฐ€ ์กด์žฌํ•ฉ๋‹ˆ๋‹ค.


๐Ÿ”‘ ํ”„๋กœ์ ํŠธ ํŠน์ง•

  • DKT๋ฅผ ํ™œ์šฉํ•˜๋ฉด ํ•™์ƒ ๊ฐœ๊ฐœ์ธ์—๊ฒŒ ๋ฌธ์ œ์— ๋Œ€ํ•œ ์ดํ•ด๋„์™€ ์ทจ์•ฝํ•œ ๋ถ€๋ถ„์„ ๊ทน๋ณตํ•˜๊ธฐ ์œ„ํ•ด ์–ด๋–ค ๋ฌธ์ œ๋“ค์„ ํ’€๋ฉด ์ข‹์„์ง€ ์ถ”์ฒœ์ด ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.
  • ํ•™์ƒ ๊ฐœ๊ฐœ์ธ์˜ ์ดํ•ด๋„๋ฅผ ๊ฐ€๋ฆฌํ‚ค๋Š” ์ง€์‹ ์ƒํƒœ๋ฅผ ์˜ˆ์ธกํ•˜๋Š” ์ผ๋ณด๋‹ค๋Š”, ์ฃผ์–ด์ง„ ๋ฌธ์ œ๋ฅผ ๋งž์ถœ์ง€ ํ‹€๋ฆด์ง€ ์˜ˆ์ธกํ•˜๋Š” ๊ฒƒ์— ์ง‘์ค‘ํ•˜์—ฌ ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.


๐Ÿ›Ž ์š”๊ตฌ์‚ฌํ•ญ

์‚ฌ์šฉ์ž์˜ ๋ฌธํ•ญ์˜ ๊ณ ์œ ๋ฒˆํ˜ธ, ์‹œํ—˜์ง€์˜ ๊ณ ์œ ๋ฒˆํ˜ธ, ๋ฌธ์ œ๋ฅผ ํ’€๊ธฐ ์‹œ์ž‘ํ•œ ์‹œ์ ๊ณผ ํƒœ๊ทธ ๋“ฑ์˜ side-information์„ ํ™œ์šฉํ•˜์—ฌ ๊ฐ ํ•™์ƒ์ด ํ‘ผ ๋ฌธ์ œ ๋ฆฌ์ŠคํŠธ์™€ ์ •๋‹ต ์—ฌ๋ถ€๊ฐ€ ๋‹ด๊ธด ๋ฐ์ดํ„ฐ๋ฅผ ๋ฐ›์•„ ์ตœ์ข… ๋ฌธ์ œ๋ฅผ ๋งž์ถœ์ง€ ํ‹€๋ฆด์ง€ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค.


๐Ÿ“Œ ํ‰๊ฐ€ ๋ฐฉ๋ฒ•

DKT๋Š” ์ฃผ์–ด์ง„ ๋งˆ์ง€๋ง‰ ๋ฌธ์ œ๋ฅผ ๋งž์•˜๋Š”์ง€ ํ‹€๋ ธ๋Š”์ง€๋กœ ๋ถ„๋ฅ˜ํ•˜๋Š” ์ด์ง„ ๋ถ„๋ฅ˜ ๋ฌธ์ œ์ด๊ธฐ์—, AUROC(Area Under the ROC curve)์™€ Accuracy๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค



ํ”„๋กœ์ ํŠธ ๊ตฌ์กฐ

Flow Chart

์Šคํฌ๋ฆฐ์ƒท 2022-05-21 ์˜คํ›„ 2 30 08



๋””๋ ‰ํ† ๋ฆฌ ๊ตฌ์กฐ

//input/data
โ”œโ”€ sample_submission.csv          # ๐Ÿ“ƒ output sample 
โ”œโ”€ train_data.csv                 # ๐Ÿ“ train data set
โ””โ”€ test_data.csv                  # ๐Ÿ“ test data set

/code
โ”œโ”€โ”€ README.md
โ”œโ”€โ”€ CatBoost
โ”‚   โ”œโ”€โ”€ config.py
โ”‚   โ”œโ”€โ”€ dataset.py
โ”‚   โ””โ”€โ”€ train_inference.py
โ”œโ”€โ”€ LGBM
โ”‚   โ”œโ”€โ”€ args.py
โ”‚		โ”œโ”€โ”€ asset
โ”‚		โ”œโ”€โ”€ inference.py
โ”‚		โ”œโ”€โ”€ lgbm
โ”‚		โ”‚   โ”œโ”€โ”€ model.py
โ”‚		โ”‚   โ”œโ”€โ”€ preprocess.py
โ”‚		โ”‚   โ””โ”€โ”€ trainer.py
โ”‚		โ”œโ”€โ”€ model
โ”‚		โ”œโ”€โ”€ output
โ”‚		โ”œโ”€โ”€ save_pic
โ”‚		โ”‚   โ””โ”€โ”€ lgbm_feature_importance.png
โ”‚		โ””โ”€โ”€ train.py
โ”œโ”€โ”€ LSTMAttn_with_LGCN
โ”‚   โ”œโ”€โ”€ criterion.py
โ”‚   โ”œโ”€โ”€ dataloader.py
โ”‚   โ”œโ”€โ”€ lightgcn
โ”‚   โ”‚   โ”œโ”€โ”€ config.py
โ”‚   โ”‚   โ”œโ”€โ”€ id2index.pickle
โ”‚   โ”‚   โ”œโ”€โ”€ inference.py
โ”‚   โ”‚   โ”œโ”€โ”€ install.sh
โ”‚   โ”‚   โ”œโ”€โ”€ lightgcn
โ”‚   โ”‚   โ”‚   โ”œโ”€โ”€ datasets.py
โ”‚   โ”‚   โ”‚   โ”œโ”€โ”€ models.py
โ”‚   โ”‚   โ”‚   โ”œโ”€โ”€ optimizer.py
โ”‚   โ”‚   โ”‚   โ”œโ”€โ”€ scheduler.py
โ”‚   โ”‚   โ”‚   โ””โ”€โ”€ utils.py
โ”‚   โ”‚   โ”œโ”€โ”€ train.py
โ”‚   โ”‚   โ””โ”€โ”€ weight
โ”‚   โ”‚       โ””โ”€โ”€ best_auc_model2.pt
โ”‚   โ”œโ”€โ”€ lightgcn_for_tag
โ”‚   โ”‚   โ”œโ”€โ”€ config.py
โ”‚   โ”‚   โ”œโ”€โ”€ inference.py
โ”‚   โ”‚   โ”œโ”€โ”€ install.sh
โ”‚   โ”‚   โ”œโ”€โ”€ lightgcn_for_tag
โ”‚   โ”‚   โ”‚   โ”œโ”€โ”€ datasets.py
โ”‚   โ”‚   โ”‚   โ”œโ”€โ”€ models.py
โ”‚   โ”‚   โ”‚   โ”œโ”€โ”€ optimizer.py
โ”‚   โ”‚   โ”‚   โ”œโ”€โ”€ scheduler.py
โ”‚   โ”‚   โ”‚   โ””โ”€โ”€ utils.py
โ”‚   โ”‚   โ”œโ”€โ”€ output
โ”‚   โ”‚   โ”‚   โ””โ”€โ”€ best_auc_submission.csv
โ”‚   โ”‚   โ”œโ”€โ”€ run.log
โ”‚   โ”‚   โ”œโ”€โ”€ tag2index.pickle
โ”‚   โ”‚   โ”œโ”€โ”€ temp.ipynb
โ”‚   โ”‚   โ”œโ”€โ”€ train.py
โ”‚   โ”‚   โ””โ”€โ”€ weight
โ”‚   โ”‚       โ””โ”€โ”€ best_auc_model_tag.pt
โ”‚   โ”œโ”€โ”€ metric.py
โ”‚   โ”œโ”€โ”€ model.py
โ”‚   โ”œโ”€โ”€ optimizer.py
โ”‚   โ”œโ”€โ”€ scheduler.py
โ”‚   โ”œโ”€โ”€ trainer.py
โ”‚   โ””โ”€โ”€ utils.py
โ”œโ”€โ”€ README.md
โ”œโ”€โ”€ args.py
โ”œโ”€โ”€ hyper_run.py
โ”œโ”€โ”€ id2index.pickle
โ”œโ”€โ”€ inference.py
โ”œโ”€โ”€ requirements.txt
โ”œโ”€โ”€ tag2index.pickle
โ””โ”€โ”€ train.py


๋ชจ๋ธ๊ณผ ์‹คํ—˜ ๋‚ด์šฉ

  • ์ด๋ฒˆ ํ”„๋กœ์ ํŠธ์—์„œ ์‹คํ—˜ํ•œ ๋‹จ์ผ ๋ชจ๋ธ ์ข…๋ฅ˜๋Š” Boosting ๊ธฐ๋ฒ•์„ ์‚ฌ์šฉํ•œ CatBoost, LightGBM, XGBoost ๋ชจ๋ธ, RNN ๊ณ„์—ด์˜ Attention Mechanism์„ ์‚ฌ์šฉํ•œ LSTM ๋ชจ๋ธ, GNN ๊ณ„์—ด์˜ LightGCN ๋ชจ๋ธ์ด ์žˆ์œผ๋ฉฐ, ๋‹จ์ผ ๋ชจ๋ธ ์„ฑ๋Šฅ์„ ๋น„๊ตํ•˜์—ฌ ์ตœ์ข… ๊ฒฐ๊ณผ์— ์‚ฌ์šฉํ•œ ๋ชจ๋ธ์€ CatBoost, LightGBM, LSTM Attention(LSTM with attention mechanism)์ž…๋‹ˆ๋‹ค.
    • ๋‹จ์ผ ๋ชจ๋ธ ๊ฐ„ ์„ฑ๋Šฅ ์‹คํ—˜ ์‹œ CatBoost, LightGBM, LSTM with attention mechanism ์ˆœ์œผ๋กœ ๋†’์•˜์Šต๋‹ˆ๋‹ค.
      • Public AUC ๊ธฐ์ค€ CatBoost 0.8139, LightGBM 0.7723, LSTM Attention 0.751
    • ๊ธฐ์กด XGBoost(AUC ๊ธฐ์ค€ 0.6171)๋ณด๋‹ค LightGBM์˜ ์„ฑ๋Šฅ์ด ๋” ๋†’์•˜๋˜ ๊ฒƒ์€ LightGCN์—์„œ ์‚ฌ์šฉํ•˜๋Š” Gradient-based One Side Sampling ๊ธฐ๋ฒ•์œผ๋กœ ๊ฒฐ๊ณผ์ ์œผ๋กœ ์ƒ๋Œ€์ ์œผ๋กœ ์ž‘์€ Gradient๋ณด๋‹ค ํฐ Gradient๋ฅผ ์ง€๋‹ˆ๋Š” Instance์— ์ดˆ์ ์„ ๋‘์–ด Underfitting์„ ๋ง‰๋Š” ํšจ๊ณผ์ธ ๊ฒƒ์œผ๋กœ ๋ณด์ž…๋‹ˆ๋‹ค.
    • LightGCN์˜ ๊ทธ๋ž˜ํ”„์—์„œ ํ•™์Šตํ•œ ์œ ์ €์™€ ํƒœ๊ทธ์— ๋Œ€ํ•œ ์ž„๋ฒ ๋”ฉ์„ LSTM Attention์˜ Feature๋กœ ์ถ”๊ฐ€ํ•˜์—ฌ ์ตœ์ข…์ ์œผ๋กœ CatBoost์™€ LightGBM์„ ์•™์ƒ๋ธ” ํ•  ๋•Œ ์ ์ˆ˜๋ฅผ ํ–ฅ์ƒ์‹œํ‚ฌ ์ˆ˜ ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค.
      • Public AUC ๊ธฐ์ค€ 0.8168 โ†’ 0.8173
  • CatBoost, LightGBM, LSTM Attention ์„ธ ๋ชจ๋ธ์„ ์•™์ƒ๋ธ”ํ•˜๊ธฐ ์œ„ํ•ด Hard Voting์„ ์ง„ํ–‰ํ–ˆ์Šต๋‹ˆ๋‹ค.
    • ์„ฑ๋Šฅ์ด ๊ฐ€์žฅ ์ข‹์•˜๋˜ ์•™์ƒ๋ธ” ๋ชจ๋ธ ๊ฒฐ๊ณผ์ธ CatBoost์— ๊ฐ€์žฅ ๋งŽ์€ ๊ฐ€์ค‘์น˜๋ฅผ ๋ถ€์—ฌํ•œ ๊ฒƒ์ด Public AUC ๊ธฐ์ค€์œผ๋กœ ๊ฐ€์žฅ ์ข‹์€ ์„ฑ๋Šฅ์„ ๋ณด์˜€์Šต๋‹ˆ๋‹ค.
      • CatBoost : LightGBM : LSTM Attention = 3.5 : 1.5 : 1 (๊ฐ€์ค‘์น˜)
    • LightGBM์—์„œ ์˜ˆ์ธกํ•œ ๊ฒฐ๊ณผ๋ฅผ ๊ฐ€์ง€๊ณ  OOF(Out Of Fold) Stacking ๊ธฐ๋ฒ•์„ ์ ์šฉํ–ˆ์„ ๋•Œ Public AUC ๊ธฐ์ค€ 0.7881๋กœ ๋ฏธ์ ์šฉํ•œ ์‹คํ—˜๋ณด๋‹ค ์ƒ๋Œ€์ ์œผ๋กœ ๋‚ฎ๊ฒŒ ๋‚˜์™”์ง€๋งŒ, Private์—์„œ๋Š” 0.8369๋กœ ์„ฑ๋Šฅ์ด ์˜ฌ๋ž์Šต๋‹ˆ๋‹ค.
      • OOF Stacking ์ ์šฉ ์‹œ ๋ฉ”ํƒ€ ๋ชจ๋ธ๋กœ๋Š” XGBoost์™€ LightGBM์„ ์‚ฌ์šฉํ–ˆ์Šต๋‹ˆ๋‹ค.
      • OOF Stacking ๊ธฐ๋ฒ•์—์„œ inferenceํ•œ ํ™•๋ฅ ์ด 0 ๋˜๋Š” 1์— ๊ฐ€๊น๊ฒŒ ๊ทน๋‹จ์ ์œผ๋กœ ๋‚˜์™€์„œ Overfitting์ด ๋˜์—ˆ๋‹ค๊ณ  ์˜ˆ์ƒํ•ด์„œ ์ตœ์ข… ์ œ์ถœ ํŒŒ์ผ์— ํฌํ•จ์‹œํ‚ค์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.


๐Ÿ’ป ํ™œ์šฉ ๋„๊ตฌ ๋ฐ ํ™˜๊ฒฝ

  • ์ฝ”๋“œ ๊ณต์œ 
    • GitHub
  • ๊ฐœ๋ฐœ ํ™˜๊ฒฝ
    • JupyterLab, VS Code, PyCharm
  • ๋ชจ๋ธ ์„ฑ๋Šฅ ๋ถ„์„
    • Wandb


๐Ÿ‘ฉ๐Ÿปโ€๐Ÿ’ป๐Ÿ‘จ๐Ÿปโ€๐Ÿ’ป ํŒ€์› ์†Œ๊ฐœ

์ง„์™„ํ˜ ๋ฐ•์ •๊ทœ ๊น€์€์„  ์ด์„ ํ˜ธ ์ด์„œํฌ
์Šคํฌ๋ฆฐ์ƒท 2022-04-19 ์˜คํ›„ 5 44 23 ์Šคํฌ๋ฆฐ์ƒท 2022-04-19 ์˜คํ›„ 5 47 38 ์Šคํฌ๋ฆฐ์ƒท 2022-04-19 ์˜คํ›„ 5 48 35 ์Šคํฌ๋ฆฐ์ƒท 2022-04-19 ์˜คํ›„ 5 49 21 ์Šคํฌ๋ฆฐ์ƒท 2022-04-19 ์˜คํ›„ 5 49 30
Truth Juke Sunny Glaneyes Brill

ํŒ€์› ์—ญํ• 
๊น€์€์„  ML๊ณ„์—ด ๊ธฐ๋ฐ˜ ๋ชจ๋ธ ์‹คํ—˜๊ณผ feature engineering์„ ํ†ตํ•œ LGBM ์„ฑ๋Šฅ ํ–ฅ์ƒ, LightGCN validation set ๋ณ€๊ฒฝ๊ณผ ํŠœ๋‹์„ ํ†ตํ•œ ์„ฑ๋Šฅ ํ–ฅ์ƒ
๋ฐ•์ •๊ทœ EDA๋ฅผ ํ†ตํ•ด ํ›„์— ๊ณ ๋ฏผํ•ด๋ณผ ๋งŒํ•œ feature ์ œ์‹œ, CatBoost๋ฅผ ํ™œ์šฉํ•ด feature๋ฅผ ํ†ตํ•œ Ensemble ์‹คํ—˜, FE ์‹คํ—˜
์ด์„œํฌ EDA์™€ ํ›„์‹œ๋ถ„์„์— ๊ธฐ๋ฐ˜ํ•œ feature engineering / LSTM attention๋“ฑ ์‹œํ€€์Šค ๊ธฐ๋ฐ˜ ๋ชจ๋ธ ์‹คํ—˜. LGCN์œผ๋กœ ์‚ฌ์ „ํ•™์Šตํ•œ ์ž„๋ฒ ๋”ฉ์„ lstm์— ๋ฐ˜์˜ํ•˜์—ฌ ์„ฑ๋Šฅ ๊ฐœ์„ 
์ด์„ ํ˜ธ ์ตœ๊ทผ ์‚ฌ์šฉ์ž๋ณ„ ๋ฌธ์ œ ํ’€์ด ์ด๋ ฅ, ํƒœ๊ทธ ์—ฐ์† ์ถœํ˜„ ํšŸ์ˆ˜ ๋“ฑ FE๋ฅผ ํ†ตํ•ด CatBoost์™€ LightGBM ์„ฑ๋Šฅ ํ–ฅ์ƒ. OOF Stacking ๊ธฐ๋ฒ• ์‹คํ—˜๊ณผ ์•™์ƒ๋ธ” ๊ตฌํ˜„
์ง„์™„ํ˜ ๋ฒ ์ด์Šค๋ผ์ธ ๋ชจ๋ธ ์„ฑ๋Šฅ ๊ฐœ์„ , ML๊ณ„์—ด ๋ชจ๋ธ ์‹คํ—˜์„ ํ†ตํ•ด CatBoost ๋ชจ๋ธ ์ ์šฉ. ํ•˜์ดํผ ํŒŒ๋ผ๋ฏธํ„ฐ ํŠœ๋‹๊ณผ FE๋ฅผ ํ†ตํ•ด ๋‹จ์ผ ๋ชจ๋ธ ์„ฑ๋Šฅ ๊ฐœ์„