emcee-jax

An experiement.

A simple example:

import jax
from emcee_jax.sampler import build_sampler

def log_prob(theta, a1=100.0, a2=20.0):
    x1, x2 = theta
    return -(a1 * (x2 - x1**2)**2 + (1 - x1)**2) / a2

num_walkers, num_steps = 100, 1000
key1, key2 = jax.random.split(jax.random.PRNGKey(0))
coords = jax.random.normal(key1, shape=(num_walkers, 2))
sample = build_sampler(log_prob)
trace = sample(key2, coords, steps=num_steps)

An example using PyTrees as input coordinates:

import jax
from emcee_jax.sampler import build_sampler

def log_prob(theta, a1=100.0, a2=20.0):
    x1, x2 = theta["x"], theta["y"]
    return -(a1 * (x2 - x1**2)**2 + (1 - x1)**2) / a2

num_walkers, num_steps = 100, 1000
key1, key2, key3 = jax.random.split(jax.random.PRNGKey(0), 3)
coords = {
    "x": jax.random.normal(key1, shape=(num_walkers,)),
    "y": jax.random.normal(key2, shape=(num_walkers,)),
}
sample = build_sampler(log_prob)
trace = sample(key3, coords, steps=num_steps)

An example that includes deterministics:

import jax
from emcee_jax.sampler import build_sampler

def log_prob(theta, a1=100.0, a2=20.0):
    x1, x2 = theta
    some_number = x1 + jax.numpy.sin(x2)
    log_prob_value = -(a1 * (x2 - x1**2)**2 + (1 - x1)**2) / a2

    # This second argument can be any PyTree
    return log_prob_value, {"some_number": some_number}

num_walkers, num_steps = 100, 1000
key1, key2 = jax.random.split(jax.random.PRNGKey(0))
coords = jax.random.normal(key1, shape=(num_walkers, 2))
sample = build_sampler(log_prob)
trace = sample(key2, coords, steps=num_steps)

You can even use pure-Python log probability functions:

import jax
import numpy as np
from emcee_jax.sampler import build_sampler
from emcee_jax.host_callback import wrap_python_log_prob_fn

# A log prob function that uses numpy, not jax.numpy inside
@wrap_python_log_prob_fn
def log_prob(theta, a1=100.0, a2=20.0):
    x1, x2 = theta
    return -(a1 * np.square(x2 - x1**2) + np.square(1 - x1)) / a2

num_walkers, num_steps = 100, 1000
key1, key2 = jax.random.split(jax.random.PRNGKey(0))
coords = jax.random.normal(key1, shape=(num_walkers, 2))
sample = build_sampler(log_prob)
trace = sample(key2, coords, steps=num_steps)

GitHub

View Github