Composable Second-Order Optimization for JAX and Optax.
A small research-engineering library for curvature-aware training: modular, matrix-free, and explicit about the moving parts.
Somax is a JAX-native library for building and running second-order optimization methods from explicit components.
Rather than treating an optimizer as a monolithic object, Somax factors a step into swappable pieces:
- curvature operator
- solver
- damping policy
- optional preconditioner
- update transform
- optional telemetry and control signals
That decomposition is the point.
Somax is built for users who want a clean second-order stack in JAX without hiding the execution model. It aims to make curvature-aware training easier to inspect, compare, and extend.
The catfish in the logo is a small nod to som, the Belarusian word for catfish. A quiet bottom-dweller, but not a first-order creature.
- Composable: build methods from curvature, solver, damping, preconditioner, and update components.
- Optax-native: computed directions are fed through Optax-style update transforms.
- Planned execution: a method is assembled once, planned once, and then executed as a stable step pipeline.
- JAX-first: intended for
jit-compiled training loops and explicit control over execution. - Multiple solve lanes: diagonal, parameter-space, and row-space paths are first-class parts of the stack.
- Research-friendly: easy to inspect, compare, ablate, and extend.
Install JAX for your backend first:
- JAX installation guide: https://docs.jax.dev/en/latest/installation.html
Then install Somax:
pip install python-somaxFor local development:
git clone https://github.com/cor3bit/somax.git
cd somax
pip install -e ".[dev]"Optional:
- install
lineaxonly if you want to use CG backends withbackend="lineax".
import jax
import jax.numpy as jnp
import somax
def predict_fn(params, x):
h = jnp.tanh(x @ params["W1"] + params["b1"])
return h @ params["W2"] + params["b2"]
rng = jax.random.PRNGKey(0)
k1, k2, k3, k4 = jax.random.split(rng, 4)
params = {
"W1": 0.1 * jax.random.normal(k1, (16, 32)),
"b1": jnp.zeros((32,)),
"W2": 0.1 * jax.random.normal(k2, (32, 10)),
"b2": jnp.zeros((10,)),
}
batch = {
"x": jax.random.normal(k3, (64, 16)),
"y": jax.random.randint(k4, (64,), 0, 10),
}
method = somax.sgn_ce(
predict_fn=predict_fn,
lam0=1e-2,
tol=1e-4,
maxiter=20,
learning_rate=1e-1,
)
state = method.init(params)
@jax.jit
def train_step(params, state, rng):
params, state, info = method.step(params, batch, state, rng)
return params, state, info
for step in range(10):
params, state, info = train_step(params, state, jax.random.fold_in(rng, step))If Somax is useful in your academic work, please cite:
Second-Order, First-Class: A Composable Stack for Curvature-Aware Training
Mikalai Korbit and Mario Zanon
https://arxiv.org/abs/2603.25976
@article{korbit2026second,
title={Second-Order, First-Class: A Composable Stack for Curvature-Aware Training},
author={Korbit, Mikalai and Zanon, Mario},
journal={arXiv preprint arXiv:2603.25976},
year={2026}
}Optimization in JAX
Optax: first-order gradient (e.g., SGD, Adam) optimisers.
JAXopt: deterministic second-order methods (e.g., Gauss-Newton, Levenberg
Marquardt), stochastic first-order methods PolyakSGD, ArmijoSGD.
Awesome Projects
Awesome JAX: a longer list of various JAX projects.
Awesome SOMs: a list
of resources for second-order optimization methods in machine learning.
Apache-2.0
