This code repository implements the neural network and one dataset from A Sensntivity Analysis (and Practitioners' Guide to) Convolutional Neural Networks for Sentence Classification, which are a one-layer CNN and the Twitter Polarity dataset respectively.
- Python 3.5 or greater
- Tensorflow 1.0 or greater
easydict
To run the polarity dataset, clone the repo onto your local machine, navigate into the repository,
using the command line/terminal, and type python train-eval.py {argparser flags}
.
Replace {argparser flags}
with the appropriate choice of options as described in
the section below.
The following commandline options are available for running the model:
-n
, run_num, default = 0, saves all model files under/save_directory/model_directory/Model[n]
-r
, restore, default = 0, binary value indicating whether to restore from a model.-m
, model_restore, default = 1, restores from/save_directory/model_directory/Model[n]
-f
, file_epoch, default = 1, restores model with the following checkpoint:part_[f].ckpt.meta
-t
, train, default=1, # Binary to train model. 0 = No train.-v
, eval, default=1) # Binary to evalulate model. 0 = No eval.-g
, gpu, default = 0, accepts a single integer or a string "all" to use all available.
Some parameters cannot be adjusted at the commandline; all others are defined in the lib/config.py
file and can be edited with a Yaml file found in the cfgs/
folder. Refer to the OneLayerConv.yml
file
for an example of adjusting the number of training epochs.
You can edit the network
method of the OneLayerConv
class in the lib/networks.py
file to change the neural network.
The paper recommends 1-max pooling over other types of pooling but suggests that deeper
networks will yield superior results. You can also add a .yaml
file under
the cfgs/
folder if you would like to edit the parameters defined in the
lib/config.py
file.
You will have to add a new class in the lib/datasets.py
file similar to
the TwitterSentiment
class currently implemented. The codebase expects only certain methods
(batch_iter
, transform_x
, transform_y
) to be implemented in
the dataset class, but other methods will be necessary (i.e. load_data).