Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions examples/wave/configs/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import ml_collections

import jax.numpy as jnp


def get_config():
"""Get the default hyperparameter configuration."""
config = ml_collections.ConfigDict()

config.mode = "train"

# Wave speed
config.c = 1.0

# Weights & Biases
config.wandb = wandb = ml_collections.ConfigDict()
wandb.project = "PINN-Wave"
wandb.name = "default"
wandb.tag = None

# Arch
config.arch = arch = ml_collections.ConfigDict()
arch.arch_name = "Mlp"
arch.num_layers = 4
arch.hidden_dim = 256
arch.out_dim = 1
arch.activation = "tanh"
arch.periodicity = None
arch.fourier_emb = ml_collections.ConfigDict({"embed_scale": 1, "embed_dim": 256})
arch.reparam = ml_collections.ConfigDict(
{"type": "weight_fact", "mean": 0.5, "stddev": 0.1}
)

# Optim
config.optim = optim = ml_collections.ConfigDict()
optim.grad_accum_steps = 0
optim.optimizer = "Adam"
optim.beta1 = 0.9
optim.beta2 = 0.999
optim.eps = 1e-8
optim.learning_rate = 1e-3
optim.decay_rate = 0.9
optim.decay_steps = 2000

# Training
config.training = training = ml_collections.ConfigDict()
training.max_steps = 200000
training.batch_size_per_device = 4096

# Weighting
config.weighting = weighting = ml_collections.ConfigDict()
weighting.scheme = "grad_norm"
weighting.init_weights = ml_collections.ConfigDict(
{"ics": 1.0, "ics_vel": 1.0, "bcs": 1.0, "res": 1.0}
)
weighting.momentum = 0.9
weighting.update_every_steps = 1000

weighting.use_causal = True
weighting.causal_tol = 1.0
weighting.num_chunks = 32

# Logging
config.logging = logging = ml_collections.ConfigDict()
logging.log_every_steps = 100
logging.log_errors = True
logging.log_losses = True
logging.log_weights = True
logging.log_preds = False
logging.log_grads = False
logging.log_ntk = False

# Saving
config.saving = saving = ml_collections.ConfigDict()
saving.save_every_steps = None
saving.num_keep_ckpts = 10

# Input shape for initializing Flax models
config.input_dim = 2

# Integer for PRNG random seed.
config.seed = 42

return config
Binary file added examples/wave/data/wave.mat
Binary file not shown.
66 changes: 66 additions & 0 deletions examples/wave/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os

import ml_collections

import jax.numpy as jnp

import matplotlib.pyplot as plt

from jaxpi.utils import restore_checkpoint

import models
from utils import get_dataset


def evaluate(config: ml_collections.ConfigDict, workdir: str):
u_ref, t_star, x_star = get_dataset()
u0 = u_ref[0, :]

# Restore model
model = models.Wave(config, u0, t_star, x_star, c=config.c)
ckpt_path = os.path.join(workdir, "ckpt", config.wandb.name)
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checkpoint restore path (workdir/ckpt/<wandb.name>) does not match the path used by train.py (cwd/<wandb.name>/ckpt). With default flags this will fail to restore immediately after training. Align the restore path with the save path once the training side is corrected (ideally both use workdir/ckpt/<wandb.name>).

Suggested change
ckpt_path = os.path.join(workdir, "ckpt", config.wandb.name)
ckpt_path = os.path.join(workdir, config.wandb.name, "ckpt")

Copilot uses AI. Check for mistakes.
model.state = restore_checkpoint(model.state, ckpt_path)
params = model.state.params

# Compute L2 error
l2_error = model.compute_l2_error(params, u_ref)
print("L2 error: {:.3e}".format(l2_error))

u_pred = model.u_pred_fn(params, model.t_star, model.x_star)
TT, XX = jnp.meshgrid(t_star, x_star, indexing="ij")

# Plot
fig = plt.figure(figsize=(18, 5))
plt.subplot(1, 3, 1)
plt.pcolor(TT, XX, u_ref, cmap="jet")
plt.colorbar()
plt.xlabel("t")
plt.ylabel("x")
plt.title("Exact")
plt.tight_layout()

plt.subplot(1, 3, 2)
plt.pcolor(TT, XX, u_pred, cmap="jet")
plt.colorbar()
plt.xlabel("t")
plt.ylabel("x")
plt.title("Predicted")
plt.tight_layout()

