PIX
PIX is an image processing library in JAX, for JAX.
JAX is a library resulting from the union of Autograd and XLA for high-performance machine learning research. It provides NumPy, SciPy, automatic differentiation and first-class GPU/TPU support.
PIX is a library built on top of JAX with the goal of providing image processing functions and tools to JAX in a way that they can be optimised and parallelised through jax.jit, jax.vmap and jax.pmap.
Installation
PIX is written in pure Python, but depends on C++ code via JAX.
Because JAX installation is different depending on your CUDA version, PIX does
not list JAX as a dependency in [requirements.txt
], although it is technically
listed for reference, but commented.
First, follow [JAX installation instructions] to install JAX with the relevant
accelerator support.
Then, install PIX using pip
:
$ pip install git+https://github.com/deepmind/dm_pix
Quickstart
To use PIX
, you just need to import dm_pix as pix
and use it right away!
For example, let's assume to have loaded the JAX logo (available in
examples/assets/jax_logo.jpg
) in a variable called image
and we want to flip
it left to right.
All it's needed is the following code!
import dm_pix as pix
# Load an image into a NumPy array with your preferred library.
image = load_image()
flip_left_right_image = pix.flip_left_right(image)
And here is the result!
All the functions in PIX can be [jax.jit
][jit]ed, [jax.vmap
][vmap]ed and
[jax.pmap
][pmap]ed, so all the following functions can take advantage of
optimization and parallelization.
import dm_pix as pix
import jax
# Load an image into a NumPy array with your preferred library.
image = load_image()
# Vanilla Python function.
flip_left_right_image = pix.flip_left_right(image)
# `jax.jit`ed function.
flip_left_right_image = jax.jit(pix.flip_left_right)(image)
# Assuming to have a single device, like a CPU or a single GPU, we add a
# single leading dimension for using `image` with the parallelized or
# the multi-device parallelization version of `pix.flip_left_right`.
# To know more, please refer to JAX documentation of `jax.vmap` and `jax.pmap`.
image = image[np.newaxis, ...]
# `jax.vmap`ed function.
flip_left_right_image = jax.vmap(pix.flip_left_right)(image)
# `jax.pmap`ed function.
flip_left_right_image = jax.pmap(pix.flip_left_right)(image)
You can check it yourself that the result from the four versions of
pix.flip_left_right
is the same (up to the accelerator floating point
accuracy)!
Examples
We have a few examples in the [examples/
] folder. They are not much
more involved then the previous example, but they may be a good starting point
for you!
Testing
We provide a suite of tests to help you both testing your development
environment and to know more about the library itself! All test files have
_test
suffix, and can be executed using pytest
.
If you already have PIX installed, you just need to install some extra
dependencies and run pytest
as follows:
$ pip install -r requirements_tests.txt
$ python -m pytest [-n <NUMCPUS>] dm_pix
If you want an isolated virtual environment, you just need to run our utility
bash
script as follows:
$ ./test.sh