Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
49c5e15
Add per-rank disk checkpointing for adjoint tape
sghelichkhani Feb 15, 2026
f6b6704
Generalize per-rank checkpointing to checkpoint_comm API
sghelichkhani Feb 18, 2026
c231fb9
Address review: deprecation warning, isinstance check, sub-comm tests
sghelichkhani Feb 18, 2026
f3335bc
Reuse _generate_function_space_name for checkpoint_comm naming
sghelichkhani Feb 18, 2026
cd6678b
Merge branch 'firedrakeproject:main' into sghelichkhani/per-rank-disk…
sghelichkhani Feb 18, 2026
6b8ba4a
Fix non-deterministic multi-mesh test failures
sghelichkhani Feb 18, 2026
c97e4be
Merge branch 'firedrakeproject:main' into sghelichkhani/per-rank-disk…
sghelichkhani Feb 19, 2026
3d38613
Merge branch 'main' into sghelichkhani/per-rank-disk-checkpointing
sghelichkhani Feb 20, 2026
312ad59
Use `Index_CreateWithArray` to build mesh spatial index (#4865)
leo-collins Feb 19, 2026
ad56c87
Avoid creating a function space when calculating bounding box coordin…
leo-collins Feb 19, 2026
125ead2
Extract TemporaryFunctionCheckpointFile, address review feedback
sghelichkhani Feb 21, 2026
7bb3ab3
Remove dead checkpoint_comm attribute from CheckPointFileReference
sghelichkhani Feb 21, 2026
5b39e8b
Clean up test temp directories to avoid leaks in CI
sghelichkhani Feb 21, 2026
78135e4
Warn when checkpoint_comm is used without checkpoint_dir
sghelichkhani Feb 21, 2026
0ec581f
Add taylor_test to multi-mesh test, clarify per-rank tmpdir comments
sghelichkhani Feb 21, 2026
86afb94
Document MeshSequenceGeometry unwrapping in _generate_function_space_…
sghelichkhani Feb 21, 2026
1f45f04
Merge branch 'main' into sghelichkhani/per-rank-disk-checkpointing
sghelichkhani Feb 21, 2026
ee88af0
linting
sghelichkhani Feb 21, 2026
88e0b78
Address review feedback: cleanup logic, test fixtures, docstrings
sghelichkhani Feb 24, 2026
c6a58a7
Use taylor_test instead of allclose in multi-mesh tests
sghelichkhani Feb 24, 2026
0bfd1b3
Add boundary conditions to multi-mesh test solves
sghelichkhani Feb 25, 2026
1b43b35
Merge branch 'main' into sghelichkhani/per-rank-disk-checkpointing
sghelichkhani Mar 1, 2026
0568984
Refactor TemporaryFunctionCheckpointFile per review feedback
sghelichkhani Mar 1, 2026
a71927e
Remove unused tempfile import from checkpointing.py
sghelichkhani Mar 1, 2026
3da0cf6
Resolve cleanup comm in CheckPointFileReference.__init__ rather than …
sghelichkhani Mar 2, 2026
bccf166
Use TemporaryDirectory(delete=cleanup) for local checkpoint dir
sghelichkhani Mar 2, 2026
2c1e17a
Single unconditional bcast at each collective site
sghelichkhani Mar 2, 2026
acb5e71
Improve new_checkpoint_file deprecation warning with migration guidance
sghelichkhani Mar 2, 2026
15bd03f
Consistent _checkpoint suffix on all _save/_restore method names
sghelichkhani Mar 2, 2026
5056677
Factor rename/recount into restore(), leaf methods return plain Function
sghelichkhani Mar 2, 2026
2e3552a
Match CheckpointFile signature with name=None, idx=None defaults
sghelichkhani Mar 2, 2026
c120dfd
Use createWithArray in save_function to avoid unnecessary allocation
sghelichkhani Mar 2, 2026
0fd0e4b
Apply suggestions from code review
sghelichkhani Mar 3, 2026
b2e63d6
Merge branch 'main' into sghelichkhani/per-rank-disk-checkpointing
sghelichkhani Mar 3, 2026
f37a152
Merge branch 'sghelichkhani/per-rank-disk-checkpointing' of github.co…
sghelichkhani Mar 3, 2026
ecab97f
Fix CheckPointFileReference missing comm attribute
sghelichkhani Mar 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 206 additions & 50 deletions firedrake/adjoint_utils/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import shutil
import atexit
import warnings
from abc import ABC, abstractmethod
from numbers import Number
_enable_disk_checkpoint = False
Expand Down Expand Up @@ -49,7 +50,8 @@ def __exit__(self, *args):
_checkpoint_init_data = self._init


def enable_disk_checkpointing(dirname=None, comm=COMM_WORLD, cleanup=True):
def enable_disk_checkpointing(dirname=None, comm=COMM_WORLD, cleanup=True,
checkpoint_comm=None, checkpoint_dir=None):
"""Add a DiskCheckpointer to the current tape.

Disk checkpointing is fully enabled by calling::
Expand All @@ -68,23 +70,48 @@ def enable_disk_checkpointing(dirname=None, comm=COMM_WORLD, cleanup=True):
`checkpoint_schedules` provides other schedules for checkpointing to memory, disk,
or a combination of both.

For HPC systems with fast node-local storage, function data can be
checkpointed on a sub-communicator to avoid parallel HDF5 overhead::

enable_disk_checkpointing(checkpoint_comm=MPI.COMM_SELF,
checkpoint_dir="/local/scratch")

Parameters
----------
dirname : str
The directory in which the disk checkpoints should be stored. If not
specified then the current working directory is used. Checkpoints are
stored in a temporary subdirectory of this directory.
The directory in which the shared disk checkpoints should be stored.
If not specified then the current working directory is used.
Checkpoints are stored in a temporary subdirectory of this directory.
comm : mpi4py.MPI.Intracomm
The MPI communicator over which the computation to be disk checkpointed
is defined. This will usually match the communicator on which the
mesh(es) are defined.
cleanup : bool
If set to False, checkpoint files will not be deleted when no longer
required. This is usually only useful for debugging.
checkpoint_comm : mpi4py.MPI.Intracomm or None
If specified, function data is checkpointed using PETSc Vec I/O on
this communicator instead of using Firedrake's CheckpointFile. This
bypasses parallel HDF5 and is ideal for node-local storage on HPC
systems. Passing ``MPI.COMM_SELF`` gives each rank its own file,
while a shared node communicator groups ranks that share storage.
The mesh checkpoint (via ``checkpointable_mesh``) always uses shared
storage. Requires the same communicator layout on restore.
checkpoint_dir : str or None
The directory in which checkpoint_comm files are stored. Only used
when ``checkpoint_comm`` is not None. Each group of ranks sharing
a checkpoint_comm creates a temporary subdirectory here. This
directory must be accessible from all ranks within each
checkpoint_comm group. For example, using a node-local path like
/tmp is safe when checkpoint_comm groups ranks on the same node,
but would fail if checkpoint_comm spans nodes whose filesystems
are not shared.
"""
tape = get_working_tape()
if "firedrake" not in tape._package_data:
tape._package_data["firedrake"] = DiskCheckpointer(dirname, comm, cleanup)
tape._package_data["firedrake"] = DiskCheckpointer(
dirname, comm, cleanup, checkpoint_comm, checkpoint_dir
)


def disk_checkpointing():
Expand Down Expand Up @@ -120,14 +147,29 @@ def __exit__(self, *args):

class CheckPointFileReference:
"""A filename which deletes the associated file when it is destroyed."""
def __init__(self, name, comm, cleanup=False):
def __init__(self, name, comm, cleanup=False, checkpoint_comm=None):
self.name = name
self.comm = comm
self.cleanup = cleanup
self.checkpoint_comm = checkpoint_comm

def __del__(self):
if self.cleanup and self.comm.rank == 0 and os.path.exists(self.name):
os.remove(self.name)
if self.cleanup and os.path.exists(self.name):
if self.comm.rank == 0:
os.remove(self.name)
Comment thread
JHopeCollins marked this conversation as resolved.
# Prune the index-tracking entry for this file from CheckpointFunction.
# This is safe for the following reasons:
# (1) CheckpointFunction holds self.file as a direct strong reference,
# so __del__ here can only fire after every CheckpointFunction that
# wrote to this filepath has already been garbage-collected.
# (2) restore() never reads _checkpoint_indices — it uses stored_name
# and stored_index baked into the CheckpointFunction at save time.
# (3) Under revolve schedules the tape checkpoint store holds the
# CheckPointFileReference alive until forward re-execution is done,
# so there is no risk of premature pruning.
# (4) pop is a no-op for init files where no CheckpointFunction ever
# wrote an entry (e.g. checkpointable_mesh files).
CheckpointFunction._checkpoint_indices.pop(self.name, None)


class DiskCheckpointer(TapePackageData):
Expand All @@ -136,52 +178,128 @@ class DiskCheckpointer(TapePackageData):
Parameters
----------
dirname : str
The directory in which the disk checkpoints should be stored. If not
specified then the current working directory is used. Checkpoints are
stored in a temporary subdirectory of this directory.
The directory in which the shared disk checkpoints should be stored.
If not specified then the current working directory is used.
Checkpoints are stored in a temporary subdirectory of this directory.
comm : mpi4py.MPI.Intracomm
The MPI communicator over which the computation to be disk checkpointed
is defined. This will usually match the communicator on which the
mesh(es) are defined.
cleanup : bool
If set to False, checkpoint files will not be deleted when no longer
required. This is usually only useful for debugging.
checkpoint_comm : mpi4py.MPI.Intracomm or None
If specified, function data is checkpointed on this communicator.
checkpoint_dir : str or None
Directory for checkpoint_comm files. This directory must be
accessible from all ranks within each checkpoint_comm group.
For example, using a node-local path like /tmp is safe when
checkpoint_comm groups ranks on the same node, but would fail
if checkpoint_comm spans nodes whose filesystems are not shared.
"""

def __init__(self, dirname=None, comm=COMM_WORLD, cleanup=True):

if comm.rank == 0:
self.dirname = comm.bcast(tempfile.mkdtemp(
prefix="firedrake_adjoint_checkpoint_", dir=dirname or os.getcwd()
))
else:
self.dirname = comm.bcast("")
def __init__(self, dirname=None, comm=COMM_WORLD, cleanup=True,
checkpoint_comm=None, checkpoint_dir=None):
self.checkpoint_comm = checkpoint_comm
self.comm = comm
self.cleanup = cleanup

# Shared directory (for mesh checkpoint and init data). The bcast
# uses comm (COMM_WORLD) so every rank knows the shared path.
path = tempfile.mkdtemp(
prefix="firedrake_adjoint_checkpoint_", dir=dirname or os.getcwd()
) if comm.rank == 0 else None
self.dirname = comm.bcast(path)
if self.cleanup and comm.rank == 0:
# Delete the checkpoint folder on process exit.
# Delete the shared checkpoint folder on process exit.
atexit.register(shutil.rmtree, self.dirname)
# # A checkpoint file holding the state of block variables set outside
# the tape.
self.init_checkpoint_file = self.new_checkpoint_file()
self.current_checkpoint_file = self.new_checkpoint_file()

def new_checkpoint_file(self):
"""Set up a disk checkpointing file."""
# Local directory (for function data on checkpoint_comm). The bcast
# uses checkpoint_comm, not comm: only ranks within the same
# checkpoint_comm group share a local filesystem, so we must not
# perform a COMM_WORLD collective here.
if self.checkpoint_comm is not None:
if checkpoint_dir is None:
warnings.warn(
"checkpoint_comm without checkpoint_dir defaults to cwd, "
"which is usually on the shared filesystem. Without a "
"node-local path the collective CheckpointFile is more "
"suitable. Consider setting checkpoint_dir.",
UserWarning
)
base_dir = checkpoint_dir or os.getcwd()
if checkpoint_comm.rank == 0:
# ignore_cleanup_errors avoids tracebacks if the finalizer fires
# during interpreter shutdown after MPI has already finalized.
self._local_tmpdir = tempfile.TemporaryDirectory(
prefix="firedrake_adjoint_checkpoint_cc_",
dir=base_dir,
delete=cleanup,
ignore_cleanup_errors=True,
)
local_path = self._local_tmpdir.name
else:
self._local_tmpdir = None
local_path = None
self._local_dirname = checkpoint_comm.bcast(local_path)
else:
self._local_tmpdir = None
self._local_dirname = None

# A checkpoint file holding the state of block variables set outside
# the tape (always shared, used by checkpointable_mesh).
self.init_checkpoint_file = self._new_shared_checkpoint_file()
self.current_checkpoint_file = self._new_checkpoint_file()

def _new_shared_checkpoint_file(self):
"""Set up a shared disk checkpointing file (all ranks use same file)."""
from firedrake.checkpointing import CheckpointFile
if self.comm.rank == 0:
_, checkpoint_file = tempfile.mkstemp(
dir=self.dirname, suffix=".h5"
)
checkpoint_file = self.comm.bcast(checkpoint_file)
_, checkpoint_file = tempfile.mkstemp(dir=self.dirname, suffix=".h5")
else:
checkpoint_file = self.comm.bcast("")
checkpoint_file = None
checkpoint_file = self.comm.bcast(checkpoint_file)
# Let h5py create a file at this location just to be sure.
with CheckpointFile(checkpoint_file, 'w'):
with CheckpointFile(checkpoint_file, 'w', comm=self.comm):
pass
return CheckPointFileReference(checkpoint_file, self.comm,
self.cleanup)

def _new_checkpoint_comm_file(self):
"""Set up a checkpoint file on the checkpoint communicator."""
from firedrake.checkpointing import TemporaryFunctionCheckpointFile
if self.checkpoint_comm.rank == 0:
fd, filepath = tempfile.mkstemp(dir=self._local_dirname, suffix=".h5")
os.close(fd)
else:
filepath = None
filepath = self.checkpoint_comm.bcast(filepath)
# Initialise an empty HDF5 file. Opened in 'w' mode and immediately
# closed so that subsequent 'a' opens from save_function find a valid
# file.
with TemporaryFunctionCheckpointFile(self.checkpoint_comm, filepath, 'w'):
pass
return CheckPointFileReference(filepath, self.checkpoint_comm, self.cleanup,
checkpoint_comm=self.checkpoint_comm)

def _new_checkpoint_file(self):
"""Set up a checkpoint file for function data."""
if self.checkpoint_comm is not None:
return self._new_checkpoint_comm_file()
else:
return self._new_shared_checkpoint_file()

def new_checkpoint_file(self):
"""Set up a disk checkpointing file."""
warnings.warn(
"'new_checkpoint_file' is deprecated and will be removed in a "
"future release. Checkpoint file management is now handled "
"internally; to advance to a new checkpoint file call "
"'reset()' on the DiskCheckpointer instead.",
FutureWarning
Comment thread
sghelichkhani marked this conversation as resolved.
)
return self._new_checkpoint_file()

def clear(self, init=True):
"""Reset the DiskCheckPointer.

Expand All @@ -198,8 +316,8 @@ def clear(self, init=True):
if not self.cleanup:
return
if init:
self.init_checkpoint_file = self.new_checkpoint_file()
self.current_checkpoint_file = self.new_checkpoint_file()
self.init_checkpoint_file = self._new_shared_checkpoint_file()
self.current_checkpoint_file = self._new_checkpoint_file()

def reset(self):
self.clear(init=False)
Expand Down Expand Up @@ -254,9 +372,9 @@ def checkpointable_mesh(mesh):
"No current checkpoint file. Call enable_disk_checkpointing()."
)

