Neural Network Library written in Python and built on top of JAX, an open-source high-performance deep learning library.

Packages used

  • JAX for automatic differentiation.
  • Mypy for static typing Python3 code.
  • Matplotlib for plotting.
  • Pandas for data analysis / manipulation.
  • tqdm for displaying progress bar.


  • Enables high-performance machine learning research.
  • Easy to use with high-level Keras-like APIs.
  • Runs seamlessly on GPU and even TPU!.

Getting started

Here's the Sequential model :

model = Sequential()

Add the fully-connected layers / densely-connected layers :

model.add(FC(units=500, activation="mish"))
model.add(FC(units=10, activation="relu"))
model.add(FC(units=1, activation="sigmoid"))

Compile the model with the hyperparameters :

model.compile(loss="binary_crossentropy", optimizer="sgd", lr=1e-02)

Train the model (with validation data) :

model.fit(x_train, y_train, epochs=50, validation_data=(x_val, y_val)

Plot the loss curves :


Toy Example


from pathlib import Path

import jax.numpy as tensor
import pandas as pd

from dnet.layers import FC
from dnet.nn import Sequential

dataset_path = Path("datasets")
train_path = dataset_path / "mnist_small" / "mnist_train_small.csv"
test_path = dataset_path / "mnist_small" / "mnist_test.csv"

training_data = pd.read_csv(train_path, header=None)
training_data = training_data.loc[training_data[0].isin([0, 1])]

y_train = tensor.array(training_data[0].values.reshape(-1, 1))  # shape : (m, 1)
x_train = tensor.array(training_data.iloc[:, 1:].values) / 255.0  # shape = (m, n)

testing_data = pd.read_csv(test_path, header=None)
testing_data = testing_data.loc[testing_data[0].isin([0, 1])]

y_val = tensor.array(testing_data[0].values.reshape(-1, 1))  # shape : (m, 1)
x_val = tensor.array(testing_data.iloc[:, 1:].values) / 255.0  # shape = (m, n)

model = Sequential()
model.add(FC(units=500, activation="mish", input_dim=784))
model.add(FC(units=10, activation="relu"))
model.add(FC(units=1, activation="sigmoid"))
model.compile(loss="binary_crossentropy", optimizer="sgd", lr=1e-02)
model.fit(inputs=x_train, targets=y_train, epochs=50, validation_data=(x_val, y_val))



/usr/local/bin/python3.7 DNet/test.py
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/jax/lib/xla_bridge.py:120: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
Training your model: 100%|██████████| 50/50 [00:02<00:00, 17.21it/s]

Toy example loss curves

Process finished with exit code 0