Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/datasets/test_jsonl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
JsonlDataset,
_apply_sample_ratio,
_filter_sampled_indices,
_get_local_concurrency,
get_local_world_size,
load_dict_from_npy_dir,
save_dict_to_npy_dir,
)
Expand Down Expand Up @@ -278,9 +278,9 @@ def test_mmap_fast_path_faster_than_slow_path(self):

self.create_pg(DEVICE)
rank = dist.get_rank()
is_local_master = dist.get_rank() % _get_local_concurrency() == 0
is_local_master = dist.get_rank() % get_local_world_size() == 0

if _get_local_concurrency() <= 1:
if get_local_world_size() <= 1:
self.skipTest("需要 LOCAL_WORLD_SIZE>1 才能走 enable_mmap_shared 的 mmap 多进程分支")

tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
Expand Down
37 changes: 8 additions & 29 deletions xtuner/v1/datasets/jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,14 @@
from xtuner.v1.datasets.pt_tokenize_fn.long_text import LongTextPretrainTokenizeFunction
from xtuner.v1.datasets.rl_tokenize_fn.rl_tokenize_fn import RLTokenizeFn
from xtuner.v1.utils import SharedMemory, get_logger
from xtuner.v1.utils.device import get_torch_device_module
from xtuner.v1.utils.dist_utils import get_local_process_group, get_local_world_size, is_local_rank0

from .utils import CachableTokenizeFunction, calculate_xxhash


T = TypeVar("T")
logger = get_logger()
_lock = Lock()
DEVICE_MODULE = get_torch_device_module()

CACHE_META = ".xpuyu-cache-meta.json"
XTUNER_FILE_OPEN_CONCURRENCY = int(os.environ.get("XTUNER_FILE_OPEN_CONCURRENCY", "8"))
Expand Down Expand Up @@ -207,26 +206,6 @@ def chunk_data_to_queue(
data_queue.put(None)


def _get_local_concurrency():
"""Get the local concurrency level based on the environment variable."""
if dist.is_initialized():
local_rank_concurrency = os.getenv("LOCAL_WORLD_SIZE")
if local_rank_concurrency is None:
local_rank_concurrency = os.getenv("PROC_PER_NODE")
if local_rank_concurrency is None:
local_rank_concurrency = DEVICE_MODULE.device_count()
else:
local_rank_concurrency = 1
return int(local_rank_concurrency)


def _is_local_rank0() -> bool:
"""Return True if this process is local rank 0."""
if not dist.is_initialized():
return True
return dist.get_rank() % _get_local_concurrency() == 0


# NOTE: The `map` or `submit` function of `concurrent.futures.ProcessPoolExecutor` will cause frequent serialization
# and deserialization of the tokenizer, processing 1000 samples will serialize and deserialize 1000 times, thus
# affecting performance. Here we redefine `parallel_execute` to bind processes with `tokenize_fn`, so the tokenizer
Expand All @@ -240,7 +219,7 @@ def parallel_execute(
rank: int,
):
cpu_ids = list(os.sched_getaffinity(0))
local_rank_concurrency = _get_local_concurrency()
local_rank_concurrency = get_local_world_size()
local_cpu_ids = cpu_ids[rank::local_rank_concurrency]

processes: list[Process] = []
Expand Down Expand Up @@ -475,9 +454,9 @@ def __init__(
).hexdigest(),
)

if enable_mmap_shared and dist.is_initialized() and _get_local_concurrency() > 1:
if enable_mmap_shared and dist.is_initialized() and get_local_world_size() > 1:
_meta_need_update = {}
if _is_local_rank0():
if is_local_rank0():
_meta_need_update = self._get_meta_need_update(
_meta,
sample_ratio=sample_ratio,
Expand All @@ -487,7 +466,7 @@ def __init__(
save_dict_to_npy_dir(_meta_need_update, tmp_dir)
atexit.register(shutil.rmtree, tmp_dir, True)

dist.barrier()
dist.barrier(group=get_local_process_group())
_meta_need_update = load_dict_from_npy_dir(tmp_dir, mmap=True)
else:
_meta_need_update = self._get_meta_need_update(
Expand Down Expand Up @@ -592,7 +571,7 @@ def _init_shared_memory(self, path: str) -> SharedMemory:
if dist.is_initialized():
rank = dist.get_rank()
output: list[None | str] = [None] * dist.get_world_size()
local_concurrency = _get_local_concurrency()
local_concurrency = get_local_world_size()
# Asumming that each node has the same rank
# This allgather eunsure that each node rank share the same shared memory.
# For example:
Expand Down Expand Up @@ -691,7 +670,7 @@ def count_tokens(self, offsets, cache_dir=None):
shm_name=shm_name,
nproc=self.tokenizer_workers,
chunksize=chunked_size,
rank=rank % _get_local_concurrency(),
rank=rank % get_local_world_size(),
)
else:
tokenized = []
Expand Down Expand Up @@ -800,7 +779,7 @@ def _release_shared_memory(self):
else:
self._shared_memory.close()

local_rank_concurrency = _get_local_concurrency()
local_rank_concurrency = get_local_world_size()
if not dist.is_initialized() or dist.get_rank() % local_rank_concurrency == 0:
self._shared_memory.unlink()
self._shared_memory = None
Expand Down
2 changes: 1 addition & 1 deletion xtuner/v1/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .compile import maybe_compile
from .config import Config
from .device import get_device, get_torch_device_module
from .dist_utils import is_local_rank0
from .dtensor import is_evenly_distributed
from .enum_helper import StrEnum
from .exception_helper import ParallelConfigException
Expand All @@ -15,7 +16,6 @@
get_function_type,
get_padding_length,
is_hf_model_path,
is_local_rank0,
record_git_info,
)
from .pad import pad_to_max_length, pad_to_multiple_of
Expand Down
95 changes: 95 additions & 0 deletions xtuner/v1/utils/dist_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Utilities for distributed training: local world size, local-rank checks, and node process groups."""

import datetime
import os
from threading import Lock
from typing import cast

from torch import distributed as dist

from xtuner.v1.utils.device import get_torch_device_module


_LOCK = Lock()
_LOCAL_PROCESS_GROUP: dist.ProcessGroup | None = None


def get_local_world_size() -> int:
"""Return how many parallel processes are assumed to run on this machine.

Resolution order when ``torch.distributed`` is initialized: environment
variable ``LOCAL_WORLD_SIZE``, then ``PROC_PER_NODE``, then the accelerator
device count from :func:`~xtuner.v1.utils.device.get_torch_device_module`.
When distributed is not initialized, returns ``1``.

Returns:
int: The local (per-node) world size used to map global ranks to nodes.
"""
if dist.is_initialized():
env = os.getenv("LOCAL_WORLD_SIZE")
if env is not None:
return int(env)
env = os.getenv("PROC_PER_NODE")
if env is not None:
return int(env)
return int(get_torch_device_module().device_count())
return 1


def is_local_rank0() -> bool:
"""Return whether this process is local rank 0 within its node.

When ``torch.distributed`` is initialized, this compares
``dist.get_rank()`` to :func:`get_local_world_size` using the same
contiguous stride mapping as :func:`get_local_process_group`. When
distributed is not initialized, falls back to the ``LOCAL_RANK`` environment
variable: ``True`` if unset or equal to ``"0"``, ``False`` otherwise.

Returns:
bool: ``True`` if this process is the first rank on its node.
"""
if not dist.is_initialized():
return True
local_rank = os.getenv("LOCAL_RANK")
if local_rank is not None:
return int(local_rank) == 0
return dist.get_rank() % get_local_world_size() == 0


def get_local_process_group() -> dist.ProcessGroup:
"""Return the process group spanning ranks that belong to this node only.

Global ranks are split into contiguous blocks of length
:func:`get_local_world_size`; each block is one node's subgroup. The subgroup
is created once per interpreter and cached so callers (for example multiple
dataset instances) reuse the same communicator. Prefer this over a global
barrier when coordinating node-local filesystem paths such as ``/tmp``.

Returns:
dist.ProcessGroup: The cached node-local process group for the current rank.
"""
global _LOCAL_PROCESS_GROUP
with _LOCK:
if _LOCAL_PROCESS_GROUP is None:
if not dist.is_initialized():
raise RuntimeError("torch.distributed is not initialized.")
world_size = dist.get_world_size()
local_ws = get_local_world_size()
num_nodes = (world_size + local_ws - 1) // local_ws
timeout = datetime.timedelta(seconds=1800)
if "gloo" in (backend := dist.get_backend()):
group_kwargs: dict = {"backend": backend, "timeout": timeout}
else:
group_kwargs = {"timeout": timeout}
node_id = dist.get_rank() // local_ws
local_group: dist.ProcessGroup | None = None
for i in range(num_nodes):
start = i * local_ws
end = min(start + local_ws, world_size)
ranks = list(range(start, end))
g = dist.new_group(ranks=ranks, **group_kwargs)
if i == node_id:
local_group = g
_LOCAL_PROCESS_GROUP = local_group
return cast(dist.ProcessGroup, _LOCAL_PROCESS_GROUP)
14 changes: 0 additions & 14 deletions xtuner/v1/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,6 @@ def unlink(self) -> None:
_mprt.unregister(self._name, "shared_memory") # type: ignore[attr-defined]


def is_local_rank0() -> bool:
"""Return whether the current process is local rank 0 on its node.

In non-distributed settings (``LOCAL_RANK`` unset) every process is
considered local rank 0 and this function returns ``True``.

Returns:
bool: ``True`` if ``LOCAL_RANK`` is unset or equal to ``"0"``,
``False`` otherwise.
"""
local_rank = os.getenv("LOCAL_RANK")
return local_rank is None or local_rank == "0"


def get_padding_length(length: int, divisors: list[int]) -> int:
"""Calculate the padding length needed to make the input length divisible
by divisors.
Expand Down
Loading