FitVid Video Prediction Model

Implementation of FitVid video prediction model in JAX/Flax.

If you find this code useful, please cite it in your paper:

  title={FitVid: Overfitting in Pixel-Level Video Prediction},
  author= {Babaeizadeh, Mohammad and Saffar, Mohammad Taghi and Nair, Suraj 
  and Levine, Sergey and Finn, Chelsea and Erhan, Dumitru},
  journal={arXiv preprint arXiv:2106.13195},


FitVid is a new architecture for conditional variational video prediction. It has ~300 million parameters and can be trained with minimal training tricks.

Sample Videos

Human3.6M RoboNet
68747470733a2f2f692e696d6775722e636f6d2f793632316376452e676966 68747470733a2f2f692e696d6775722e636f6d2f4b735a446e68302e676966
68747470733a2f2f692e696d6775722e636f6d2f794d486b716f682e676966 68747470733a2f2f692e696d6775722e636f6d2f664f59504e4d782e676966

For more samples please visit FitVid. website:


Get dependencies:

pip3 install --user tensorflow
pip3 install --user tensorflow_addons
pip3 install --user flax
pip3 install --user ffmpeg

Train on RoboNet:

python -m fitvid.train  --output_dir /tmp/output

Disclaimer: Not an official Google product.