Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions src/maxtext/utils/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.

"""Common Max Utils needed by multiple modules.
All the functions include MaxText modules, such as Pyconfig, should be moved to MaxText utils file."""
All the functions include MaxText modules, such as Pyconfig, should be moved to MaxText utils file.
"""

import collections
from collections.abc import Sequence
Expand Down Expand Up @@ -250,12 +251,21 @@ def initialize_jax_for_gpu(raw_keys):
if os.environ.get("JAX_COORDINATOR_IP") is not None:
coordinator_ip = str(os.getenv("JAX_COORDINATOR_IP"))
coordinator_port = str(os.getenv("JAX_COORDINATOR_PORT"))
devices = os.getenv("CUDA_VISIBLE_DEVICES")
env_var_list = ["CUDA_VISIBLE_DEVICES", "SLURM_STEP_GPUS"]
for env_var in env_var_list:
devices = os.getenv(env_var)
if devices is not None:
max_logging.log(f"Using {env_var} to initialize JAX distributed system: {devices}")
break
if devices is None:
jax.config.update("jax_cuda_visible_devices", "all")
jax.config.update("jax_rocm_visible_devices", "all")

if devices is not None:
try:
devices = [int(x) for x in devices.split(",")]
except (ValueError, TypeError) as e:
max_logging.log(f"Error parsing CUDA_VISIBLE_DEVICES: {e}")
max_logging.log(f"Error parsing {env_var}: {e}")
devices = None

jax.distributed.initialize(
Expand Down Expand Up @@ -848,7 +858,14 @@ def reorder_causal_load_balanced(batch, cp_size):
cp_size=cp_size,
)
if key
in ["inputs", "targets", "inputs_position", "targets_position", "inputs_segmentation", "targets_segmentation"]
in [
"inputs",
"targets",
"inputs_position",
"targets_position",
"inputs_segmentation",
"targets_segmentation",
]
else value
for key, value in batch.items()
}
Expand Down
66 changes: 58 additions & 8 deletions tests/unit/max_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def test_unscan_train_state_params(self):


class TestGpuDistributedInitialization(unittest.TestCase):
"""Tests using CUDA_VISIBLE_DEVICES to control which GPUs are used in jax.distributed.initialize."""
"""Tests using CUDA_VISIBLE_DEVICES / SLURM_STEP_GPUS for jax.distributed.initialize."""

@mock.patch.dict(
os.environ,
Expand All @@ -210,7 +210,7 @@ class TestGpuDistributedInitialization(unittest.TestCase):
"JAX_COORDINATOR_PORT": "1234",
"NNODES": "1",
"NODE_RANK": "0",
"CUDA_VISIBLE_DEVICES": "0,2,3", # Simulating Slurm/orchestrator assignment
"CUDA_VISIBLE_DEVICES": "0,2,3",
},
)
@mock.patch("jax.distributed.initialize")
Expand All @@ -220,7 +220,6 @@ def test_initialize_jax_for_gpu_valid_devices(self, _mock_log, _mock_devices, mo
"""Verifies that a comma-separated string of IDs is correctly parsed."""
raw_keys = {"jax_distributed_initialization_timeout": 300}
max_utils.initialize_jax_for_gpu(raw_keys)
# Check that local_device_ids was passed correctly as a list of integers
_, kwargs = mock_init.call_args
self.assertEqual(kwargs["local_device_ids"], [0, 2, 3])
self.assertEqual(kwargs["coordinator_address"], "10.0.0.1:1234")
Expand All @@ -232,17 +231,16 @@ def test_initialize_jax_for_gpu_valid_devices(self, _mock_log, _mock_devices, mo
"JAX_COORDINATOR_PORT": "1234",
"NNODES": "1",
"NODE_RANK": "0",
"CUDA_VISIBLE_DEVICES": "GPU-8f2e3072-...", # Invalid format for integer parsing
"CUDA_VISIBLE_DEVICES": "GPU-8f2e3072-...",
},
)
@mock.patch("jax.distributed.initialize")
@mock.patch("jax.devices")
@mock.patch("maxtext.utils.max_logging.log")
def test_initialize_jax_for_gpu_invalid_devices(self, _mock_log, mock_devices, mock_init):
def test_initialize_jax_for_gpu_invalid_devices(self, _mock_log, _mock_devices, mock_init):
"""Verifies fallback behavior when parsing fails (e.g., UUIDs)."""
raw_keys = {"jax_distributed_initialization_timeout": 300}
max_utils.initialize_jax_for_gpu(raw_keys)
# Check that it falls back to None (JAX auto-detection default) on error
_, kwargs = mock_init.call_args
self.assertIsNone(kwargs.get("local_device_ids"))
self.assertEqual(kwargs["coordinator_address"], "10.0.0.1:1234")
Expand All @@ -255,17 +253,69 @@ def test_initialize_jax_for_gpu_invalid_devices(self, _mock_log, mock_devices, m
"NNODES": "1",
"NODE_RANK": "0",
},
clear=True,
)
@mock.patch("jax.config.update")
@mock.patch("jax.distributed.initialize")
@mock.patch("jax.devices")
@mock.patch("maxtext.utils.max_logging.log")
def test_initialize_jax_for_gpu_no_devices(self, _mock_log, mock_devices, mock_init):
"""Verifies that no error occurs when CUDA_VISIBLE_DEVICES is not set"""
def test_initialize_jax_for_gpu_no_devices(self, _mock_log, _mock_devices, mock_init, mock_config_update):
"""When coordinator env is set but neither CUDA_VISIBLE_DEVICES nor SLURM_STEP_GPUS is set, JAX uses all devices
(config) and init gets no local ids.
"""
raw_keys = {"jax_distributed_initialization_timeout": 300}
max_utils.initialize_jax_for_gpu(raw_keys)
_, kwargs = mock_init.call_args
self.assertIsNone(kwargs.get("local_device_ids"))
self.assertEqual(kwargs["coordinator_address"], "10.0.0.1:1234")
mock_config_update.assert_has_calls(
[
mock.call("jax_cuda_visible_devices", "all"),
mock.call("jax_rocm_visible_devices", "all"),
]
)

@mock.patch("jax.distributed.initialize")
@mock.patch("jax.devices")
@mock.patch("maxtext.utils.max_logging.log")
def test_initialize_jax_for_gpu_uses_slurm_when_cuda_unset(self, mock_log, _mock_devices, mock_init):
"""Uses SLURM_STEP_GPUS when CUDA_VISIBLE_DEVICES is absent (loop over env_var_list)."""
env = {
"JAX_COORDINATOR_IP": "10.0.0.1",
"JAX_COORDINATOR_PORT": "1234",
"NNODES": "1",
"NODE_RANK": "0",
"SLURM_STEP_GPUS": "1,3",
}
with mock.patch.dict(os.environ, env, clear=False):
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
raw_keys = {"jax_distributed_initialization_timeout": 300}
max_utils.initialize_jax_for_gpu(raw_keys)
_, kwargs = mock_init.call_args
self.assertEqual(kwargs["local_device_ids"], [1, 3])
mock_log.assert_any_call("Using SLURM_STEP_GPUS to initialize JAX distributed system: 1,3")

@mock.patch.dict(
os.environ,
{
"JAX_COORDINATOR_IP": "10.0.0.1",
"JAX_COORDINATOR_PORT": "1234",
"NNODES": "1",
"NODE_RANK": "0",
"CUDA_VISIBLE_DEVICES": "0,2",
"SLURM_STEP_GPUS": "4,5,6",
},
)
@mock.patch("jax.distributed.initialize")
@mock.patch("jax.devices")
@mock.patch("maxtext.utils.max_logging.log")
def test_initialize_jax_for_gpu_prefers_cuda_visible_devices_in_loop(self, mock_log, _mock_devices, mock_init):
"""First matching env var in the list wins; CUDA_VISIBLE_DEVICES is checked before SLURM_STEP_GPUS."""
raw_keys = {"jax_distributed_initialization_timeout": 300}
max_utils.initialize_jax_for_gpu(raw_keys)
_, kwargs = mock_init.call_args
self.assertEqual(kwargs["local_device_ids"], [0, 2])
mock_log.assert_any_call("Using CUDA_VISIBLE_DEVICES to initialize JAX distributed system: 0,2")


if __name__ == "__main__":
Expand Down
Loading