add config hyperparam to control hidden nodes, allowing for comparison of ~equivalent-size models with diff. receptive field sizes
controllable metric target bounds depend on actual map shape when randomizing map shape per-episode
eval on larger maps (32x32)
- eval on all possible map shapes within larger square map?
train on larger maps (32x32 if possible)
sparse reward (to compare against action shapes. Also might help with training on larger maps)
optimize pathfinding (jax.lax.conv)
new domains (treasure... more keys/doors?) and representations (turtle, re-implement O.G. wide model and compare against NCA, FractalNet...)
make enemies chase agent (when agent is in "line of sight", move toward the player by 1 tile every 2 timesteps), add combat mechanics
pip install -r requirements.txt
Then install jax:
To train a model, run:
python train.py
Arguments (pass these by running, e.g., python train.py overwrite=True
):
overwrite
, bool, default=False` Whether to overwrite the model if it already exists.render_freq
, int, default=100 How often to render the environment.
During training, we render a few episodes to see how the model is doing (every render_freq
updates). We use the same
random seeds when resetting the environment, so that initial level layouts are the same between rounds of rendering.
To train a sweep of models, run:
python sweep.py
This will perform grid searches over the groups of hyperparameters defined in hypers
.
Arguments:
mode
, string, default=train
, what type of jobs to launch while sweeping. Options are:train
trains the model for each experiment. Will attempt to re-load existing checkpoints by default.eval
evaluates each model in the sweep, given the same environment parameters as were seen during training.eval_cp
evaluates each model over a range of permitted change percentages.plot
iterates plot the results of the sweep.
slurm
, bool, default=True Whether to submit each job in the sweep to a SLURM cluster (using the submitit package)
To save a misc_stats.json
that records the number of timesteps for which a given mmodel has trained, we hackishly run python sweep.py mode=plot slurm=False
(we're getting this info from the last row of the progress.csv
used for plotting). Other stats are recorded when running with mode=eval
or the like.