with CheckpointFile(checkpoint_file.name, 'a') as outfile:
with CheckpointFile(checkpoint_file.name, 'a', comm=checkpoint_file.comm) as outfile:
outfile.save_mesh(mesh)
with CheckpointFile(checkpoint_file.name, 'r') as outfile:
with CheckpointFile(checkpoint_file.name, 'r', comm=checkpoint_file.comm) as outfile:
return outfile.load_mesh(mesh.name)


Expand Down Expand Up @@ -290,7 +408,6 @@ class CheckpointFunction(CheckpointBase, OverloadedType):
_checkpoint_indices = {}

def __init__(self, function):
from firedrake.checkpointing import CheckpointFile
self.name = function.name()
self.mesh = function.function_space().mesh()
self.file = current_checkpoint_file()
Expand All @@ -300,31 +417,70 @@ def __init__(self, function):
"No current checkpoint file. Call enable_disk_checkpointing()."
)

self.count = function.count()

# Compute stored_name and stored_index once, shared by both checkpoint
# paths. stored_name encodes the function space (mesh name + element
# family/degree) so that functions on different meshes or spaces never
# collide. stored_index disambiguates successive saves of the same
# space to the same file.
from firedrake.checkpointing import _generate_function_space_name
stored_names = CheckpointFunction._checkpoint_indices
if self.file.name not in stored_names:
stored_names[self.file.name] = {}
self.stored_name = _generate_function_space_name(function.function_space())
indices = stored_names[self.file.name]
indices.setdefault(self.stored_name, 0)
indices[self.stored_name] += 1
self.stored_index = indices[self.stored_name]

