# Connex

Connex is a small JAX library built on Equinox inspired by two features of biological neural networks:

**Complex Connectivity**: Turn any directed acyclic graph (DAG) into a trainable neural network.**Plasticity**: Add and remove both connections and neurons at the individual level.

## Installation

`pip install connex`

Requires Python 3.7+, JAX 0.3.4+, and Equinox 0.5.2+.

## Documentation

Available here.

## Usage

Suppose we would like to create a trainable neural network from the following DAG

with input neuron 0 and output neurons 3 and 11 (in that order), with a `jax.nn.relu`

activation function for the hidden neurons:

```
import connex as cnx
import jax.nn as jnn
# Specify number of neurons
num_neurons = 12
# Build the adjacency dict
adjacency_dict = {
0: [1, 2, 3],
1: [4],
2: [4, 5],
4: [6],
5: [7],
6: [8, 9],
7: [10],
8: [11],
9: [11],
10: [11]
}
# Specify the input and output neurons
input_neurons = [0]
output_neurons = [3, 11]
# Create the network
network = cnx.NeuralNetwork(
num_neurons,
adjacency_dict,
input_neurons,
output_neurons,
jnn.relu
)
```

That’s it! A `connex.NeuralNetwork`

is a subclass of `equinox.Module`

, so it can be trained in the same fashion:

```
import equinox as eqx
import jax
import jax.numpy as jnp
import optax
# Initialize the optimizer
optim = optax.adam(1e-3)
opt_state = optim.init(eqx.filter(network, eqx.is_array))
# Define the loss function
def loss_fn(model, x, y):
preds = jax.vmap(model)(x)
return jnp.mean((preds - y) ** 2)
# Define a single training step
@eqx.filter_jit
def step(model, optim, opt_state, X_batch, y_batch):
loss, grads = eqx.filter_value_and_grad(loss_fn)(model, X_batch, y_batch)
updates, opt_state = optim.update(grads, opt_state, model)
model = eqx.apply_updates(model, updates)
return model, opt_state, loss
# Toy data
X = jnp.expand_dims(jnp.linspace(0, 2 * jnp.pi, 250), 1)
y = jnp.hstack((jnp.cos(X), jnp.sin(X)))
# Training loop
n_epochs = 1000
for _ in range(n_epochs):
network, opt_state, loss = step(network, optim, opt_state, X, y)
```

Now suppose we wish to add connections 1 → 6 and 2 → 11, and remove neuron 9:

```
# Add connections
network = cnx.add_connections(network, [(1, 6), (2, 11)])
# Remove neuron
network, _ = cnx.remove_neurons(network, [9])
```

That’s all there is to it. The parameters have been retained for neurons in the original network that have not been removed. `connex.remove_neurons`

also returns auxiliary information about neuron ids, since removal of neurons re-numbers the remaining ones.

For more information about manipulating connectivity structure and the `NeuralNetwork`

base class, please see the API section of the documentation. For examples of subclassing `NeuralNetwork`

, please see `connex.nn`

.

Feedback is greatly appeciated!