Gradient-based learning drives robust representations in recurrent neural networks by balancing compression and expansion
Code is developed by Matthew Farrell
Original version can be found on codeocean
Code to generate the figures for the manuscript "Gradient-based learning drives robust representations in recurrent neural networks by balancing compression and expansion" published in Nature Machine Intelligence.
The script figures_for_manuscript.py contains code to call plotting functions that generate the main figures in the paper, as well as Extended Data Figures 1-4, 6, and 10.
Many of the plots for Figures 5 and 6 take a long time to run. The function calls that generate these plots are not called by default. Change which routines are called at the bottom of figures_for_manuscript.py to produce the specific plots of interest. If run outside of Code Ocean, a different value of n_cores in figures_for_manuscript.py can be chosen which will parallelize the runs. An available GPU will also be used automatically (change "device" in initialize_and_train.py to change this).
To run a faster version of the simulations for the main manuscript figures, make sure that FAST_RUN is set to True at the top of figures_for_manuscript.py (true by default). This runs five realizations of the network for each set of hyperparameters. To use a version of the simulations that corresponds with the plots in the paper, set FAST_RUN to False. This trains 30 networks for every set of hyperparameters, each with a different random number generator seed, and the variation is used to plot error bars. While this is sufficient to reproduce the results of the paper, the error bars in the paper use 30 network realizations. To simulate this, change the line "seeds = list(range(10))" to "seeds = list(range(30))" in figures_for_manuscript.py.
In the paper, the error bars for plots are generated by a custom modification made to the seaborn plotting library (see the paper for details about these error bars).
When the neural networks are trained, the weights through training are saved in the directory data/output. The CSV file data/output/output_table.csv maps the model output directories to the set of parameters used to train the model (i.e. those parameters passed to initialize_and_train.py). This mapping of parameters and save locations is handled by model_output_manager.py. The next time initialize_and_train.py is called, it will automatically check for a previous run that matches that set of parameters, and load the model if it finds it. To disable this behavior, either keep the directory "data/output/"" empty, or set rerun=True in the parameters passed to initialize_and_train.py. The variable CLEAR_PREVIOUS_RUNS at the top of figures_for_manuscript.py can also be set to true, which clears out the "data/output/" folder before training. Code Ocean automatically removes all saved data after finishing, but if run outside of code ocean the saved parameters should remain in the directory data/output. An additional level of memoization is implemented by the joblib Memory class applied to the plotting routines in plots.py. These only seem to work when using a single thread for training. One strategy is to first train with multiple cores/threads with n_cores > 1, and finally run the training again with n_cores=1 which will load the previous runs and also cache the outputs of the plotting routines.
There are a variety of hyperparameter options that can be set, including looking at the layers of a feedforward network instead of the timesteps of an RNN, or looking at other RNN models (the parameter 'network' specifies this). These hyperparameters haven't necessarily been exhaustively tested and exceptions may be thrown if they are changed.
The parameter 'g_radius' in the code controls the level of chaos of the initialized RNN, and corresponds with the parameter "α" in the manuscript.
The figures for the pca snapshots are labelled by timepoint / layer, so "Xdim_200_snapshot_0_tanh_cce" is the snapshot at time t=0. "Xdim_200_snapshot_-1_tanh_cce" is the pca plot of the inputs.
While efforts have been taken to avoid collisions between cores/threads, errors can occasionally arise when parallelizing runs, and the source of these errors has not been identified. It may help to erase both the output folder and entry of data/output/output_table.csv corresponding to the run that causes the error, and then to run the parallelized simulations again.