Skip to content

[Distributed] hardware=gpu does not correctly configure process-per-node mode with Slurm #3433

@olupton

Description

@olupton

Bug report

The logic for distributed initialisation with hardware=gpu:

def initialize_jax_for_gpu(raw_keys):
"""Jax distributed initialize for GPUs."""
if os.environ.get("JAX_COORDINATOR_IP") is not None:
coordinator_ip = str(os.getenv("JAX_COORDINATOR_IP"))
coordinator_port = str(os.getenv("JAX_COORDINATOR_PORT"))
devices = os.getenv("CUDA_VISIBLE_DEVICES")
if devices is not None:
try:
devices = [int(x) for x in devices.split(",")]
except (ValueError, TypeError) as e:
max_logging.log(f"Error parsing CUDA_VISIBLE_DEVICES: {e}")
devices = None
jax.distributed.initialize(
coordinator_address=f"{coordinator_ip}:{coordinator_port}",
num_processes=int(os.getenv("NNODES")),
process_id=int(os.getenv("NODE_RANK")),
initialization_timeout=raw_keys["jax_distributed_initialization_timeout"],
local_device_ids=devices,
)
max_logging.log(f"JAX global devices: {jax.devices()}")

Is targeted at running in process-per-node mode (i.e. 1 process driving all GPUs in the machine). However, if CUDA_VISIBLE_DEVICES is not set explicitly then this falls through to auto-detection, which assumes process-per-GPU mode on Slurm. The result is that only the 0th GPU on each node is used.

It would be more user-friendly if -- given that this MaxText code is quite explicitly targeted at process-per-node mode -- it did not require the user to explicitly set CUDA_VISIBLE_DEVICES.

There is also hardware=gpu_multiprocess that defers to JAX's default distributed initialisation and, on a Slurm cluster, correctly yields a process-per-GPU configuration.

Logs/Output

srun --container-image=...--container-remap-root sh -c 'CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 NNODES=2 NODE_RANK=${SLURM_PROCID} JAX_COORDINATOR_PORT=2222 JAX_COORDINATOR_IP=... python3 -m maxtext.trainers.pre_train.train ... hardware=gpu ici_fsdp_parallelism=8 ...'

will work, while

srun --container-image=...--container-remap-root sh -c 'NNODES=2 NODE_RANK=${SLURM_PROCID} JAX_COORDINATOR_PORT=2222 JAX_COORDINATOR_IP=... python3 -m maxtext.trainers.pre_train.train ... hardware=gpu ici_fsdp_parallelism=8 ...'

will yield

AssertionError: Number of devices per slice 1 does not match the product of the ICI parallelism 8

Environment Information

No response

Additional Context

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions