Bug report
The logic for distributed initialisation with hardware=gpu:
|
def initialize_jax_for_gpu(raw_keys): |
|
"""Jax distributed initialize for GPUs.""" |
|
if os.environ.get("JAX_COORDINATOR_IP") is not None: |
|
coordinator_ip = str(os.getenv("JAX_COORDINATOR_IP")) |
|
coordinator_port = str(os.getenv("JAX_COORDINATOR_PORT")) |
|
devices = os.getenv("CUDA_VISIBLE_DEVICES") |
|
if devices is not None: |
|
try: |
|
devices = [int(x) for x in devices.split(",")] |
|
except (ValueError, TypeError) as e: |
|
max_logging.log(f"Error parsing CUDA_VISIBLE_DEVICES: {e}") |
|
devices = None |
|
|
|
jax.distributed.initialize( |
|
coordinator_address=f"{coordinator_ip}:{coordinator_port}", |
|
num_processes=int(os.getenv("NNODES")), |
|
process_id=int(os.getenv("NODE_RANK")), |
|
initialization_timeout=raw_keys["jax_distributed_initialization_timeout"], |
|
local_device_ids=devices, |
|
) |
|
max_logging.log(f"JAX global devices: {jax.devices()}") |
Is targeted at running in process-per-node mode (i.e. 1 process driving all GPUs in the machine). However, if CUDA_VISIBLE_DEVICES is not set explicitly then this falls through to auto-detection, which assumes process-per-GPU mode on Slurm. The result is that only the 0th GPU on each node is used.
It would be more user-friendly if -- given that this MaxText code is quite explicitly targeted at process-per-node mode -- it did not require the user to explicitly set CUDA_VISIBLE_DEVICES.
There is also hardware=gpu_multiprocess that defers to JAX's default distributed initialisation and, on a Slurm cluster, correctly yields a process-per-GPU configuration.
Logs/Output
srun --container-image=...--container-remap-root sh -c 'CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 NNODES=2 NODE_RANK=${SLURM_PROCID} JAX_COORDINATOR_PORT=2222 JAX_COORDINATOR_IP=... python3 -m maxtext.trainers.pre_train.train ... hardware=gpu ici_fsdp_parallelism=8 ...'
will work, while
srun --container-image=...--container-remap-root sh -c 'NNODES=2 NODE_RANK=${SLURM_PROCID} JAX_COORDINATOR_PORT=2222 JAX_COORDINATOR_IP=... python3 -m maxtext.trainers.pre_train.train ... hardware=gpu ici_fsdp_parallelism=8 ...'
will yield
AssertionError: Number of devices per slice 1 does not match the product of the ICI parallelism 8
Environment Information
No response
Additional Context
No response
Bug report
The logic for distributed initialisation with
hardware=gpu:maxtext/src/maxtext/utils/max_utils.py
Lines 246 to 266 in 37ded59
Is targeted at running in process-per-node mode (i.e. 1 process driving all GPUs in the machine). However, if
CUDA_VISIBLE_DEVICESis not set explicitly then this falls through to auto-detection, which assumes process-per-GPU mode on Slurm. The result is that only the 0th GPU on each node is used.It would be more user-friendly if -- given that this MaxText code is quite explicitly targeted at process-per-node mode -- it did not require the user to explicitly set
CUDA_VISIBLE_DEVICES.There is also
hardware=gpu_multiprocessthat defers to JAX's default distributed initialisation and, on a Slurm cluster, correctly yields a process-per-GPU configuration.Logs/Output
will work, while
will yield
Environment Information
No response
Additional Context
No response