diff --git a/examples/wave/configs/default.py b/examples/wave/configs/default.py new file mode 100644 index 00000000..ab9102af --- /dev/null +++ b/examples/wave/configs/default.py @@ -0,0 +1,84 @@ +import ml_collections + +import jax.numpy as jnp + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.mode = "train" + + # Wave speed + config.c = 1.0 + + # Weights & Biases + config.wandb = wandb = ml_collections.ConfigDict() + wandb.project = "PINN-Wave" + wandb.name = "default" + wandb.tag = None + + # Arch + config.arch = arch = ml_collections.ConfigDict() + arch.arch_name = "Mlp" + arch.num_layers = 4 + arch.hidden_dim = 256 + arch.out_dim = 1 + arch.activation = "tanh" + arch.periodicity = None + arch.fourier_emb = ml_collections.ConfigDict({"embed_scale": 1, "embed_dim": 256}) + arch.reparam = ml_collections.ConfigDict( + {"type": "weight_fact", "mean": 0.5, "stddev": 0.1} + ) + + # Optim + config.optim = optim = ml_collections.ConfigDict() + optim.grad_accum_steps = 0 + optim.optimizer = "Adam" + optim.beta1 = 0.9 + optim.beta2 = 0.999 + optim.eps = 1e-8 + optim.learning_rate = 1e-3 + optim.decay_rate = 0.9 + optim.decay_steps = 2000 + + # Training + config.training = training = ml_collections.ConfigDict() + training.max_steps = 200000 + training.batch_size_per_device = 4096 + + # Weighting + config.weighting = weighting = ml_collections.ConfigDict() + weighting.scheme = "grad_norm" + weighting.init_weights = ml_collections.ConfigDict( + {"ics": 1.0, "ics_vel": 1.0, "bcs": 1.0, "res": 1.0} + ) + weighting.momentum = 0.9 + weighting.update_every_steps = 1000 + + weighting.use_causal = True + weighting.causal_tol = 1.0 + weighting.num_chunks = 32 + + # Logging + config.logging = logging = ml_collections.ConfigDict() + logging.log_every_steps = 100 + logging.log_errors = True + logging.log_losses = True + logging.log_weights = True + logging.log_preds = False + logging.log_grads = False + logging.log_ntk = False + + # Saving + config.saving = saving = ml_collections.ConfigDict() + saving.save_every_steps = None + saving.num_keep_ckpts = 10 + + # Input shape for initializing Flax models + config.input_dim = 2 + + # Integer for PRNG random seed. + config.seed = 42 + + return config diff --git a/examples/wave/data/wave.mat b/examples/wave/data/wave.mat new file mode 100644 index 00000000..1d95c785 Binary files /dev/null and b/examples/wave/data/wave.mat differ diff --git a/examples/wave/eval.py b/examples/wave/eval.py new file mode 100644 index 00000000..85be8647 --- /dev/null +++ b/examples/wave/eval.py @@ -0,0 +1,66 @@ +import os + +import ml_collections + +import jax.numpy as jnp + +import matplotlib.pyplot as plt + +from jaxpi.utils import restore_checkpoint + +import models +from utils import get_dataset + + +def evaluate(config: ml_collections.ConfigDict, workdir: str): + u_ref, t_star, x_star = get_dataset() + u0 = u_ref[0, :] + + # Restore model + model = models.Wave(config, u0, t_star, x_star, c=config.c) + ckpt_path = os.path.join(workdir, "ckpt", config.wandb.name) + model.state = restore_checkpoint(model.state, ckpt_path) + params = model.state.params + + # Compute L2 error + l2_error = model.compute_l2_error(params, u_ref) + print("L2 error: {:.3e}".format(l2_error)) + + u_pred = model.u_pred_fn(params, model.t_star, model.x_star) + TT, XX = jnp.meshgrid(t_star, x_star, indexing="ij") + + # Plot + fig = plt.figure(figsize=(18, 5)) + plt.subplot(1, 3, 1) + plt.pcolor(TT, XX, u_ref, cmap="jet") + plt.colorbar() + plt.xlabel("t") + plt.ylabel("x") + plt.title("Exact") + plt.tight_layout() + + plt.subplot(1, 3, 2) + plt.pcolor(TT, XX, u_pred, cmap="jet") + plt.colorbar() + plt.xlabel("t") + plt.ylabel("x") + plt.title("Predicted") + plt.tight_layout() + + plt.subplot(1, 3, 3) + plt.pcolor(TT, XX, jnp.abs(u_ref - u_pred), cmap="jet") + plt.colorbar() + plt.xlabel("t") + plt.ylabel("x") + plt.title("Absolute error") + plt.tight_layout() + + # Save the figure + save_dir = os.path.join(workdir, "figures", config.wandb.name) + if not os.path.isdir(save_dir): + os.makedirs(save_dir) + + fig_path = os.path.join(save_dir, "wave.png") + fig.savefig(fig_path, bbox_inches="tight", dpi=300) + print(f"Figure saved to {fig_path}") + plt.close() diff --git a/examples/wave/generate_data.py b/examples/wave/generate_data.py new file mode 100644 index 00000000..e28a9714 --- /dev/null +++ b/examples/wave/generate_data.py @@ -0,0 +1,29 @@ +"""Generate reference solution for the 1D wave equation. + +Solves: u_tt = c^2 * u_xx on [0, 1] x [0, 1] +BCs: u(0, t) = u(1, t) = 0 (fixed ends) +ICs: u(x, 0) = sin(pi * x), u_t(x, 0) = 0 +Exact: u(x, t) = sin(pi * x) * cos(pi * c * t) +""" + +import numpy as np +import scipy.io as sio +import os + +c = 1.0 # wave speed + +nx = 256 +nt = 201 + +x_star = np.linspace(0, 1, nx) +t_star = np.linspace(0, 1, nt) + +TT, XX = np.meshgrid(t_star, x_star, indexing="ij") +usol = np.sin(np.pi * XX) * np.cos(np.pi * c * TT) + +save_path = os.path.join(os.path.dirname(__file__), "data", "wave.mat") +sio.savemat(save_path, {"usol": usol, "t": t_star.reshape(-1, 1), "x": x_star.reshape(-1, 1)}) +print(f"Saved reference solution to {save_path}") +print(f" usol shape: {usol.shape}") +print(f" t shape: {t_star.shape}") +print(f" x shape: {x_star.shape}") diff --git a/examples/wave/main.py b/examples/wave/main.py new file mode 100644 index 00000000..eb4fd45d --- /dev/null +++ b/examples/wave/main.py @@ -0,0 +1,41 @@ +import os + +# Deterministic +os.environ["TF_CUDNN_DETERMINISTIC"] = "1" + +from absl import app +from absl import flags +from absl import logging + +from ml_collections import config_flags + +import jax + +jax.config.update("jax_default_matmul_precision", "highest") + +import train +import eval + + +FLAGS = flags.FLAGS + +flags.DEFINE_string("workdir", ".", "Directory to store model data.") + +config_flags.DEFINE_config_file( + "config", + "./configs/default.py", + "File path to the training hyperparameter configuration.", + lock_config=True, +) + + +def main(argv): + if FLAGS.config.mode == "train": + train.train_and_evaluate(FLAGS.config, FLAGS.workdir) + + elif FLAGS.config.mode == "eval": + eval.evaluate(FLAGS.config, FLAGS.workdir) + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/wave/models.py b/examples/wave/models.py new file mode 100644 index 00000000..75bf7c9a --- /dev/null +++ b/examples/wave/models.py @@ -0,0 +1,165 @@ +from functools import partial + +import jax.numpy as jnp +from jax import lax, jit, grad, vmap + +from jaxpi.models import ForwardIVP +from jaxpi.evaluator import BaseEvaluator +from jaxpi.utils import ntk_fn, flatten_pytree + +from matplotlib import pyplot as plt + + +class Wave(ForwardIVP): + def __init__(self, config, u0, t_star, x_star, c=1.0): + super().__init__(config) + + self.u0 = u0 + self.t_star = t_star + self.x_star = x_star + self.c = c + + self.t0 = t_star[0] + self.t1 = t_star[-1] + + # Predictions over a grid + self.u_pred_fn = vmap(vmap(self.u_net, (None, None, 0)), (None, 0, None)) + self.r_pred_fn = vmap(vmap(self.r_net, (None, None, 0)), (None, 0, None)) + + def u_net(self, params, t, x): + z = jnp.stack([t, x]) + u = self.state.apply_fn(params, z) + return u[0] + + def r_net(self, params, t, x): + """Residual: u_tt - c^2 * u_xx = 0""" + u_tt = grad(grad(self.u_net, argnums=1), argnums=1)(params, t, x) + u_xx = grad(grad(self.u_net, argnums=2), argnums=2)(params, t, x) + return u_tt - self.c ** 2 * u_xx + + @partial(jit, static_argnums=(0,)) + def res_and_w(self, params, batch): + """Compute residuals and weights for causal training.""" + t_sorted = batch[:, 0].sort() + r_pred = vmap(self.r_net, (None, 0, 0))(params, t_sorted, batch[:, 1]) + r_pred = r_pred.reshape(self.num_chunks, -1) + l = jnp.mean(r_pred ** 2, axis=1) + w = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ l))) + return l, w + + @partial(jit, static_argnums=(0,)) + def losses(self, params, batch): + # Initial displacement: u(x, 0) = u0(x) + u_pred = vmap(self.u_net, (None, None, 0))(params, self.t0, self.x_star) + ics_loss = jnp.mean((self.u0 - u_pred) ** 2) + + # Initial velocity: u_t(x, 0) = 0 + u_t_pred = vmap( + grad(self.u_net, argnums=1), (None, None, 0) + )(params, self.t0, self.x_star) + ics_vel_loss = jnp.mean(u_t_pred ** 2) + + # Boundary conditions: u(0, t) = u(1, t) = 0 + u_left = vmap(self.u_net, (None, 0, None))( + params, self.t_star, self.x_star[0] + ) + u_right = vmap(self.u_net, (None, 0, None))( + params, self.t_star, self.x_star[-1] + ) + bcs_loss = jnp.mean(u_left ** 2) + jnp.mean(u_right ** 2) + + # Residual loss + if self.config.weighting.use_causal: + l, w = self.res_and_w(params, batch) + res_loss = jnp.mean(l * w) + else: + r_pred = vmap(self.r_net, (None, 0, 0))( + params, batch[:, 0], batch[:, 1] + ) + res_loss = jnp.mean(r_pred ** 2) + + loss_dict = { + "ics": ics_loss, + "ics_vel": ics_vel_loss, + "bcs": bcs_loss, + "res": res_loss, + } + return loss_dict + + @partial(jit, static_argnums=(0,)) + def compute_diag_ntk(self, params, batch): + ics_ntk = vmap(ntk_fn, (None, None, None, 0))( + self.u_net, params, self.t0, self.x_star + ) + + ics_vel_ntk = vmap( + ntk_fn, + (None, None, None, 0), + )(grad(self.u_net, argnums=1), params, self.t0, self.x_star) + + bcs_left_ntk = vmap(ntk_fn, (None, None, 0, None))( + self.u_net, params, self.t_star, self.x_star[0] + ) + bcs_right_ntk = vmap(ntk_fn, (None, None, 0, None))( + self.u_net, params, self.t_star, self.x_star[-1] + ) + bcs_ntk = jnp.concatenate([bcs_left_ntk, bcs_right_ntk]) + + if self.config.weighting.use_causal: + batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T + res_ntk = vmap(ntk_fn, (None, None, 0, 0))( + self.r_net, params, batch[:, 0], batch[:, 1] + ) + res_ntk = res_ntk.reshape(self.num_chunks, -1) + res_ntk = jnp.mean(res_ntk, axis=1) + _, causal_weights = self.res_and_w(params, batch) + res_ntk = res_ntk * causal_weights + else: + res_ntk = vmap(ntk_fn, (None, None, 0, 0))( + self.r_net, params, batch[:, 0], batch[:, 1] + ) + + ntk_dict = { + "ics": ics_ntk, + "ics_vel": ics_vel_ntk, + "bcs": bcs_ntk, + "res": res_ntk, + } + return ntk_dict + + @partial(jit, static_argnums=(0,)) + def compute_l2_error(self, params, u_ref): + u_pred = self.u_pred_fn(params, self.t_star, self.x_star) + error = jnp.linalg.norm(u_pred - u_ref) / jnp.linalg.norm(u_ref) + return error + + +class WaveEvaluator(BaseEvaluator): + def __init__(self, config, model): + super().__init__(config, model) + + def log_errors(self, params, u_ref): + l2_error = self.model.compute_l2_error(params, u_ref) + self.log_dict["l2_error"] = l2_error + + def log_preds(self, params): + u_pred = self.model.u_pred_fn( + params, self.model.t_star, self.model.x_star + ) + fig = plt.figure(figsize=(6, 5)) + plt.imshow(u_pred.T, aspect="auto", origin="lower", cmap="jet") + plt.colorbar() + self.log_dict["u_pred"] = fig + plt.close() + + def __call__(self, state, batch, u_ref): + self.log_dict = {} + self.log_dict = super().__call__(state, batch) + + if self.config.logging.log_errors: + self.log_errors(state.params, u_ref) + + if self.config.logging.log_preds: + self.log_preds(state.params) + + return self.log_dict diff --git a/examples/wave/train.py b/examples/wave/train.py new file mode 100644 index 00000000..92ba9177 --- /dev/null +++ b/examples/wave/train.py @@ -0,0 +1,80 @@ +import os +import time + +import jax +import jax.numpy as jnp +from jax.tree_util import tree_map + +import ml_collections +from absl import logging +import wandb + +from jaxpi.samplers import UniformSampler +from jaxpi.logging import Logger +from jaxpi.utils import save_checkpoint + +import models +from utils import get_dataset + + +def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): + # Initialize W&B + wandb_config = config.wandb + wandb.init(project=wandb_config.project, name=wandb_config.name) + + # Initialize logger + logger = Logger() + + # Get dataset + u_ref, t_star, x_star = get_dataset() + u0 = u_ref[0, :] + + t0 = t_star[0] + t1 = t_star[-1] + + x0 = x_star[0] + x1 = x_star[-1] + + # Define domain + dom = jnp.array([[t0, t1], [x0, x1]]) + + # Define residual sampler + res_sampler = iter(UniformSampler(dom, config.training.batch_size_per_device)) + + # Initialize model + model = models.Wave(config, u0, t_star, x_star, c=config.c) + + # Initialize evaluator + evaluator = models.WaveEvaluator(config, model) + + print("Waiting for JIT...") + start_time = time.time() + for step in range(config.training.max_steps): + batch = next(res_sampler) + model.state = model.step(model.state, batch) + + if config.weighting.scheme in ["grad_norm", "ntk"]: + if step % config.weighting.update_every_steps == 0: + model.state = model.update_weights(model.state, batch) + + # Log training metrics, only use host 0 to record results + if jax.process_index() == 0: + if step % config.logging.log_every_steps == 0: + state = jax.device_get(tree_map(lambda x: x[0], model.state)) + batch = jax.device_get(tree_map(lambda x: x[0], batch)) + log_dict = evaluator(state, batch, u_ref) + wandb.log(log_dict, step) + + end_time = time.time() + logger.log_iter(step, start_time, end_time, log_dict) + start_time = end_time + + # Saving + if config.saving.save_every_steps is not None: + if (step + 1) % config.saving.save_every_steps == 0 or ( + step + 1 + ) == config.training.max_steps: + ckpt_path = os.path.join(os.getcwd(), config.wandb.name, "ckpt") + save_checkpoint(model.state, ckpt_path, keep=config.saving.num_keep_ckpts) + + return model diff --git a/examples/wave/utils.py b/examples/wave/utils.py new file mode 100644 index 00000000..655e141c --- /dev/null +++ b/examples/wave/utils.py @@ -0,0 +1,12 @@ +import scipy.io +import os + + +def get_dataset(): + data_path = os.path.join(os.path.dirname(__file__), "data", "wave.mat") + data = scipy.io.loadmat(data_path) + u_ref = data["usol"] + t_star = data["t"].flatten() + x_star = data["x"].flatten() + + return u_ref, t_star, x_star