JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"


Model based reinforcement learning typically trains the dynamics and reward functions by minimizing the error of predictions.
The error is only a proxy to maximizing the sum of rewards, the ultimate goal of the agent, leading to the objective mismatch.
We propose an end-to-end algorithm called Optimal Model Design (OMD) that optimizes the returns directly for model learning.
OMD leverages the implicit function theorem to optimize the model parameters and forms the following computational graph:


Please cite our work if you find it useful in your research:

  title={Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation},
  author={Nikishin, Evgenii and Abachi, Romina and Agarwal, Rishabh and Bacon, Pierre-Luc},
  journal={arXiv preprint arXiv:2106.03273},


We assume that you use Python 3. To install the necessary dependencies, run the following commands:

1. virtualenv ~/env_omd
2. source ~/env_omd/bin/activate
3. pip install -r requirements.txt

To use JAX with GPU, follow the official instructions.
To install MuJoCo, check the instructions.


For historical reasons, the code is divided into 3 parts.


All results for the tabular experiments could be reproduced by running the tabular.ipynb notebook.

To open the notebook in Google Colab, use this link.


To train the OMD agent on CartPole, use the following commands:

cd cartpole
python --agent_type omd

We also provide the implementation of the corresponding MLE and VEP baselines. To train the agents, change the --agent_type flag to mle or vep.


To train the OMD agent on MuJoCo HalfCheetah-v2, use the following commands:

cd mujoco
python --config.algo=omd

To train the MLE baseline, change the --config.algo flag to mle.