Add per-rank disk checkpointing for adjoint tape#4891
Add per-rank disk checkpointing for adjoint tape#4891sghelichkhani wants to merge 20 commits intofiredrakeproject:mainfrom
Conversation
Enable each MPI rank to write its adjoint checkpoint data to its own HDF5 file using PETSc Vec I/O on COMM_SELF. This avoids parallel HDF5 overhead and enables use of fast node-local storage (NVMe/SSD) on HPC systems, where shared filesystem I/O is a major bottleneck for large-scale time-dependent adjoint computations. New parameter `per_rank_dirname` on `enable_disk_checkpointing()`. When set, function data is checkpointed per-rank while mesh data (via `checkpointable_mesh`) remains on shared storage. Requires same number of ranks on restore (inherent in adjoint workflows).
|
In PyOP2 we have the notion of a 'compilation comm' which is a communicator defined over each node (https://github.com/firedrakeproject/firedrake/blob/main/pyop2/mpi.py#L450). Might something like this be appropriate/more general here? |
|
Thanks Connor, that's a great idea. I hadn't considered using the compilation comm pattern here. I did look into what a node-local comm approach would involve. The main challenge is that From my perspective on Gadi, I have 48 cores per node writing to node-local SSDs, and per-rank I/O is completely manageable. My thinking here is that this is specifically an adjoint solution. Disk checkpointing for the adjoint tape is extremely I/O heavy, so the less communicator overhead involved, the better. I'd expect the That said, this is ultimately a decision for the Firedrake folks on what works best for general users. If a node-local comm approach is preferred, it's doable. Happy to refactor if that's the direction you'd like to go. |
|
@connorjward Following up on your suggestion about using the compilation communicator. Angus and I had a discussion about this and we tried to see if we can simply do Saving works fine, but loading deadlocks. Here's a minimal reproducer: """mpiexec -n 4 python test_subcomm_checkpoint.py"""
import os
import tempfile
from firedrake import *
from firedrake.checkpointing import CheckpointFile
comm = COMM_WORLD
mesh = UnitSquareMesh(4, 4)
V = FunctionSpace(mesh, "CG", 1)
f = Function(V, name="f")
f.interpolate(SpatialCoordinate(mesh)[0])
node_comm = comm.Split(color=comm.rank // 2, key=comm.rank)
if comm.rank == 0:
tmpdir = tempfile.mkdtemp()
else:
tmpdir = None
tmpdir = comm.bcast(tmpdir, root=0)
fname = os.path.join(tmpdir, f"node{comm.rank // 2}.h5")
with CheckpointFile(fname, 'w', comm=node_comm) as out:
out.save_mesh(mesh)
out.save_function(f)
with CheckpointFile(fname, 'r', comm=node_comm) as inp:
mesh2 = inp.load_mesh()
f2 = inp.load_function(mesh2, "f") # deadlocks hereThe issue is that That said, I might be missing something. Is there a way to make this work that I'm not seeing? |
I don't think so. I'd have been surprised had that worked.
This is what I'm suggesting. It doesn't seem like a lot of work to change your API from |
Sorry, that was It kind of seems like a |
Refactors the per_rank_dirname parameter into a more general checkpoint_comm + checkpoint_dir interface, following reviewer feedback. Instead of hardcoding COMM_SELF, users now pass any MPI communicator (COMM_SELF for per-rank files, a node-local comm for per-node files, etc.). The PETSc Vec I/O uses createMPI on the supplied communicator rather than createSeq on COMM_SELF, making the approach work for arbitrary communicator topologies. Removes three serial-only checkpoint_comm tests that are fully covered by their parallel counterparts and adds node_comm tests that exercise the multi-rank-per-file path using COMM_TYPE_SHARED.
|
Thanks Connor, done. I've refactored the API from The main PETSc-level change this required was switching from I've also added tests with |
connorjward
left a comment
There was a problem hiding this comment.
Seems alright to me. It would definitely be good to get some feedback from @JHopeCollins, who has done similar comm wrangling for ensemble.
It's definitely an interesting question. Conceptually I think it should be possible to checkpoint a DMPlex to multiple files but its far from trivial. An added complication is that we would have to preserve the N-to-M checkpointing behaviour (i.e. reading and writing with different numbers of ranks). |
Add FutureWarning to deprecated new_checkpoint_file method. Use isinstance(mesh, ufl.MeshSequence) instead of hasattr check. Replace COMM_TYPE_SHARED tests with comm.Split(rank // 2) to guarantee a communicator with 1 < size < COMM_WORLD in the 3-rank test.
Extract _generate_function_space_name from CheckpointFile into a module-level function in firedrake/checkpointing.py and reuse it for the checkpoint_comm Vec naming instead of maintaining a separate _generate_checkpoint_vec_name. The free function also handles MeshSequenceGeometry defensively. CheckpointFile method delegates to it.
The multi-mesh tests chained two PDE solves via assemble(u_a * dx), a global reduction whose floating-point result can vary across parallel runs due to reduction ordering. This made the J == Jnew assertion flaky at the np.allclose tolerance boundary. Make the mesh_b solve independent and drop the redundant memory baseline comparison.
|
Simplified the multi-mesh tests to fix intermittent CI failures. The original design chained two solves via a global reduction (assemble(u_a * dx)), which amplified parallel floating-point non-determinism across tape replays. The two solves are now independent while still doing multi-mesh checkpointing. |
JHopeCollins
left a comment
There was a problem hiding this comment.
My gut feeling is that this should be part of CheckpointFile rather than hidden in the adjoint utils.
I don't think I fully understand the problem with checkpointing the mesh using the global comm but checkpointing the function data using a subcomm. Why does that mean you have to split the DM into multiple files?
You can't split the DM into multiple files. That's why this only works for functions. |
|
The main thing is that I would like to see the local save/load logic abstracted out of the adjoint code somehow, because the adjoint code really shouldn't be thinking about actual concrete data. I can think of four potential ways to do this. Happy to hear arguments for/against each one. In each case we'd obviously need to be very explicit about the restrictions of saving/loading locally, i.e. you must have exactly the same partition for saving and viewing so its basically only for saving/loading during the same programme (and for option 3 you also can't save/load the mesh with this class).
|
This is a fair point.
I've read the implementation and still confused about the difference between these. In PETSc the term 'local' applies to lots of different things so this may not do quite what you expect. I wonder if this is basically reimplementing We should discuss this in today's meeting. |
|
@sghelichkhani from today's meeting we decided that we want this functionality exposed as a new I quite like something like |
How about
@sghelichkhani is it possible to use |
|
Note on _checkpoint_indices: this class-level dict on CheckpointFunction is never pruned. File entries persist after the HDF5 file is deleted. The local path (TemporaryFunctionCheckpointFile._indices) does not have this issue since remove_file handles cleanup. Not fixing this to avoid the risk of pruning entries before restore() reads them. The memory cost is negligible. |
…ject#4865) * add log event markers * build spatial index using CreateWithArray
…ates (firedrakeproject#4900) use numpy directly for non-extruded
Move PETSc Vec I/O into TemporaryFunctionCheckpointFile in checkpointing.py. Rename save/restore methods, fix deprecation warning, remove redundant fixture and forwarding method, clean up imports.
|
All review comments from @connorjward and @JHopeCollins are addressed. The main change is extracting the PETSc Vec I/O into a The Tests cleaned up: local Two pre-existing items I looked at but decided not to change: |
connorjward
left a comment
There was a problem hiding this comment.
I think I'm now being very nitpicky and this is basically fine.
Remove redundant self.cleanup guard from TemporaryFunctionCheckpointFile.remove_file so the cleanup decision lives solely in CheckPointFileReference.__del__. Remove manual tape setup from the four new parallel tests since the autouse_test_taping fixture handles it. Clarify TemporaryFunctionCheckpointFile.comm docstring per Connor suggestion and document why _broadcast_tmpdir uses COMM_WORLD.
The multi-mesh tests have an independent solve on mesh_b that can give slightly different results after repartitioning by checkpointable_mesh. The taylor_test is the proper correctness check for the adjoint.
connorjward
left a comment
There was a problem hiding this comment.
I'm happy with this. It's a big change that involves code I am not super familiar with so it would be good to have an approving review from @JHopeCollins too.
|
CI failures are unrelated: timeout in nprocs=6 I/O tests and a Gusto smoke test options issue. No files from this PR are involved. |
Motivation
We run time-dependent adjoint Stokes simulations with close to a billion degrees of freedom per timestep. Recomputation-based checkpointing schedules (revolve/binomial) are infeasible due to the cost of recomputing the Stokes solve, so disk checkpointing (
SingleDiskStorageSchedule) is the only viable option.Currently,
CheckpointFilewrites all ranks to a single shared HDF5 file via parallel HDF5 (PETSc.ViewerHDF5onCOMM_WORLD). On HPC systems, this means all checkpoint I/O goes through the shared parallel filesystem (Lustre/GPFS), which becomes a severe bottleneck. Under 24-hour job time limits, the disk I/O overhead makes simulations that comfortably fit in memory-checkpointed wall time infeasible when switching to disk checkpointing.HPC nodes typically have fast node-local NVMe/SSD storage that is orders of magnitude faster than the shared filesystem. However, the current collective I/O approach in
CheckpointFilecannot use node-local storage because all ranks must access the same file path.Approach
Following @connorjward's suggestion in #4891 (comment), the implementation uses a general
checkpoint_commparameter rather than a hardcoded per-rank approach. Users pass any MPI communicator to control how function data is checkpointed:The function data is written using
PETSc.Vec.createMPI+ViewerHDF5on the supplied communicator, bypassingCheckpointFileand its collectiveglobalVectorView/globalVectorLoadonCOMM_WORLD. We tried usingCheckpointFiledirectly with a sub-communicator (see #4891 (comment)), but loading deadlocks because the mesh DM'ssectionLoad/globalVectorLoadare collective onCOMM_WORLD.The mesh checkpoint via
checkpointable_meshstill uses shared storage throughCheckpointFilesince that's a one-time operation and not performance-critical. Fully backwards compatible: withoutcheckpoint_comm, behaviour is unchanged.Multi-mesh considerations
Functions on different meshes with different partitioning work correctly because Vec dataset names include the mesh name and element info (
ckpt_mesh_a_CG2vsckpt_mesh_b_DG1), andcheckpointable_meshensures deterministic partitioning per mesh independently.The supermesh projection across two different meshes still fails in parallel, but that's a pre-existing limitation unrelated to this PR.
Testing
11 tests total covering three checkpointing modes:
Existing shared-mode tests (5): serial and parallel basic checkpointing, successive writes, timestepper with taylor_test, and boundary conditions.
checkpoint_commwithCOMM_SELF(3): parallel basic checkpointing, successive writes (serial), and multi-mesh parallel. These exercise the per-rank file path where each rank writes independently.checkpoint_commwith node communicator (3): parallel basic checkpointing, multi-mesh parallel, and timestepper with taylor_test. These exercise the multi-rank-per-file path usingCOMM_TYPE_SHARED.