# Re-implementation of the paper ‘Grokking: Generalization beyond overfitting on small algorithmic datasets’

## Paper

Original paper can be found here

## Datasets

I’m not super clear on how they defined their division. I am using integer division:

- $$x\circ y = (x // y) mod p$$, for some prime $$p$$ and $$0\leq x,y \leq p$$
- $$x\circ y = (x // y) mod p$$ if y is odd else (x – y) mod p, for some prime $$p$$ and $$0\leq x,y \leq p$$

## Hyperparameters

The default hyperparameters are from the paper, but can be adjusted via the command line when running `train.py`

## Running experiments

To run with default settings, simply run `python train.py`

.

The first time you train on any dataset you have to specify `--force_data`

.

### Arguments:

### optimizer args

- “–lr”, type=float, default=1e-3
- “–weight_decay”, type=float, default=1
- “–beta1”, type=float, default=0.9
- “–beta2”, type=float, default=0.98

### model args

- “–num_heads”, type=int, default=4
- “–layers”, type=int, default=2
- “–width”, type=int, default=128

### data args

- “–data_name”, type=str, default=”perm”, choices=[
- “perm_xy”, # permutation composition x * y
- “perm_xyx1”, # permutation composition x * y * x^-1
- “perm_xyx”, # permutation composition x * y * x
- “plus”, # x + y
- “minus”, # x – y
- “div”, # x / y
- “div_odd”, # x / y if y is odd else x – y
- “x2y2”, # x^2 + y^2
- “x2xyy2”, # x^2 + y^2 + xy
- “x2xyy2x”, # x^2 + y^2 + xy + x
- “x3xy”, # x^3 + y
- “x3xy2y” # x^3 + xy^2 + y

]

- “–num_elements”, type=int, default=5 (choose 5 for permutation data, 97 for arithmetic data)
- “–data_dir”, type=str, default=”./data”
- “–force_data”, action=”store_true”, help=”Whether to force dataset creation.”

### training args

- “–batch_size”, type=int, default=512
- “–steps”, type=int, default=10**5
- “–train_ratio”, type=float, default=0.5
- “–seed”, type=int, default=42
- “–verbose”, action=”store_true”
- “–log_freq”, type=int, default=10
- “–num_workers”, type=int, default=4