/RNN2DFA

Extracting minimal DFA's from well-trained RNN's

Primary LanguagePythonMIT LicenseMIT

RNN to DFA

This is a project aiming to extract a minimal DFA from a well-trained RNN model.

The Tomita Grammars

The Tomita grammars are a set of widely used benchmark grammars in the problem of grammar inference. It contains 7 different regular grammars defined upon the binary alphabet , as shown below.

Tomita 1    1*
Tomita 2    (10)*
Tomita 3    all strings without containing odd number of consecutive 0's after odd number of consecutive 1's
Tomita 4    all strings without containing 3 consecutive 0's (000)
Tomita 5    all strings with even numbers of 0's and 1's
Tomita 6    all strings satisfying #(0)-#(1) = 3n (n=...,-1,0,1,...)
Tomita 7    1*0*1*0*

The corresponding minimal DFA's are shown below. The number of states are less than 5. States with thick border represent ACC states, while states with thin border represent REJ states. States with an "S" are starting staes of each DFA. The DFA's are equivalent to Tomita 1 to Tomita 7, from left to right.

tomita grammars dfa

The 7 DFA's are defined in ./tomita/tomita.py, which are able to classify a given sequence (ACC/REJ) and generate sequences with their corresponding ACC/REJ labels. As the 7 datasets are large to some degree, and they are able to be automatically generated, they are excluded from this repo. REMEMBER TO GENERATE THE DATASET BEFORE TRAINING THE RNN MODELS. Run python3 generator.py under the directory ./tomita/ to generate the 7 datasets of the Tomita grammars.

RNNs Trained on The Tomita Grammars

Sequence datasets of all Tomita grammars are generated, with the max length of 20. RNN models (RNN/LSTM/GRU) with a 128 hidden-state width are trained on the datasets. The accuracy results are shown in the table.

Tomita State # RNN Type Train Acc % Test Acc %
Tomita 1 2 RNN ? ?
GRU 99.9+ 100
LSTM ? ?
Tomita 2 3 RNN ? ?
GRU 99.9+ 100
LSTM ? ?
Tomita 3 5 RNN ? ?
GRU 99.9+ 99.9+
LSTM ? ?
Tomita 4 4 RNN ? ?
GRU 99.9+ 99.9+
LSTM ? ?
Tomita 5 4 RNN ? ?
GRU 83.3 50.1
LSTM ? ?
Tomita 6 3 RNN ? ?
GRU 66.7 66.7
LSTM ? ?
Tomita 7 5 RNN ? ?
GRU 99.9+ 99.9+
LSTM ? ?

Only those models with and are referred as models who have learned the grammar well. Therefore the RNN2DFA extraction is applied only on these models.

Clustering of the Hidden State

Currently, it is quite odd that our rnns do not tend to cluster as shown in the previous papers. Instead, we have such "cloud-like" hidden state distribution, as shown below.

gru_cluster_1 gru_cluster_2 gru_cluster_3 gru_cluster_4