if self.file.checkpoint_comm is not None:
self._function_space = function.function_space()
self._save_local_checkpoint(function)
else:
self._save_shared_checkpoint(function)

self.count = function.count()
with CheckpointFile(self.file.name, 'a') as outfile:
self.stored_name = outfile._generate_function_space_name(
function.function_space()
)
indices = stored_names[self.file.name]
indices.setdefault(self.stored_name, 0)
indices[self.stored_name] += 1
self.stored_index = indices[self.stored_name]
def _save_shared_checkpoint(self, function):
"""Save function data to a shared HDF5 file via CheckpointFile."""
from firedrake.checkpointing import CheckpointFile
with CheckpointFile(self.file.name, 'a', self.file.comm) as outfile:
outfile.save_function(function, name=self.stored_name,
idx=self.stored_index)

def _save_local_checkpoint(self, function):
"""Save function data to a local HDF5 file via PETSc Vec I/O."""
from firedrake.checkpointing import TemporaryFunctionCheckpointFile
with TemporaryFunctionCheckpointFile(
self.file.checkpoint_comm, self.file.name, 'a'
) as outfile:
outfile.save_function(function, self.stored_name, self.stored_index)

def restore(self):
"""Read and return this Function from the checkpoint."""
from firedrake.checkpointing import CheckpointFile
with CheckpointFile(self.file.name, 'r') as infile:
function = infile.load_function(self.mesh, self.stored_name,
idx=self.stored_index)
if self.file.checkpoint_comm is not None:
function = self._restore_local_checkpoint()
else:
function = self._restore_shared_checkpoint()
return type(function)(function.function_space(),
function.dat, name=self.name, count=self.count)

def _restore_shared_checkpoint(self):
"""Load function data from a shared HDF5 file via :class:`.CheckpointFile`."""
from firedrake.checkpointing import CheckpointFile
with CheckpointFile(self.file.name, 'r', comm=self.file.comm) as infile:
return infile.load_function(self.mesh, self.stored_name,
idx=self.stored_index)

def _restore_local_checkpoint(self):
"""Load function data via :class:`TemporaryFunctionCheckpointFile`."""
from firedrake.checkpointing import TemporaryFunctionCheckpointFile
with TemporaryFunctionCheckpointFile(
self.file.checkpoint_comm, self.file.name, 'r'
) as infile:
return infile.load_function(
self._function_space, self.stored_name, self.stored_index
)

def _ad_restore_at_checkpoint(self, checkpoint):
return checkpoint.restore()

Expand Down
Loading