Re-implementation of the paper 'Grokking: Generalization beyond overfitting on small algorithmic datasets'
Original paper can be found here
All datasets from the original paper's appendix are supported.
The default hyperparameters are from the paper, but can be adjusted via the command line when running train.py
To run with default settings, simply run python train.py
.
The first time you train on any dataset you have to specify --force_data
.
- "--lr", type=float, default=1e-3
- "--weight_decay", type=float, default=1
- "--beta1", type=float, default=0.9
- "--beta2", type=float, default=0.98
- "--num_heads", type=int, default=4
- "--layers", type=int, default=2
- "--width", type=int, default=128
- "--data_name", type=str, default="perm", choices=[
- "perm_xy", # permutation composition x * y
- "perm_xyx1", # permutation composition x * y * x^-1
- "perm_xyx", # permutation composition x * y * x
- "plus", # x + y
- "minus", # x - y
- "div", # x / y
- "div_odd", # x / y if y is odd else x - y
- "x2y2", # x^2 + y^2
- "x2xyy2", # x^2 + y^2 + xy
- "x2xyy2x", # x^2 + y^2 + xy + x
- "x3xy", # x^3 + y
- "x3xy2y" # x^3 + xy^2 + y ]
- "--num_elements", type=int, default=5 (choose 5 for permutation data, 97 for arithmetic data)
- "--data_dir", type=str, default="./data"
- "--force_data", action="store_true", help="Whether to force dataset creation."
- "--batch_size", type=int, default=512
- "--steps", type=int, default=10**5
- "--train_ratio", type=float, default=0.5
- "--seed", type=int, default=42
- "--verbose", action="store_true"
- "--log_freq", type=int, default=10
- "--num_workers", type=int, default=4
- "--disable_logging", action="store_true", help="Whether to use wandb logging"
- "--checkpoints", type=int, default=None, nargs="*", help="List of number of steps after which to save model."