PIX is an image processing library in JAX, for JAX.
JAX is a library resulting from the union of Autograd and XLA for high-performance machine learning research. It provides NumPy, SciPy, automatic differentiation and first-class GPU/TPU support.
PIX is a library built on top of JAX with the goal of providing image processing functions and tools to JAX in a way that they can be optimised and parallelised through jax.jit, jax.vmap and jax.pmap.
PIX is written in pure Python, but depends on C++ code via JAX.
Because JAX installation is different depending on your CUDA version, PIX does
not list JAX as a dependency in [
requirements.txt], although it is technically
listed for reference, but commented.
First, follow [JAX installation instructions] to install JAX with the relevant
Then, install PIX using
$ pip install git+https://github.com/deepmind/dm_pix
PIX, you just need to
import dm_pix as pix and use it right away!
For example, let's assume to have loaded the JAX logo (available in
examples/assets/jax_logo.jpg) in a variable called
image and we want to flip
it left to right.
All it's needed is the following code!
import dm_pix as pix # Load an image into a NumPy array with your preferred library. image = load_image() flip_left_right_image = pix.flip_left_right(image)
And here is the result!
All the functions in PIX can be [
jax.pmap][pmap]ed, so all the following functions can take advantage of
optimization and parallelization.
import dm_pix as pix import jax # Load an image into a NumPy array with your preferred library. image = load_image() # Vanilla Python function. flip_left_right_image = pix.flip_left_right(image) # `jax.jit`ed function. flip_left_right_image = jax.jit(pix.flip_left_right)(image) # Assuming to have a single device, like a CPU or a single GPU, we add a # single leading dimension for using `image` with the parallelized or # the multi-device parallelization version of `pix.flip_left_right`. # To know more, please refer to JAX documentation of `jax.vmap` and `jax.pmap`. image = image[np.newaxis, ...] # `jax.vmap`ed function. flip_left_right_image = jax.vmap(pix.flip_left_right)(image) # `jax.pmap`ed function. flip_left_right_image = jax.pmap(pix.flip_left_right)(image)
You can check it yourself that the result from the four versions of
pix.flip_left_right is the same (up to the accelerator floating point
We have a few examples in the [
examples/] folder. They are not much
more involved then the previous example, but they may be a good starting point
We provide a suite of tests to help you both testing your development
environment and to know more about the library itself! All test files have
_test suffix, and can be executed using
If you already have PIX installed, you just need to install some extra
dependencies and run
pytest as follows:
$ pip install -r requirements_tests.txt $ python -m pytest [-n <NUMCPUS>] dm_pix
If you want an isolated virtual environment, you just need to run our utility
bash script as follows: