CQL-JAX

This repository implements Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX (FLAX). Implementation is built on top of the SAC base of JAX-RL.

Usage

Install Dependencies-

pip install -r requirements.txt
pip install "jax[cuda111]<=0.21.1" -f https://storage.googleapis.com/jax-releases/jax_releases.html

Run CQL-

python train_offline.py --env_name=hopper-expert-v0 --min_q_weight=5

Please use the following values of min_q_weight on MuJoCo tasks to reproduce CQL results from IQL paper

Domain medium medium-replay medium-expert
walker 10 1 10
hopper 5 5 1
cheetah 90 80 100

For antmaze tasks min_q_weight=10 is found to work best.

In case of Out-Of Memory errors in JAX, try running with the following env variables-

XLA_PYTHON_CLIENT_MEM_FRACTION=0.80 python ...
XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 python ...

Performance & Runtime

Returns are more or less same as the torch implementation and comparable to IQL-

Task CQL(PyTorch) CQL(JAX) IQL
hopper-medium-v2 58.5 74.6 66.3
hopper-medium-replay-v2 95.0 92.1 94.7
hopper-medium-expert-v2 105.4 83.2 91.5
antmaze-umaze-v0 74.0 69.5 87.5
antmaze-umaze-diverse-v0 84.0 78.7 62.2
antmaze-medium-play-v0 61.2 14.2 71.2
antmaze-medium-diverse-v0 53.7 10.7 70.2
antmaze-large-play-v0 15.8 0.0 39.6
antmaze-large-diverse-v0 14.9 0.0 47.5

Wall-clock time averages to ~50 mins, improving over IQL paper’s 80 min CQL and closing the gap with IQL’s 20 min.

Task CQL(JAX) IQL
hopper-medium-v2 52 27
hopper-medium-replay-v2 54 30
hopper-medium-expert-v2 57 29

Time efficiency over the original torch implementation is more than 4 times.

For more offline RL algorithm implementations, check out the JAX-RL, IQL and rlkit repositories.

Citation

In case you use CQL-JAX for your research, please cite the following-

@misc{cqljax,
  author = {Suri, Karush},
  title = {{Conservative Q Learning in JAX.}},
  url = {https://github.com/karush17/cql-jax},
  year = {2021}
}

References

GitHub

View Github