diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index cdfde92d50..5296153514 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -29,7 +29,6 @@ from maxtext.utils import exceptions from maxtext.utils import max_logging from maxtext.utils import gcs_utils -import numpy as np import orbax.checkpoint as ocp from orbax.checkpoint import v1 as ocp_v1 from orbax.checkpoint._src.arrays import sharding as sharding_utils @@ -406,32 +405,6 @@ def print_save_message(step, async_checkpointing): max_logging.log(f"Saved a checkpoint at step {step}.") -def _find_idx(array: np.ndarray, replica_axis_idx: int): - """Returns the index along given dimension that the current host belongs to.""" - idx = None - for idx, val in np.ndenumerate(array): - if val.process_index == jax.process_index(): - break - return idx[replica_axis_idx] - - -def _replica_devices(device_array: np.ndarray, replica_axis_idx: int): - """Returns the devices from the replica that current host belongs to. - - Replicas are assumed to be restricted to the first axis. - - Args: - device_array: devices of the mesh that can be obtained by mesh.devices() - replica_axis_idx: axis dimension along which replica is taken - - Returns: - devices inside the replica that current host is in - """ - idx = _find_idx(device_array, replica_axis_idx) - replica_result = np.take(device_array, idx, axis=replica_axis_idx) - return np.expand_dims(replica_result, axis=replica_axis_idx) - - def _prepare_scaled_down_grain_restore_args( data_iterator: list, process_count_jax: int, process_count_stored: int, directory: epath.Path ) -> GrainCheckpointRestore: @@ -579,16 +552,8 @@ def load_state_if_possible( def map_to_pspec(data): if not enable_single_replica_ckpt_restoring: return ocp.type_handlers.ArrayRestoreArgs(sharding=data.sharding) - pspec = data.sharding.spec - mesh = data.sharding.mesh - replica_axis_index = 0 - replica_devices = _replica_devices(mesh.devices, replica_axis_index) - replica_mesh = jax.sharding.Mesh(replica_devices, mesh.axis_names) - single_replica_sharding = jax.sharding.NamedSharding(replica_mesh, pspec) - return ocp.type_handlers.SingleReplicaArrayRestoreArgs( - sharding=jax.sharding.NamedSharding(mesh, pspec), - single_replica_sharding=single_replica_sharding, + sharding=data.sharding, global_shape=data.shape, dtype=data.dtype, )