From f765297e871a59d6a47803504e79a80321674028 Mon Sep 17 00:00:00 2001 From: Gabe Weisz Date: Mon, 6 Apr 2026 12:32:05 +0000 Subject: [PATCH] Try to use SLURM_STEP_GPUS for device list if CUDA_VISIBLE_DEVICES is not set --- src/maxtext/utils/max_utils.py | 25 ++++++++++--- tests/unit/max_utils_test.py | 66 +++++++++++++++++++++++++++++----- 2 files changed, 79 insertions(+), 12 deletions(-) diff --git a/src/maxtext/utils/max_utils.py b/src/maxtext/utils/max_utils.py index 9e45def210..325331ce6f 100644 --- a/src/maxtext/utils/max_utils.py +++ b/src/maxtext/utils/max_utils.py @@ -13,7 +13,8 @@ # limitations under the License. """Common Max Utils needed by multiple modules. -All the functions include MaxText modules, such as Pyconfig, should be moved to MaxText utils file.""" +All the functions include MaxText modules, such as Pyconfig, should be moved to MaxText utils file. +""" import collections from collections.abc import Sequence @@ -250,12 +251,21 @@ def initialize_jax_for_gpu(raw_keys): if os.environ.get("JAX_COORDINATOR_IP") is not None: coordinator_ip = str(os.getenv("JAX_COORDINATOR_IP")) coordinator_port = str(os.getenv("JAX_COORDINATOR_PORT")) - devices = os.getenv("CUDA_VISIBLE_DEVICES") + env_var_list = ["CUDA_VISIBLE_DEVICES", "SLURM_STEP_GPUS"] + for env_var in env_var_list: + devices = os.getenv(env_var) + if devices is not None: + max_logging.log(f"Using {env_var} to initialize JAX distributed system: {devices}") + break + if devices is None: + jax.config.update("jax_cuda_visible_devices", "all") + jax.config.update("jax_rocm_visible_devices", "all") + if devices is not None: try: devices = [int(x) for x in devices.split(",")] except (ValueError, TypeError) as e: - max_logging.log(f"Error parsing CUDA_VISIBLE_DEVICES: {e}") + max_logging.log(f"Error parsing {env_var}: {e}") devices = None jax.distributed.initialize( @@ -848,7 +858,14 @@ def reorder_causal_load_balanced(batch, cp_size): cp_size=cp_size, ) if key - in ["inputs", "targets", "inputs_position", "targets_position", "inputs_segmentation", "targets_segmentation"] + in [ + "inputs", + "targets", + "inputs_position", + "targets_position", + "inputs_segmentation", + "targets_segmentation", + ] else value for key, value in batch.items() } diff --git a/tests/unit/max_utils_test.py b/tests/unit/max_utils_test.py index 5eba20f807..b027bb33be 100644 --- a/tests/unit/max_utils_test.py +++ b/tests/unit/max_utils_test.py @@ -201,7 +201,7 @@ def test_unscan_train_state_params(self): class TestGpuDistributedInitialization(unittest.TestCase): - """Tests using CUDA_VISIBLE_DEVICES to control which GPUs are used in jax.distributed.initialize.""" + """Tests using CUDA_VISIBLE_DEVICES / SLURM_STEP_GPUS for jax.distributed.initialize.""" @mock.patch.dict( os.environ, @@ -210,7 +210,7 @@ class TestGpuDistributedInitialization(unittest.TestCase): "JAX_COORDINATOR_PORT": "1234", "NNODES": "1", "NODE_RANK": "0", - "CUDA_VISIBLE_DEVICES": "0,2,3", # Simulating Slurm/orchestrator assignment + "CUDA_VISIBLE_DEVICES": "0,2,3", }, ) @mock.patch("jax.distributed.initialize") @@ -220,7 +220,6 @@ def test_initialize_jax_for_gpu_valid_devices(self, _mock_log, _mock_devices, mo """Verifies that a comma-separated string of IDs is correctly parsed.""" raw_keys = {"jax_distributed_initialization_timeout": 300} max_utils.initialize_jax_for_gpu(raw_keys) - # Check that local_device_ids was passed correctly as a list of integers _, kwargs = mock_init.call_args self.assertEqual(kwargs["local_device_ids"], [0, 2, 3]) self.assertEqual(kwargs["coordinator_address"], "10.0.0.1:1234") @@ -232,17 +231,16 @@ def test_initialize_jax_for_gpu_valid_devices(self, _mock_log, _mock_devices, mo "JAX_COORDINATOR_PORT": "1234", "NNODES": "1", "NODE_RANK": "0", - "CUDA_VISIBLE_DEVICES": "GPU-8f2e3072-...", # Invalid format for integer parsing + "CUDA_VISIBLE_DEVICES": "GPU-8f2e3072-...", }, ) @mock.patch("jax.distributed.initialize") @mock.patch("jax.devices") @mock.patch("maxtext.utils.max_logging.log") - def test_initialize_jax_for_gpu_invalid_devices(self, _mock_log, mock_devices, mock_init): + def test_initialize_jax_for_gpu_invalid_devices(self, _mock_log, _mock_devices, mock_init): """Verifies fallback behavior when parsing fails (e.g., UUIDs).""" raw_keys = {"jax_distributed_initialization_timeout": 300} max_utils.initialize_jax_for_gpu(raw_keys) - # Check that it falls back to None (JAX auto-detection default) on error _, kwargs = mock_init.call_args self.assertIsNone(kwargs.get("local_device_ids")) self.assertEqual(kwargs["coordinator_address"], "10.0.0.1:1234") @@ -255,17 +253,69 @@ def test_initialize_jax_for_gpu_invalid_devices(self, _mock_log, mock_devices, m "NNODES": "1", "NODE_RANK": "0", }, + clear=True, ) + @mock.patch("jax.config.update") @mock.patch("jax.distributed.initialize") @mock.patch("jax.devices") @mock.patch("maxtext.utils.max_logging.log") - def test_initialize_jax_for_gpu_no_devices(self, _mock_log, mock_devices, mock_init): - """Verifies that no error occurs when CUDA_VISIBLE_DEVICES is not set""" + def test_initialize_jax_for_gpu_no_devices(self, _mock_log, _mock_devices, mock_init, mock_config_update): + """When coordinator env is set but neither CUDA_VISIBLE_DEVICES nor SLURM_STEP_GPUS is set, JAX uses all devices + (config) and init gets no local ids. + """ raw_keys = {"jax_distributed_initialization_timeout": 300} max_utils.initialize_jax_for_gpu(raw_keys) _, kwargs = mock_init.call_args self.assertIsNone(kwargs.get("local_device_ids")) self.assertEqual(kwargs["coordinator_address"], "10.0.0.1:1234") + mock_config_update.assert_has_calls( + [ + mock.call("jax_cuda_visible_devices", "all"), + mock.call("jax_rocm_visible_devices", "all"), + ] + ) + + @mock.patch("jax.distributed.initialize") + @mock.patch("jax.devices") + @mock.patch("maxtext.utils.max_logging.log") + def test_initialize_jax_for_gpu_uses_slurm_when_cuda_unset(self, mock_log, _mock_devices, mock_init): + """Uses SLURM_STEP_GPUS when CUDA_VISIBLE_DEVICES is absent (loop over env_var_list).""" + env = { + "JAX_COORDINATOR_IP": "10.0.0.1", + "JAX_COORDINATOR_PORT": "1234", + "NNODES": "1", + "NODE_RANK": "0", + "SLURM_STEP_GPUS": "1,3", + } + with mock.patch.dict(os.environ, env, clear=False): + os.environ.pop("CUDA_VISIBLE_DEVICES", None) + raw_keys = {"jax_distributed_initialization_timeout": 300} + max_utils.initialize_jax_for_gpu(raw_keys) + _, kwargs = mock_init.call_args + self.assertEqual(kwargs["local_device_ids"], [1, 3]) + mock_log.assert_any_call("Using SLURM_STEP_GPUS to initialize JAX distributed system: 1,3") + + @mock.patch.dict( + os.environ, + { + "JAX_COORDINATOR_IP": "10.0.0.1", + "JAX_COORDINATOR_PORT": "1234", + "NNODES": "1", + "NODE_RANK": "0", + "CUDA_VISIBLE_DEVICES": "0,2", + "SLURM_STEP_GPUS": "4,5,6", + }, + ) + @mock.patch("jax.distributed.initialize") + @mock.patch("jax.devices") + @mock.patch("maxtext.utils.max_logging.log") + def test_initialize_jax_for_gpu_prefers_cuda_visible_devices_in_loop(self, mock_log, _mock_devices, mock_init): + """First matching env var in the list wins; CUDA_VISIBLE_DEVICES is checked before SLURM_STEP_GPUS.""" + raw_keys = {"jax_distributed_initialization_timeout": 300} + max_utils.initialize_jax_for_gpu(raw_keys) + _, kwargs = mock_init.call_args + self.assertEqual(kwargs["local_device_ids"], [0, 2]) + mock_log.assert_any_call("Using CUDA_VISIBLE_DEVICES to initialize JAX distributed system: 0,2") if __name__ == "__main__":