# jaxfg

Factor graph-based nonlinear least squares library for JAX.

The premise: we provide a high-level interface for describing nonlinear
optimization problems as probabilistic factor graphs. jaxfg then exploits the
graph structure for accelerating optimization. Repeated factor and variable
types have operations vectorized, and the sparsity of graph connections is
leveraged for sparse matrix operations.


  • Autodiff-powered sparse Jacobians.
  • Automatic vectorization for repeated factor and variable types.
  • Manifold definition interface, with implementations for SO(2), SE(2), SO(3),
    and SE(3) Lie groups.
  • Support for standard JAX function transformations: jit, vmap, pmap,
    grad, etc.
  • Nonlinear optimizers: Gauss-Newton, Levenberg-Marquardt, Dogleg.
  • Sparse linear solvers: conjugate gradient (Jacobi-preconditioned), sparse
    cholesky (via CHOLMOD).

Borrows heavily from a wide set of existing libraries, including:
Ceres Solver,
g2o, GTSAM,


scikit-sparse require SuiteSparse:

sudo apt update
sudo apt install -y libsuitesparse-dev

Then, from your environment of choice:

git clone https://github.com/brentyi/jaxfg.git
cd jaxfg
pip install -e .

Example scripts

Toy pose graph optimization:

python scripts/pose_graph_simple.py

Pose graph optimization from .g2o files:

python scripts/pose_graph_g2o.py  # For options, pass in a --help flag


Engineering notes

We currently take a "make everything a dataclass" philosophy for software
engineering in this library. This is convenient for several reasons, but notably
makes it easy for objects to be registered as pytree nodes in JAX. See
jax_dataclasses for details on

In XLA, JIT compilation needs to happen for each unique set of input shapes.
This is a core liimitation, as it can introduce significant re-compilation
overheads when graph structures are modified. Restricts dynamic and online


This library's still in development mode! Here's our TODO list:

  • [x] Preliminary graph, variable, factor interfaces
  • [x] Real vector variable types
  • [x] Refactor into package
  • [x] Nonlinear optimization for MAP inference
    • [x] Conjugate gradient linear solver
    • [x] CHOLMOD linear solver
      • [x] Basic implementation. JIT-able, but no vmap, pmap, or autodiff
      • [ ] Custom VJP rule? vmap support?
    • [x] Gauss-Newton implementation
    • [x] Termination criteria
    • [x] Damped least squares
    • [x] Dogleg
    • [x] Inexact Newton steps
    • [x] Revisit termination criteria
    • [x] Reduce redundant code
    • [ ] Robust losses
  • [x] Marginalization
    • [x] Working prototype using sksparse/CHOLMOD
    • [ ] JAX implementation?
  • [x] Validate g2o example
  • [x] Performance
    • [x] More intentional JIT compilation
    • [x] Re-implement parallel factor computation
    • [x] Vectorized linearization
    • [x] Basic (Jacobi) CGLS preconditioning
  • [x] Manifold optimization (mostly offloaded to
    • [x] Basic interface
    • [x] Manifold optimization on SO2
    • [x] Manifold optimization on SE2
    • [x] Manifold optimization on SO3
    • [x] Manifold optimization on SE3
  • [ ] Usability + code health (low priority)
    • [x] Basic cleanup/refactor
      • [x] Better parallel factor interface
      • [x] Separate out utils, lie group helpers
      • [x] Put things in folders
    • [x] Resolve typing errors
    • [x] Cleanup/refactor (more)
    • [x] Package cleanup: dependencies, etc
    • [x] Add CI:
      • [x] mypy
      • [x] lint
      • [x] build
      • [ ] coverage
    • [ ] More comprehensive tests
    • [ ] Clean up docstrings
    • [ ] Come up with a better name than "jaxfg"