TextGAN is a PyTorch framework for Generative Adversarial Networks (GANs) based text generation models. TextGAN serves as a benchmarking platform to support research on GAN-based text generation models. Since most GAN-based text generation models are implemented by Tensorflow, TextGAN can help those who get used to PyTorch to enter the text generation field faster.
For now, only few GANs-based models are implemented, including SeqGAN (Yu et. al, 2017), LeakGAN (Guo et. al, 2018) and RelGAN (Nie et. al, 2018). If you find any mistake in my implementation, please let me know! Also, please feel free to contribute to this repository if you want to add other models.
- PyTorch >= 1.0.0
- Python 3.6
- Numpy 1.14.5
- CUDA 7.5+ (For GPU)
- nltk 3.4
- tqdm 4.32.1
To install, run pip install -r requirements.txt
. In case of CUDA problems, consult the official PyTorch Get Started guide.
- SeqGAN - SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient
- LeakGAN - Long Text Generation via Adversarial Training with Leaked Information
- RelGAN - RelGAN: Relational Generative Adversarial Networks for Text Generation
- Get Started
git clone https://github.com/williamSYSU/TextGAN-PyTorch.git
cd TextGAN-PyTorch
- For real data experiments,
Image COCO
andEMNLP news
dataset can be downloaded from here. - Run with
SeqGAN
cd run
python3 run_seqgan.py 0 0 # The first 0 is job_id, the second 0 is gpu_id
- Run with
LeakGAN
cd run
python3 run_leakgan.py 0 0
- Run with
RelGAN
cd run
python3 run_relgan.py 0 0
-
Instructor
For each model, the entire runing process is defined in
instructor/oracle_data/seqgan_instructor.py
. (Take SeqGAN in Synthetic data experiment for example). Some basic functions likeinit_model()
andoptimize()
are defined in the base classBasicInstructor
ininstructor.py
. If you want to add a new GAN-based text generation model, please create a new instructor underinstructor/oracle_data
and define the training process for the model. -
Visualization
Use
utils/visualization.py
to visualize the log file, including model loss and metrics scores. Custom your log files inlog_file_list
, no more thanlen(color_list)
. The log filename should exclude.txt
. -
Logging
The TextGAN-PyTorch use the
logging
module in Python to record the running process, like generator's loss and metric scores. For the convenience of visualization, there would be two same log file saved inlog/log_****_****.txt
andsave/**/log.txt
respectively. Furthermore, The code would automatically save the state dict of models and a batch-size of generator's samples in./save/**/models
and./save/**/samples
per log step, where**
depends on your hyper-parameters. -
Running Signal
You can easily control the training process with the class
Signal
(please refer toutils/helpers.py
) based on dictionary filerun_signal.txt
.For using the
Signal
, just edit the local filerun_signal.txt
and setpre_sig
toFasle
for example, the program will stop pre-training process and step into next training phase. It is convenient to early stop the training if you think the current training is enough. -
Automatiaclly select GPU
In
config.py
, the program would automatically select a GPU device with the leastGPU-Util
innvidia-smi
. This feature is enabled by default. If you want to manually select a GPU device, please uncomment the--device
args inrun_[run_model].py
and specify a GPU device with command.
Please note the log step of each model is different. See run_[run_model].py
for details of log step.
-
NLL_oracle
LeakGAN suprisely outperforms RelGAN due to its temperature control, but LeakGAN’s samples suffered from severe mode collapse.
Though both LeakGAN and RelGAN would suffer from mode collapse, the pattern of collapse is different. LeakGAN will generate a sentence with only a few words. RelGAN will generate repeated sentences with different words.
-
NLL_gen
- Add Experiment Results
- Fix bugs in
LeakGAN
model - Add instructors of
SeqGAN
andLeakGAN
ininstrutor/real_data
- Fix logging bugs for
save_root
- Fix issues of
$NLL_{oracle}$ inSeqGAN
model