plt.subplot(1, 3, 3)
plt.pcolor(TT, XX, jnp.abs(u_ref - u_pred), cmap="jet")
plt.colorbar()
plt.xlabel("t")
plt.ylabel("x")
plt.title("Absolute error")
plt.tight_layout()

# Save the figure
save_dir = os.path.join(workdir, "figures", config.wandb.name)
if not os.path.isdir(save_dir):
os.makedirs(save_dir)

fig_path = os.path.join(save_dir, "wave.png")
fig.savefig(fig_path, bbox_inches="tight", dpi=300)
print(f"Figure saved to {fig_path}")
plt.close()
29 changes: 29 additions & 0 deletions examples/wave/generate_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Generate reference solution for the 1D wave equation.

Solves: u_tt = c^2 * u_xx on [0, 1] x [0, 1]
BCs: u(0, t) = u(1, t) = 0 (fixed ends)
ICs: u(x, 0) = sin(pi * x), u_t(x, 0) = 0
Exact: u(x, t) = sin(pi * x) * cos(pi * c * t)
"""

import numpy as np
import scipy.io as sio
import os

c = 1.0 # wave speed

nx = 256
nt = 201

x_star = np.linspace(0, 1, nx)
t_star = np.linspace(0, 1, nt)

TT, XX = np.meshgrid(t_star, x_star, indexing="ij")
usol = np.sin(np.pi * XX) * np.cos(np.pi * c * TT)

save_path = os.path.join(os.path.dirname(__file__), "data", "wave.mat")
sio.savemat(save_path, {"usol": usol, "t": t_star.reshape(-1, 1), "x": x_star.reshape(-1, 1)})
print(f"Saved reference solution to {save_path}")
print(f" usol shape: {usol.shape}")
print(f" t shape: {t_star.shape}")
print(f" x shape: {x_star.shape}")
41 changes: 41 additions & 0 deletions examples/wave/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os

# Deterministic
os.environ["TF_CUDNN_DETERMINISTIC"] = "1"

from absl import app
from absl import flags
from absl import logging

from ml_collections import config_flags

import jax

jax.config.update("jax_default_matmul_precision", "highest")

import train
import eval


FLAGS = flags.FLAGS

flags.DEFINE_string("workdir", ".", "Directory to store model data.")

config_flags.DEFINE_config_file(
"config",
"./configs/default.py",
"File path to the training hyperparameter configuration.",
lock_config=True,
)


def main(argv):
if FLAGS.config.mode == "train":
train.train_and_evaluate(FLAGS.config, FLAGS.workdir)

elif FLAGS.config.mode == "eval":
eval.evaluate(FLAGS.config, FLAGS.workdir)


if __name__ == "__main__":
app.run(main)
Comment on lines +38 to +41
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This main entrypoint is missing flags.mark_flags_as_required(["config", "workdir"]) which is present in other examples’ main.py files. Without it, the example can run with unintended defaults and diverges from the repo’s established example CLI pattern.

Copilot uses AI. Check for mistakes.
165 changes: 165 additions & 0 deletions examples/wave/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from functools import partial

import jax.numpy as jnp
from jax import lax, jit, grad, vmap

from jaxpi.models import ForwardIVP
from jaxpi.evaluator import BaseEvaluator
from jaxpi.utils import ntk_fn, flatten_pytree

from matplotlib import pyplot as plt


class Wave(ForwardIVP):
def __init__(self, config, u0, t_star, x_star, c=1.0):
super().__init__(config)

self.u0 = u0
self.t_star = t_star
self.x_star = x_star
self.c = c

self.t0 = t_star[0]
self.t1 = t_star[-1]

# Predictions over a grid
self.u_pred_fn = vmap(vmap(self.u_net, (None, None, 0)), (None, 0, None))
self.r_pred_fn = vmap(vmap(self.r_net, (None, None, 0)), (None, 0, None))

def u_net(self, params, t, x):
z = jnp.stack([t, x])
u = self.state.apply_fn(params, z)
return u[0]

def r_net(self, params, t, x):
"""Residual: u_tt - c^2 * u_xx = 0"""
u_tt = grad(grad(self.u_net, argnums=1), argnums=1)(params, t, x)
u_xx = grad(grad(self.u_net, argnums=2), argnums=2)(params, t, x)
return u_tt - self.c ** 2 * u_xx

@partial(jit, static_argnums=(0,))
def res_and_w(self, params, batch):
"""Compute residuals and weights for causal training."""
t_sorted = batch[:, 0].sort()
r_pred = vmap(self.r_net, (None, 0, 0))(params, t_sorted, batch[:, 1])
r_pred = r_pred.reshape(self.num_chunks, -1)
l = jnp.mean(r_pred ** 2, axis=1)
w = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ l)))
return l, w

@partial(jit, static_argnums=(0,))
def losses(self, params, batch):
# Initial displacement: u(x, 0) = u0(x)
u_pred = vmap(self.u_net, (None, None, 0))(params, self.t0, self.x_star)
ics_loss = jnp.mean((self.u0 - u_pred) ** 2)

# Initial velocity: u_t(x, 0) = 0
u_t_pred = vmap(
grad(self.u_net, argnums=1), (None, None, 0)
)(params, self.t0, self.x_star)
ics_vel_loss = jnp.mean(u_t_pred ** 2)

# Boundary conditions: u(0, t) = u(1, t) = 0
u_left = vmap(self.u_net, (None, 0, None))(
params, self.t_star, self.x_star[0]
)
u_right = vmap(self.u_net, (None, 0, None))(
params, self.t_star, self.x_star[-1]
)
bcs_loss = jnp.mean(u_left ** 2) + jnp.mean(u_right ** 2)

# Residual loss
if self.config.weighting.use_causal:
l, w = self.res_and_w(params, batch)
res_loss = jnp.mean(l * w)
else:
r_pred = vmap(self.r_net, (None, 0, 0))(
params, batch[:, 0], batch[:, 1]
)
res_loss = jnp.mean(r_pred ** 2)

loss_dict = {
"ics": ics_loss,
"ics_vel": ics_vel_loss,
"bcs": bcs_loss,
"res": res_loss,
}
return loss_dict

@partial(jit, static_argnums=(0,))
def compute_diag_ntk(self, params, batch):
ics_ntk = vmap(ntk_fn, (None, None, None, 0))(
self.u_net, params, self.t0, self.x_star
)

ics_vel_ntk = vmap(
ntk_fn,
(None, None, None, 0),
)(grad(self.u_net, argnums=1), params, self.t0, self.x_star)

bcs_left_ntk = vmap(ntk_fn, (None, None, 0, None))(
self.u_net, params, self.t_star, self.x_star[0]
)
bcs_right_ntk = vmap(ntk_fn, (None, None, 0, None))(
self.u_net, params, self.t_star, self.x_star[-1]
)
bcs_ntk = jnp.concatenate([bcs_left_ntk, bcs_right_ntk])

if self.config.weighting.use_causal:
batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T
res_ntk = vmap(ntk_fn, (None, None, 0, 0))(
self.r_net, params, batch[:, 0], batch[:, 1]
)
res_ntk = res_ntk.reshape(self.num_chunks, -1)
res_ntk = jnp.mean(res_ntk, axis=1)
_, causal_weights = self.res_and_w(params, batch)
res_ntk = res_ntk * causal_weights
else:
res_ntk = vmap(ntk_fn, (None, None, 0, 0))(
self.r_net, params, batch[:, 0], batch[:, 1]
)

ntk_dict = {
"ics": ics_ntk,
"ics_vel": ics_vel_ntk,
"bcs": bcs_ntk,
"res": res_ntk,
}
return ntk_dict

@partial(jit, static_argnums=(0,))
def compute_l2_error(self, params, u_ref):
u_pred = self.u_pred_fn(params, self.t_star, self.x_star)
error = jnp.linalg.norm(u_pred - u_ref) / jnp.linalg.norm(u_ref)
return error


class WaveEvaluator(BaseEvaluator):
def __init__(self, config, model):
super().__init__(config, model)

def log_errors(self, params, u_ref):
l2_error = self.model.compute_l2_error(params, u_ref)
self.log_dict["l2_error"] = l2_error

def log_preds(self, params):
u_pred = self.model.u_pred_fn(
params, self.model.t_star, self.model.x_star
)
fig = plt.figure(figsize=(6, 5))
plt.imshow(u_pred.T, aspect="auto", origin="lower", cmap="jet")
plt.colorbar()
self.log_dict["u_pred"] = fig
plt.close()

def __call__(self, state, batch, u_ref):
self.log_dict = {}
self.log_dict = super().__call__(state, batch)

if self.config.logging.log_errors:
self.log_errors(state.params, u_ref)

if self.config.logging.log_preds:
self.log_preds(state.params)

return self.log_dict
Loading
Loading