## omd

JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"

## Summary

Model based reinforcement learning typically trains the dynamics and reward functions by minimizing the error of predictions.

The error is only a proxy to maximizing the sum of rewards, the ultimate goal of the agent, leading to the objective mismatch.

We propose an end-to-end algorithm called *Optimal Model Design* (OMD) that optimizes the returns directly for model learning.

OMD leverages the implicit function theorem to optimize the model parameters and forms the following computational graph:

Please cite our work if you find it useful in your research:

```
@article{nikishin2021control,
title={Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation},
author={Nikishin, Evgenii and Abachi, Romina and Agarwal, Rishabh and Bacon, Pierre-Luc},
journal={arXiv preprint arXiv:2106.03273},
year={2021}
}
```

## Installation

We assume that you use Python 3. To install the necessary dependencies, run the following commands:

```
1. virtualenv ~/env_omd
2. source ~/env_omd/bin/activate
3. pip install -r requirements.txt
```

To use JAX with GPU, follow the official instructions.

To install MuJoCo, check the instructions.

## Run

For historical reasons, the code is divided into 3 parts.

### Tabular

All results for the tabular experiments could be reproduced by running the `tabular.ipynb`

notebook.

To open the notebook in Google Colab, use this link.

### CartPole

To train the OMD agent on CartPole, use the following commands:

```
cd cartpole
python train.py --agent_type omd
```

We also provide the implementation of the corresponding MLE and VEP baselines. To train the agents, change the `--agent_type`

flag to `mle`

or `vep`

.

### MuJoCo

To train the OMD agent on MuJoCo HalfCheetah-v2, use the following commands:

```
cd mujoco
python train.py --config.algo=omd
```

To train the MLE baseline, change the `--config.algo`

flag to `mle`

.