Re-implementation of the paper ‘Grokking: Generalization beyond overfitting on small algorithmic datasets’

Paper

Original paper can be found here

Datasets

I’m not super clear on how they defined their division. I am using integer division:

  • $$x\circ y = (x // y) mod p$$, for some prime $$p$$ and $$0\leq x,y \leq p$$
  • $$x\circ y = (x // y) mod p$$ if y is odd else (x – y) mod p, for some prime $$p$$ and $$0\leq x,y \leq p$$

Hyperparameters

The default hyperparameters are from the paper, but can be adjusted via the command line when running train.py

Running experiments

To run with default settings, simply run python train.py.
The first time you train on any dataset you have to specify --force_data.

Arguments:

optimizer args

  • “–lr”, type=float, default=1e-3
  • “–weight_decay”, type=float, default=1
  • “–beta1”, type=float, default=0.9
  • “–beta2”, type=float, default=0.98

model args

  • “–num_heads”, type=int, default=4
  • “–layers”, type=int, default=2
  • “–width”, type=int, default=128

data args

  • “–data_name”, type=str, default=”perm”, choices=[
    • “perm_xy”, # permutation composition x * y
    • “perm_xyx1”, # permutation composition x * y * x^-1
    • “perm_xyx”, # permutation composition x * y * x
    • “plus”, # x + y
    • “minus”, # x – y
    • “div”, # x / y
    • “div_odd”, # x / y if y is odd else x – y
    • “x2y2”, # x^2 + y^2
    • “x2xyy2”, # x^2 + y^2 + xy
    • “x2xyy2x”, # x^2 + y^2 + xy + x
    • “x3xy”, # x^3 + y
    • “x3xy2y” # x^3 + xy^2 + y
      ]
  • “–num_elements”, type=int, default=5 (choose 5 for permutation data, 97 for arithmetic data)
  • “–data_dir”, type=str, default=”./data”
  • “–force_data”, action=”store_true”, help=”Whether to force dataset creation.”

training args

  • “–batch_size”, type=int, default=512
  • “–steps”, type=int, default=10**5
  • “–train_ratio”, type=float, default=0.5
  • “–seed”, type=int, default=42
  • “–verbose”, action=”store_true”
  • “–log_freq”, type=int, default=10
  • “–num_workers”, type=int, default=4

GitHub

View Github