- ์ฑ๋ด ๋น๋๋ ์ฑ์ ์์ฐจ๊ณ , ์์ ๋ง์ ๋ฅ๋ฌ๋ ์ฑ๋ด ์ ํ๋ฆฌ์ผ์ด์ ์ ๋ง๋์๊ณ ์ถ์ผ์ ๊ฐ์?
- Kochat์ ์ด์ฉํ๋ฉด ์์ฝ๊ฒ ์์ ๋ง์ ๋ฅ๋ฌ๋ ์ฑ๋ด ์ ํ๋ฆฌ์ผ์ด์ ์ ๋น๋ํ ์ ์์ต๋๋ค.
# 1. ๋ฐ์ดํฐ์
๊ฐ์ฒด ์์ฑ
dataset = Dataset(ood=True)
# 2. ์๋ฒ ๋ฉ ํ๋ก์ธ์ ์์ฑ
emb = GensimEmbedder(model=embed.FastText())
# 3. ์๋(Intent) ๋ถ๋ฅ๊ธฐ ์์ฑ
clf = DistanceClassifier(
model=intent.CNN(dataset.intent_dict),
loss=CenterLoss(dataset.intent_dict)
)
# 4. ๊ฐ์ฒด๋ช
(Named Entity) ์ธ์๊ธฐ ์์ฑ
rcn = EntityRecognizer(
model=entity.LSTM(dataset.entity_dict),
loss=CRFLoss(dataset.entity_dict)
)
# 5. ๋ฅ๋ฌ๋ ์ฑ๋ด RESTful API ํ์ต & ๋น๋
kochat = KochatApi(
dataset=dataset,
embed_processor=(emb, True),
intent_classifier=(clf, True),
entity_recognizer=(rcn, True),
scenarios=[
weather, dust, travel, restaurant
]
)
# 6. View ์์คํ์ผ๊ณผ ์ฐ๊ฒฐ
@kochat.app.route('/')
def index():
return render_template("index.html")
# 7. ์ฑ๋ด ์ ํ๋ฆฌ์ผ์ด์
์๋ฒ ๊ฐ๋
if __name__ == '__main__':
kochat.app.template_folder = kochat.root_dir + 'templates'
kochat.app.static_folder = kochat.root_dir + 'static'
kochat.app.run(port=8080, host='0.0.0.0')
- ํ๊ตญ์ด๋ฅผ ์ง์ํ๋ ์ต์ด์ ์คํ์์ค ๋ฅ๋ฌ๋ ์ฑ๋ด ํ๋ ์์ํฌ์ ๋๋ค. (๋น๋์๋ ๋ค๋ฆ ๋๋ค.)
- ๋ค์ํ Pre built-in ๋ชจ๋ธ๊ณผ Lossํจ์๋ฅผ ์ง์ํฉ๋๋ค. NLP๋ฅผ ์ ๋ชฐ๋ผ๋ ์ฑ๋ด์ ๋ง๋ค ์ ์์ต๋๋ค.
- ์์ ๋ง์ ์ปค์คํ ๋ชจ๋ธ, Lossํจ์๋ฅผ ์ ์ฉํ ์ ์์ต๋๋ค. NLP ์ ๋ฌธ๊ฐ์๊ฒ ๋์ฑ ์ ์ฉํฉ๋๋ค.
- ์ฑ๋ด์ ํ์ํ ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ, ๋ชจ๋ธ, ํ์ต ํ์ดํ๋ผ์ธ, RESTful API๊น์ง ๋ชจ๋ ๋ถ๋ถ์ ์ ๊ณตํฉ๋๋ค.
- ๊ฐ๊ฒฉ ๋ฑ์ ์ ๊ฒฝ์ธ ํ์ ์์ผ๋ฉฐ, ์์ผ๋ก๋ ์ญ ์คํ์์ค ํ๋ก์ ํธ๋ก ์ ๊ณตํ ์์ ์ ๋๋ค.
- ์๋์ ๊ฐ์ ๋ค์ํ ์ฑ๋ฅ ํ๊ฐ ๋ฉํธ๋ฆญ๊ณผ ๊ฐ๋ ฅํ ์๊ฐํ ๊ธฐ๋ฅ์ ์ ๊ณตํฉ๋๋ค.
- 1. Kochat ์ด๋?
- 2. About Chatbot
- 3. Getting Started
- 4. Usage
- 5. Visualization Support
- 5.1. Train/Test Accuracy
- 5.2. Train/Test Recall (macro average)
- 5.3. Train/Test Precision (macro average)
- 5.4. Train/Test F1-Score (macro average)
- 5.5. Train/Test Confusion Matrix
- 5.6. Train/Test Classification Performance Report
- 5.7. Train/Test Fallback Detection Performance Report
- 5.8. Feature Space Visualization
- 6. Performance Issue
- 7. Demo Application
- 8. Contributor
- 9. TODO List
- 10. Reference
- 11. License
Kochat์ ํ๊ตญ์ด ์ ์ฉ ์ฑ๋ด ๊ฐ๋ฐ ํ๋ ์์ํฌ๋ก, ๋จธ์ ๋ฌ๋ ๊ฐ๋ฐ์๋ผ๋ฉด
๋๊ตฌ๋ ๋ฌด๋ฃ๋ก ์์ฝ๊ฒ ํ๊ตญ์ด ์ฑ๋ด์ ๊ฐ๋ฐ ํ ์ ์๋๋ก ๋๋ ์คํ์์ค ํ๋ ์์ํฌ์
๋๋ค.
๋จ์ Chit-chat์ด ์๋ ์ฌ์ฉ์์๊ฒ ์ฌ๋ฌ ๊ธฐ๋ฅ์ ์ ๊ณตํ๋ ์์ฉํ ๋ ๋ฒจ์ ์ฑ๋ด ๊ฐ๋ฐ์
๋จ์ผ ๋ชจ๋ธ๋ง์ผ๋ก ๊ฐ๋ฐ๋๋ ๊ฒฝ์ฐ๋ณด๋ค ๋ค์ํ ๋ฐ์ดํฐ, configuration, ML๋ชจ๋ธ,
Restful Api ๋ฐ ์ ํ๋ฆฌ์ผ์ด์
, ๋ ์ด๋ค์ ์ ๊ธฐ์ ์ผ๋ก ์ฐ๊ฒฐํ ํ์ดํ๋ผ์ธ์ ๊ฐ์ถ์ด์ผ ํ๋๋ฐ
์ด ๊ฒ์ ์ฒ์๋ถํฐ ๊ฐ๋ฐ์๊ฐ ์ค์ค๋ก ๊ตฌํํ๋ ๊ฒ์ ๊ต์ฅํ ๋ฒ๊ฑฐ๋กญ๊ณ ์์ด ๋ง์ด ๊ฐ๋ ์์
์
๋๋ค.
์ค์ ๋ก ์ฑ๋ด ์ ํ๋ฆฌ์ผ์ด์
์ ์ง์ ๊ตฌํํ๋ค๋ณด๋ฉด ์๋ ๊ทธ๋ฆผ์ฒ๋ผ ์ค์ง์ ์ผ๋ก ๋ชจ๋ธ ๊ฐ๋ฐ๋ณด๋ค๋
์ด๋ฐ ๋ถ๋ถ๋ค์ ํจ์ฌ ์๊ฐ๊ณผ ๋
ธ๋ ฅ์ด ๋ง์ด ํ์ํฉ๋๋ค.
Kochat์ ์ด๋ฌํ ๋ถ๋ถ์ ํด๊ฒฐํ๊ธฐ ์ํด ์ ์๋์์ต๋๋ค.
๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ, ์ํคํ
์ฒ, ๋ชจ๋ธ๊ณผ์ ํ์ดํ๋ผ์ธ, ์คํ ๊ฒฐ๊ณผ ์๊ฐํ,
์ฑ๋ฅํ๊ฐ ๋ฑ์ Kochat์ ๊ตฌ์ฑ์ ์ฌ์ฉํ๋ฉด์ ๊ฐ๋ฐ์๊ฐ ์ํ๋ ๋ชจ๋ธ์ด๋ Lossํจ์,
๋ฐ์ดํฐ ์
๋ฑ๋ง ๊ฐ๋จํ๊ฒ ์์ฑํ์ฌ ๋ด๊ฐ ์ํ๋ ๋ชจ๋ธ์ ์ฑ๋ฅ์ ๋น ๋ฅด๊ฒ ์คํํ ์ ์๊ฒ ๋์์ค๋๋ค.
๋ํ ํ๋ฆฌ ๋นํธ์ธ ๋ชจ๋ธ๋ค๊ณผ Loss ํจ์๋ฑ์ ์ง์ํ์ฌ ๋ฅ๋ฌ๋์ด๋ ์์ฐ์ด์ฒ๋ฆฌ์ ๋ํด ์ ๋ชจ๋ฅด๋๋ผ๋
ํ๋ก์ ํธ์ ๋ฐ์ดํฐ๋ง ์ถ๊ฐํ๋ฉด ์์ฝ๊ฒ ์๋นํ ๋์ ์ฑ๋ฅ์ ์ฑ๋ด์ ๊ฐ๋ฐํ ์ ์๊ฒ ๋์์ค๋๋ค.
์์ง์ ์ด๊ธฐ๋ ๋ฒจ์ด๊ธฐ ๋๋ฌธ์ ๋ง์ ๋ชจ๋ธ๊ณผ ๊ธฐ๋ฅ์ ์ง์ํ์ง๋ ์์ง๋ง ์ ์ฐจ ๋ชจ๋ธ๊ณผ
๊ธฐ๋ฅ์ ๋๋ ค๋๊ฐ ๊ณํ์
๋๋ค.
-
๊ธฐ์กด์ ์์ฉํ๋ ๋ง์ ์ฑ๋ด ๋น๋์ Kochat์ ํ๊น์ผ๋ก ํ๋ ์ฌ์ฉ์๊ฐ ๋ค๋ฆ ๋๋ค. ์์ฉํ๋ ์ฑ๋ด ๋น๋๋ค์ ๋งค์ฐ ๊ฐํธํ ์น ๊ธฐ๋ฐ์ UX/UI๋ฅผ ์ ๊ณตํ๋ฉฐ ์ผ๋ฐ์ธ์ ํ๊น์ผ๋ก ํฉ๋๋ค. ๊ทธ์ ๋ฐํด Kochat์ ์ฑ๋ด๋น๋ ๋ณด๋ค๋ ๊ฐ๋ฐ์๋ฅผ ํ๊น์ผ๋กํ๋ ํ๋ ์์ํฌ์ ๊ฐ๊น์ต๋๋ค. ๊ฐ๋ฐ์๋ ์์ค์ฝ๋๋ฅผ ์์ฑํจ์ ๋ฐ๋ผ์ ํ๋ ์์ํฌ์ ๋ณธ์ธ๋ง์ ๋ชจ๋ธ์ ์ถ๊ฐํ ์ ์๊ณ , Loss ํจ์๋ฅผ ๋ฐ๊พธ๊ฑฐ๋ ๋ณธ์ธ์ด ์ํ๋ฉด ์์ ์๋ก์ด ๊ธฐ๋ฅ์ ์ฒจ๊ฐํ ์๋ ์์ต๋๋ค.
-
Kochat์ ์คํ์์ค ํ๋ก์ ํธ์ ๋๋ค. ๋ฐ๋ผ์ ๋ง์ ์ฌ๋์ด ์ฐธ์ฌํด์ ํจ๊ป ๊ฐ๋ฐํ ์ ์๊ณ ๋ง์ฝ ์๋ก์ด ๋ชจ๋ธ์ ๊ฐ๋ฐํ๊ฑฐ๋ ์๋ก์ด ๊ธฐ๋ฅ์ ์ถ๊ฐํ๊ณ ์ถ๋ค๋ฉด ์ผ๋ง๋ ์ง ๋ ํฌ์งํ ๋ฆฌ์ ์ปจํธ๋ฆฌ๋ทฐ์ ํ ์ ์์ต๋๋ค.
-
Kochat์ ๋ฌด๋ฃ์ ๋๋ค. ๋งค๋ฌ ์ฌ์ฉ๋ฃ๋ฅผ ๋ด์ผํ๋ ์ฑ๋ด ๋น๋๋ค์ ๋นํด ์์ฒด์ ์ธ ์๋ฒ๋ง ๊ฐ์ง๊ณ ์๋ค๋ฉด ๋น์ฉ์ ์ฝ ์์ด ์ผ๋ง๋ ์ง ์ฑ๋ด์ ๊ฐ๋ฐํ๊ณ ์๋น์ค ํ ์ ์์ต๋๋ค. ์์ง์ ๊ธฐ๋ฅ์ด ๋ฏธ์ฝํ์ง๋ง ์ถํ์๋ ์ ๋ง ์ฌ๋งํ ์ฑ๋ด ๋น๋๋ค ๋ณด๋ค ๋ ๋ค์ํ ๊ธฐ๋ฅ์ ๋ฌด๋ฃ๋ก ์ ๊ณตํ ์์ ์ ๋๋ค.
์ด์ ์ ์ฌ๊ธฐ์ ๊ธฐ์ ์ฝ๋๋ฅผ ๊ธ์ด๋ชจ์์ ๋ง๋ , ์์ค ๋ฎ์ ์ ๋ฅ๋ฌ๋ chatbot ๋ ํฌ์งํ ๋ฆฌ๊ฐ
์๊ฐ๋ณด๋ค ํฐ ๊ด์ฌ์ ๋ฐ์ผ๋ฉด์, ํ๊ตญ์ด๋ก ๋ ๋ฅ๋ฌ๋ ์ฑ๋ด ๊ตฌํ์ฒด๊ฐ ์ ๋ง ๋ง์ด ์๋ค๋ ๊ฒ์ ๋๊ผ์ต๋๋ค.
ํ์ฌ ๋๋ถ๋ถ์ ์ฑ๋ด ๋น๋๋ค์ ๋๋ถ๋ถ ์ผ๋ฐ์ธ์ ๊ฒจ๋ฅํ๊ธฐ ๋๋ฌธ์ ์น์์์ ์์ฌ์ด UX/UI
๊ธฐ๋ฐ์ผ๋ก ์๋น์ค ์ค์
๋๋ค. ์ผ๋ฐ์ธ ์ฌ์ฉ์๋ ์ฌ์ฉํ๊ธฐ ํธ๋ฆฌํ๊ฒ ์ง๋ง, ์ ์ ๊ฐ์ ๊ฐ๋ฐ์๋ค์
๋ชจ๋ธ๋ ์ปค์คํฐ๋ง์ด์ง ํ๊ณ ์ถ๊ณ , ๋ก์คํจ์๋ ๋ฐ๊ฟ๋ณด๊ณ ์ถ๊ณ , ์๊ฐํ๋ ํ๋ฉด์ ๋์ฑ ๋์ ์ฑ๋ฅ์
์ถ๊ตฌํ๊ณ ์ถ์ง๋ง ์์ฝ๊ฒ๋ ํ๊ตญ์ด ์ฑ๋ด ๋น๋ ์ค์์ ์ด๋ฌํ ๋ฐฉ์์ผ๋ก ์ ์๋ ค์ง ๊ฒ์ ์์ต๋๋ค.
๊ทธ๋ฌ๋ ์ค, ๋ฏธ๊ตญ์ RASA๋ผ๋ ์ฑ๋ด ํ๋ ์์ํฌ๋ฅผ ๋ณด๊ฒ ๋์์ต๋๋ค.
RASA๋ ๊ฐ๋ฐ์๊ฐ ์ง์ ์์ค์ฝ๋๋ฅผ ์์ ํ ์ ์๊ธฐ ๋๋ฌธ์ ๋ค์ํ ๋ถ๋ถ์ ์ปค์คํฐ๋ง์ด์ง ํ ์ ์์ต๋๋ค.
๊ทธ๋ฌ๋ ํ๊ตญ์ด๋ฅผ ์ ๋๋ก ์ง์ํ์ง ์์์, ์ ์ฉ ํ ํฌ๋์ด์ ๋ฅผ ์ถ๊ฐํ๋ ๋ฑ ๋งค์ฐ ๋ฒ๊ฑฐ๋ก์ด ์์
์ด
ํ์ํ๊ณ ์ค์ ๋ก ๋๋ฌด ๋ค์ํ ์ปดํฌ๋ํธ๊ฐ ์กด์ฌํ์ฌ ์ต์ํด์ง๋๋ฐ ์กฐ๊ธ ์ด๋ ค์ด ํธ์
๋๋ค.
๋๋ฌธ์ ๋๊ตฐ๊ฐ ํ๊ตญ์ด ๊ธฐ๋ฐ์ด๋ฉด์ ์กฐ๊ธ ๋ ์ปดํฉํธํ ์ฑ๋ด ํ๋ ์์ํฌ๋ฅผ ์ ์ํ๋ค๋ฉด
์ฑ๋ด์ ๊ฐ๋ฐํด์ผํ๋ ๊ฐ๋ฐ์๋ค์๊ฒ ์ ๋ง ์ ์ฉํ ๊ฒ์ด๋ผ๊ณ ์๊ฐ๋์๊ณ ์ง์ ์ด๋ฌํ ํ๋ ์์ํฌ๋ฅผ
๋ง๋ค์ด๋ณด์๋ ์๊ฐ์ Kochat์ ์ ์ํ๊ฒ ๋์์ต๋๋ค.
Kochat์ ํ๊ตญ์ด(Korean)์ ์๊ธ์์ธ Ko์ ์ ์ด๋ฆ ์ ๊ธ์์ธ Ko๋ฅผ ๋ฐ์์ ์ง์์ต๋๋ค. Kochat์ ์์ผ๋ก๋ ๊ณ์ ์คํ์์ค ํ๋ก์ ํธ๋ก ์ ์ง๋ ๊ฒ์ด๋ฉฐ, ์ ์ด๋ 1~2๋ฌ์ 1๋ฒ ์ด์์ ์๋ก์ด ๋ชจ๋ธ์ ์ถ๊ฐํ๊ณ , ๊ธฐ์กด ์์ค์ฝ๋์ ๋ฒ๊ทธ๋ฅผ ์์ ํ๋ ๋ฑ ์ ์ง๋ณด์ ์์ ์ ์ด์ด๊ฐ ๊ฒ์ด๋ฉฐ ์ฒ์์๋ ๋ฏธ์ฒํ ์ค๋ ฅ์ธ ์ ๊ฐ ์์ํ์ง๋ง, ๊ทธ ๋์ RASA์ฒ๋ผ ์ ๋ง ์ ์ฉํ๊ณ ๋์ ์ฑ๋ฅ์ ๋ณด์ฌ์ฃผ๋ ์์ค๋์ ์คํ์์ค ํ๋ ์์ํฌ๊ฐ ๋์์ผ๋ฉด ์ข๊ฒ ์ต๋๋ค. :)
์ด ์ฑํฐ์์๋ ์ฑ๋ด์ ๋ถ๋ฅ์ ๊ตฌํ๋ฐฉ๋ฒ, Kochat์ ์ด๋ป๊ฒ ์ฑ๋ด์ ๊ตฌํํ๊ณ ์๋์ง์ ๋ํด
๊ฐ๋จํ๊ฒ ์๊ฐํฉ๋๋ค.
์ฑ๋ด์ ํฌ๊ฒ ๋น๋ชฉ์ ๋ํ๋ฅผ ์ํ Open domain ์ฑ๋ด๊ณผ ๋ชฉ์ ๋ํ๋ฅผ ์ํ Close domain ์ฑ๋ด์ผ๋ก ๋๋ฉ๋๋ค.
Open domain ์ฑ๋ด์ ์ฃผ๋ก ์ก๋ด ๋ฑ์ ์ํํ๋ ์ฑ๋ด์ ์๋ฏธํ๋๋ฐ,
์ฌ๋ฌ๋ถ์ด ์ ์๊ณ ์๋ ์ฌ์ฌ์ด ๋ฑ์ด ์ฑ๋ด์ด ๋ํ์ ์ธ Open domain ์ฑ๋ด์ด๋ฉฐ Chit-chat์ด๋ผ๊ณ ๋ ๋ถ๋ฆฝ๋๋ค.
Close domain ์ฑ๋ด์ด๋ ํ์ ๋ ๋ํ ๋ฒ์ ์์์ ์ฌ์ฉ์๊ฐ ์ํ๋ ๋ชฉ์ ์ ๋ฌ์ฑํ๊ธฐ ์ํ ์ฑ๋ด์ผ๋ก
์ฃผ๋ก ๊ธ์ต์๋ด๋ด, ์๋น์์ฝ๋ด ๋ฑ์ด ์ด์ ํด๋นํ๋ฉฐ Goal oriented ์ฑ๋ด์ด๋ผ๊ณ ๋ ๋ถ๋ฆฝ๋๋ค.
์์ฆ ์ถ์๋๋ ์๋ฆฌ๋ ๋น
์ค๋น ๊ฐ์ ์ธ๊ณต์ง๋ฅ ๋น์, ์ธ๊ณต์ง๋ฅ ์คํผ์ปค๋ค์ ํน์ ๊ธฐ๋ฅ๋ ์ํํด์ผํ๊ณ
์ฌ์ฉ์์ ์ก๋ด๋ ์ ํด์ผํ๋ฏ๋ก Open domain ์ฑ๋ด๊ณผ Close domain ์ฑ๋ด์ด ๋ชจ๋ ํฌํจ๋์ด ์๋ ๊ฒฝ์ฐ๊ฐ ๋ง์ต๋๋ค.
์ฑ๋ด์ ๊ตฌํํ๋ ๋ฐฉ๋ฒ์ ํฌ๊ฒ ํต๊ณ๊ธฐ๋ฐ์ ์ฑ๋ด๊ณผ ๋ฅ๋ฌ๋ ๊ธฐ๋ฐ์ ์ฑ๋ด์ผ๋ก ๋๋ฉ๋๋ค.
์ฌ๊ธฐ์์๋ ๋ฅ๋ฌ๋ ๊ธฐ๋ฐ์ ์ฑ๋ด๋ง ์๊ฐํ๋๋ก ํ๊ฒ ์ต๋๋ค.
๋จผ์ Open domain ์ฑ๋ด์ ๊ฒฝ์ฐ๋ ๋ฅ๋ฌ๋ ๋ถ์ผ์์๋ ๋๋ถ๋ถ, End to End
์ ๊ฒฝ๋ง ๊ธฐ๊ณ๋ฒ์ญ ๋ฐฉ์(Seq2Seq)์ผ๋ก ๊ตฌํ๋์ด์์ต๋๋ค. Seq2Seq์ ํ ๋ฌธ์ฅ์ ๋ค๋ฅธ ๋ฌธ์ฅ์ผ๋ก
๋ณํ/๋ฒ์ญํ๋ ๋ฐฉ์์
๋๋ค. ๋ฒ์ญ๊ธฐ์๊ฒ "๋๋ ๋ฐฐ๊ณ ํ๋ค"๋ผ๋ ์
๋ ฅ์ด ์ฃผ์ด์ง๋ฉด "I'm Hungry"๋ผ๊ณ
๋ฒ์ญํด๋ด๋ฏ์ด, ์ฑ๋ด Seq2Seq๋ "๋๋ ๋ฐฐ๊ณ ํ๋ค"๋ผ๋ ์
๋ ฅ์ด ์ฃผ์ด์ง ๋, "๋ง์ด ๋ฐฐ๊ณ ํ์ ๊ฐ์?" ๋ฑ์ ๋๋ต์ผ๋ก ๋ฒ์ญํฉ๋๋ค.
์ต๊ทผ์ ๋ฐํ๋ Google์ Meena
๊ฐ์ ๋ชจ๋ธ์ ๋ณด๋ฉด, ๋ณต์กํ ๋ชจ๋ธ ์ํคํ
์ฒ๋ ํ์ต ํ๋ ์์ํฌ ์์ด End to End (Seq2Seq) ๋ชจ๋ธ๋ง์ผ๋ก๋
๋งค์ฐ ๋ฐฉ๋ํ ๋ฐ์ดํฐ์
๊ณผ ๋์ ์ฑ๋ฅ์ ์ปดํจํ
๋ฆฌ์์ค๋ฅผ ํ์ฉํ๋ฉด ์ ๋ง ์ฌ๋๊ณผ ๊ทผ์ ํ ์์ค์ผ๋ก ๋ํํ ์ ์๋ค๋ ๊ฒ์ผ๋ก ์๋ ค์ ธ์์ต๋๋ค.
(๊ทธ๋ฌ๋ ํ์ฌ๋ฒ์ ํ๋ ์์ํฌ์์๋ Close domain ๋ง ์ง์ํฉ๋๋ค. ์ฐจํ ๋ฒ์ ์์ ๋ค์ํ Seq2Seq ๋ชจ๋ธ๋ ์ถ๊ฐํ ์์ ์
๋๋ค.)
Close domain ์ฑ๋ด์ ๋๋ถ๋ถ Slot Filling ๋ฐฉ์์ผ๋ก ๊ตฌํ๋์ด ์์ต๋๋ค. ๋ฌผ๋ก Close domain ์ฑ๋ด๋
Open domain์ฒ๋ผ End to end๋ก ๊ตฌํํ๋ ค๋ ๋ค์ํ
์๋ ๋ค๋
์กด์ฌ ํ์์ผ๋, ๋
ผ๋ฌธ์์ ์ ์ํ๋
๋ฐ์ดํฐ์
์์๋ง ์ ์๋ํ๊ณ , ์ค์ ๋ค๋ฅธ ๋ฐ์ดํฐ ์
(Task6์ DSTC dataset)์ ์ ์ฉํ๋ฉด ๊ทธ ์ ๋์
์ฑ๋ฅ์ด ๋์ค์ง ์์๊ธฐ ๋๋ฌธ์ ํ์
์ ์ ์ฉ๋๊ธฐ๋ ์ด๋ ค์์ด ์์ต๋๋ค. ๋๋ฌธ์ ํ์ฌ๋ ๋๋ถ๋ถ์ ๋ชฉ์ ์งํฅ
์ฑ๋ด ์ ํ๋ฆฌ์ผ์ด์
์ด ๊ธฐ์กด ๋ฐฉ์์ธ Slot Filling ๋ฐฉ์์ผ๋ก ๊ตฌํ๋๊ณ ์์ต๋๋ค.
Slot Filling ๋ฐฉ์์ ๋ฏธ๋ฆฌ ๊ธฐ๋ฅ์ ์ํํ ์ ๋ณด๋ฅผ ๋ด๋ '์ฌ๋กฏ'์ ๋จผ์ ์ ์ํ ๋ค์,
์ฌ์ฉ์์ ๋ง์ ๋ฃ๊ณ ์ด๋ค ์ฌ๋กฏ์ ์ ํํ ์ง ์ ํ๊ณ , ํด๋น ์ฌ๋กฏ์ ์ฑ์๋๊ฐ๋ ๋ฐฉ์์
๋๋ค.
๊ทธ๋ฆฌ๊ณ ์ด๋ฌํ Slot Filling ๋ฐฉ์ ์ฑ๋ด์ ๊ตฌํ์ ์ํด '์ธํ
ํธ'์ '์ํฐํฐ'๋ผ๋ ๊ฐ๋
์ด ๋ฑ์ฅํฉ๋๋ค.
๋ง๋ก๋ง ์ค๋ช
ํ๋ฉด ์ด๋ ค์ฐ๋ ์์๋ฅผ ๋ด
์๋ค. ๊ฐ์ฅ ๋จผ์ ์ฐ๋ฆฌ๊ฐ ์ฌํ ์ ๋ณด ์๋ฆผ ์ฑ๋ด์ ๋ง๋ ๋ค๊ณ ๊ฐ์ ํ๊ณ ,
์ฌํ์ ๋ณด ์ ๊ณต์ ์ํด "๋ ์จ ์ ๋ณด์ ๊ณต", "๋ฏธ์ธ๋จผ์ง ์ ๋ณด์ ๊ณต", "๋ง์ง ์ ๋ณด์ ๊ณต", "์ฌํ์ง ์ ๋ณด์ ๊ณต"์ด๋ผ๋ 4๊ฐ์ง
ํต์ฌ ๊ธฐ๋ฅ์ ๊ตฌํํด์ผํ๋ค๊ณ ํฉ์๋ค.
๊ฐ์ฅ ๋จผ์ ์ฌ์ฉ์์๊ฒ ๋ฌธ์ฅ์ ์
๋ ฅ๋ฐ์์ ๋, ์ฐ๋ฆฌ๋ ์ 4๊ฐ์ง ์ ๋ณด์ ๊ณต ๊ธฐ๋ฅ ์ค
์ด๋ค ๊ธฐ๋ฅ์ ์คํํด์ผํ๋์ง ์์์ฑ์ผํฉ๋๋ค. ์ด ๊ฒ์ ์ธํ
ํธ(Intent)๋ถ๋ฅ. ์ฆ, ์๋ ๋ถ๋ฅ๋ผ๊ณ ํฉ๋๋ค.
์ฌ์ฉ์๋ก๋ถํฐ "์์์ผ ๋ถ์ฐ ๋ ์จ ์ด๋ ๋?"๋ผ๋ ๋ฌธ์ฅ์ด ์
๋ ฅ๋๋ฉด 4๊ฐ์ง ๊ธฐ๋ฅ ์ค ๋ ์จ ์ ๋ณด์ ๊ณต ๊ธฐ๋ฅ์
์ํํด์ผ ํ๋ค๋ ๊ฒ์ ์์๋ด์ผํฉ๋๋ค. ๋๋ฌธ์ ๋ฌธ์ฅ ๋ฒกํฐ๊ฐ ์
๋ ฅ๋๋ฉด, Text Classification์ ์ํํ์ฌ
์ด๋ค API๋ฅผ ์ฌ์ฉํด์ผํ ์ง ์์๋
๋๋ค.
๊ทธ๋ฌ๋ ์ฌ๊ธฐ์ ์ ๊ฒฝ์จ์ผํ ๋ถ๋ถ์ด ํ ๋ถ๋ถ ์กด์ฌํฉ๋๋ค. ์ผ๋ฐ์ ์ธ ๋ฅ๋ฌ๋ ๋ถ๋ฅ๋ชจ๋ธ์ ๋ชจ๋ธ์ด ํ์ตํ ํด๋์ค ๋ด์์๋ง ๋ถ๋ฅ๊ฐ ๊ฐ๋ฅํฉ๋๋ค. ๊ทธ๋ฌ๋ ์ฌ์ฉ์๊ฐ 4๊ฐ์ง์ ๋ฐํ์๋ ์์์๋ง ๋งํ ๊ฒ์ด๋ผ๋ ๋ณด์ฅ์ ์์ต๋๋ค. ๋ง์ฝ ์์ฒ๋ผ "๋ ์จ ์ ๋ณด์ ๊ณต", "๋ฏธ์ธ๋จผ์ง ์ ๋ณด์ ๊ณต", "๋ง์ง ์ ๋ณด์ ๊ณต", "์ฌํ์ง ์ ๋ณด์ ๊ณต"์ ๋ฐ์ดํฐ๋ง ํ์ตํ ์ธํ ํธ ๋ถ๋ฅ๋ชจ๋ธ์ "์๋ ๋ฐ๊ฐ๋ค."๋ผ๋ ๋ง์ ํ๊ฒ ๋๋ฉด ์ด๋ป๊ฒ ๋ ๊น์? ์ 4๊ฐ์ง์ ์ํ์ง ์์ ๋ฐํ ์๋์ธ "์ธ์ฌ"์ ํด๋นํ์ง๋ง ๋ชจ๋ธ์ ์ธ์ฟ๋ง์ ํ๋ฒ๋ ๋ณธ์ ์ด ์๊ธฐ ๋๋ฌธ์ ์ด๊ฒ๋ ์ญ์ 4๊ฐ์ง ์ค ํ๋๋ก ๋ถ๋ฅํ๊ฒ ๋ฉ๋๋ค. ์ด๋ฌํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ์๋ ๋ถ๋ฅ๋ชจ๋ธ์๋ ๋ฐ๋์ ํด๋ฐฑ (Fallback) ๊ฒ์ถ ์ ๋ต์ด ํฌํจ๋์ด์ผํฉ๋๋ค.
๋ณดํต์ ์ฑ๋ด๋น๋๋ค์ ์ ๋ ฅ ๋จ์ด๋ค์ ์๋ฒ ๋ฉ์ธ ๋ฌธ์ฅ ๋ฒกํฐ์ ๊ธฐ์กด ๋ฐ์ดํฐ์ ์ ์๋ ๋ฌธ์ฅ ๋ฒกํฐ๋ค์ Cosine ์ ์ฌ๋๋ฅผ ๋น๊ตํฉ๋๋ค. ์ด ๋ ๊ฐ์ฅ ๊ฐ๊น์ด ์ธ์ ํด๋์ค์์ ๊ฐ๋๊ฐ ์๊ณ์น ์ด์์ด๋ฉด Fallback์ด๊ณ , ๊ทธ๋ ์ง ์์ผ๋ฉด ๊ฐ์ฅ ๊ฐ๊น์ด ์ธ์ ํด๋์ค๋ก ๋ฐ์ดํฐ ์ํ์ ๋ถ๋ฅํ๊ฒ ๋ฉ๋๋ค. ์๋ ๊ทธ๋ฆผ์ ๋ณด๋ฉด ์ผ๋ฐ์ ์ธ ์ฑ๋ด ๋น๋๋ค์ด ์ด๋ค์์ผ๋ก Fallback์ ๊ฒ์ถํ๋์ง ์ ์ ์์ต๋๋ค.
Kochat์ ์ด๋ ๊ฒ ๋จ์ํ ๋ฌธ์ฅ๋ค์ ๋ฒกํฐ Cosine ์ ์ฌ๋๋ฅผ ๋น๊ตํ์ง ์๊ณ
๋์ฑ ๊ณ ์ฐจ์์ ์ธ ๋ฐฉ๋ฒ์ ์ฌ์ฉํ์ฌ Fallback ๋ํ
์
์ ๋ณด๋ค ๋ ์ ์ํํ๋๋ก
์ค๊ณํ์๋๋ฐ ์ด์ ๋ํ ์์ธํ ๋ด์ฉ์ ์๋์ Usage์์ ์์ธํ ์ธ๊ธํ๋๋ก ํ๊ฒ ์ต๋๋ค.
๊ทธ ๋ค์ ํด์ผํ ์ผ์ ๋ฐ๋ก ๊ฐ์ฒด๋ช
์ธ์ (Named Entity Recognition)์
๋๋ค.
์ด๋ค API๋ฅผ ํธ์ถํ ์ง ์์๋๋ค๋ฉด, ์ด์ ๊ทธ API๋ฅผ ํธ์ถํ๊ธฐ ์ํ ํ๋ผ๋ฏธํฐ๋ฅผ ์ฐพ์์ผํฉ๋๋ค.
๋ง์ฝ ๋ ์จ API์ ์คํ์ ์ํ ํ๋ผ๋ฏธํฐ๊ฐ "์ง์ญ"๊ณผ "๋ ์จ"๋ผ๋ฉด ์ฌ์ฉ์์ ์
๋ ฅ ๋ฌธ์ฅ์์ "์ง์ญ"์ ๊ด๋ จ๋ ์ ๋ณด์
"๋ ์จ"์ ๊ด๋ จ๋ ์ ๋ณด๋ฅผ ์ฐพ์๋ด์ ํด๋น ์ฌ๋กฏ์ ์ฑ์๋๋ค. ๋ง์ฝ ์ฌ์ฉ์๊ฐ "์์์ผ ๋ ์จ ์๋ ค์ค"๋ผ๊ณ ๋ง ๋งํ๋ค๋ฉด,
์ง์ญ์ ๊ด๋ จ๋ ์ ๋ณด๋ ์์ง ์ฐพ์๋ด์ง ๋ชปํ๊ธฐ ๋๋ฌธ์ ๋ค์ ๋๋ฌผ์ด์ ์ฐพ์๋ด์ผํฉ๋๋ค.
์ฌ๋กฏ์ด ๋ชจ๋ ์ฑ์์ก๋ค๋ฉด API๋ฅผ ์คํ์์ผ์ ์ธ๋ถ๋ก๋ถํฐ ์ ๋ณด๋ฅผ ์ ๊ณต๋ฐ์ต๋๋ค.
API๋ก๋ถํฐ ๊ฒฐ๊ณผ๊ฐ ๋์ฐฉํ๋ฉด, ๋ฏธ๋ฆฌ ๋ง๋ค์ด๋ ํ
ํ๋ฆฟ ๋ฌธ์ฅ์ ํด๋น ์คํ ๊ฒฐ๊ณผ๋ฅผ ์ฝ์
ํ์ฌ ๋๋ต์ ๋ง๋ค์ด๋ด๊ณ ,
์ด ๋๋ต์ ์ฌ์ฉ์์๊ฒ responseํฉ๋๋ค. ์ด API๋ ์์ ๋กญ๊ฒ ์ํ๋ API๋ฅผ ์ฌ์ฉํ๋ฉด ๋ฉ๋๋ค.
์์ ์ ํ๋ฆฌ์ผ์ด์
์์๋ ์ฃผ๋ก ์น ํฌ๋กค๋ง์ ์ด์ฉํ์ฌ API๋ฅผ ๊ตฌ์ฑํ์๊ณ , ํฌ๋กค๋ฌ ๊ตฌํ ์ํคํ
์ฒ์ ๋ํด์๋ ํ์ ํ๋๋ก ํ๊ฒ ์ต๋๋ค.
Slot Filling ๋ฐฉ์์ ์ฑ๋ด์ ์์ ๊ฐ์ ํ๋ฆ์ผ๋ก ์งํ๋ฉ๋๋ค. ๋ฐ๋ผ์ ์ด๋ฌํ ๋ฐฉ์์ ์ฑ๋ด์ ๊ตฌํํ๋ ค๋ฉด ์ต์ํ 3๊ฐ์ง์ ๋ชจ๋์ด ํ์ํฉ๋๋ค. ์ฒซ๋ฒ์งธ๋ก ์ธํ ํธ ๋ถ๋ฅ๋ชจ๋ธ, ์ํฐํฐ ์ธ์๋ชจ๋ธ, ๊ทธ๋ฆฌ๊ณ ๋๋ต ์์ฑ๋ชจ๋(์์ ์์๋ ํฌ๋กค๋ง)์ ๋๋ค. Kochat์ ์ด ์ธ๊ฐ์ง ๋ชจ๋๊ณผ ์ด๋ฅผ ์๋นํ Restful API๊น์ง ๋ชจ๋ ํฌํจํ๊ณ ์์ต๋๋ค. ์ด์ ๋ํด์๋ ์๋์ Usage ์ฑํฐ์์ ๊ฐ๊ฐ ๋ชจ๋ธ์ด ์ด๋ป๊ฒ ๊ตฌํ๋์ด ์๋์ง ์์ธํ ์ค๋ช ํฉ๋๋ค.
Kochat์ ์ด์ฉํ๋ ค๋ฉด ๋ฐ๋์ ๋ณธ์ธ์ OS์ ๋จธ์ ์ ๋ง๋ Pytorch๊ฐ ์ค์น ๋์ด์์ด์ผํฉ๋๋ค. ๋ง์ฝ Pytorch๋ฅผ ์ค์นํ์ง ์์ผ์ จ๋ค๋ฉด ์ฌ๊ธฐ ์์ ๋ค์ด๋ก๋ ๋ฐ์์ฃผ์ธ์. (Kochat์ ์ค์นํ๋ค๊ณ ํด์ Pytorch๊ฐ ํจ๊ป ์ค์น๋์ง ์์ต๋๋ค. ๋ณธ์ธ ๋ฒ์ ์ ๋ง๋ Pytorch๋ฅผ ๋ค์ด๋ก๋ ๋ฐ์์ฃผ์ธ์)
pip๋ฅผ ์ด์ฉํด Kochat์ ๊ฐ๋จํ๊ฒ ๋ค์ด๋ก๋ํ๊ณ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์๋ ๋ช ๋ น์ด๋ฅผ ํตํด์ kochat์ ๋ค์ด๋ก๋ ๋ฐ์์ฃผ์ธ์.
pip install kochat
ํจํค์ง๋ฅผ ๊ตฌํํ๋๋ฐ ์ฌ์ฉ๋ ๋ํ๋์๋ ์๋์ ๊ฐ์ต๋๋ค. (Kochat ์ค์น์ ํจ๊ป ์ค์น๋ฉ๋๋ค.)
matplotlib==3.2.1
pandas==1.0.4
gensim==3.8.3
konlpy==0.5.2
numpy==1.18.5
joblib==0.15.1
scikit-learn==0.23.1
pytorch-crf==0.7.2
requests==2.24.0
flask==1.1.2
pip๋ฅผ ์ด์ฉํด Kochat์ ๋ค์ด๋ก๋ ๋ฐ์๋ค๋ฉด ํ๋ก์ ํธ์, kochat์ configuration ํ์ผ์ ์ถ๊ฐํด์ผํฉ๋๋ค. kochat_config.zip ์ ๋ค์ด๋ก๋ ๋ฐ๊ณ ์์ถ์ ํ์ด์ interpreter์ working directory์ ๋ฃ์ต๋๋ค. (kochat api๋ฅผ ์คํํ๋ ํ์ผ๊ณผ ๋์ผํ ๊ฒฝ๋ก์ ์์ด์ผํฉ๋๋ค. ์์ธํ ์์๋ ์๋ ๋ฐ๋ชจ์์ ํ์ธํ์ค ์ ์์ต๋๋ค.) config ํ์ผ์๋ ๋ค์ํ ์ค์ ๊ฐ๋ค์ด ์กด์ฌํ๋ ํ์ธํ๊ณ ์ ๋ง๋๋ก ๋ณ๊ฒฝํ์๋ฉด ๋ฉ๋๋ค.
์ด์ ์ฌ๋ฌ๋ถ์ด ํ์ต์ํฌ ๋ฐ์ดํฐ์
์ ๋ฃ์ด์ผํฉ๋๋ค.
๊ทธ ์ ์ ๋ฐ์ดํฐ์
์ ํฌ๋งท์ ๋ํด์ ๊ฐ๋จํ๊ฒ ์์๋ด
์๋ค.
Kochat์ ๊ธฐ๋ณธ์ ์ผ๋ก Slot filling์ ๊ธฐ๋ฐ์ผ๋ก
ํ๊ณ ์๊ธฐ ๋๋ฌธ์ Intent์ Entity ๋ฐ์ดํฐ์
์ด ํ์ํฉ๋๋ค.
๊ทธ๋ฌ๋ ์ด ๋๊ฐ์ง ๋ฐ์ดํฐ์
์ ๋ฐ๋ก ๋ง๋ค๋ฉด ์๋นํ ๋ฒ๊ฑฐ๋ก์์ง๊ธฐ ๋๋ฌธ์
ํ๊ฐ์ง ํฌ๋งท์ผ๋ก ๋๊ฐ์ง ๋ฐ์ดํฐ๋ฅผ ์๋์ผ๋ก ์์ฑํฉ๋๋ค.
์๋ ๋ฐ์ดํฐ์
๊ท์น๋ค์ ๋ง์ถฐ์ ๋ฐ์ดํฐ๋ฅผ ์์ฑํด์ฃผ์ธ์
๊ธฐ๋ณธ์ ์ผ๋ก intent์ entity๋ฅผ ๋๋๋ ค๋ฉด, ๋๊ฐ์ง๋ฅผ ๋ชจ๋ ๊ตฌ๋ถํ ์ ์์ด์ผํฉ๋๋ค.
๊ทธ๋์ ์ ํํ ๋ฐฉ์์ ์ธํ
ํธ๋ ํ์ผ๋ก ๊ตฌ๋ถ, ์ํฐํฐ๋ ๋ผ๋ฒจ๋ก ๊ตฌ๋ถํ๋ ๊ฒ์ด์์ต๋๋ค.
์ถํ ๋ฆด๋ฆฌ์ฆ ๋ฒ์ ์์๋ Rasa์ฒ๋ผ ํจ์ฌ ์ฌ์ด ๋ฐฉ์์ผ๋ก ๋ณ๊ฒฝํ๋ ค๊ณ ํฉ๋๋ค๋ง, ์ด๊ธฐ๋ฒ์ ์์๋
๋ค์ ๋ถํธํ๋๋ผ๋ ์๋์ ํฌ๋งท์ ๋ฐ๋ผ์ฃผ์๊ธธ ๋ฐ๋๋๋ค.
- weather.csv
question,label
๋ ์จ ์๋ ค์ฃผ์ธ์,O O
์์์ผ ์ธ์ ๋น์ค๋,S-DATE S-LOCATION O
๊ตฐ์ฐ ๋ ์จ ์ถ์ธ๊น ์ ๋ง,S-LOCATION O O O
๊ณก์ฑ ๋น์ฌ๊น,S-LOCATION O
๋ด์ผ ๋จ์ ๋ ์ค๊ฒ ์ง ์๋ง,S-DATE S-LOCATION O O O
๊ฐ์๋ ์ถ์ฒ ๊ฐ๋๋ฐ ์ค๋ ๋ ์จ ์๋ ค์ค,B-LOCATION E-LOCATION O S-DATE O O
์ ๋ถ ๊ตฐ์ฐ ๊ฐ๋๋ฐ ํ์์ผ ๋ ์จ ์๋ ค์ค๋,B-LOCATION E-LOCATION O S-DATE O O
์ ์ฃผ ์๊ทํฌ ๊ฐ๋ ค๋๋ฐ ํ์์ผ ๋ ์จ ์๋ ค์ค,B-LOCATION E-LOCATION O S-DATE O O
์ค๋ ์ ์ฃผ๋ ๋ ์จ ์๋ ค์ค,S-DATE S-LOCATION O O
... (์๋ต)
- travel.csv
question,label
์ด๋ ๊ด๊ด์ง ๊ฐ๊ฒ ๋,O O O
ํ์ฃผ ์ ๋ช
ํ ๊ณต์ฐ์ฅ ์๋ ค์ค,S-LOCATION O S-PLACE O
์ฐฝ์ ์ฌํ ๊ฐ๋งํ ๋ฐ๋ค,S-LOCATION O O S-PLACE
ํํ ๊ฐ๋งํ ์คํค์ฅ ์ฌํ ํด๋ณด๊ณ ์ถ๋ค,S-LOCATION O S-PLACE O O O
์ ์ฃผ๋ ํ
ํ์คํ
์ด ์ฌํ ๊ฐ ๋ฐ ์ถ์ฒํด ์ค,S-LOCATION S-PLACE O O O O O
์ ์ฃผ ๊ฐ๊น์ด ๋ฐ๋ค ๊ด๊ด์ง ๋ณด์ฌ์ค ๋ด์,S-LOCATION O S-PLACE O O O
์ฉ์ธ ๊ฐ๊น์ด ์ถ๊ตฌ์ฅ ์ด๋จ์ด,S-LOCATION O S-PLACE O
๋ถ๋น๋ ๊ด๊ด์ง,O O
์ฒญ์ฃผ ๊ฐ์ ํ๊ฒฝ ์์ ์ฐ ๊ฐ๋ณด๊ณ ์ถ์ด,S-LOCATION S-DATE O O S-PLACE O O
... (์๋ต)
์ ์ฒ๋ผ question,label์ด๋ผ๋ ํค๋(์ปฌ๋ผ๋ช )์ ๊ฐ์ฅ ์์ค์ ์์น์ํค๊ณ , ๊ทธ ์๋๋ก ๋๊ฐ์ ์ปฌ๋ฆผ question๊ณผ label์ ํด๋นํ๋ ๋ด์ฉ์ ์์ฑํฉ๋๋ค. ๊ฐ ๋จ์ด ๋ฐ ์ํฐํฐ๋ ๋์ด์ฐ๊ธฐ๋ก ๊ตฌ๋ถ๋ฉ๋๋ค. ๋ฐ๋ชจ ๋ฐ์ดํฐ๋ BIOํ๊น ์ ๊ฐ์ ํ BIOESํ๊น ์ ์ฌ์ฉํ์ฌ ๋ผ๋ฒจ๋งํ๋๋ฐ, ์ํฐํฐ ํ๊น ๋ฐฉ์์ ์์ ๋กญ๊ฒ ๊ณ ๋ฅด์ ๋ ๋ฉ๋๋ค. (config์์ ์ค์ ๊ฐ๋ฅํฉ๋๋ค.) ์ํฐํฐ ํ๊น ์คํค๋ง์ ๊ด๋ จ๋ ์์ธํ ๋ด์ฉ์ ์ฌ๊ธฐ ๋ฅผ ์ฐธ๊ณ ํ์ธ์.
๋ฐ์ดํฐ์ ์ ์ฅ๊ฒฝ๋ก๋ ๊ธฐ๋ณธ์ ์ผ๋ก configํ์ผ์ด ์๋ ๊ณณ์ root๋ก ์๊ฐํ์ ๋, "root/data/raw"์ ๋๋ค. ์ด ๊ฒฝ๋ก๋ config์ DATA ์ฑํฐ์์ ๋ณ๊ฒฝ ๊ฐ๋ฅํฉ๋๋ค.
root
|_data
|_raw
|_weather.csv
|_dust.csv
|_retaurant.csv
|_...
๊ฐ ์ธํ ํธ ๋จ์๋ก ํ์ผ์ ๋ถํ ํฉ๋๋ค. ์ด ๋, ํ์ผ๋ช ์ด ์ธํ ํธ๋ช ์ด ๋ฉ๋๋ค. ํ์ผ๋ช ์ ํ๊ธ๋ก ํด๋ ์๊ด ์๊ธด ํ์ง๋ง, ๋ฆฌ๋ ์ค ์ด์์ฒด์ ์ ๊ฒฝ์ฐ ์๊ฐํ์ matplotlib์ ํ๊ธํฐํธ๊ฐ ์ค์น๋์ด์์ง ์๋ค๋ฉด ๊ธ์๊ฐ ๊นจ์ง๋, ๊ฐ๊ธ์ ์ด๋ฉด ์๊ฐํ๋ฅผ ์ํด ์์ด๋ก ํ๋ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค. (๋ง์ฝ ๊ธ์๊ฐ ๊นจ์ง์ง ์์ผ๋ฉด ํ๊ธ๋ก ํด๋ ๋ฌด๋ฐฉํ๋, ํ๊ธ๋ก ํ๋ ค๋ฉด ํฐํธ๋ฅผ ์ค์นํด์ฃผ์ธ์.)
root
|_data
|_raw
|_weather.csv โ intent : weather
|_dust.csv โ intent : dust
|_retaurant.csv โ intent : restaurant
|_...
ํ์ผ์ ํค๋(์ปฌ๋ผ๋ช )์ ๋ฐ๋์ question๊ณผ label๋ก ํด์ฃผ์ธ์. ํค๋๋ฅผ config์์ ๋ฐ๊ฟ ์ ์๊ฒ ํ ๊น๋ ์๊ฐํ์ง๋ง, ๋ณ๋ก ํฐ ์๋ฏธ๊ฐ ์๋ ๊ฒ ๊ฐ์์ ์ฐ์ ์ ๊ณ ์ ๋ ๊ฐ์ธ question๊ณผ label๋ก ์ค์ ํ์์ต๋๋ค.
question,label โ ์ค์ !!!
... (์๋ต)
์ํ ๋น question์ ๋จ์ด ๊ฐฏ์์ label์ ์ํฐํฐ ๊ฐฏ์๋ ๋์ผํด์ผํ๋ฉฐ config์ ์ ์ํ ์ํฐํฐ๋ง ์ฌ์ฉ ๊ฐ๋ฅํฉ๋๋ค. ์ด๋ฌํ ๋ผ๋ฒจ๋ง ์ค์๋ Kochat์ด ๋ฐ์ดํฐ๋ฅผ ๋ณํํ ๋ ๊ฒ์ถํด์ ์ด๋๊ฐ ํ๋ ธ๋์ง ์๋ ค์ค๋๋ค.
case 1: ๋ผ๋ฒจ๋ง ๋งค์นญ ์ค์ ๋ฐฉ์ง
question = ์ ์ฃผ ๋ ์ฌ๊น (size : 3)
label = S-LOCATION O O O (size : 4)
โ ์๋ฌ ๋ฐ์! (question๊ณผ label์ ์๊ฐ ๋ค๋ฆ)
case 2: ๋ผ๋ฒจ๋ง ์คํ ๋ฐฉ์ง
(in kochat_config.py)
DATA = {
... (์๋ต)
'NER_categories': ['DATE', 'LOCATION', 'RESTAURANT', 'PLACE'], # ์ฌ์ฉ์ ์ ์ ํ๊ทธ
'NER_tagging': ['B', 'E', 'I', 'S'], # NER์ BEGIN, END, INSIDE, SINGLE ํ๊ทธ
'NER_outside': 'O', # NER์ Oํ๊ทธ (Outside๋ฅผ ์๋ฏธ)
}
question = ์ ์ฃผ ๋ ์ฌ๊น
label = Z-LOC O O
โ ์๋ฌ ๋ฐ์! (์ ์๋์ง ์์ ์ํฐํฐ : Z-LOC)
NER_tagging + '-' + NER_categories์ ํํ๊ฐ ์๋๋ฉด ์๋ฌ๋ฅผ ๋ฐํํฉ๋๋ค.
OOD๋ Out of distribution์ ์ฝ์๋ก, ๋ถํฌ ์ธ ๋ฐ์ดํฐ์ ์ ์๋ฏธํฉ๋๋ค. ์ฆ, ํ์ฌ ์ฑ๋ด์ด ์ง์ํ๋ ๊ธฐ๋ฅ ์ด์ธ์ ๋ฐ์ดํฐ๋ฅผ ์๋ฏธํฉ๋๋ค. OOD ๋ฐ์ดํฐ์ ์ด ์์ด๋ Kochat์ ์ด์ฉํ๋๋ฐ์๋ ์๋ฌด๋ฐ ๋ฌธ์ ๊ฐ ์์ง๋ง, OOD ๋ฐ์ดํฐ์ ์ ๊ฐ์ถ๋ฉด ๋งค์ฐ ๊ท์ฐฎ์ ๋ช๋ช ๋ถ๋ถ๋ค์ ํจ๊ณผ์ ์ผ๋ก ์๋ํ ํ ์ ์์ต๋๋ค. (์ฃผ๋ก Fallback Detection threshold ์ค์ ) OOD ๋ฐ์ดํฐ์ ์ ์๋์ฒ๋ผ "root/data/ood"์ ์ถ๊ฐํฉ๋๋ค.
root
|_data
|_raw
|_weather.csv
|_dust.csv
|_retaurant.csv
|_...
|_ood
|_ood_data_1.csv โ data/oodํด๋์ ์์นํ๊ฒ ํฉ๋๋ค.
|_ood_data_2.csv โ data/oodํด๋์ ์์นํ๊ฒ ํฉ๋๋ค.
OOD ๋ฐ์ดํฐ์ ์ ์๋์ ๊ฐ์ด question๊ณผ OOD์ ์๋๋ก ๋ผ๋ฒจ๋งํฉ๋๋ค. ๋ฐ๋ชจ ๋ฐ์ดํฐ์ ์ ์ ๋ถ ์๋๋๋ก ๋ผ๋ฒจ๋งํ์ง๋ง, ์ด ์๋๊ฐ์ ์ฌ์ฉํ์ง ์๊ธฐ ๋๋ฌธ์ ๊ทธ๋ฅ ์๋ฌด๊ฐ์ผ๋ก๋ ๋ผ๋ฒจ๋งํด๋ ์ฌ์ค ๋ฌด๊ดํฉ๋๋ค.
๋ฐ๋ชจ_ood_๋ฐ์ดํฐ.csv
question,label
์ต๊ทผ ์๋์ผ ์ต๊ทผ ์ด์ ์๋ ค์ค,๋ด์ค์ด์
์ต๊ทผ ํซํ๋ ๊ฒ ์๋ ค์ค,๋ด์ค์ด์
๋ํํ
์ข์ ๋ช
์ธํด์ค ์ ์๋,๋ช
์ธ
๋ ์ข์ ๋ช
์ธ ์ข ๋ค๋ ค์ฃผ๋ผ,๋ช
์ธ
์ข์ ๋ช
์ธ ์ข ํด๋ด,๋ช
์ธ
๋ฐฑ์ฌ๋ฒ ๋
ธ๋ ๋ค์๋์,์์
๋น ๋
ธ๋ ๊นก ๋ฃ๊ณ ์ถ๋ค,์์
์ํ ost ์ถ์ฒํด์ค,์์
์ง๊ธ ์๊ฐ ์ข ์๋ ค๋ฌ๋ผ๊ณ ,๋ ์ง์๊ฐ
์ง๊ธ ์๊ฐ ์ข ์๋ ค์ค,๋ ์ง์๊ฐ
์ง๊ธ ๋ช ์ ๋ช ๋ถ์ธ์ง ์๋,๋ ์ง์๊ฐ
๋ช
์ ์คํธ๋ ์ค ใ
ใ
,์ก๋ด
๋ญํ๊ณ ๋์ง ใ
ใ
,์ก๋ด
๋๋ ๋์์ฃผ๋ผ ์ข,์ก๋ด
๋ญํ๊ณ ์ด์ง,์ก๋ด
... (์๋ต)
์ด๋ ๊ฒ ๋ผ๋ฒจ๋ง ํด๋ ๋์ง๋ง ์ด์ฐจํผ ๋ผ๋ฒจ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ์ง ์๊ธฐ ๋๋ฌธ์ ์๋์ฒ๋ผ ๋ผ๋ฒจ๋งํด๋ ๋ฌด๊ดํฉ๋๋ค.
๋ฐ๋ชจ_ood_๋ฐ์ดํฐ.csv
question,label
์ต๊ทผ ์๋์ผ ์ต๊ทผ ์ด์ ์๋ ค์ค,OOD
์ต๊ทผ ํซํ๋ ๊ฒ ์๋ ค์ค,OOD
๋ํํ
์ข์ ๋ช
์ธํด์ค ์ ์๋,OOD
๋ ์ข์ ๋ช
์ธ ์ข ๋ค๋ ค์ฃผ๋ผ,OOD
์ข์ ๋ช
์ธ ์ข ํด๋ด,OOD
๋ฐฑ์ฌ๋ฒ ๋
ธ๋ ๋ค์๋์,OOD
๋น ๋
ธ๋ ๊นก ๋ฃ๊ณ ์ถ๋ค,OOD
์ํ ost ์ถ์ฒํด์ค,OOD
์ง๊ธ ์๊ฐ ์ข ์๋ ค๋ฌ๋ผ๊ณ ,OOD
์ง๊ธ ์๊ฐ ์ข ์๋ ค์ค,OOD
์ง๊ธ ๋ช ์ ๋ช ๋ถ์ธ์ง ์๋,OOD
๋ช
์ ์คํธ๋ ์ค ใ
ใ
,OOD
๋ญํ๊ณ ๋์ง ใ
ใ
,OOD
๋๋ ๋์์ฃผ๋ผ ์ข,OOD
๋ญํ๊ณ ์ด์ง,OOD
... (์๋ต)
OOD ๋ฐ์ดํฐ๋ ๋ฌผ๋ก ๋ง์ผ๋ฉด ์ข๊ฒ ์ง๋ง ๋ง๋๋ ๊ฒ ์์ฒด๊ฐ ๋ถ๋ด์ด๊ธฐ ๋๋ฌธ์ ์ ์ ์๋ง ๋ฃ์ด๋ ๋ฉ๋๋ค.
๋ฐ๋ชจ ๋ฐ์ดํฐ์ ๊ฒฝ์ฐ๋ ์ด 3000๋ผ์ธ์ ๋ฐ์ดํฐ ์ค 600๋ผ์ธ์ ๋์ OOD ๋ฐ์ดํฐ๋ฅผ ์ฝ์
ํ์์ต๋๋ค.
๋ฐ์ดํฐ๊น์ง ๋ชจ๋ ์ฝ์
ํ์
จ๋ค๋ฉด kochat์ ์ด์ฉํ ์ค๋น๊ฐ ๋๋ฌ์ต๋๋ค. ์๋ ์ฑํฐ์์๋
์์ธํ ์ฌ์ฉ๋ฒ์ ๋ํด ์๋ ค๋๋ฆฌ๊ฒ ์ต๋๋ค.
kochat.data
ํจํค์ง์๋ Dataset
ํด๋์ค๊ฐ ์์ต๋๋ค. Dataset
ํด๋์ค๋
๋ถ๋ฆฌ๋ raw ๋ฐ์ดํฐ ํ์ผ๋ค์ ํ๋๋ก ํฉ์ณ์ ํตํฉ intentํ์ผ๊ณผ ํตํฉ entityํ์ผ๋ก ๋ง๋ค๊ณ ,
embedding, intent, entity, inference์ ๊ด๋ จ๋ ๋ฐ์ดํฐ์
์ ๋ฏธ๋๋ฐฐ์น๋ก ์๋ผ์
pytorch์ DataLoader
ํํ๋ก ์ ๊ณตํฉ๋๋ค.
๋ํ ๋ชจ๋ธ, Loss ํจ์ ๋ฑ์ ์์ฑํ ๋ ํ๋ผ๋ฏธํฐ๋ก ์
๋ ฅํ๋ label_dict
๋ฅผ ์ ๊ณตํฉ๋๋ค.
Dataset
ํด๋์ค๋ฅผ ์์ฑํ ๋ ํ์ํ ํ๋ผ๋ฏธํฐ์ธ ood
๋ OOD ๋ฐ์ดํฐ์
์ฌ์ฉ ์ฌ๋ถ์
๋๋ค.
True๋ก ์ค์ ํ๋ฉด ood ๋ฐ์ดํฐ์
์ ์ฌ์ฉํฉ๋๋ค.
- Dataset ๊ธฐ๋ฅ 1. ๋ฐ์ดํฐ์ ์์ฑ
from kochat.data import Dataset
# ํด๋์ค ์์ฑ์ rawํ์ผ๋ค์ ๊ฒ์ฆํ๊ณ ํตํฉํฉ๋๋ค.
dataset = Dataset(ood=True, naver_fix=True)
# ์๋ฒ ๋ฉ ๋ฐ์ดํฐ์
์์ฑ
embed_dataset = dataset.load_embed()
# ์ธํ
ํธ ๋ฐ์ดํฐ์
์์ฑ (์๋ฒ ๋ฉ ํ๋ก์ธ์ ํ์)
intent_dataset = dataset.load_intent(emb)
# ์ํฐํฐ ๋ฐ์ดํฐ์
์์ฑ (์๋ฒ ๋ฉ ํ๋ก์ธ์ ํ์)
entity_dataset = dataset.load_entity(emb)
# ์ถ๋ก ์ฉ ๋ฐ์ดํฐ์
์์ฑ (์๋ฒ ๋ฉ ํ๋ก์ธ์ ํ์)
predict_dataset = dataset.load_predict("์์ธ ๋ง์ง ์ถ์ฒํด์ค", emb)
- Dataset ๊ธฐ๋ฅ 2. ๋ผ๋ฒจ ๋์ ๋๋ฆฌ ์์ฑ
from kochat.data import Dataset
# ํด๋์ค ์์ฑ์ rawํ์ผ๋ค์ ๊ฒ์ฆํ๊ณ ํตํฉํฉ๋๋ค.
dataset = Dataset(ood=True, naver_fix=True)
# ์ธํ
ํธ ๋ผ๋ฒจ ๋์
๋๋ฆฌ๋ฅผ ์์ฑํฉ๋๋ค.
intent_dict = dataset.intent_dict
# ์ํฐํฐ ๋ผ๋ฒจ ๋์
๋๋ฆฌ๋ฅผ ์์ฑํฉ๋๋ค.
entity_dict = dataset.entity_dict
Dataset
ํด๋์ค๋ ์ ์ฒ๋ฆฌ์ ํ ํฐํ๋ฅผ ์ํํ ๋,
ํ์ต/ํ
์คํธ ๋ฐ์ดํฐ๋ ๋์ด์ฐ๊ธฐ๋ฅผ ๊ธฐ์ค์ผ๋ก ํ ํฐํ๋ฅผ ์ํํ๊ณ , ์ค์ ์ฌ์ฉ์์ ์
๋ ฅ์
์ถ๋ก ํ ๋๋ ๋ค์ด๋ฒ ๋ง์ถค๋ฒ ๊ฒ์ฌ๊ธฐ์ Konlpy ํ ํฌ๋์ด์ ๋ฅผ ์ฌ์ฉํ์ฌ ํ ํฐํ๋ฅผ ์ํํฉ๋๋ค.
๋ค์ด๋ฒ ๋ง์ถค๋ฒ ๊ฒ์ฌ๊ธฐ๋ฅผ ์ฌ์ฉํ๋ฉด ์ฑ๋ฅ์ ๋์ฑ ํฅ์๋๊ฒ ์ง๋ง, ์์
์ ์ผ๋ก ์ด์ฉ์ ๋ฌธ์ ๊ฐ
๋ฐ์ํ ์ ์๊ณ , ์ด์ ๋ํด ๊ฐ๋ฐ์๋ ์ด๋ ํ ์ฑ
์๋ ์ง์ง ์์ต๋๋ค.
๋ง์ฝ Kochat์ ์์
์ ์ผ๋ก ์ด์ฉํ์๋ ค๋ฉด Dataset
์์ฑ์ naver_fix
ํ๋ผ๋ฏธํฐ๋ฅผ
False
๋ก ์ค์ ํด์ฃผ์๊ธธ ๋ฐ๋๋๋ค. False
์ค์ ์์๋ Konlpy ํ ํฐํ๋ง ์ํํ๋ฉฐ,
์ถํ ๋ฒ์ ์์๋ ๋ค์ด๋ฒ ๋ง์ถค๋ฒ ๊ฒ์ฌ๊ธฐ๋ฅผ ์์ฒด์ ์ธ ๋์ด์ฐ๊ธฐ ๊ฒ์ฌ๋ชจ๋ ๋ฑ์ผ๋ก
๊ต์ฒดํ ์์ ์
๋๋ค.
model
ํจํค์ง๋ ์ฌ์ ์ ์๋ ๋ค์ํ built-in ๋ชจ๋ธ๋ค์ด ์ ์ฅ๋ ํจํค์ง์
๋๋ค.
ํ์ฌ ๋ฒ์ ์์๋ ์๋ ๋ชฉ๋ก์ ํด๋นํ๋ ๋ชจ๋ธ๋ค์ ์ง์ํฉ๋๋ค. ์ถํ ๋ฒ์ ์ด ์
๋ฐ์ดํธ ๋๋ฉด
์ง๊ธ๋ณด๋ค ํจ์ฌ ๋ค์ํ built-in ๋ชจ๋ธ์ ์ง์ํ ์์ ์
๋๋ค. ์๋ ๋ชฉ๋ก์ ์ฐธ๊ณ ํ์ฌ ์ฌ์ฉํด์ฃผ์๊ธธ ๋ฐ๋๋๋ค.
from kochat.model import embed
# 1. Gensim์ Word2Vec ๋ชจ๋ธ์ Wrapper์
๋๋ค.
# (OOV ํ ํฐ์ ๊ฐ์ config์์ ์ค์ ๊ฐ๋ฅํฉ๋๋ค.)
word2vec = embed.Word2Vec()
# 2. Gensim์ FastText ๋ชจ๋ธ์ Wrapper์
๋๋ค.
fasttext = embed.FastText()
from kochat.model import intent
# 1. Residual Learning์ ์ง์ํ๋ 1D CNN์
๋๋ค.
cnn = intent.CNN(label_dict=dataset.intent_dict, residual=True)
# 2. Bidirectional์ ์ง์ํ๋ LSTM์
๋๋ค.
lstm = intent.LSTM(label_dict=dataset.intent_dict, bidirectional=True)
from kochat.model import entity
# 1. Bidirectional์ ์ง์ํ๋ LSTM์
๋๋ค.
lstm = entity.LSTM(label_dict=dataset.entity_dict, bidirectional=True)
Kochat์ ์ปค์คํ
๋ชจ๋ธ์ ์ง์ํฉ๋๋ค.
Gensim์ด๋ Pytorch๋ก ์์ฑํ ์ปค์คํ
๋ชจ๋ธ์ ์ง์ ํ์ต์ํค๊ธฐ๊ณ ์ฑ๋ด ์ ํ๋ฆฌ์ผ์ด์
์
์ฌ์ฉํ ์ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ๋ง์ฝ ์ปค์คํ
๋ชจ๋ธ์ ์ฌ์ฉํ๋ ค๋ฉด ์๋์ ๋ช๊ฐ์ง ๊ท์น์ ๋ฐ๋์
๋ฐ๋ผ์ผํฉ๋๋ค.
์๋ฒ ๋ฉ์ ๊ฒฝ์ฐ ํ์ฌ๋ Gensim ๋ชจ๋ธ๋ง ์ง์ํฉ๋๋ค. ์ถํ์ Pytorch๋ก ๋
์๋ฒ ๋ฉ ๋ชจ๋ธ(ELMO, BERT)๋ฑ๋ ์ง์ํ ๊ณํ์
๋๋ค.
Gensim Embedding ๋ชจ๋ธ์ ์๋์ ๊ฐ์ ํํ๋ก ๊ตฌํํด์ผํฉ๋๋ค.
@gensim
๋ฐ์ฝ๋ ์ดํฐ ์ค์ BaseWordEmbeddingsModel
๋ชจ๋ธ ์ค ํ ๊ฐ์ง ์์๋ฐ๊ธฐsuper().__init__()
์ ํ๋ผ๋ฏธํฐ ์ฝ์ ํ๊ธฐ (self.XXX๋ก ์ ๊ทผ๊ฐ๋ฅ)
from gensim.models import FastText
from kochat.decorators import gensim
# 1. @gensim ๋ฐ์ฝ๋ ์ดํฐ๋ฅผ ์ค์ ํ๋ฉด
# config์ GENSIM์ ์๋ ๋ชจ๋ ๋ฐ์ดํฐ์ ์ ๊ทผ ๊ฐ๋ฅํฉ๋๋ค.
@gensim
class FastText(FastText):
# 2. BaseWordEmbeddingsModel ๋ชจ๋ธ์ค ํ ๊ฐ์ง๋ฅผ ์์๋ฐ์ต๋๋ค.
def __init__(self):
# 3. `super().__init__()`์ ํ์ํ ํ๋ผ๋ฏธํฐ๋ฅผ ๋ฃ์ด์ ์ด๊ธฐํํด์ค๋๋ค.
super().__init__(size=self.vector_size,
window=self.window_size,
workers=self.workers,
min_count=self.min_count,
iter=self.iter)
์ธํ
ํธ ๋ชจ๋ธ์ torch๋ก ๊ตฌํํฉ๋๋ค.
์ธํ
ํธ ๋ชจ๋ธ์๋ self.label_dict
๊ฐ ๋ฐ๋์ ์กด์ฌํด์ผํฉ๋๋ค.
๋ํ ์ต์ข
output ๋ ์ด์ด๋ ์๋์์ฑ๋๊ธฐ ๋๋ฌธ์ feature๋ง ์ถ๋ ฅํ๋ฉด ๋ฉ๋๋ค.
๋์ฑ ์ธ๋ถ์ ์ธ ๊ท์น์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
@intent
๋ฐ์ฝ๋ ์ดํฐ ์ค์ torch.nn.Module
์์๋ฐ๊ธฐ- ํ๋ผ๋ฏธํฐ๋ก label_dict๋ฅผ ์
๋ ฅ๋ฐ๊ณ
self.label_dict
์ ํ ๋นํ๊ธฐ forward()
ํจ์์์ feature๋ฅผ [batch_size, -1] ๋ก ๋ง๋ค๊ณ ๋ฆฌํด
from torch import nn
from torch import Tensor
from kochat.decorators import intent
from kochat.model.layers.convolution import Convolution
# 1. @intent ๋ฐ์ฝ๋ ์ดํฐ๋ฅผ ์ค์ ํ๋ฉด
# config์ INTENT์ ์๋ ๋ชจ๋ ์ค์ ๊ฐ์ ์ ๊ทผ ๊ฐ๋ฅํฉ๋๋ค.
@intent
class CNN(nn.Module):
# 2. torch.nn์ Module์ ์์๋ฐ์ต๋๋ค.
def __init__(self, label_dict: dict, residual: bool = True):
super(CNN, self).__init__()
self.label_dict = label_dict
# 3. intent๋ชจ๋ธ์ ๋ฐ๋์ ์์ฑ์ผ๋ก self.label_dict๋ฅผ ๊ฐ์ง๊ณ ์์ด์ผํฉ๋๋ค.
self.stem = Convolution(self.vector_size, self.d_model, kernel_size=1, residual=residual)
self.hidden_layers = nn.Sequential(*[
Convolution(self.d_model, self.d_model, kernel_size=1, residual=residual)
for _ in range(self.layers)])
def forward(self, x: Tensor) -> Tensor:
x = x.permute(0, 2, 1)
x = self.stem(x)
x = self.hidden_layers(x)
return x.view(x.size(0), -1)
# 4. feature๋ฅผ [batch_size, -1]๋ก ๋ง๋ค๊ณ ๋ฐํํฉ๋๋ค.
# ์ต์ข
output ๋ ์ด์ด๋ kochat์ด ์๋ ์์ฑํ๊ธฐ ๋๋ฌธ์ feature๋ง ์ถ๋ ฅํฉ๋๋ค.
import torch
from torch import nn, autograd
from torch import Tensor
from kochat.decorators import intent
# 1. @intent ๋ฐ์ฝ๋ ์ดํฐ๋ฅผ ์ค์ ํ๋ฉด
# config์ INTENT์ ์๋ ๋ชจ๋ ์ค์ ๊ฐ์ ์ ๊ทผ ๊ฐ๋ฅํฉ๋๋ค.
@intent
class LSTM(nn.Module):
# 2. torch.nn์ Module์ ์์๋ฐ์ต๋๋ค.
def __init__(self, label_dict: dict, bidirectional: bool = True):
super().__init__()
self.label_dict = label_dict
# 3. intent๋ชจ๋ธ์ ๋ฐ๋์ ์์ฑ์ผ๋ก self.label_dict๋ฅผ ๊ฐ์ง๊ณ ์์ด์ผํฉ๋๋ค.
self.direction = 2 if bidirectional else 1
self.lstm = nn.LSTM(input_size=self.vector_size,
hidden_size=self.d_model,
num_layers=self.layers,
batch_first=True,
bidirectional=bidirectional)
def init_hidden(self, batch_size: int) -> autograd.Variable:
param1 = torch.randn(self.layers * self.direction, batch_size, self.d_model).to(self.device)
param2 = torch.randn(self.layers * self.direction, batch_size, self.d_model).to(self.device)
return autograd.Variable(param1), autograd.Variable(param2)
def forward(self, x: Tensor) -> Tensor:
b, l, v = x.size()
out, (h_s, c_s) = self.lstm(x, self.init_hidden(b))
# 4. feature๋ฅผ [batch_size, -1]๋ก ๋ง๋ค๊ณ ๋ฐํํฉ๋๋ค.
# ์ต์ข
output ๋ ์ด์ด๋ kochat์ด ์๋ ์์ฑํ๊ธฐ ๋๋ฌธ์ feature๋ง ์ถ๋ ฅํฉ๋๋ค.
return h_s[0]
์ํฐํฐ ๋ชจ๋ธ๋ ์ญ์ torch๋ก ๊ตฌํํฉ๋๋ค.
์ํฐํฐ ๋ชจ๋ธ์๋ ์ญ์ self.label_dict
๊ฐ ๋ฐ๋์ ์กด์ฌํด์ผํ๋ฉฐ,
๋ํ ์ต์ข
output ๋ ์ด์ด๋ ์๋์์ฑ๋๊ธฐ ๋๋ฌธ์ feature๋ง ์ถ๋ ฅํ๋ฉด ๋ฉ๋๋ค.
๋์ฑ ์ธ๋ถ์ ์ธ ๊ท์น์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
@entity
๋ฐ์ฝ๋ ์ดํฐ ์ค์ torch.nn.Module
์์๋ฐ๊ธฐ- ํ๋ผ๋ฏธํฐ๋ก label_dict๋ฅผ ์
๋ ฅ๋ฐ๊ณ
self.label_dict
์ ํ ๋นํ๊ธฐ forward()
ํจ์์์ feature๋ฅผ [batch_size, max_len, -1] ๋ก ๋ง๋ค๊ณ ๋ฆฌํด
import torch
from torch import nn, autograd
from torch import Tensor
from kochat.decorators import entity
# 1. @entity ๋ฐ์ฝ๋ ์ดํฐ๋ฅผ ์ค์ ํ๋ฉด
# config์ ENTITY์ ์๋ ๋ชจ๋ ์ค์ ๊ฐ์ ์ ๊ทผ ๊ฐ๋ฅํฉ๋๋ค.
@entity
class LSTM(nn.Module):
# 2. torch.nn์ Module์ ์์๋ฐ์ต๋๋ค.
def __init__(self, label_dict: dict, bidirectional: bool = True):
super().__init__()
self.label_dict = label_dict
# 3. entity๋ชจ๋ธ์ ๋ฐ๋์ ์์ฑ์ผ๋ก self.label_dict๋ฅผ ๊ฐ์ง๊ณ ์์ด์ผํฉ๋๋ค.
self.direction = 2 if bidirectional else 1
self.lstm = nn.LSTM(input_size=self.vector_size,
hidden_size=self.d_model,
num_layers=self.layers,
batch_first=True,
bidirectional=bidirectional)
def init_hidden(self, batch_size: int) -> autograd.Variable:
param1 = torch.randn(self.layers * self.direction, batch_size, self.d_model).to(self.device)
param2 = torch.randn(self.layers * self.direction, batch_size, self.d_model).to(self.device)
return torch.autograd.Variable(param1), torch.autograd.Variable(param2)
def forward(self, x: Tensor) -> Tensor:
b, l, v = x.size()
out, _ = self.lstm(x, self.init_hidden(b))
# 4. feature๋ฅผ [batch_size, max_len, -1]๋ก ๋ง๋ค๊ณ ๋ฐํํฉ๋๋ค.
# ์ต์ข
output ๋ ์ด์ด๋ kochat์ด ์๋ ์์ฑํ๊ธฐ ๋๋ฌธ์ feature๋ง ์ถ๋ ฅํฉ๋๋ค.
return out
proc
์ Procssor์ ์ค์๋ง๋ก, ๋ค์ํ ๋ชจ๋ธ๋ค์
ํ์ต/ํ
์คํธ์ ์ํํ๋ ํจ์์ธ fit()
๊ณผ
์ถ๋ก ์ ์ํํ๋ ํจ์์ธ predict()
๋ฑ์ ์ํํ๋ ํด๋์ค ์งํฉ์
๋๋ค.
ํ์ฌ ์ง์ํ๋ ํ๋ก์ธ์๋ ์ด 4๊ฐ์ง๋ก ์๋์์ ์์ธํ๊ฒ ์ค๋ช
ํฉ๋๋ค.
GensimEmbedder๋ Gensim์ ์๋ฒ ๋ฉ ๋ชจ๋ธ์ ํ์ต์ํค๊ณ , ํ์ต๋ ๋ชจ๋ธ์ ์ฌ์ฉํด ๋ฌธ์ฅ์ ์๋ฒ ๋ฉํ๋ ํด๋์ค์ ๋๋ค. ์์ธํ ์ฌ์ฉ๋ฒ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
from kochat.data import Dataset
from kochat.proc import GensimEmbedder
from kochat.model import embed
dataset = Dataset(ood=True)
# ํ๋ก์ธ์ ์์ฑ
emb = GensimEmbedder(
model=embed.FastText()
)
# ๋ชจ๋ธ ํ์ต
emb.fit(dataset.load_embed())
# ๋ชจ๋ธ ์ถ๋ก (์๋ฒ ๋ฉ)
user_input = emb.predict("์์ธ ํ๋ ๋ง์ง ์๋ ค์ค")
SoftmaxClassifier
๋ ๊ฐ์ฅ ๊ธฐ๋ณธ์ ์ธ ๋ถ๋ฅ ํ๋ก์ธ์์
๋๋ค.
์ด๋ฆ์ด SoftmaxClassifier์ธ ์ด์ ๋ Softmax Score๋ฅผ ์ด์ฉํด Fallback Detection์ ์ํํ๊ธฐ ๋๋ฌธ์
์ด๋ ๊ฒ ๋ช
๋ช
ํ๊ฒ ๋์์ต๋๋ค. ๊ทธ๋ฌ๋ ๋ช๋ช ๋
ผ๋ฌธ
์์ Calibrate๋์ง ์์ Softmax Score์ ๋ง์น Confidence์ฒ๋ผ
์ฐฉ๊ฐํด์ ์ฌ์ฉํ๋ฉด ์ฌ๊ฐํ ๋ฌธ์ ๊ฐ ๋ฐ์ํ ์ ์๋ค๋ ๊ฒ์ ๋ณด์ฌ์ฃผ์์ต๋๋ค.
์์ ๊ทธ๋ฆผ์ MNIST ๋ถ๋ฅ๋ชจ๋ธ์์ 0.999 ์ด์์ Softmax Score๋ฅผ ๊ฐ์ง๋ ์ด๋ฏธ์ง๋ค์
๋๋ค.
์ค์ ๋ก 0 ~ 9๊น์ง์ ์ซ์์๋ ์ ํ ์๊ด์๋ ์ด๋ฏธ์ง๋ค์ด๊ธฐ ๋๋ฌธ์ ๋ฎ์ Softmax Score๋ฅผ
๊ฐ์ง ๊ฒ์ด๋ผ๊ณ ์๊ฐ๋์ง๋ง ์ค์ ๋ก๋ ๊ทธ๋ ์ง ์์ต๋๋ค.
์ฌ์ค SoftmaxClassifier
๋ฅผ ์ค์ ์ฑ๋ด์ Intent Classification ๊ธฐ๋ฅ์ ์ํด
์ฌ์ฉํ๋ ๊ฒ์ ์ ์ ํ์ง ๋ชปํฉ๋๋ค. SoftmaxClassifier
๋ ์๋ ํ์ ํ DistanceClassifier
์์ ์ฑ๋ฅ ๋น๊ต๋ฅผ ์ํด ๊ตฌํํ์์ต๋๋ค. ์ฌ์ฉ๋ฒ์ ์๋์ ๊ฐ์ต๋๋ค.
from kochat.data import Dataset
from kochat.proc import SoftmaxClassifier
from kochat.model import intent
from kochat.loss import CrossEntropyLoss
dataset = Dataset(ood=True)
# ํ๋ก์ธ์ ์์ฑ
clf = SoftmaxClassifier(
model=intent.CNN(dataset.intent_dict),
loss=CrossEntropyLoss(dataset.intent_dict)
)
# ๋๋๋ก์ด๋ฉด SoftmaxClassifier๋ CrossEntropyLoss๋ฅผ ์ด์ฉํด์ฃผ์ธ์
# ๋ค๋ฅธ Loss ํจ์๋ค์ ๊ฑฐ๋ฆฌ ๊ธฐ๋ฐ์ Metric Learning์ ์ํํ๊ธฐ ๋๋ฌธ์
# Softmax Classifiaction์ ์ ์ ํ์ง ๋ชปํ ์ ์์ต๋๋ค.
# ๋ชจ๋ธ ํ์ต
clf.fit(dataset.load_intent(emb))
# ๋ชจ๋ธ ์ถ๋ก (์ธํ
ํธ ๋ถ๋ฅ)
clf.predict(dataset.load_predict("์ค๋ ์์ธ ๋ ์จ ์ด๋จ๊น", emb))
DistanceClassifier
๋ SoftmaxClassifier
์๋ ๋ค๋ฅด๊ฒ ๊ฑฐ๋ฆฌ๊ธฐ๋ฐ์ผ๋ก ์๋ํ๋ฉฐ,
์ผ์ข
์ Memory Network์
๋๋ค. [batch_size, -1] ์ ์ฌ์ด์ฆ๋ก ์ถ๋ ฅ๋ ์ถ๋ ฅ๋ฒกํฐ์
๊ธฐ์กด ๋ฐ์ดํฐ์
์ ์๋ ๋ฌธ์ฅ ๋ฒกํฐ๋ค ์ฌ์ด์ ๊ฑฐ๋ฆฌ๋ฅผ ๊ณ์ฐํ์ฌ ๋ฐ์ดํฐ์
์์ ๊ฐ์ฅ ๊ฐ๊น์ด
K๊ฐ์ ์ํ์ ์ฐพ๊ณ ์ต๋ค ์ํ ํด๋์ค๋ก ๋ถ๋ฅํ๋ ์ต๊ทผ์ ์ด์ Retrieval ๊ธฐ๋ฐ์ ๋ถ๋ฅ ๋ชจ๋ธ์
๋๋ค.
์ด ๋ ๋ค๋ฅธ ํด๋์ค๋ค์ ๋ฉ๋ฆฌ, ๊ฐ์ ํด๋์ค๋ผ๋ฆฌ๋ ๊ฐ๊น์ด ์์ด์ผ ๋ถ๋ฅํ๊ธฐ์ ์ข๊ธฐ ๋๋ฌธ์ ์ฌ์ฉ์๊ฐ ์ค์ ํ Lossํจ์(์ฃผ๋ก Margin ๊ธฐ๋ฐ Loss)๋ฅผ ์ ์ฉํด Metric Learning์ ์ํํด์ ํด๋์ค ๊ฐ์ Margin์ ์ต๋์น๋ก ๋ฒ๋ฆฌ๋ ๋ฉ์ปค๋์ฆ์ด ๊ตฌํ๋์ด์์ต๋๋ค. ๋ํ ์ต๊ทผ์ ์ด์ ์๊ณ ๋ฆฌ์ฆ์ K๊ฐ์ config์์ ์ง์ ์ง์ ํ ์๋ ์๊ณ GridSearch๋ฅผ ์ ์ฉํ์ฌ ์๋์ผ๋ก ์ต์ ์ K๊ฐ์ ์ฐพ์ ์ ์๊ฒ ์ค๊ณํ์์ต๋๋ค.
์ต๊ทผ์ ์ด์์ ์ฐพ์ ๋ Brute force๋ก ์ง์ ๊ฑฐ๋ฆฌ๋ฅผ ์ผ์ผ์ด ๋ค ๊ตฌํ๋ฉด ๊ต์ฅํ ๋๋ฆฌ๊ธฐ
๋๋ฌธ์ ๋ค์ฐจ์ ๊ฒ์ํธ๋ฆฌ์ธ KDTree
ํน์ BallTree
(KDTree์ ๊ฐ์ ํํ)๋ฅผ ํตํด์
๊ฑฐ๋ฆฌ๋ฅผ ๊ณ์ฐํ๋ฉฐ ๊ฒฐ๊ณผ๋ก ๋ง๋ค์ด์ง ํธ๋ฆฌ ๊ตฌ์กฐ๋ฅผ ๋ฉ๋ชจ๋ฆฌ์ ์ ์ฅํฉ๋๋ค. ๊ฒ์ํธ๋ฆฌ์ ์ข
๋ฅ,
๊ฑฐ๋ฆฌ ๋ฉํธ๋ฆญ(์ ํด๋ฆฌ๋์ธ, ๋งจํํผ ๋ฑ..)์ ์ ๋ถ GridSearch๋ก ์๋ํ ์ํฌ ์ ์์ผ๋ฉฐ
์ด์ ๋ํ ์ค์ ์ config์์ ๊ฐ๋ฅํฉ๋๋ค. ํธ๋ฆฌ๊ธฐ๋ฐ์ ๊ฒ์ ์๊ณ ๋ฆฌ์ฆ์ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์
SoftmaxClassifier
์ ๊ฑฐ์ ๋น์ทํ ์๋๋ก ํ์ต ๋ฐ ์ถ๋ก ์ด ๊ฐ๋ฅํฉ๋๋ค.
์ฌ์ฉ๋ฒ์ ์๋์ ๊ฐ์ต๋๋ค.
from kochat.data import Dataset
from kochat.proc import DistanceClassifier
from kochat.model import intent
from kochat.loss import CenterLoss
dataset = Dataset(ood=True)
# ํ๋ก์ธ์ ์์ฑ
clf = DistanceClassifier(
model=intent.CNN(dataset.intent_dict),
loss=CenterLoss(dataset.intent_dict)
)
# ๋๋๋ก์ด๋ฉด DistanceClassifier๋ Margin ๊ธฐ๋ฐ์ Loss ํจ์๋ฅผ ์ด์ฉํด์ฃผ์ธ์
# ํ์ฌ๋ CenterLoss, COCOLoss, Cosface, GausianMixture ๋ฑ์
# ๊ฑฐ๋ฆฌ๊ธฐ๋ฐ Metric Learning ์ ์ฉ Lossํจ์๋ฅผ ์ง์ํฉ๋๋ค.
# ๋ชจ๋ธ ํ์ต
clf.fit(dataset.load_intent(emb))
# ๋ชจ๋ธ ์ถ๋ก (์ธํ
ํธ ๋ถ๋ฅ)
clf.predict(dataset.load_predict("์ค๋ ์์ธ ๋ ์จ ์ด๋จ๊น", emb))
SoftmaxClassifier
์ DistanceClassifier
๋ชจ๋ Fallback Detection ๊ธฐ๋ฅ์ ๊ตฌํ๋์ด์์ต๋๋ค.
Fallback Detection ๊ธฐ๋ฅ์ ์ด์ฉํ๋ ๋ฐฉ๋ฒ์ ์๋์ ๊ฐ์ด ๋ ๊ฐ์ง ๋ฐฉ๋ฒ์ ์ ๊ณตํฉ๋๋ค.
1. OOD ๋ฐ์ดํฐ๊ฐ ์๋ ๊ฒฝ์ฐ : ์ง์ config์ Threshold๋ฅผ ๋ง์ถฐ์ผํฉ๋๋ค.
2. OOD ๋ฐ์ดํฐ๊ฐ ์๋ ๊ฒฝ์ฐ : ๋จธ์ ๋ฌ๋์ ์ด์ฉํ์ฌ Threshold๋ฅผ ์๋ ํ์ตํฉ๋๋ค.
๋ฐ๋ก ์ฌ๊ธฐ์์ OOD ๋ฐ์ดํฐ์
์ด ์ฌ์ฉ๋ฉ๋๋ค.
SoftmaxClassifier
๋ out distribution ์ํ๋ค๊ณผ in distribution ์ํ๊ฐ์
maximum softmax score (size = [batch_size, 1])๋ฅผ feature๋ก ํ์ฌ
๋จธ์ ๋ฌ๋ ๋ชจ๋ธ์ ํ์ตํ๊ณ ,
DistanceClassifier
๋ out distribution ์ํ๋ค๊ณผ in distribution ์ํ๋ค์
K๊ฐ์ ์ต๊ทผ์ ์ด์์ ๊ฑฐ๋ฆฌ (size = [batch_size, K])๋ฅผ feature๋ก ํ์ฌ
๋จธ์ ๋ฌ๋ ๋ชจ๋ธ์ ํ์ตํฉ๋๋ค.
์ด๋ฌํ ๋จธ์ ๋ฌ๋ ๋ชจ๋ธ์ FallbackDetector
๋ผ๊ณ ํฉ๋๋ค. FallbackDetector
๋ ๊ฐ
Classifier์์ ๋ด์ฅ ๋์ด์๊ธฐ ๋๋ฌธ์ ๋ณ๋ค๋ฅธ ์ถ๊ฐ ์์ค์ฝ๋ ์์ด Dataset
์ ood
ํ๋ผ๋ฏธํฐ๋ง True
๋ก ์ค์ ๋์ด์๋ค๋ฉด Classifierํ์ต์ด ๋๋๊ณ ๋์ ์๋์ผ๋ก ํ์ต๋๊ณ ,
predict()
์ ์ ์ฅ๋ FallbackDetector
๊ฐ ์๋ค๋ฉด ์๋์ผ๋ก ๋์ํฉ๋๋ค.
๋ํ FallbackDetector
๋ก ์ฌ์ฉํ ๋ชจ๋ธ์ ์๋์ฒ๋ผ config์์ ์ฌ์ฉ์๊ฐ ์ง์ ์ค์ ํ ์ ์์ผ๋ฉฐ
GridSearch๋ฅผ ์ง์ํ์ฌ ์ฌ๋ฌ๊ฐ์ ๋ชจ๋ธ์ ๋ฆฌ์คํธ์ ๋ฃ์ด๋๋ฉด Kochat ํ๋ ์์ํฌ๊ฐ
ํ์ฌ ๋ฐ์ดํฐ์
์ ๊ฐ์ฅ ์ ํฉํ FallbackDetector
๋ฅผ ์๋์ผ๋ก ๊ณจ๋ผ์ค๋๋ค.
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
INTENT = {
# ... (์๋ต)
# ํด๋ฐฑ ๋ํ
ํฐ ํ๋ณด (์ ํ ๋ชจ๋ธ์ ์ถ์ฒํฉ๋๋ค)
'fallback_detectors': [
LogisticRegression(max_iter=30000),
LinearSVC(max_iter=30000)
# ๊ฐ๋ฅํ max_iter๋ฅผ ๋๊ฒ ์ค์ ํด์ฃผ์ธ์
# sklearn default๊ฐ max_iter=100์ด๋ผ์ ์๋ ด์ด ์๋ฉ๋๋ค...
]
}
Fallback Detection ๋ฌธ์ ๋ Fallback ๋ฉํธ๋ฆญ(๊ฑฐ๋ฆฌ or score)๊ฐ ์ผ์ ์๊ณ์น๋ฅผ ๋์ด๊ฐ๋ฉด
์ํ์ in / out distribution ์ํ๋ก ๋ถ๋ฅํ๋๋ฐ ๊ทธ ์๊ณ์น๋ฅผ ํ์ฌ ๋ชจ๋ฅด๋ ์ํฉ์ด๋ฏ๋ก
์ ํ ๋ฌธ์ ๋ก ํด์ํ ์ ์์ต๋๋ค. ๋ฐ๋ผ์ FallbackDetector๋ก๋ ์ ์ฒ๋ผ ์ ํ ๋ชจ๋ธ์ธ
์ ํ SVM, ๋ก์ง์คํฑ ํ๊ท ๋ฑ์ ์ฃผ๋ก ์ด์ฉํฉ๋๋ค. ๋ฌผ๋ก ์์ ๋ฆฌ์คํธ์
RandomForestClassifier()
๋ BernoulliNB()
, GradientBoostingClassifier()
๋ฑ
๋ค์ํ sklearn ๋ชจ๋ธ์ ์
๋ ฅํด๋ ๋์์ ํ์ง๋ง, ์ผ๋ฐ์ ์ผ๋ก ์ ํ๋ชจ๋ธ์ด ๊ฐ์ฅ ์ฐ์ํ๊ณ
์์ ์ ์ธ ์ฑ๋ฅ์ ๋ณด์์ต๋๋ค.
์ด๋ ๊ฒ Fallback์ ๋ฉํธ๋ฆญ์ผ๋ก ๋จธ์ ๋ฌ๋ ๋ชจ๋ธ์ ํ์ตํ๋ฉด Threshold๋ฅผ ์ง์ ์ ์ ๊ฐ ์ค์ ํ์ง ์์๋ ๋ฉ๋๋ค. OOD ๋ฐ์ดํฐ์ ์ด ํ์ํ๋ค๋ ์น๋ช ์ ์ธ ๋จ์ ์ด ์์ง๋ง, ์ฐจํ ๋ฒ์ ์์๋ BERT์ Markov Chain์ ์ด์ฉํด OOD ๋ฐ์ดํฐ์ ์ ์๋์ผ๋ก ๋น ๋ฅด๊ฒ ์์ฑํ๋ ๋ชจ๋ธ์ ๊ตฌํํ์ฌ ์ถ๊ฐํ ์์ ์ ๋๋ค. (์ด ์ ๋ฐ์ดํธ ์ดํ๋ถํฐ๋ OOD ๋ฐ์ดํฐ์ ์ด ํ์ ์์ด์ง๋๋ค.)
๊ทธ๋ฌ๋ ์์ง OOD ๋ฐ์ดํฐ์ ์์ฑ๊ธฐ๋ฅ์ ์ง์ํ์ง ์๊ธฐ ๋๋ฌธ์ ํ์ฌ ๋ฒ์ ์์๋ ๋ง์ฝ OOD ๋ฐ์ดํฐ์ ์ด ์๋ค๋ฉด ์ฌ์ฉ์๊ฐ ์ง์ Threshold๋ฅผ ์ค์ ํด์ผ ํ๋ฏ๋ก ๋์ผ๋ก ์ํ๋ค์ด ์ด๋์ ๋ score ํน์ ๊ฑฐ๋ฆฌ๋ฅผ ๊ฐ๋์ง ํ์ธํด์ผํฉ๋๋ค. ๋ฐ๋ผ์ Kochat์ Calibrate ๋ชจ๋๋ฅผ ์ง์ํฉ๋๋ค.
while True:
user_input = dataset.load_predict(input(), emb)
# ํฐ๋ฏธ๋์ ์ง์ ood๋ก ์๊ฐ๋ ๋งํ ์ํ์ ์
๋ ฅํด์
# ๋์ผ๋ก ๊ฒฐ๊ณผ๋ฅผ ์ง์ ํ์ธํ๊ณ , threshold๋ฅผ ์ง์ ์กฐ์ ํฉ๋๋ค.
result = clf.predict(user_input, calibrate=True)
print("classification result : {}".format(result))
# DistanceClassifier
>>> '=====================CALIBRATION_MODE====================='
'ํ์ฌ ์
๋ ฅํ์ ๋ฌธ์ฅ๊ณผ ๊ธฐ์กด ๋ฌธ์ฅ๋ค ์ฌ์ด์ ๊ฑฐ๋ฆฌ ํ๊ท ์ 2.912์ด๊ณ '
'๊ฐ๊น์ด ์ํ๋ค๊ณผ์ ๊ฑฐ๋ฆฌ๋ [2.341, 2.351, 2.412, 2.445 ...]์
๋๋ค.'
'์ด ์์น๋ฅผ ๋ณด๊ณ Config์ fallback_detection_threshold๋ฅผ ๋ง์ถ์ธ์.'
'criteria๋ ๊ฑฐ๋ฆฌํ๊ท (mean) / ์ต์๊ฐ(min)์ผ๋ก ์ค์ ํ ์ ์์ต๋๋ค.'
# SoftmaxClassifier
>>> '=====================CALIBRATION_MODE====================='
'ํ์ฌ ์
๋ ฅํ์ ๋ฌธ์ฅ์ softmax logits์ 0.997์
๋๋ค.'
'์ด ์์น๋ฅผ ๋ณด๊ณ Config์ fallback_detection_threshold๋ฅผ ๋ง์ถ์ธ์.'
์ด๋ ๊ฒ calibrate ๋ชจ๋๋ฅผ ์ฌ๋ฌ๋ฒ ์งํํ์ ์ ์ค์ค๋ก ๊ณ์ฐํ threshold์ ์ํ๋ criteria๋ฅผ ์๋์ฒ๋ผ config์ ์ค์ ํ๋ฉด ood ๋ฐ์ดํฐ์ ์์ด๋ FallbackDetector๋ฅผ ์ด์ฉํ ์ ์์ต๋๋ค.
INTENT = {
'distance_fallback_detection_criteria': 'mean', # or 'min'
# [auto, min, mean], auto๋ OOD ๋ฐ์ดํฐ ์์๋๋ง ๊ฐ๋ฅ
'distance_fallback_detection_threshold': 3.2,
# mean ํน์ min ์ ํ์ ์๊ณ๊ฐ
'softmax_fallback_detection_criteria': 'other',
# [auto, other], auto๋ OOD ๋ฐ์ดํฐ ์์๋๋ง ๊ฐ๋ฅ
'softmax_fallback_detection_threshold': 0.88,
# other ์ ํ์ fallback์ด ๋์ง ์๋ ์ต์ ๊ฐ
}
๊ทธ๋ฌ๋ ์ง๊ธ ๋ฒ์ ์์๋ ๊ฐ๊ธ์ OOD ๋ฐ์ดํฐ์
์ ์ถ๊ฐํด์ ์ด์ฉํด์ฃผ์ธ์.
์ ์์ผ์๋ฉด ์ ๊ฐ ๋ฐ๋ชจ ํด๋์ ๋ฃ์ด๋์ ๋ฐ์ดํฐ๋ผ๋ ๋ฃ์ด์ ์๋ํํด์ ์ฐ๋๊ฒ
ํจ์ฌ ์ฑ๋ฅ์ด ์ข์ต๋๋ค. ๋ช๋ช ๋น๋๋ค์ ์ด ์๊ณ์น๋ฅผ ์ง์ ์ ํ๊ฒ ํ๊ฑฐ๋ ๊ทธ๋ฅ ์์๋ก
fixํด๋๋๋ฐ, ๊ฐ์ธ์ ์ผ๋ก ์ด๊ฑธ ๊ทธ๋ฅ ์์๋ก fix ํด๋๊ฑฐ๋ ์ ์ ๋ณด๊ณ ์ง์ ์ ํ๊ฒ ํ๋๊ฑด
์ฑ๋ด ๋น๋๋ก์, ํน์ ํ๋ ์์ํฌ๋ก์ ๋ฌด์ฑ
์ํ ๊ฒ ์๋๊ฐ ์ถ์ต๋๋ค.
EntityRecongnizer
๋ ์ํฐํฐ ๊ฒ์ถ์ ๋ด๋นํ๋ Entity ๋ชจ๋ธ๋ค์ ํ์ต/ํ
์คํธ ์ํค๊ณ ์ถ๋ก ํ๋
ํด๋์ค์
๋๋ค. Entity ๊ฒ์ฌ์ ๊ฒฝ์ฐ ๋ฌธ์ฅ 1๊ฐ๋น ๋ผ๋ฒจ์ด ์ฌ๋ฌ๊ฐ(๋จ์ด ๊ฐฏ์์ ๋์ผ)์
๋๋ค.
๋ฌธ์ ๋ Outside ํ ํฐ์ธ 'O'๊ฐ ๋๋ถ๋ถ์ด๊ธฐ ๋๋ฌธ์ ์ ๋ถ๋ค 'O'๋ผ๊ณ ๋ง ์์ธกํด๋ ๊ฑฐ์ 90% ์ก๋ฐํ๋
์ ํ๋๊ฐ ๋์ค๊ฒ ๋ฉ๋๋ค. ๋ํ, ํจ๋์ํ์ฑํ ๋ถ๋ถ๋ 'O'๋ก ์ฒ๋ฆฌ ๋์ด์๋๋ฐ, ์ด ๋ถ๋ถ๋ ๋ง์๊ฒ์ผ๋ก
์๊ฐํ๊ณ Loss๋ฅผ ๊ณ์ฐํฉ๋๋ค.
์ด๋ฌํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด Kochat์ F1 Score, Recall, Precision ๋ฑ NER์ ์ฑ๋ฅ์ ๋ณด๋ค ์ ํํ๊ฒ ํ๊ฐ ํ ์ ์๋ ๊ฐ๋ ฅํ Validation ๋ฐ ์๊ฐํ ์ง์๊ณผ Loss ํจ์ ๊ณ์ฐ์ PAD๋ถ๋ถ์ masking์ ์ ์ฉํ ์ ์์ต๋๋ค. (mask ์ ์ฉ ์ฌ๋ถ ์ญ์ config์์ ์ค์ ๊ฐ๋ฅํฉ๋๋ค.) ์ฌ์ฉ๋ฒ์ ์๋์ ๊ฐ์ต๋๋ค.
from kochat.data import Dataset
from kochat.proc import EntityRecognizer
from kochat.model import entity
from kochat.loss import CRFLoss
dataset = Dataset(ood=True)
# ํ๋ก์ธ์ ์์ฑ
rcn = EntityRecognizer(
model=entity.LSTM(dataset.intent_dict),
loss=CRFLoss(dataset.intent_dict)
# Conditional Random Field๋ฅผ Lossํจ์๋ก ์ง์ํฉ๋๋ค.
)
# ๋ชจ๋ธ ํ์ต
rcn.fit(dataset.load_entity(emb))
# ๋ชจ๋ธ ์ถ๋ก (์ํฐํฐ ๊ฒ์ถ)
rcn.predict(dataset.load_predict("์ค๋ ์์ธ ๋ ์จ ์ด๋จ๊น", emb))
loss
ํจํค์ง๋ ์ฌ์ ์ ์๋ ๋ค์ํ built-in Loss ํจ์๋ค์ด ์ ์ฅ๋ ํจํค์ง์
๋๋ค.
ํ์ฌ ๋ฒ์ ์์๋ ์๋ ๋ชฉ๋ก์ ํด๋นํ๋ Loss ํจ์๋ค์ ์ง์ํฉ๋๋ค. ์ถํ ๋ฒ์ ์ด ์
๋ฐ์ดํธ ๋๋ฉด
์ง๊ธ๋ณด๋ค ํจ์ฌ ๋ค์ํ built-in Loss ํจ์๋ฅผ ์ง์ํ ์์ ์
๋๋ค. ์๋ ๋ชฉ๋ก์ ์ฐธ๊ณ ํ์ฌ ์ฌ์ฉํด์ฃผ์๊ธธ ๋ฐ๋๋๋ค.
Intent Loss ํจ์๋ ๊ธฐ๋ณธ์ ์ธ CrossEntropyLoss์ ๋ค์ํ Distance ๊ธฐ๋ฐ์ Lossํจ์๋ฅผ ํ์ฉํ ์ ์์ต๋๋ค. CrossEntropy๋ ํ์ ํ Softmax ๊ธฐ๋ฐ์ IntentClassifier์ ์ฃผ๋ก ํ์ฉํ๊ณ , Distance ๊ธฐ๋ฐ์ Loss ํจ์๋ค์ Distance ๊ธฐ๋ฐ์ IntentClassifier์ ํ์ฉํ ์ ์์ต๋๋ค. Distance ๊ธฐ๋ฐ์ Lossํจ์๋ค์ ์ปดํจํฐ ๋น์ ์์ญ (์ฃผ๋ก ์ผ๊ตด์ธ์) ๋ถ์ผ์์ ์ ์๋ ํจ์๋ค์ด์ง๋ง Intent ๋ถ๋ฅ์ Fallback ๋ํ ์ ์๋ ๋งค์ฐ ์ฐ์ํ ์ฑ๋ฅ์ ๋ณด์ ๋๋ค.
from kochat.loss import CrossEntropyLoss
from kochat.loss import CenterLoss
from kochat.loss import GaussianMixture
from kochat.loss import COCOLoss
from kochat.loss import CosFace
# 1. ๊ฐ์ฅ ๊ธฐ๋ณธ์ ์ธ Cross Entropy Loss ํจ์์
๋๋ค.
cross_entropy = CrossEntropyLoss(label_dict=dataset.intent_dict)
# 2. Intra Class ๊ฐ์ ๊ฑฐ๋ฆฌ๋ฅผ ์ขํ ์ ์๋ Center Loss ํจ์์
๋๋ค.
center_loss = CenterLoss(label_dict=dataset.intent_dict)
# 3. Intra Class ๊ฐ์ ๊ฑฐ๋ฆฌ๋ฅผ ์ขํ ์ ์๋ Large Margin Gaussian Mixture Loss ํจ์์
๋๋ค.
lmgl = GaussianMixture(label_dict=dataset.intent_dict)
# 4. Inter Class ๊ฐ์ Cosine ๋ง์ง์ ํค์ธ ์ ์๋ COCO (Congenerous Cosine) Loss ํจ์์
๋๋ค.
coco_loss = COCOLoss(label_dict=dataset.intent_dict)
# 5. Inter Class ๊ฐ์ Cosine ๋ง์ง์ ํค์ธ ์ ์๋ Cosface (Large Margin Cosine) Lossํจ์์
๋๋ค.
cosface = CosFace(label_dict=dataset.intent_dict)
Entity Loss ํจ์๋ ๊ธฐ๋ณธ์ ์ธ CrossEntropyLoss์ ํ๋ฅ ์ ๋ชจ๋ธ์ธ
Conditional Random Field (์ดํ CRF) Loss๋ฅผ ์ง์ํฉ๋๋ค.
CRF Loss๋ฅผ ์ ์ฉํ๋ฉด, EntityRecognizer์ ์ถ๋ ฅ ๊ฒฐ๊ณผ๋ฅผ ๋ค์ํ๋ฒ ๊ต์ ํ๋
ํจ๊ณผ๋ฅผ ๋ณผ ์ ์์ผ๋ฉฐ CRF Loss๋ฅผ ์ ์ฉํ๋ฉด, ์ถ๋ ฅ ๋์ฝ๋ฉ์ Viterbi ์๊ณ ๋ฆฌ์ฆ์
ํตํด ์ํํฉ๋๋ค.
from kochat.loss import CrossEntropyLoss
from kochat.loss import CRFLoss
# 1. ๊ฐ์ฅ ๊ธฐ๋ณธ์ ์ธ cross entropy ๋ก์ค ํจ์์
๋๋ค.
cross_entropy = CrossEntropyLoss(label_dict=dataset.intent_dict)
# 2. CRF Loss ํจ์์
๋๋ค.
center_loss = CRFLoss(label_dict=dataset.intent_dict)
Kochat์ ์ปค์คํ
๋ชจ๋ธ์ ์ง์ํฉ๋๋ค.
Pytorch๋ก ์์ฑํ ์ปค์คํ
๋ชจ๋ธ์ ์ง์ ํ์ต์ํค๊ธฐ๊ณ ์ฑ๋ด ์ ํ๋ฆฌ์ผ์ด์
์
์ฌ์ฉํ ์ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ๋ง์ฝ ์ปค์คํ
๋ชจ๋ธ์ ์ฌ์ฉํ๋ ค๋ฉด ์๋์ ๋ช๊ฐ์ง ๊ท์น์ ๋ฐ๋์
๋ฐ๋ผ์ผํฉ๋๋ค.
- forward ํจ์์์ ํด๋น loss๋ฅผ ๊ณ์ฐํฉ๋๋ค.
- compute_loss ํจ์์์ ๋ผ๋ฒจ๊ณผ ๋น๊ตํ์ฌ ์ต์ข
loss๊ฐ์ ๊ณ์ฐํฉ๋๋ค.
์๋์ ๊ตฌํ ์์ ๋ฅผ ๋ณด๋ฉด ๋์ฑ ์ฝ๊ฒ ์ดํดํ ์ ์์ต๋๋ค.
@intent
class CosFace(BaseLoss):
def __init__(self, label_dict: dict):
super(CosFace, self).__init__()
self.classes = len(label_dict)
self.centers = nn.Parameter(torch.randn(self.classes, self.d_loss))
def forward(self, feat: Tensor, label: Tensor) -> Tensor:
# 1. forward ํจ์์์ ํ์ฌ lossํจ์์ loss๋ฅผ ๊ณ์ฐํฉ๋๋ค.
batch_size = feat.shape[0]
norms = torch.norm(feat, p=2, dim=-1, keepdim=True)
nfeat = torch.div(feat, norms)
norms_c = torch.norm(self.centers, p=2, dim=-1, keepdim=True)
ncenters = torch.div(self.centers, norms_c)
logits = torch.matmul(nfeat, torch.transpose(ncenters, 0, 1))
y_onehot = torch.FloatTensor(batch_size, self.classes)
y_onehot.zero_()
y_onehot = Variable(y_onehot).cuda()
y_onehot.scatter_(1, torch.unsqueeze(label, dim=-1), self.cosface_m)
margin_logits = self.cosface_s * (logits - y_onehot)
return margin_logits
def compute_loss(self, label: Tensor, logits: Tensor, feats: Tensor, mask: nn.Module = None) -> Tensor:
# 2. compute loss์์ ์ต์ข
loss๊ฐ์ ๊ณ์ฐํฉ๋๋ค.
mlogits = self(feats, label)
# ์๊ธฐ ์์ ์ forward ํธ์ถ
return F.cross_entropy(mlogits, label)
@intent
class CenterLoss(BaseLoss):
def __init__(self, label_dict: dict):
super(CenterLoss, self).__init__()
self.classes = len(label_dict)
self.centers = nn.Parameter(torch.randn(self.classes, self.d_loss))
self.center_loss_function = CenterLossFunction.apply
def forward(self, feat: Tensor, label: Tensor) -> Tensor:
# 1. forward ํจ์์์ ํ์ฌ lossํจ์์ loss๋ฅผ ๊ณ์ฐํฉ๋๋ค.
batch_size = feat.size(0)
feat = feat.view(batch_size, 1, 1, -1).squeeze()
if feat.size(1) != self.d_loss:
raise ValueError("Center's dim: {0} should be equal to input feature's dim: {1}"
.format(self.d_loss, feat.size(1)))
return self.center_loss_function(feat, label, self.centers)
def compute_loss(self, label: Tensor, logits: Tensor, feats: Tensor, mask: nn.Module = None) -> Tensor:
# 2. compute loss์์ ์ต์ข
loss๊ฐ์ ๊ณ์ฐํฉ๋๋ค.
nll_loss = F.cross_entropy(logits, label)
center_loss = self(feats, label)
# ์๊ธฐ ์์ ์ forward ํธ์ถ
return nll_loss + self.center_factor * center_loss
app
ํจํค์ง๋ kochat ๋ชจ๋ธ์ ์ ํ๋ฆฌ์ผ์ด์
์ผ๋ก ๋ฐฐํฌํ ์ ์๊ฒ๋ ํด์ฃผ๋
RESTful API์ธ KochatApi
ํด๋์ค์ API ํธ์ถ์ ๊ด๋ จ๋ ์๋๋ฆฌ์ค๋ฅผ
์์ฑํ ์ ์๊ฒ๋ ํ๋ Scenario
ํด๋์ค๋ฅผ ์ ๊ณตํฉ๋๋ค.
Scenario
ํด๋์ค๋ ์ด๋ค intent์์๋ ์ด๋ค entity๊ฐ ํ์ํ๊ณ ,
์ด๋ค api๋ฅผ ํธ์ถํ๋์ง ์ ์ํ๋ ์ผ์ข
์ ๋ช
์ธ์์ ๊ฐ์ต๋๋ค.
์๋๋ฆฌ์ค ์์ฑ์ ์๋์ ๊ฐ์ ๋ช๊ฐ์ง ์ฃผ์์ฌํญ์ด ์์ต๋๋ค.
- intent๋ ๋ฐ๋์ raw๋ฐ์ดํฐ ํ์ผ ๋ช ๊ณผ ๋์ผํ๊ฒ ์ค์ ํ๊ธฐ
- api๋ ํจ์ ๊ทธ ์์ฒด๋ฅผ ๋ฃ์ต๋๋ค (๋ฐ๋์ callable ํด์ผํฉ๋๋ค.)
- scenario ๋์ ๋๋ฆฌ ์ ์์์ KEY๊ฐ์ api ํจ์์ ์์/์ฒ ์๊ฐ ๋์ผํด์ผํฉ๋๋ค.
- scenario ๋์ ๋๋ฆฌ ์ ์์์ KEY๊ฐ์ config์ NER_categories์ ์ ์๋ ์ํฐํฐ๋ง ํ์ฉ๋ฉ๋๋ค.
- ๊ธฐ๋ณธ๊ฐ(default) ์ค์ ์ ์ํ๋ฉด scenario ๋์
๋๋ฆฌ์ ๋ฆฌ์คํธ์ ๊ฐ์ ์ฒจ๊ฐํฉ๋๋ค.
- kocrawl (๋ ์จ) ์์
from kochat.app import Scenario
from kocrawl.weather import WeatherCrawler
# kocrawl์ kochat์ ๋ง๋ค๋ฉด์ ํจ๊ป ๊ฐ๋ฐํ ํฌ๋กค๋ฌ์
๋๋ค.
# (https://github.com/gusdnd852/kocrawl)
# 'pip install kocrawl'๋ก ์์ฝ๊ฒ ์ค์นํ ์ ์์ต๋๋ค.
weather_scenario = Scenario(
intent='weather', # intent๋ ์ธํ
ํธ ๋ช
์ ์ ์ต๋๋ค (raw ๋ฐ์ดํฐ ํ์ผ๋ช
๊ณผ ๋์ผํด์ผํฉ๋๋ค)
api=WeatherCrawler().request, # API๋ ํจ์ ์ด๋ฆ ์์ฒด๋ฅผ ๋ฃ์ต๋๋ค. (callableํด์ผํฉ๋๋ค)
scenario={
'LOCATION': [],
# ๊ธฐ๋ณธ์ ์ผ๋ก 'KEY' : []์ ํํ๋ก ๋ง๋ญ๋๋ค.
'DATE': ['์ค๋']
# entity๊ฐ ๊ฒ์ถ๋์ง ์์์ ๋ default ๊ฐ์ ์ง์ ํ๊ณ ์ถ์ผ๋ฉด ๋ฆฌ์คํธ ์์ ์ํ๋ ๊ฐ์ ๋ฃ์ต๋๋ค.
# [์ ์ฃผ, ๋ ์จ, ์๋ ค์ค] => [S-LOCATION, O, O] => api('์ค๋', S-LOCATION) call
# ๋ง์ฝ ['์ค๋', 'ํ์ฌ']์ฒ๋ผ 2๊ฐ ์ด์์ default๋ฅผ ๋ฃ์ผ๋ฉด ๋๋ค์ผ๋ก ์ ํํด์ default ๊ฐ์ผ๋ก ์ง์ ํฉ๋๋ค.
}
# ์๋๋ฆฌ์ค ๋์
๋๋ฆฌ๋ฅผ ์ ์ํฉ๋๋ค.
# ์ฃผ์์ 1 : scenario ํค๊ฐ(LOCATION, DATE)์ ์์๋ API ํจ์์ ํ๋ผ๋ฏธํฐ ์์์ ๋์ผํด์ผํฉ๋๋ค.
# ์ฃผ์์ 2 : scenario ํค๊ฐ(LOCATION, DATE)์ ์ฒ ์๋ API ํจ์์ ํ๋ผ๋ฏธํฐ ์ฒ ์์ ๋์ผํด์ผํฉ๋๋ค.
# ์ฃผ์์ 3 : raw ๋ฐ์ดํฐ ํ์ผ์ ๋ผ๋ฒจ๋งํ ์ํฐํฐ๋ช
๊ณผ scenario ํค๊ฐ์ ๋์ผํด์ผํฉ๋๋ค.
# ์ฆ config์ NER_categories์ ๋ฏธ๋ฆฌ ์ ์๋ ์ํฐํฐ๋ง ์ฌ์ฉํ์
์ผํฉ๋๋ค.
# B-, I- ๋ฑ์ BIOํ๊ทธ๋ ์๋ตํฉ๋๋ค. (S-DATE โ DATE๋ก ์๊ฐ)
# ๋/์๋ฌธ์๊น์ง ๋์ผํ ํ์๋ ์๊ณ , ์ฒ ์๋ง ๊ฐ์ผ๋ฉด ๋ฉ๋๋ค. (๋ชจ๋ lowercase ์ํ์์ ๋น๊ต)
# ๋ค์ ๊ท์ฐฎ๋๋ผ๋ ์ ํํ ๊ฐ ์ ๋ฌ์ ์ํด ์ผ๋ถ๋ฌ ๋ง๋ ์ธ ๊ฐ์ง ์ ํ์ฌํญ์ด๋ ๋ฐ๋ผ์ฃผ์๊ธธ ๋ฐ๋๋๋ค.
# WeatherCrawler().request์ ํ๋ผ๋ฏธํฐ๋ WeatherCrawler().request(location, date)์
๋๋ค.
# APIํ๋ผ๋ฏธํฐ์ ์์/์ด๋ฆ์ด ๋์ผํ๋ฉฐ, ๋ฐ๋ชจ ๋ฐ์ดํฐ ํ์ผ์ ์๋ ์ํฐํฐ์ธ LOCATION, DATE์ ๋์ผํฉ๋๋ค.
# ๋ง์ฝ ํ๋ฆฌ๋ฉด ์ด๋์ ํ๋ ธ๋์ง ์๋ฌ ๋ฉ์์ง๋ก ์๋ ค๋๋ฆฝ๋๋ค.
)
- ๋ ์คํ ๋ ์์ฝ ์๋๋ฆฌ์ค
from kochat.app import Scenario
reservation_scenario = Scenario(
intent='reservation',
api=reservation_check,
# reservation_check(num_people, reservation_time)์ ๊ฐ์
# ํจ์๋ฅผ ํธ์ถํ์ง ๋ง๊ณ ๊ทธ ์์ฒด๋ฅผ ํ๋ผ๋ฏธํฐ๋ก ์
๋ ฅํฉ๋๋ค.
# ํจ์๋ฅผ ๋ฐ์์ ์ ์ฅํด๋๋ค๊ฐ ์์ฒญ ๋ฐ์์ Api ๋ด๋ถ์์ call ํฉ๋๋ค
scenario={
'NUM_PEOPLE': [4],
# NUM_PEOPLE์ default๋ฅผ 4๋ช
์ผ๋ก ์ค์ ํ์ต๋๋ค.
'RESERVATION_TIME': []
# API(reservation_check(num_people, reservation_time)์ ํ๋ผ๋ฏธํฐ์ ์์/์ฒ ์๊ฐ ์ผ์นํฉ๋๋ค.
# ์ด ๋, ๋ฐ๋์ NER_categories์ NUM_PEOPLE๊ณผ RESERVATION_TIME์ด ์ ์๋์ด ์์ด์ผํ๋ฉฐ,
# ์ค์ raw๋ฐ์ดํฐ์ ๋ผ๋ฒจ๋ง๋ ๋ ์ด๋ธ๋ ์์ ์ด๋ฆ์ ์ฌ์ฉํด์ผํฉ๋๋ค.
}
)
KochatApi
๋ Flask๋ก ๊ตฌํ๋์์ผ๋ฉฐ restful api๋ฅผ ์ ๊ณตํ๋ ํด๋์ค์
๋๋ค.
์ฌ์ค ์๋ฒ๋ก ๊ตฌ๋ํ ๊ณํ์ด๋ผ๋ฉด ์์์ ์ค๋ช
ํ ๊ฒ ๋ณด๋ค ํจ์ฌ ์ฝ๊ฒ ํ์ตํ ์ ์์ต๋๋ค.
(ํ์ต์ ๋ง์ ๋ถ๋ถ๋ค์ด KochatApi
์์ ์๋ํ ๋๊ธฐ ๋๋ฌธ์ ํ๋ผ๋ฏธํฐ ์ ๋ฌ๋ง์ผ๋ก ํ์ต์ด ๊ฐ๋ฅํฉ๋๋ค.)
KochatApi
ํด๋์ค๋ ์๋์ ๊ฐ์ ๋ฉ์๋๋ค์ ์ง์ํ๋ฉฐ ์ฌ์ฉ๋ฒ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
from kochat.app import KochatApi
# kochat api ๊ฐ์ฒด๋ฅผ ์์ฑํฉ๋๋ค.
kochat = KochatApi(
dataset=dataset, # ๋ฐ์ดํฐ์
๊ฐ์ฒด
embed_processor=(emb, True), # ์๋ฒ ๋ฉ ํ๋ก์ธ์, ํ์ต์ฌ๋ถ
intent_classifier=(clf, True), # ์ธํ
ํธ ๋ถ๋ฅ๊ธฐ, ํ์ต์ฌ๋ถ
entity_recognizer=(rcn, True), # ์ํฐํฐ ๊ฒ์ถ๊ธฐ, ํ์ต์ฌ๋ถ
scenarios=[ #์๋๋ฆฌ์ค ๋ฆฌ์คํธ
weather, dust, travel, restaurant
]
)
# kochat.app์ FLask ๊ฐ์ฒด์
๋๋ค.
# Flask์ ์ฌ์ฉ๋ฒ๊ณผ ๋์ผํ๊ฒ ์ฌ์ฉํ๋ฉด ๋ฉ๋๋ค.
@kochat.app.route('/')
def index():
return render_template("index.html")
# ์ ํ๋ฆฌ์ผ์ด์
์๋ฒ๋ฅผ ๊ฐ๋ํฉ๋๋ค.
if __name__ == '__main__':
kochat.app.template_folder = kochat.root_dir + 'templates'
kochat.app.static_folder = kochat.root_dir + 'static'
kochat.app.run(port=8080, host='0.0.0.0')
์์ ๊ฐ์ด kochat ์๋ฒ๋ฅผ ์คํ์ํฌ ์ ์์ต๋๋ค. (์ฌ๋งํ๋ฉด ์์ ๊ฐ์ด template๊ณผ static์ ๋ช ์์ ์ผ๋ก ์ ์ด์ฃผ์ธ์.) ์ ์์์ฒ๋ผ ๋ทฐ๋ฅผ ์ง์ ์๋ฒ์ ์ฐ๊ฒฐํด์ ํ๋์ ์๋ฒ์์ ๋ทฐ์ ๋ฅ๋ฌ๋ ์ฝ๋๋ฅผ ๋ชจ๋ ๊ตฌ๋์ํฌ ์๋ ์๊ณ , ๋ง์ฝ Micro Service Architecture๋ฅผ ๊ตฌ์ฑํด์ผํ๋ค๋ฉด, ์ฑ๋ด ์๋ฒ์ index route ('/')๋ฑ์ ์ค์ ํ์ง ์๊ณ ๋ฅ๋ฌ๋ ๋ฐฑ์๋ ์๋ฒ๋ก๋ ์ถฉ๋ถํ ํ์ฉํ ์ ์์ต๋๋ค. ๋ง์ฝ ํ์ต์ ์ํ์ง ์์ ๋๋ ์๋์ฒ๋ผ ๊ตฌํํฉ๋๋ค.
# 1. Tuple์ ๋๋ฒ์งธ ์ธ์์ False ์
๋ ฅ
kochat = KochatApi(
dataset=dataset, # ๋ฐ์ดํฐ์
๊ฐ์ฒด
embed_processor=(emb, False), # ์๋ฒ ๋ฉ ํ๋ก์ธ์, ํ์ต์ฌ๋ถ
intent_classifier=(clf, False), # ์ธํ
ํธ ๋ถ๋ฅ๊ธฐ, ํ์ต์ฌ๋ถ
entity_recognizer=(rcn, False), # ์ํฐํฐ ๊ฒ์ถ๊ธฐ, ํ์ต์ฌ๋ถ
scenarios=[ #์๋๋ฆฌ์ค ๋ฆฌ์คํธ
weather, dust, travel, restaurant
]
)
# 2. Tuple์ ํ๋ก์ธ์๋ง ์
๋ ฅ
kochat = KochatApi(
dataset=dataset, # ๋ฐ์ดํฐ์
๊ฐ์ฒด
embed_processor=(emb), # ์๋ฒ ๋ฉ ํ๋ก์ธ์
intent_classifier=(clf), # ์ธํ
ํธ ๋ถ๋ฅ๊ธฐ
entity_recognizer=(rcn), # ์ํฐํฐ ๊ฒ์ถ๊ธฐ
scenarios=[ #์๋๋ฆฌ์ค ๋ฆฌ์คํธ
weather, dust, travel, restaurant
]
)
# 3. ๊ทธ๋ฅ ํ๋ก์ธ์๋ง ์
๋ ฅ
kochat = KochatApi(
dataset=dataset, # ๋ฐ์ดํฐ์
๊ฐ์ฒด
embed_processor=emb, # ์๋ฒ ๋ฉ ํ๋ก์ธ์
intent_classifier=clf, # ์ธํ
ํธ ๋ถ๋ฅ๊ธฐ
entity_recognizer=rcn, # ์ํฐํฐ ๊ฒ์ถ๊ธฐ
scenarios=[ #์๋๋ฆฌ์ค ๋ฆฌ์คํธ
weather, dust, travel, restaurant
]
)
์๋์์๋ Kochat ์๋ฒ์ url ํจํด์ ๋ํด ์์ธํ๊ฒ ์ค๋ช ํฉ๋๋ค. ํ์ฌ kochat api๋ ๋ค์๊ณผ ๊ฐ์ 4๊ฐ์ url ํจํด์ ์ง์ํ๋ฉฐ, ์ด url ํจํด๋ค์ config์ API ์ฑํฐ์์ ๋ณ๊ฒฝ ๊ฐ๋ฅํฉ๋๋ค.
API = {
'request_chat_url_pattern': 'request_chat', # request_chat ๊ธฐ๋ฅ url pattern
'fill_slot_url_pattern': 'fill_slot', # fill_slot ๊ธฐ๋ฅ url pattern
'get_intent_url_pattern': 'get_intent', # get_intent ๊ธฐ๋ฅ url pattern
'get_entity_url_pattern': 'get_entity' # get_entity ๊ธฐ๋ฅ url pattern
}
๊ฐ์ฅ ๊ธฐ๋ณธ์ ์ธ ํจํด์ธ request_chat์
๋๋ค. intent๋ถ๋ฅ, entity๊ฒ์ถ, api์ฐ๊ฒฐ์ ํ๋ฒ์ ์งํํฉ๋๋ค.
๊ธฐ๋ณธ ํจํด : https://youripaddress/request_chat//
case 1. state SUCCESS
๋ชจ๋ entity๊ฐ ์ ์์ ์ผ๋ก ์
๋ ฅ๋ ๊ฒฝ์ฐ state 'SUCCESS'๋ฅผ ๋ฐํํฉ๋๋ค.
>>> ์ ์ gusdnd852 : ๋ชจ๋ ๋ถ์ฐ ๋ ์จ ์ด๋
https://123.456.789.000:1234/request_chat/gusdnd852/๋ชจ๋ ๋ถ์ฐ ๋ ์จ ์ด๋
โ {
'input': [๋ชจ๋ , ๋ถ์ฐ, ๋ ์จ, ์ด๋],
'intent': 'weather',
'entity': [S-DATE, S-LOCATION, O, O]
'state': 'SUCCESS',
'answer': '๋ถ์ฐ์ ๋ ์จ ์ ๋ณด๋ฅผ ์ ํด๋๋ฆด๊ฒ์. ๐
๋ชจ๋ ๋ถ์ฐ์ง์ญ์ ์ค์ ์๋ ์ญ์จ 19๋์ด๋ฉฐ, ์๋ง ํ๋์ด ๋ง์ ๊ฒ ๊ฐ์์. ์คํ์๋ ์ญ์จ 26๋์ด๋ฉฐ, ์๋ง ํ๋์ด ๋ง์ ๊ฒ ๊ฐ์์.'
}
case 2. state REQUIRE_XXX
๋ง์ฝ default๊ฐ์ด ์๋ ์ํฐํฐ๊ฐ ์
๋ ฅ๋์ง ์์ ๊ฒฝ์ฐ state 'REQUIRE_XXX'๋ฅผ ๋ฐํํฉ๋๋ค.
๋๊ฐ ์ด์์ ์ํฐํฐ๊ฐ ๋ชจ์๋ผ๋ฉด state 'REQUIRE_XXX_YYY'๊ฐ ๋ฐํ๋ฉ๋๋ค.
>>> ์ ์ minqukanq : ๋ชฉ์์ผ ๋ ์จ ์ด๋
e.g. https://123.456.789.000:1234/request_chat/minqukanq/๋ชฉ์์ผ ๋ ์จ ์ด๋
โ {
'input': [๋ชฉ์์ผ, ๋ ์จ, ์ด๋],
'intent': 'weather',
'entity': [S-DATE, O, O]
'state': 'REQUIRE_LOCATION',
'answer': None
}
case 3. state FALLBACK
์ธํ
ํธ ๋ถ๋ฅ์ FALLBACK์ด ๋ฐ์ํ๋ฉด FALLBACK์ ๋ฐํํฉ๋๋ค.
>>> ์ ์ sangji11 : ๋ชฉ์์ผ ์น๊ตฌ ์์ผ์ด๋ค
e.g. https://123.456.789.000:1234/request_chat/sangji11/๋ชฉ์์ผ ์น๊ตฌ ์์ผ์ด๋ค
โ {
'input': [๋ชฉ์์ผ, ์น๊ตฌ, ์์ผ์ด๋ค],
'intent': 'FALLBACK',
'entity': [S-DATE, O, O]
'state': 'FALLBACK',
'answer': None
}
๊ฐ์ฅ request์ REQUIRE_XXX๊ฐ ๋์ฌ๋, ์ฌ์ฉ์์๊ฒ ๋๋ฌป๊ณ ๊ธฐ์กด ๋์
๋๋ฆฌ์ ์ถ๊ฐํด์ api๋ฅผ ํธ์ถํฉ๋๋ค.
๊ธฐ๋ณธ ํจํด : https://youripaddress/fill_slot//
>>> ์ ์ gusdnd852 : ๋ชจ๋ ๋ ์จ ์๋ ค์ค โ REQUIRE_LOCATION
>>> ๋ด : ์ด๋ ์ง์ญ์ ์๋ ค๋๋ฆด๊น์?
>>> ์ ์ gusdnd852 : ๋ถ์ฐ
https://123.456.789.000:1234/fill_slot/gusdnd852/๋ถ์ฐ
โ {
'input': [๋ถ์ฐ] + [๋ชจ๋ , ๋ ์จ, ์ด๋],
'intent': 'weather',
'entity': [S-LOCATION] + [S-DATE, O, O]
'state': 'SUCCESS',
'answer': '๋ถ์ฐ์ ๋ ์จ ์ ๋ณด๋ฅผ ์ ํด๋๋ฆด๊ฒ์. ๐
๋ชจ๋ ๋ถ์ฐ์ง์ญ์ ์ค์ ์๋ ์ญ์จ 19๋์ด๋ฉฐ, ์๋ง ํ๋์ด ๋ง์ ๊ฒ ๊ฐ์์. ์คํ์๋ ์ญ์จ 26๋์ด๋ฉฐ, ์๋ง ํ๋์ด ๋ง์ ๊ฒ ๊ฐ์์.'
}
>>> ์ ์ gusdnd852 : ๋ ์จ ์๋ ค์ค โ REQUIRE_DATE_LOCATION
>>> ๋ด : ์ธ์ ์ ์ด๋ ์ง์ญ์ ๋ ์จ๋ฅผ ์๋ ค๋๋ฆด๊น์?
>>> ์ ์ gusdnd852 : ๋ถ์ฐ ๋ชจ๋
https://123.456.789.000:1234/fill_slot/gusdnd852/๋ถ์ฐ ๋ชจ๋
โ {
'input': [๋ถ์ฐ, ๋ชจ๋ ] + [๋ ์จ, ์ด๋],
'intent': 'weather',
'entity': [S-LOCATION, S-DATE] + [O, O]
'state': 'SUCCESS',
'answer': '๋ถ์ฐ์ ๋ ์จ ์ ๋ณด๋ฅผ ์ ํด๋๋ฆด๊ฒ์. ๐
๋ชจ๋ ๋ถ์ฐ์ง์ญ์ ์ค์ ์๋ ์ญ์จ 19๋์ด๋ฉฐ, ์๋ง ํ๋์ด ๋ง์ ๊ฒ ๊ฐ์์. ์คํ์๋ ์ญ์จ 26๋์ด๋ฉฐ, ์๋ง ํ๋์ด ๋ง์ ๊ฒ ๊ฐ์์.'
}
intent๋ง ์๊ณ ์ถ์๋ ํธ์ถํฉ๋๋ค.
๊ธฐ๋ณธ ํจํด : https://youripaddress/get_intent/
https://123.456.789.000:1234/get_intent/์ ์ฃผ ๋ ์จ ์ด๋
โ {
'input': [์ ์ฃผ, ๋ ์จ, ์ด๋],
'intent': 'weather',
'entity': None,
'state': 'REQUEST_INTENT',
'answer': None
}
entity๋ง ์๊ณ ์ถ์๋ ํธ์ถํฉ๋๋ค.
๊ธฐ๋ณธ ํจํด : https://youripaddress/get_entity/
https://123.456.789.000:1234/get_entity/์ ์ฃผ ๋ ์จ ์ด๋
โ {
'input': [์ ์ฃผ, ๋ ์จ, ์ด๋],
'intent': None,
'entity': [S-LOCATION, O, O],
'state': 'REQUEST_ENTITY',
'answer': None
}
Kochat์ ์๋์ ๊ฐ์ด ๋ค์ํ ์๊ฐํ ๊ธฐ๋ฅ์ ์ง์ํฉ๋๋ค.
Feature Space๋ ์ผ์ Epoch๋ง๋ค ๋ฉ๋ชจ๋ฆฌ์ ์ ์ฅ๋๊ณ ,
๊ทธ ์ธ์ ์๊ฐํ ์๋ฃ๋ ๋งค Epoch๋ง๋ค ๊ณ์ ์
๋ฐ์ดํธ ๋๋ฉฐ
"root/saved"์ ๋ชจ๋ธ ์ ์ฅํ์ผ๊ณผ ํจ๊ป ์ ์ฅ๋ฉ๋๋ค.
์๊ฐํ ์๋ฃ ๋ฐ ๋ชจ๋ธ ์ ์ฅ ๊ฒฝ๋ก๋
config์์ ๋ณ๊ฒฝํ ์ ์์ต๋๋ค.
Confusion Matrix์ ๊ฒฝ์ฐ๋ X์ถ(์๋)๊ฐ Prediction, Y์ถ(์ผ์ชฝ)์ด Label์
๋๋ค.
๋ค์ ๋ฒ์ ์์ xticks์ yticks๋ฅผ ์ถ๊ฐํ ์์ ์
๋๋ค.
Accuracy, Precision, Recall, F1 Score ๋ฑ ๋ชจ๋ธ์ ๋ค์ํ ๋ฉํธ๋ฆญ์ผ๋ก ํ๊ฐํ๊ณ , ํ ํํ๋ก ์ด๋ฏธ์งํ์ผ์ ๋ง๋ค์ด์ค๋๋ค.
์์์ ๋ช๋ฒ์งธ ๊น์ง ๋ฐ์ฌ๋ฆผํด์ ๋ณด์ฌ์ค์ง config์์ ์ค์ ํ ์ ์์ต๋๋ค.
PROC = {
# ...(์๋ต)
'logging_precision': 5, # ๊ฒฐ๊ณผ ์ ์ฅ์ ๋ฐ์ฌ๋ฆผ ์์์ n๋ฒ์งธ์์ ๋ฐ์ฌ๋ฆผ
}
Fallback Detection์ Intent Classification์ ์์ญ์ ๋๋ค. Intent Classification๋ง ์ง์ํฉ๋๋ค. (Fallback Detection ์ฑ๋ฅ ํ๊ฐ๋ฅผ ์ํด์๋ ๋ฐ๋์ ood=True์ฌ์ผํฉ๋๋ค.)
Feature Space๋ Distance ๊ธฐ๋ฐ์ Metric Learning Lossํจ์๊ฐ ์ ์๋ํ๊ณ ์๋์ง ํ์ธํ๊ธฐ ์ํ๊ฒ์ผ๋ก Intent Classification๋ง ์ง์ํฉ๋๋ค. ๋ํ ์๊ฐํ ์ฐจ์์ config์ d_loss์ ๋ฐ๋ผ ๊ฒฐ์ ๋ฉ๋๋ค.
- d_loss = 2์ธ ๊ฒฝ์ฐ : 2์ฐจ์์ผ๋ก ์๊ฐํ
- d_loss = 3์ธ ๊ฒฝ์ฐ : 3์ฐจ์์ผ๋ก ์๊ฐํ
- d_loss > 3์ธ ๊ฒฝ์ฐ : Incremetal PCA๋ฅผ ํตํด 3์ฐจ์์ผ๋ก ์ฐจ์ ๊ฐ์ ํ ์๊ฐํ
Feature Space Visualization์ PCA๋ฅผ ์คํํ๊ธฐ ๋๋ฌธ์ ๋น์ฉ์ด ์๋นํ ํฝ๋๋ค. ๋ค๋ฅธ ์๊ฐํ๋ ๋งค Epoch๋ง๋ค ์ํํ์ง๋ง, Feature Space Visulization์ ๋ช Epoch๋ง๋ค ์ํํ ์ง ๊ฒฐ์ ํ ์ ์์ต๋๋ค.
PROC = {
# ...(์๋ต)
'visualization_epoch': 50, # ์๊ฐํ ๋น๋ (์ ํญ๋ง๋ค ์๊ฐํ ์ํ)
}
์ด ์ฑํฐ๋ Kochat์ ๋ค์ํ ์ฑ๋ฅ ์ด์์ ๋ํด ๊ธฐ๋กํฉ๋๋ค.
์ฌ์ค CenterLoss๋ CosFace ๊ฐ์ Margin Lossํจ์๋ค์ด ์ปดํจํฐ ๋น์ ์ ์ผ๊ตด์ธ์ ์์ญ์์
๋ง์ด ์ฐ์ธ๋ค๊ณ ๋ ํ๋ ๊ธฐ๋ณธ์ ์ผ๋ก ๋ชจ๋ Retrieval ๋ฌธ์ ์ ์ ์ฉํ ์ ์๋ Lossํจ์์
๋๋ค.
Kochat์ DistanceClassifier๋ ๊ฑฐ๋ฆฌ๊ธฐ๋ฐ์ Retrieval์ ์ํํ๊ธฐ ๋๋ฌธ์ ์ด๋ฌํ
Lossํจ์๋ฅผ ๋งค์ฐ ํจ๊ณผ์ ์ผ๋ก ํ์ฉํ ์ ์์ต๋๋ค. ์ค์ ๋ก ๋ฐ๋ชจ ๋ฐ์ดํฐ์
์ ์ ์ฉํ์ ๋
CrossEntropyLoss๋ก๋ 70% ์ธ์ ๋ฆฌ์ธ FallbackDetection ์ฑ๋ฅ์ด CenterLoss, CosFace
๋ฑ์ ์ ์ฉํ๋ฉด 90~95%๊น์ง ํฅ์๋์์ต๋๋ค. (120๊ฐ์ OOD ์ํ ํ
์คํธ)
- SoftmaxClassifier + CrossEntropyLoss + CNN (d_model=512, layers=1)
- DistanceClassifier + CrossEntropyLoss + CNN (d_model=512, layers=1)
- DistanceClassifier + CenterLoss + CNN (d_model=512, layers=1)
Retrieval ๊ธฐ๋ฐ์ Distance Classification์ ๊ฒฝ์ฐ LSTM๋ณด๋ค CNN์ Feature๋ค์ด
ํด๋์ค๋ณ๋ก ํจ์ฌ ์ ๊ตฌ๋ถ๋๋ ๊ฒ์ ํ์ธํ์ต๋๋ค. Feature Extraction ๋ฅ๋ ฅ ์์ฒด๋
CNN์ด ์ข๋ค๊ณ ์๋ ค์ง ๊ฒ์ฒ๋ผ ์๋ฌด๋๋ CNN์ด Feature๋ฅผ ๋ ์ ๋ฝ์๋ด๋ ๊ฒ ๊ฐ์ต๋๋ค.
Feature Space์์ ๊ตฌ๋ถ์ด ์ ๋๋ค๋ ๊ฒ์ OOD ์ฑ๋ฅ์ด ์ฐ์ํ๋ค๋ ๊ฒ๊ณผ ๋์น์ด๋ฏ๋ก,
DistanceClassifier ์ฌ์ฉ์ LSTM๋ณด๋จ CNN์ ์ฌ์ฉํ๋ ๊ฒ์ด ๋์ฑ ๋ฐ๋์งํด๋ณด์
๋๋ค.
- ์ข : LSTM (d_model=512, layers=1) + CosFace, 500 Epoch ํ์ต (์๋ ดํจ)
- ์ฐ : CNN (d_model=512, layers=1) + CosFace, 500 Epoch ํ์ต (์๋ ดํจ)
EntityRecognizer์ ๊ฒฝ์ฐ ๋์ผ ์ฌ์ด์ฆ, ๋์ผ Layer์์ CRF Loss๋ฅผ ์ฌ์ฉํ๋ฉด
ํ์คํ ์ฑ๋ฅ์ ๋์ฑ ์ฐ์ํด์ง๋, ์กฐ๊ธ ๋ ๋ ๋๋ฆฌ๊ฒ ์๋ ดํ๋ ๊ฒ์ ํ์ธํ์ต๋๋ค.
CRF Loss์ ๊ฒฝ์ฐ ์กฐ๊ธ ๋ ๋ง์ ํ์ต ์๊ฐ์ ์ค์ผ ์ ์ฑ๋ฅ์ ๋ด๋ ๊ฒ ๊ฐ์ต๋๋ค.
- ์ข : LSTM (d_model=512, layers=1) + CrossEntropy โ Epoch 300์ f1-score 90% ๋๋ฌ
- ์ฐ : LSTM (d_model=512, layers=1) + CRFLoss โ Epoch 450์ f1-score 90% ๋๋ฌ
Fallback Detector๋ sklearn ๋ชจ๋ธ๋ค์ ํ์ฉํ๋๋ฐ ๊ธฐ์กด sklearn๋ชจ๋ธ๋ค์
max_iter์ default๊ฐ์ด 100์ผ๋ก ์ค์ ๋์ด ์๋ ดํ๊ธฐ ์ ์ ํ์ต์ด ๋๋๋ฒ๋ฆฝ๋๋ค.
๋๋ฌธ์ Fallback Detector๋ฅผ config์ ์ ์ํ ๋ max_iter๋ฅผ ๋๊ฒ ์ค์ ํด์ผ
์ถฉ๋ถํ ํ์ต์๊ฐ์ ๋ณด์ฅ๋ฐ์ ์ ์์ต๋๋ค.
์ด ์ฑํฐ์์๋ Demo ์ ํ๋ฆฌ์ผ์ด์
์ ๋ํด ์๊ฐํฉ๋๋ค.
๋ฐ๋ชจ ์ ํ๋ฆฌ์ผ์ด์
์ ์ฌํ์ ๋ณด๋ฅผ ์๊ฐํ๋ ์ฑ๋ด ์ ํ๋ฆฌ์ผ์ด์
์ผ๋ก,
๋ ์จ, ๋ฏธ์ธ๋จผ์ง, ๋ง์ง ์ฌํ์ง ์ ๋ณด๋ฅผ ์๋ ค์ฃผ๋ ๊ธฐ๋ฅ์ ๋ณด์ ํ๊ณ ์์ต๋๋ค.
Api๋ Kochat์ ๋ง๋ค๋ฉด์ ํจ๊ป ๋ง๋ Kocrawl
์ ์ฌ์ฉํ์ต๋๋ค.
Html๊ณผ CSS๋ฅผ ์ฌ์ฉํ์ฌ View๋ฅผ ๊ตฌํํ์์ต๋๋ค. ์ ๊ฐ ๋์์ธ ํ ๊ฒ์ ์๋๊ณ ์ฌ๊ธฐ ์์ ์ ๊ณต๋๋ ๋ถํธ์คํธ๋ฉ ํ ๋ง๋ฅผ ์ฌ์ฉํ์์ต๋๋ค.
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Kochat ๋ฐ๋ชจ</title>
<script src="{{ url_for('static', filename="js/jquery.js") }}" type="text/javascript"></script>
<script src="{{ url_for('static', filename="js/bootstrap.js") }}" type="text/javascript"></script>
<script src="{{ url_for('static', filename="js/main.js") }}" type="text/javascript"></script>
<link href="{{ url_for('static', filename="css/bootstrap.css") }}" rel="stylesheet" id="bootstrap-css">
<link href="{{ url_for('static', filename="css/main.css") }}" rel="stylesheet" id="main-css">
<script>
greet();
onClickAsEnter();
</script>
</head>
<body>
<div class="chat_window">
<div class="top_menu">
<div class="buttons">
<div class="button close_button"></div>
<div class="button minimize"></div>
<div class="button maximize"></div>
</div>
<div class="title">Kochat ๋ฐ๋ชจ</div>
</div>
<ul class="messages"></ul>
<div class="bottom_wrapper clearfix">
<div class="message_input_wrapper">
<input class="message_input"
onkeyup="return onClickAsEnter(event)"
placeholder="๋ด์ฉ์ ์
๋ ฅํ์ธ์."/>
</div>
<div class="send_message"
id="send_message"
onclick="onSendButtonClicked()">
<div class="icon"></div>
<div class="text">๋ณด๋ด๊ธฐ</div>
</div>
</div>
</div>
<div class="message_template">
<li class="message">
<div class="avatar"></div>
<div class="text_wrapper">
<div class="text"></div>
</div>
</li>
</div>
</body>
</html>
์๋์ ๊ฐ์ ๋ชจ๋ธ ๊ตฌ์ฑ์ ์ฌ์ฉํ์์ต๋๋ค.
dataset = Dataset(ood=True)
emb = GensimEmbedder(model=embed.FastText())
clf = DistanceClassifier(
model=intent.CNN(dataset.intent_dict),
loss=CenterLoss(dataset.intent_dict)
)
rcn = EntityRecognizer(
model=entity.LSTM(dataset.entity_dict),
loss=CRFLoss(dataset.entity_dict)
)
kochat = KochatApi(
dataset=dataset,
embed_processor=(emb, True),
intent_classifier=(clf, True),
entity_recognizer=(rcn, True),
scenarios=[
weather, dust, travel, restaurant
]
)
@kochat.app.route('/')
def index():
return render_template("index.html")
if __name__ == '__main__':
kochat.app.template_folder = kochat.root_dir + 'templates'
kochat.app.static_folder = kochat.root_dir + 'static'
kochat.app.run(port=8080, host='0.0.0.0')
Kocrawl์ ์ด์ฉํด 4๊ฐ์ง ์๋์ ๋ง๋ ์๋๋ฆฌ์ค๋ฅผ ๊ตฌ์ฑํ์์ต๋๋ค.
weather = Scenario(
intent='weather',
api=WeatherCrawler().request,
scenario={
'LOCATION': [],
'DATE': ['์ค๋']
}
)
dust = Scenario(
intent='dust',
api=DustCrawler().request_debug,
scenario={
'LOCATION': [],
'DATE': ['์ค๋']
}
)
restaurant = Scenario(
intent='restaurant',
api=RestaurantCrawler().request,
scenario={
'LOCATION': [],
'RESTAURANT': ['์ ๋ช
ํ']
}
)
travel = Scenario(
intent='travel',
api=MapCrawler().request_debug,
scenario={
'LOCATION': [],
'PLACE': ['๊ด๊ด์ง']
}
)
๋ง์ง๋ง์ผ๋ก ๋ฒํผ์ ๋๋ฅด๋ฉด ๋ฉ์์ง๊ฐ ๋์์ง๋ ์ ๋๋ฉ์ด์ ๊ณผ Ajax๋ฅผ ํตํด Kochat ์๋ฒ์ ํต์ ํ๋ ์์ค์ฝ๋๋ฅผ ์์ฑํ์์ต๋๋ค. ๊ฐ๋จํ chit chat ๋ํ 3๊ฐ์ง (์๋ , ๊ณ ๋ง์, ์์ด)๋ ๊ท์น๊ธฐ๋ฐ์ผ๋ก ๊ตฌํํ์์ต๋๋ค. ์ถํ์ Seq2Seq ๊ธฐ๋ฅ์ ์ถ๊ฐํ์ฌ ์ด ๋ถ๋ถ๋ ๋จธ์ ๋ฌ๋ ๊ธฐ๋ฐ์ผ๋ก ๋ณ๊ฒฝํ ์์ ์ ๋๋ค.
// variables
let userName = null;
let state = 'SUCCESS';
// functions
function Message(arg) {
this.text = arg.text;
this.message_side = arg.message_side;
this.draw = function (_this) {
return function () {
let $message;
$message = $($('.message_template').clone().html());
$message.addClass(_this.message_side).find('.text').html(_this.text);
$('.messages').append($message);
return setTimeout(function () {
return $message.addClass('appeared');
}, 0);
};
}(this);
return this;
}
function getMessageText() {
let $message_input;
$message_input = $('.message_input');
return $message_input.val();
}
function sendMessage(text, message_side) {
let $messages, message;
$('.message_input').val('');
$messages = $('.messages');
message = new Message({
text: text,
message_side: message_side
});
message.draw();
$messages.animate({scrollTop: $messages.prop('scrollHeight')}, 300);
}
function greet() {
setTimeout(function () {
return sendMessage("Kochat ๋ฐ๋ชจ์ ์ค์ ๊ฑธ ํ์ํฉ๋๋ค.", 'left');
}, 1000);
setTimeout(function () {
return sendMessage("์ฌ์ฉํ ๋๋ค์์ ์๋ ค์ฃผ์ธ์.", 'left');
}, 2000);
}
function onClickAsEnter(e) {
if (e.keyCode === 13) {
onSendButtonClicked()
}
}
function setUserName(username) {
if (username != null && username.replace(" ", "" !== "")) {
setTimeout(function () {
return sendMessage("๋ฐ๊ฐ์ต๋๋ค." + username + "๋. ๋๋ค์์ด ์ค์ ๋์์ต๋๋ค.", 'left');
}, 1000);
setTimeout(function () {
return sendMessage("์ ๋ ๊ฐ์ข
์ฌํ ์ ๋ณด๋ฅผ ์๋ ค์ฃผ๋ ์ฌํ๋ด์
๋๋ค.", 'left');
}, 2000);
setTimeout(function () {
return sendMessage("๋ ์จ, ๋ฏธ์ธ๋จผ์ง, ์ฌํ์ง, ๋ง์ง ์ ๋ณด์ ๋ํด ๋ฌด์์ด๋ ๋ฌผ์ด๋ณด์ธ์!", 'left');
}, 3000);
return username;
} else {
setTimeout(function () {
return sendMessage("์ฌ๋ฐ๋ฅธ ๋๋ค์์ ์ด์ฉํด์ฃผ์ธ์.", 'left');
}, 1000);
return null;
}
}
function requestChat(messageText, url_pattern) {
$.ajax({
url: "http://your_server_address:8080/" + url_pattern + '/' + userName + '/' + messageText,
type: "GET",
dataType: "json",
success: function (data) {
state = data['state'];
if (state === 'SUCCESS') {
return sendMessage(data['answer'], 'left');
} else if (state === 'REQUIRE_LOCATION') {
return sendMessage('์ด๋ ์ง์ญ์ ์๋ ค๋๋ฆด๊น์?', 'left');
} else {
return sendMessage('์ฃ์กํฉ๋๋ค. ๋ฌด์จ๋ง์ธ์ง ์ ๋ชจ๋ฅด๊ฒ ์ด์.', 'left');
}
},
error: function (request, status, error) {
console.log(error);
return sendMessage('์ฃ์กํฉ๋๋ค. ์๋ฒ ์ฐ๊ฒฐ์ ์คํจํ์ต๋๋ค.', 'left');
}
});
}
function onSendButtonClicked() {
let messageText = getMessageText();
sendMessage(messageText, 'right');
if (userName == null) {
userName = setUserName(messageText);
} else {
if (messageText.includes('์๋
')) {
setTimeout(function () {
return sendMessage("์๋
ํ์ธ์. ์ ๋ Kochat ์ฌํ๋ด์
๋๋ค.", 'left');
}, 1000);
} else if (messageText.includes('๊ณ ๋ง์')) {
setTimeout(function () {
return sendMessage("์ฒ๋ง์์. ๋ ๋ฌผ์ด๋ณด์ค ๊ฑด ์๋์?", 'left');
}, 1000);
} else if (messageText.includes('์์ด')) {
setTimeout(function () {
return sendMessage("๊ทธ๋ ๊ตฐ์. ์๊ฒ ์ต๋๋ค!", 'left');
}, 1000);
} else if (state.includes('REQUIRE')) {
return requestChat(messageText, 'fill_slot');
} else {
return requestChat(messageText, 'request_chat');
}
}
}
์์ ๊ฐ์ด ๋ฐ๋ชจ ์ ํ๋ฆฌ์ผ์ด์
์ด ์ ์คํ๋ฉ๋๋ค.
๊ทธ๋ฌ๋ ๋ฐ๋ชจ ๋ฐ์ดํฐ์
์ ์์ด ์ ๊ธฐ ๋๋ฌธ์ ๋ชจ๋ ์ง๋ช
์ด๋ ๋ชจ๋
์์, ๋ชจ๋ ์ฌํ์ง ๋ฑ์ ์์ ๋ฃ์ง ๋ชปํฉ๋๋ค.
์ค์ ๋ก ๋ชจ๋ ๋์๋ ๋ชจ๋ ์์ ๋ฑ์ ์์ ๋ค์ ์ ๋๋ก
๋ํ๋ฅผ ๋๋๋ ค๋ฉด ๋ฐ๋ชจ ๋ฐ์ดํฐ์
๋ณด๋ค ๋ง์
๋ฐ์ดํฐ๋ฅผ ์ฝ์
ํ์
์ผ ๋์ฑ ์ข์ ์ฑ๋ฅ์ ๊ธฐ๋ํ ์ ์์ ๊ฒ์
๋๋ค.
๋ชจ๋ ๋ฐ๋ชจ ์ ํ๋ฆฌ์ผ์ด์
์์ค์ฝ๋๋ ์ฌ๊ธฐ
๋ฅผ ์ฐธ๊ณ ํด์ฃผ์ธ์
๋ง์ฝ ๋ณธ์ธ์ด ์ํ๋ ๊ธฐ๋ฅ์ Kocchat์ ์ถ๊ฐํ๊ณ ์ถ์ผ์๋ค๋ฉด ์ธ์ ๋ ์ง ์ปจํธ๋ฆฌ๋ทฐ์
ํ ์ ์์ต๋๋ค.
- ver 1.0 : ์ํฐํฐ ํ์ต์ CRF ๋ฐ ๋ก์ค ๋ง์คํน ์ถ๊ฐํ๊ธฐ
- ver 1.0 : ์์ธํ README ๋ฌธ์ ์์ฑ ๋ฐ PyPI ๋ฐฐํฌํ๊ธฐ
- ver 1.0 : ๊ฐ๋จํ ์น ์ธํฐํ์ด์ค ๊ธฐ๋ฐ ๋ฐ๋ชจ ์ ํ๋ฆฌ์ผ์ด์ ์ ์ํ๊ธฐ
- ver 1.0 : Jupyter Note Example ์์ฑํ๊ธฐ + Colab ์คํ ํ๊ฒฝ
- ver 1.1 : ๋ฐ์ดํฐ์ ํฌ๋งท RASA์ฒ๋ผ markdown์ ๋๊ดํธ ํํ๋ก ๋ณ๊ฒฝ
- ver 1.2 : Pretrain Embedding ์ ์ฉ ๊ฐ๋ฅํ๊ฒ ๋ณ๊ฒฝ (Gensim)
- ver 1.3 : Transformer ๊ธฐ๋ฐ ๋ชจ๋ธ ์ถ๊ฐ (Etri BERT, SK BERT)
- ver 1.3 : Pytorch Embedding ๋ชจ๋ธ ์ถ๊ฐ + Pretrain ์ ์ฉ ๊ฐ๋ฅํ๊ฒ
- ver 1.4 : Seq2Seq ์ถ๊ฐํด์ Fallback์ ๋์ฒํ ์ ์๊ฒ ๋ง๋ค๊ธฐ (LSTM, SK GPT2)
- ver 1.5 : ๋ค์ด๋ฒ ๋ง์ถค๋ฒ ๊ฒ์ฌ๊ธฐ ์ ๊ฑฐํ๊ณ , ์์ฒด์ ์ธ ๋์ด์ฐ๊ธฐ ๊ฒ์ฌ๋ชจ๋ ์ถ๊ฐ
- ver 1.6 : BERT์ Markov ์ฒด์ธ์ ์ด์ฉํ ์๋ OOD ๋ฐ์ดํฐ ์์ฑ๊ธฐ๋ฅ ์ถ๊ฐ
- ver 1.7 : ๋ํ ํ๋ฆ๊ด๋ฆฌ๋ฅผ ์ํ Story ๊ด๋ฆฌ ๊ธฐ๋ฅ ๊ตฌํํด์ ์ถ๊ฐํ๊ธฐ
- ์ฑ๋ด ๋ถ๋ฅ ๊ทธ๋ฆผ
- seq2seq ๊ทธ๋ฆผ
- Fallback Detection ๊ทธ๋ฆผ
- ๋ฐ๋ชจ ์ ํ๋ฆฌ์ผ์ด์ ํ ํ๋ฆฟ
- ๊ทธ ์ธ์ ๊ทธ๋ฆผ ๋ฐ ์์ค์ฝ๋ : ๋ณธ์ธ ์ ์
Copyright 2020 Kochat.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.