CakeChat is a dialog system that is able to express emotions in a text conversation. Try it online!
It is written in Theano and Lasagne. It uses end-to-end trained embeddings of 5 different emotions to generate responses conditioned by a given emotion.
The code is flexible and allows to condition a response by an arbitrary categorical variable defined for some samples in the training data. With CakeChat you can, for example, train your own persona-based neural conversational model[5] or create an emotional chatting machine without external memory[4].
- Network architecture and features
- Quick start
- Setup
- Getting the model
- Running the system
- Repository overview
- Example use cases
- References
- Credits & Support
- License
- Model:
- Hierarchical Recurrent Encoder-Decoder (HRED) architecture for handling deep dialog context[7]
- Multilayer RNN with GRU cells. First layer of the utterance-level encoder is always bidirectional.
- Thought vector is fed into decoder on each decoding step.
- Decoder can be conditioned on any string label. For example: emotion label or id of a person talking.
- Word embedding layer:
- May be initialized using w2v model trained on your own corpus.
- Embedding layer may either stay fixed of be fine-tuned along with all other weights of the network.
- Decoding
- 4 different response generation algorithms: "sampling", "beamsearch", "sampling-reranking" and "beamsearch-reranking". Reranking of the generated candidates is performed according to the log-likelihood or MMI-criteria[3]. See configuration settings description for details.
- Metrics:
- Perplexity
- n-gram distinct metrics adjusted to the samples size[3].
- Lexical similarity between samples of the model and some fixed dataset. Lexical similarity is a cosine distance between TF-IDF vector of responses generated by the model and tokens in the dataset.
- Ranking metrics: mean average precision and mean recall@k.[8]
Run the CPU-only pre-built docker image & start the CakeChat serving the model on 8080 port:
docker run --name cakechat-dev -p 127.0.0.1:8080:8080 -it lukalabs/cakechat:latest \
bash -c "python bin/cakechat_server.py"
(Or) using the GPU-enabled image:
nvidia-docker run --name cakechat-gpu-dev -p 127.0.0.1:8080:8080 -it lukalabs/cakechat-gpu:latest \
bash -c "USE_GPU=0 python bin/cakechat_server.py"
That's it! Now you can try it by running python tools/test_api.py -f localhost -p 8080 -c "Hi! How are you?"
from the host command line.
This is the easiest way to set up the environment and install all the dependencies.
-
Install Docker
-
Build a docker image
Build a CPU-only image:
docker build -t cakechat:latest -f dockerfiles/Dockerfile.cpu dockerfiles/
- Start the container
Run a docker container in the CPU-only environment
docker run --name <CONTAINER_NAME> -it cakechat:latest
-
Install nvidia-docker for the GPU support.
-
Build a GPU-enabled docker image:
nvidia-docker build -t cakechat-gpu:latest -f dockerfiles/Dockerfile.gpu dockerfiles/
- Start the container
Run a docker container in the GPU-enabled environment:
nvidia-docker run --name <CONTAINER_NAME> -it cakechat-gpu:latest
That's it! Now you can train your model and chat with it.
If you don't want to deal with docker images and containers, you can always simply run (with sudo
, --user
or inside your virtualenv):
pip install -r requirements.txt
Most likely this will do the job. NB: This method only provides a CPU-only environment. To get a GPU support, you'll need to build and install libgpuarray by yourself (see Dockerfile.gpu for example).
Run python tools/download_model.py
to download our pre-trained model.
The model is trained with context size 3 where the encoded sequence contains 30 tokens or less and the decoded sequence contains 32 tokens or less. Both encoder and decoder contain 2 GRU layers with 512 hidden units each.
The model was trained on a Twitter preprocessed conversational data.
To clean up the data, we removed URLs, retweets and citations.
Also we removed mentions and hashtags that are not preceded by normal words or punctuation marks
and filtered out all messages that contains more than 30 tokens.
Then we marked out each utterance with our emotions classifier that predicts one of the
5 emotions: "neutral", "joy", "anger", "sadness" and "fear".
To mark-up your own corpus with emotions you can use, for example, DeepMoji tool
or any other emotions classifier that you have.
For some tools (for example tools/train.py
) you can specify the path to model's initialization weights via --init_weights
argument.
The weights may come from a trained CakeChat model or from a model with a different architecture. In the latter case some parameters of Cakechat model may be left without initialization: a parameter will be initialized with a saved value if the parameter's name and shape are identical to the saved parameter, otherwise the parameter will keep its default initialization weights.
See load_weights
function for the details.
-
Put your training text corpus to
data/corpora_processed/
. Each line of the corpus file should be a JSON object containing a list of dialog messages sorted in chronological order. Code is fully language-agnostic — you can use any unicode texts in datasets. Refer to our dummy corpus to see the input formatdata/corpora_processed/train_processed_dialogs.txt
. -
The following datasets are used for validation and early stopping:
data/corpora_processed/val_processed_dialogs.txt
(dummy example) - for the context sensitive datasetdata/quality/context_free_validation_set.txt
- for the context-free validation datasetdata/quality/context_free_questions.txt
- is used for generating responses for logging and computing distinct-metricsdata/quality/context_free_test_set.txt
- is used for computing metrics of the trained model, e.g. ranking metrics
- Set up training parameters in
cakechat/config.py
. See configuration settings description for more details. - Run
python tools/prepare_index_files.py
to build the index files with tokens and conditions from the training corpus. - Run
python tools/train.py
. Don't forget to setUSE_GPU=<GPU_ID>
environment variable (with GPU_ID as from nvidia-smi) if you want to use GPU. UseSLICE_TRAINSET=N
to train the model on a subset of the first N samples of your training data to speed up preprocessing for debugging. - You can also set
IS_DEV=1
to enable the "development mode". It uses a reduced number of model parameters (decreased hidden layer dimensions, input and output sizes of token sequences, etc.), performs verbose logging and disables Theano graph optimizations. Use this mode for debugging. - Weights of your model will be saved in
data/nn_models/
.
You can train a dialog model on any text conversational dataset available to you. A great overview of existing conversational datasets can be found here: https://breakend.github.io/DialogDatasets/
Run a server that processes HTTP-requests with given input messages (contexts) and returns response messages of the model:
python bin/cakechat_server.py
Specify USE_GPU=<GPU_ID>
environment variable if you want to use a certain GPU.
Wait until the model is compiled.
Don't forget to run tools/download_model.py
prior to running bin/cakechat_server.py
if you want to start an API with our pre-trained model.
To make sure everything works fine, test the model on the following conversation:
– Hi, Eddie, what's up?
– Not much, what about you?
– Fine, thanks. Are you going to the movies tomorrow?
python tools/test_api.py -f 127.0.0.1 -p 8080 \
-c "Hi, Eddie, what's up?" \
-c "Not much, what about you?" \
-c "Fine, thanks. Are you going to the movies tomorrow?"
JSON parameters are:
Parameter | Type | Description |
---|---|---|
context | list of strings | List of previous messages from the dialogue history (max. 3 is used) |
emotion | string, one of enum | One of {'neutral', 'anger', 'joy', 'fear', 'sadness'}. An emotion to condition the response on. Optional param, if not specified, 'neutral' is used |
POST /cakechat_api/v1/actions/get_response
data: {
'context': ['Hello', 'Hi!', 'How are you?'],
'emotion': 'joy'
}
200 OK
{
'response': 'I\'m fine!'
}
We recommend to use Gunicorn for serving the API of your model at a production scale.
Run a server that processes HTTP-queries with input messages and returns response messages of the model:
cd bin && gunicorn cakechat_server:app -w 1 -b 127.0.0.1:8080 --timeout 2000
You may need to install gunicorn from pip: pip install gunicorn
.
You can also test your model in a Telegram bot: create a telegram bot and run
python tools/telegram_bot.py --token <YOUR_BOT_TOKEN>
cakechat/dialog_model/
- contains computational graph, training procedure and other model utilitiescakechat/dialog_model/inference/
- algorithms for response generationcakechat/dialog_model/quality/
- code for metrics calculation and loggingcakechat/utils/
- utilities for text processing, w2v training, etc.cakechat/api/
- functions to run http server: API configuration, error handlingtools/
- scripts for training, testing and evaluating your model
bin/cakechat_server.py
- Runs an HTTP-server that returns response messages of the model given dialog contexts and an emotion. See run section for details.tools/train.py
- Trains the model on your data. You can specify the path to model's initialization weights via--init_weights
argument. Also use the--reverse
flag to train the model used in "*-reranking" response generation algorithms for more accurate predictions.tools/prepare_index_files.py
- Prepares index for the most commonly used tokens and conditions. Use this script before training the model.tools/quality/ranking_quality.py
- Computes ranking metrics of a dialog model.tools/quality/prediction_distinctness.py
- Computes distinct-metrics of a dialog model. See the features section for details about the metrics.tools/quality/condition_quality.py
- Computes metrics on different subsets of a data according to the condition value.tools/generate_predictions.py
- Evaluates the model. Generates predictions of a dialog model on the set of given dialog contexts and then computes metrics. Note that you should have a reverse-model in thedata/nn_models
directory, if you want to use "*-reranking" prediction modes.tools/generate_predictions_for_condition.py
- Generates predictions for a given condition value.tools/test_api.py
- Example code to send requests to a running HTTP-server.tools/download_model.py
- Downloads the pre-trained model and index files associated with it. Also compiles the whole model once to create Theano cache.tools/telegram_bot.py
- Runs a Telegram bot that uses a trained model.
All the configuration parameters for the network architecture, training, predicting and logging steps are defined in cakechat/config.py
.
Some inference parameters used in an HTTP-server are defined in cakechat/api/config.py
.
-
Network architecture and size
HIDDEN_LAYER_DIMENSION
is the main parameter that defines the number of hidden units in recurrent layers.WORD_EMBEDDING_DIMENSION
andCONDITION_EMBEDDING_DIMENSION
define the number of hidden units that each token/condition are mapped into. Together they sum up to the dimension of input vector passed to the encoder RNN.- Number of units of the output layer of the decoder is defined by the number of tokens in the dictionary in the tokens_index directory.
-
Decoding algorithm:
PREDICTION_MODE_FOR_TESTS
defines how the responses of the model are generated. The options are the following:- sampling – response is sampled from output distribution token-by-token.
For every token the temperature transform is performed prior to sampling.
You can control the temperature value by tuning
DEFAULT_TEMPERATURE
parameter. - sampling-reranking – multiple candidate-responses are generated using sampling procedure described above.
After that the candidates are ranked according to their MMI-score[3]
You can tune this mode by picking
SAMPLES_NUM_FOR_RERANKING
andMMI_REVERSE_MODEL_SCORE_WEIGHT
parameters. - beamsearch – candidates are sampled using beam search algorithm. The candidates are ordered according to their log-likelihood score computed by the beam search procedure.
- beamsearch-reranking – same as above, but the candidates are re-ordered after the generation in the same way as in sampling-reranking mode.
- sampling – response is sampled from output distribution token-by-token.
For every token the temperature transform is performed prior to sampling.
You can control the temperature value by tuning
Note that there are other parameters that affect the response generation process. See
REPETITION_PENALIZE_COEFFICIENT
,NON_PENALIZABLE_TOKENS
,MAX_PREDICTIONS_LENGTH
.
By providing additional condition labels within a dataset entries, you can build the following models:
- A Persona-Based Neural Conversation Model — a model that allows to condition responses on a persona ID to make them lexically similar to the given persona's linguistic style.
- Emotional Chatting Machine-like model — a model that allows to condition responses on an emotion to provide emotional styles (anger, sadness, joy, etc).
- Topic Aware Neural Response Generation-like model — a model that allows to condition responses on a certain topic to keep the topic-aware conversation.
To make use of these extra conditions, please refer to the section Training your own model. Just set the "condition" field in the training set to one of the following: persona ID, emotion or topic label, update the index files and start the training.
- [1] A Neural Conversational Model
- [2] How NOT To Evaluate Your Dialogue System
- [3] A Diversity-Promoting Objective Function for Neural Conversation Models
- [4] Emotional Chatting Machine: Emotional Conversation Generation with Internal and External Memory
- [5] A Persona-Based Neural Conversation Model
- [6] Topic Aware Neural Response Generation
- [7] A Hierarchical Recurrent Encoder-Decoder For Generative Context-Aware Query Suggestion
- [8] Quantitative Evaluation of User Simulation Techniques for Spoken Dialogue Systems
CakeChat is developed and maintained by the Replika team: Michael Khalman, Nikita Smetanin, Artem Sobolev, Nicolas Ivanov, Artem Rodichev and Denis Fedorenko. Demo by Oleg Akbarov, Alexander Kuznetsov and Vladimir Chernosvitov.
All issues and feature requests can be tracked here - GitHub Issues.
© 2018 Luka, Inc. Licensed under the Apache License, Version 2.0. See LICENSE file for more details.