TextGAN is a PyTorch framework for Generative Adversarial Networks (GANs) based text generation models, including general text generation models and category text generation models. TextGAN serves as a benchmarking platform to support research on GAN-based text generation models. Since most GAN-based text generation models are implemented by Tensorflow, TextGAN can help those who get used to PyTorch to enter the text generation field faster.
If you find any mistake in my implementation, please let me know! Also, please feel free to contribute to this repository if you want to add other models.
- PyTorch >= 1.1.0
- Python 3.6
- Numpy 1.14.5
- CUDA 7.5+ (For GPU)
- nltk 3.4
- tqdm 4.32.1
- KenLM (https://github.com/kpu/kenlm)
To install, run pip install -r requirements.txt
. In case of CUDA problems, consult the official PyTorch Get Started guide.
-
Download stable release and unzip: http://kheafield.com/code/kenlm.tar.gz
-
Need Boost >= 1.42.0 and bjam
- Ubuntu:
sudo apt-get install libboost-all-dev
- Mac:
brew install boost; brew install bjam
- Ubuntu:
-
Run within kenlm directory:
mkdir -p build cd build cmake .. make -j 4
-
pip install https://github.com/kpu/kenlm/archive/master.zip
-
For more information on KenLM see: https://github.com/kpu/kenlm and http://kheafield.com/code/kenlm/
- SeqGAN - SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient
- LeakGAN - Long Text Generation via Adversarial Training with Leaked Information
- MaliGAN - Maximum-Likelihood Augmented Discrete Generative Adversarial Networks
- JSDGAN - Adversarial Discrete Sequence Generation without Explicit Neural Networks as Discriminators
- RelGAN - RelGAN: Relational Generative Adversarial Networks for Text Generation
- DPGAN - DP-GAN: Diversity-Promoting Generative Adversarial Network for Generating Informative and Diversified Text
- DGSAN - DGSAN: Discrete Generative Self-Adversarial Network
- CoT - CoT: Cooperative Training for Generative Modeling of Discrete Data
- SentiGAN - SentiGAN: Generating Sentimental Texts via Mixture Adversarial Networks
- CatGAN (ours) - CatGAN: Category-aware Generative Adversarial Networks with Hierarchical Evolutionary Learning for Category Text Generation
- Get Started
git clone https://github.com/williamSYSU/TextGAN-PyTorch.git
cd TextGAN-PyTorch
- For real data experiments, all datasets (
Image COCO
,EMNLP NEWs
,Movie Review
,Amazon Review
) can be downloaded from here. - Run with a specific model
cd run
python3 run_[model_name].py 0 0 # The first 0 is job_id, the second 0 is gpu_id
# For example
python3 run_seqgan.py 0 0
-
Instructor
For each model, the entire runing process is defined in
instructor/oracle_data/seqgan_instructor.py
. (Take SeqGAN in Synthetic data experiment for example). Some basic functions likeinit_model()
andoptimize()
are defined in the base classBasicInstructor
ininstructor.py
. If you want to add a new GAN-based text generation model, please create a new instructor underinstructor/oracle_data
and define the training process for the model. -
Visualization
Use
utils/visualization.py
to visualize the log file, including model loss and metrics scores. Custom your log files inlog_file_list
, no more thanlen(color_list)
. The log filename should exclude.txt
. -
Logging
The TextGAN-PyTorch use the
logging
module in Python to record the running process, like generator's loss and metric scores. For the convenience of visualization, there would be two same log file saved inlog/log_****_****.txt
andsave/**/log.txt
respectively. Furthermore, The code would automatically save the state dict of models and a batch-size of generator's samples in./save/**/models
and./save/**/samples
per log step, where**
depends on your hyper-parameters. -
Running Signal
You can easily control the training process with the class
Signal
(please refer toutils/helpers.py
) based on dictionary filerun_signal.txt
.For using the
Signal
, just edit the local filerun_signal.txt
and setpre_sig
toFasle
for example, the program will stop pre-training process and step into next training phase. It is convenient to early stop the training if you think the current training is enough. -
Automatiaclly select GPU
In
config.py
, the program would automatically select a GPU device with the leastGPU-Util
innvidia-smi
. This feature is enabled by default. If you want to manually select a GPU device, please uncomment the--device
args inrun_[run_model].py
and specify a GPU device with command.
-
run file: run_seqgan.py
-
Instructors: oracle_data, real_data
-
Models: generator, discriminator
-
Structure (from SeqGAN)
-
run file: run_leakgan.py
-
Instructors: oracle_data, real_data
-
Models: generator, discriminator
-
Structure (from LeakGAN)
-
run file: run_maligan.py
-
Instructors: oracle_data, real_data
-
Models: generator, discriminator
-
Structure (from my understanding)
-
run file: run_jsdgan.py
-
Instructors: oracle_data, real_data
-
Models: generator (No discriminator)
-
Structure (from my understanding)
-
run file: run_relgan.py
-
Instructors: oracle_data, real_data
-
Models: generator, discriminator
-
Structure (from my understanding)
-
run file: run_dpgan.py
-
Instructors: oracle_data, real_data
-
Models: generator, discriminator
-
Structure (from DPGAN)
-
run file: run_dgsan.py
-
Instructors: oracle_data, real_data
-
Models: generator, discriminator
-
run file: run_cot.py
-
Instructors: oracle_data, real_data
-
Models: generator, discriminator
-
Structure (from CoT)
-
run file: run_sentigan.py
-
Instructors: oracle_data, real_data
-
Models: generator, discriminator
-
Structure (from SentiGAN)
-
run file: run_catgan.py
-
Instructors: oracle_data, real_data
-
Models: generator, discriminator
-
Structure (from CatGAN)
MIT lincense