Connex

Connex is a small JAX library built on Equinox inspired by two features of biological neural networks:

  • Complex Connectivity: Turn any directed acyclic graph (DAG) into a trainable neural network.
  • Plasticity: Add and remove both connections and neurons at the individual level.

Installation

pip install connex

Requires Python 3.7+, JAX 0.3.4+, and Equinox 0.5.2+.

Documentation

Available here.

Usage

Suppose we would like to create a trainable neural network from the following DAG

dag

with input neuron 0 and output neurons 3 and 11 (in that order), with a jax.nn.relu activation function for the hidden neurons:

import connex as cnx
import jax.nn as jnn

# Specify number of neurons
num_neurons = 12

# Build the adjacency dict
adjacency_dict = {
    0: [1, 2, 3],
    1: [4],
    2: [4, 5],
    4: [6],
    5: [7],
    6: [8, 9],
    7: [10],
    8: [11],
    9: [11],
    10: [11]
}

# Specify the input and output neurons
input_neurons = [0]
output_neurons = [3, 11]

# Create the network
network = cnx.NeuralNetwork(
    num_neurons,
    adjacency_dict, 
    input_neurons, 
    output_neurons,
    jnn.relu
)

That’s it! A connex.NeuralNetwork is a subclass of equinox.Module, so it can be trained in the same fashion:

import equinox as eqx
import jax
import jax.numpy as jnp
import optax

# Initialize the optimizer
optim = optax.adam(1e-3)
opt_state = optim.init(eqx.filter(network, eqx.is_array))

# Define the loss function
def loss_fn(model, x, y):
    preds = jax.vmap(model)(x)
    return jnp.mean((preds - y) ** 2)

# Define a single training step
@eqx.filter_jit
def step(model, optim, opt_state, X_batch, y_batch):
    loss, grads = eqx.filter_value_and_grad(loss_fn)(model, X_batch, y_batch)
    updates, opt_state = optim.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss

# Toy data
X = jnp.expand_dims(jnp.linspace(0, 2 * jnp.pi, 250), 1)
y = jnp.hstack((jnp.cos(X), jnp.sin(X)))

# Training loop
n_epochs = 1000
for _ in range(n_epochs):
    network, opt_state, loss = step(network, optim, opt_state, X, y)

Now suppose we wish to add connections 1 → 6 and 2 → 11, and remove neuron 9:

# Add connections
network = cnx.add_connections(network, [(1, 6), (2, 11)])

# Remove neuron
network, _ = cnx.remove_neurons(network, [9])

That’s all there is to it. The parameters have been retained for neurons in the original network that have not been removed. connex.remove_neurons also returns auxiliary information about neuron ids, since removal of neurons re-numbers the remaining ones.

For more information about manipulating connectivity structure and the NeuralNetwork base class, please see the API section of the documentation. For examples of subclassing NeuralNetwork, please see connex.nn.

Feedback is greatly appeciated!

GitHub

View Github