See results in arXiv:2010.15114. This repo contains code (written in JAX) for training networks on various text-classification tasks and analyzing the learned dynamical systems.
Note: This repo is in the process of being converted to open-source, but still has artifacts left over from the earlier, more messy state, e.g. the name of the cloud storage bucket. These should be fixed soon.
Install dependencies using make
or pip
:
make install
will create a VirtualEnv in the project root with all the required packages, or
pip install -r pip_requirements.txt
can be used to install the dependencies in an existing VirtualEnv.
Currently, there are both a pip_requirements.txt
and a requirements.txt
file
in the project root directory. The requirements.txt
file was used with
Docker for cloud job submission, but except for the version of JAX, the two
files should be identical. Either should be fine for installation.
To train a model, run a command like the following:
python -m src.train --cell_type GRU --emb_size 128 --num_units 256 --dataset imdb
Ordered datasets in this study, Yelp and Amazon reviews, have examples divided into five classes (the number of stars the user left with their review). This dataset can be coarse-grained into three classes (keeping one-star, three-star, and five-star reviews), or two classes (grouping one-star and two-star reviews together, and four-star and five-star reviews together). Categorical datasets in this study, AG News and DBPedia have 4 and 14 classes, respectively. These can also be reduced down to 3 classes. In both cases this is done using the --num_classes
flag.