Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 1 addition & 36 deletions src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from maxtext.utils import exceptions
from maxtext.utils import max_logging
from maxtext.utils import gcs_utils
import numpy as np
import orbax.checkpoint as ocp
from orbax.checkpoint import v1 as ocp_v1
from orbax.checkpoint._src.arrays import sharding as sharding_utils
Expand Down Expand Up @@ -406,32 +405,6 @@ def print_save_message(step, async_checkpointing):
max_logging.log(f"Saved a checkpoint at step {step}.")


def _find_idx(array: np.ndarray, replica_axis_idx: int):
"""Returns the index along given dimension that the current host belongs to."""
idx = None
for idx, val in np.ndenumerate(array):
if val.process_index == jax.process_index():
break
return idx[replica_axis_idx]


def _replica_devices(device_array: np.ndarray, replica_axis_idx: int):
"""Returns the devices from the replica that current host belongs to.

Replicas are assumed to be restricted to the first axis.

Args:
device_array: devices of the mesh that can be obtained by mesh.devices()
replica_axis_idx: axis dimension along which replica is taken

Returns:
devices inside the replica that current host is in
"""
idx = _find_idx(device_array, replica_axis_idx)
replica_result = np.take(device_array, idx, axis=replica_axis_idx)
return np.expand_dims(replica_result, axis=replica_axis_idx)


def _prepare_scaled_down_grain_restore_args(
data_iterator: list, process_count_jax: int, process_count_stored: int, directory: epath.Path
) -> GrainCheckpointRestore:
Expand Down Expand Up @@ -579,16 +552,8 @@ def load_state_if_possible(
def map_to_pspec(data):
if not enable_single_replica_ckpt_restoring:
return ocp.type_handlers.ArrayRestoreArgs(sharding=data.sharding)
pspec = data.sharding.spec
mesh = data.sharding.mesh
replica_axis_index = 0
replica_devices = _replica_devices(mesh.devices, replica_axis_index)
replica_mesh = jax.sharding.Mesh(replica_devices, mesh.axis_names)
single_replica_sharding = jax.sharding.NamedSharding(replica_mesh, pspec)

return ocp.type_handlers.SingleReplicaArrayRestoreArgs(
sharding=jax.sharding.NamedSharding(mesh, pspec),
single_replica_sharding=single_replica_sharding,
sharding=data.sharding,
global_shape=data.shape,
dtype=data.dtype,
)
Expand Down
Loading