sentiment_transfer_network
Generative Adversarial Network based:
Which is based on Style Transfer from Non-Parallel Text by Cross-Alignment
Prerequisties:
- python 3.6 or higher
- tensroflow 1.3 or higher
- GPU memory 6G or higher (GeForce GTX 1060 up)
usage:
- Prepare dataset:
- Step 1 : Make a directory for your dataset.
mkdir -r data/[your_dataset_name]
- Step 2 : Prepare training data.
Put positive and negative datasets into the directory and rename them aspos_file.txt
andneg_file.txt
respectively.
(Every sentences inpos_file.txt
andneg_file.txt
are split by\n
) - Step 3 : Prepare testing data.
Put the testing data into the directory and rename it asseq2seq.txt
.
(format ofseq2seq.txt
is the same aspos_file.txt
andneg_file.txt
)
- Training:
- run
python main.py -train -attention -model [your_model_name] -data_path [your_dataset_name]
- Testing:
- run
python main.py -test -attention -model [your_model_name] -data_path [your_dataset_name]
Transfer Network based:
Prerequisties:
- python 3.6 or higher
- tensroflow 1.0
usage:
- Prepare dataset:
-
Step 1 : Make a directory for your model.
mkdir -r works/[your_model_name]
-
Step 2 : Prepare sentiment data.
Put positive and negative datasets into the directory and rename them aspos_file.txt
andneg_file.txt
respectively.
(Every sentences inpos_file.txt
andneg_file.txt
are split by\n
) -
Step 3 : Prepare dialogue data.
Put dialogue dataset into the directory and rename them aschat.txt
.
(Each pair of dialogue (question and answer) is split by\n
) -
Step 4 : Prepare testing data.
Put the testing data into the root directory and rename it asseq2seq.txt
.
(format ofseq2seq.txt
is the same aspos_file.txt
andneg_file.txt
)
- Training:
- Step 1: Training the variational autoencoder.
-- runpython main.py -step1 -model [your_model_name]
- Step 2: Training the sentiment classifier.
-- runpython main.py -step2 -model [your_model_name]
- Step 3: Training the transfer network.
-- runpython main.py -step3 -model [your_model_name]
- Testing:
- run
python main.py --test -model [your_model_name]