Re-implementation of the paper ‘Grokking: Generalization beyond overfitting on small algorithmic datasets’
Paper
Original paper can be found here
Datasets
I’m not super clear on how they defined their division. I am using integer division:
- $$x\circ y = (x // y) mod p$$, for some prime $$p$$ and $$0\leq x,y \leq p$$
- $$x\circ y = (x // y) mod p$$ if y is odd else (x – y) mod p, for some prime $$p$$ and $$0\leq x,y \leq p$$
Hyperparameters
The default hyperparameters are from the paper, but can be adjusted via the command line when running train.py
Running experiments
To run with default settings, simply run python train.py
.
The first time you train on any dataset you have to specify --force_data
.
Arguments:
optimizer args
- “–lr”, type=float, default=1e-3
- “–weight_decay”, type=float, default=1
- “–beta1”, type=float, default=0.9
- “–beta2”, type=float, default=0.98
model args
- “–num_heads”, type=int, default=4
- “–layers”, type=int, default=2
- “–width”, type=int, default=128
data args
- “–data_name”, type=str, default=”perm”, choices=[
- “perm_xy”, # permutation composition x * y
- “perm_xyx1”, # permutation composition x * y * x^-1
- “perm_xyx”, # permutation composition x * y * x
- “plus”, # x + y
- “minus”, # x – y
- “div”, # x / y
- “div_odd”, # x / y if y is odd else x – y
- “x2y2”, # x^2 + y^2
- “x2xyy2”, # x^2 + y^2 + xy
- “x2xyy2x”, # x^2 + y^2 + xy + x
- “x3xy”, # x^3 + y
- “x3xy2y” # x^3 + xy^2 + y
]
- “–num_elements”, type=int, default=5 (choose 5 for permutation data, 97 for arithmetic data)
- “–data_dir”, type=str, default=”./data”
- “–force_data”, action=”store_true”, help=”Whether to force dataset creation.”
training args
- “–batch_size”, type=int, default=512
- “–steps”, type=int, default=10**5
- “–train_ratio”, type=float, default=0.5
- “–seed”, type=int, default=42
- “–verbose”, action=”store_true”
- “–log_freq”, type=int, default=10
- “–num_workers”, type=int, default=4