PIX is an image processing library in JAX, for JAX.

JAX is a library resulting from the union of Autograd and XLA for high-performance machine learning research. It provides NumPy, SciPy, automatic differentiation and first-class GPU/TPU support.

PIX is a library built on top of JAX with the goal of providing image processing functions and tools to JAX in a way that they can be optimised and parallelised through jax.jit, jax.vmap and jax.pmap.


PIX is written in pure Python, but depends on C++ code via JAX.

Because JAX installation is different depending on your CUDA version, PIX does
not list JAX as a dependency in [requirements.txt], although it is technically
listed for reference, but commented.

First, follow [JAX installation instructions] to install JAX with the relevant
accelerator support.

Then, install PIX using pip:

$ pip install git+https://github.com/deepmind/dm_pix


To use PIX, you just need to import dm_pix as pix and use it right away!

For example, let's assume to have loaded the JAX logo (available in
examples/assets/jax_logo.jpg) in a variable called image and we want to flip
it left to right.


All it's needed is the following code!

import dm_pix as pix

# Load an image into a NumPy array with your preferred library.
image = load_image()

flip_left_right_image = pix.flip_left_right(image)

And here is the result!


All the functions in PIX can be [jax.jit][jit]ed, [jax.vmap][vmap]ed and
[jax.pmap][pmap]ed, so all the following functions can take advantage of
optimization and parallelization.

import dm_pix as pix
import jax

# Load an image into a NumPy array with your preferred library.
image = load_image()

# Vanilla Python function.
flip_left_right_image = pix.flip_left_right(image)

# `jax.jit`ed function.
flip_left_right_image = jax.jit(pix.flip_left_right)(image)

# Assuming to have a single device, like a CPU or a single GPU, we add a
# single leading dimension for using `image` with the parallelized or
# the multi-device parallelization version of `pix.flip_left_right`.
# To know more, please refer to JAX documentation of `jax.vmap` and `jax.pmap`.
image = image[np.newaxis, ...]

# `jax.vmap`ed function.
flip_left_right_image = jax.vmap(pix.flip_left_right)(image)

# `jax.pmap`ed function.
flip_left_right_image = jax.pmap(pix.flip_left_right)(image)

You can check it yourself that the result from the four versions of
pix.flip_left_right is the same (up to the accelerator floating point


We have a few examples in the [examples/] folder. They are not much
more involved then the previous example, but they may be a good starting point
for you!


We provide a suite of tests to help you both testing your development
environment and to know more about the library itself! All test files have
_test suffix, and can be executed using pytest.

If you already have PIX installed, you just need to install some extra
dependencies and run pytest as follows:

$ pip install -r requirements_tests.txt
$ python -m pytest [-n <NUMCPUS>] dm_pix

If you want an isolated virtual environment, you just need to run our utility
bash script as follows:

$ ./test.sh