Using transformer to train a Chinese chatbot.
- Reference: PTT 中文語料
- Download dataset:
sh download_data.sh
- Google transformer architecture: tensorflow/models/official/transformer
pip3 install -r requires.txt
- Build the data file with
.tfrecord
format:
python3 build_data.py
- Train your model:
python3 train.py config/test_config.yml
You can customize your model architecture by writing a new .yml
file.
For more detail, see config/test_config.yml
If you want to change the learning rate, total training steps or other training strategies, please modify the code in train.py
.
Type the following command and check the url: http://localhost:8080
tensorboard --logdir build --port 8080
train.py
will export a Tensorflow SavedModel every 100000 training steps.
Those models will be placed under serve
folder.
To run a simple demo, make sure SavedModel exist and type the following command:
python3 chat.py serve/[YOUR MODEL FOLDER]