Skip to content
12 changes: 12 additions & 0 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,18 @@ def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
self.executor.wakeup(tags)

def start_expert_distribution_record(self):
"""Start recording expert distribution on all workers."""
self.executor.start_expert_distribution_record()

def stop_expert_distribution_record(self):
"""Stop recording expert distribution on all workers."""
self.executor.stop_expert_distribution_record()

def dump_expert_distribution_record(self):
"""Dump accumulated expert distribution data on all workers."""
return self.executor.dump_expert_distribution_record()

async def async_loop(self):
engine_loop = None
try:
Expand Down
12 changes: 12 additions & 0 deletions lmdeploy/pytorch/engine/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
raise NotImplementedError('Not Implemented.')

def start_expert_distribution_record(self):
"""Start recording expert distribution."""
raise NotImplementedError('Not Implemented.')

def stop_expert_distribution_record(self):
"""Stop recording expert distribution."""
raise NotImplementedError('Not Implemented.')

def dump_expert_distribution_record(self):
"""Dump accumulated expert distribution data."""
raise NotImplementedError('Not Implemented.')

def update_params(self, request: Any):
"""Update params."""
raise NotImplementedError('Not Implemented.')
Expand Down
12 changes: 12 additions & 0 deletions lmdeploy/pytorch/engine/executor/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,18 @@ def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
self.model_agent.wakeup(tags)

def start_expert_distribution_record(self):
"""Start recording expert distribution."""
self.model_agent.start_expert_distribution_record()

def stop_expert_distribution_record(self):
"""Stop recording expert distribution."""
self.model_agent.stop_expert_distribution_record()

def dump_expert_distribution_record(self):
"""Dump accumulated expert distribution data."""
return self.model_agent.dump_expert_distribution_record()

def get_input_processor(self):
"""Build cache engine."""
return self.model_agent.get_input_processor()
Expand Down
9 changes: 9 additions & 0 deletions lmdeploy/pytorch/engine/executor/ray_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,15 @@ def wakeup(self, tags: list[str] | None = None):
self.update_configs()
self.collective_rpc('wakeup', (tags, ))

def start_expert_distribution_record(self):
self.collective_rpc('start_expert_distribution_record')

def stop_expert_distribution_record(self):
self.collective_rpc('stop_expert_distribution_record')

def dump_expert_distribution_record(self):
return self.collective_rpc('dump_expert_distribution_record')

def get_input_processor(self):
"""Build cache engine."""
return ray.get(self.workers[0].get_input_processor.remote())
Expand Down
9 changes: 9 additions & 0 deletions lmdeploy/pytorch/engine/executor/uni_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ def build_cache_engine(self):
def warmup(self):
self.model_agent.warmup()

def start_expert_distribution_record(self):
self.model_agent.start_expert_distribution_record()

def stop_expert_distribution_record(self):
self.model_agent.stop_expert_distribution_record()

def dump_expert_distribution_record(self):
return self.model_agent.dump_expert_distribution_record()

def start(self, forward_event: asyncio.Event):
"""Start engine loop."""
self.model_agent.start(forward_event)
Expand Down
21 changes: 21 additions & 0 deletions lmdeploy/pytorch/engine/model_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,6 +1242,27 @@ def wakeup(self, tags: list[str] | None = None):
if self.dist_config.dp > 1:
self.state.to_wakeup.set()

def start_expert_distribution_record(self):
"""Start recording expert distribution on this worker."""
from lmdeploy.pytorch.models.utils.expert_distribution_recorder import (
get_expert_distribution_recorder,
)
get_expert_distribution_recorder().start_record()

def stop_expert_distribution_record(self):
"""Stop recording expert distribution on this worker."""
from lmdeploy.pytorch.models.utils.expert_distribution_recorder import (
get_expert_distribution_recorder,
)
get_expert_distribution_recorder().stop_record()

def dump_expert_distribution_record(self):
"""Dump accumulated expert distribution data on this worker."""
from lmdeploy.pytorch.models.utils.expert_distribution_recorder import (
get_expert_distribution_recorder,
)
return get_expert_distribution_recorder().dump_record()

def release(self):
"""release."""
self.reset_graph_runner()
Expand Down
9 changes: 9 additions & 0 deletions lmdeploy/pytorch/engine/mp_engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
return self._collective_rpc('wakeup', tags)

def start_expert_distribution_record(self):
return self._collective_rpc('start_expert_distribution_record')

def stop_expert_distribution_record(self):
return self._collective_rpc('stop_expert_distribution_record')

def dump_expert_distribution_record(self):
return self._collective_rpc('dump_expert_distribution_record')

def update_params(self, request: Any):
"""Update params."""
return self._collective_rpc('update_params', request)
Expand Down
9 changes: 9 additions & 0 deletions lmdeploy/pytorch/engine/mp_engine/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,15 @@ def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
return self.engine.wakeup(tags)

def start_expert_distribution_record(self):
return self.engine.start_expert_distribution_record()

def stop_expert_distribution_record(self):
return self.engine.stop_expert_distribution_record()

def dump_expert_distribution_record(self):
return self.engine.dump_expert_distribution_record()

def update_params(self, request: Any):
"""Update params."""
return self.engine.update_params(request)
Expand Down
6 changes: 6 additions & 0 deletions lmdeploy/pytorch/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ def _patched_get_env(
# repetition check
repetition_window_size = env_to_int('LMDEPLOY_REPETITION_WINDOW_SIZE', 1024)

# dump expert distribution
dump_expert_distribution = env_to_bool('LMDEPLOY_DUMP_EXPERT_DISTRIBUTION', False)
expert_dump_dir = os.getenv('LMDEPLOY_EXPERT_DUMP_DIR', '/tmp/lmdeploy/expert_distribution')
expert_dump_rank = env_to_int('LMDEPLOY_EXPERT_DUMP_RANK', 0)
expert_dump_visualize = env_to_bool('LMDEPLOY_EXPERT_DUMP_VISUALIZE', False)


def get_all_envs():
"""Get all environment variables."""
Expand Down
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin
from .utils.expert_distribution_recorder import get_expert_distribution_recorder


# microbatch
Expand Down Expand Up @@ -678,6 +679,7 @@ class DeepseekV2MoE(nn.Module):
def __init__(self, config: Any, layer_idx, dtype: torch.dtype = None, device: torch.device = None):
super().__init__()
quantization_config = getattr(config, 'quantization_config', None)
self.layer_idx = layer_idx
self.hidden_dim = config.hidden_size
self.ffn_dim = config.moe_intermediate_size
self.num_experts = config.n_routed_experts
Expand Down Expand Up @@ -751,6 +753,7 @@ def forward(self, hidden_states: torch.Tensor):
if self._all_reduce:
dist.all_reduce(out_states)

get_expert_distribution_recorder().record(topk_ids, self.layer_idx, self.num_experts)
return out_states


Expand Down
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/models/qwen3_5_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from .qwen3_5 import Qwen3_5VisionModel as Qwen3_5MoeVisionModel
from .qwen3_vl import Qwen3VLInputProcessor as Qwen3_5MoeInputProcessor
from .utils.expert_distribution_recorder import get_expert_distribution_recorder


class Qwen3_5MoeTopKRouter(nn.Module):
Expand Down Expand Up @@ -123,6 +124,8 @@ def forward(self, hidden_states: torch.Tensor, all_routed_experts: torch.Tensor

if self._all_reduce:
dist.all_reduce(out_states)

get_expert_distribution_recorder().record(topk_ids, self.layer_idx, self.num_experts)
return out_states


Expand Down
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from .patch import add_prefix, get_build_model_context
from .utils.cudagraph import CudaGraphMixin
from .utils.expert_distribution_recorder import get_expert_distribution_recorder
from .utils.model import DeployModelMixinV1, build_embedding


Expand Down Expand Up @@ -263,6 +264,8 @@ def forward(
)

out_states = out_states.reshape(batch_size, sequence_length, -1)

get_expert_distribution_recorder().record(topk_ids, self.layer_idx, self.num_experts)
return out_states


Expand Down
157 changes: 157 additions & 0 deletions lmdeploy/pytorch/models/utils/expert_distribution_recorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright (c) OpenMMLab. All rights reserved.
# adapted from https://github.com/DeepLink-org/dlBLAS/blob/main/dlblas/layers/moe/experts_distribution_recorder.py

import os

import torch
import torch.distributed as dist

from lmdeploy.pytorch import envs
from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')


class _ExpertsDistributionRecorderNoOp:
"""No-op recorder used when expert distribution recording is disabled."""

def record(self, *args, **kwargs):
pass

def start_record(self):
pass

def stop_record(self):
pass

def dump_record(self):
pass


class _ExpertsDistributionRecorderImpl:
"""Records per-expert token dispatch counts across MoE layers."""

def __init__(self):
self.output_dir = envs.expert_dump_dir
self.dispatch_count = {}
self.accum_token_counts = {}
self.dump_rank = envs.expert_dump_rank
self._recording = False

def _reset_accumulators(self):
self.dispatch_count.clear()
self.accum_token_counts.clear()

def _build_counts_tensor(self, counts_dict=None):
"""Stack per-layer counts into a (num_layers, num_experts) tensor."""
if counts_dict is None:
counts_dict = self.accum_token_counts
sorted_keys = sorted(counts_dict.keys(), key=lambda k: int(k.split('_')[0]))
return torch.stack([counts_dict[k].cpu() for k in sorted_keys])

@staticmethod
def _compute_balancedness(counts: torch.Tensor) -> torch.Tensor:
"""Per-layer balancedness: mean / max load. Shape: (num_layers,). Range (0,1]; 1 = perfect balance."""
counts_f = counts.float()
return (counts_f.mean(dim=1) + 1e-5) / (counts_f.max(dim=1).values + 1e-5)

def record(self, topk_ids, layer_index, num_experts):
if not self._recording:
return

key = f'{layer_index}_{num_experts}'
if key not in self.dispatch_count:
self.dispatch_count[key] = 0
self.dispatch_count[key] += 1
if key not in self.accum_token_counts:
self.accum_token_counts[key] = torch.zeros(num_experts, dtype=torch.int64, device=topk_ids.device)
topk_ids_flat = topk_ids.reshape(-1).long()
# scatter_add_ is graph-capturable; torch.bincount is not
self.accum_token_counts[key].scatter_add_(
0, topk_ids_flat, torch.ones(topk_ids_flat.numel(), dtype=torch.int64, device=topk_ids_flat.device))

def start_record(self):
logger.info('[Expert Statistics] Recording started.')
self._reset_accumulators()
self._recording = True

def stop_record(self):
logger.info('[Expert Statistics] Recording stopped.')
self._recording = False

def dump_record(self):
if not self._recording:
logger.info('[Expert Statistics] dump_record called but recording is not active.')
return None

if not self.accum_token_counts:
logger.info('[Expert Statistics] dump_record called but no data has been accumulated yet.')
return None

if torch.cuda.is_current_stream_capturing():
logger.warning('[Expert Statistics] dump_record skipped during CUDA graph capture.')
return None

rank = dist.get_rank() if dist.is_initialized() else 0
step = max(self.dispatch_count.values()) if self.dispatch_count else 0
return self._dump(rank, step)

def _dump(self, rank: int, step: int):
logger.info(f'[Expert Statistics] Dumping expert distribution at step {step} from rank {rank}...')

if dist.is_initialized():
# clone before all_reduce to avoid corrupting the local accumulator.
global_counts = {k: v.clone() for k, v in self.accum_token_counts.items()}
for t in global_counts.values():
dist.all_reduce(t, op=dist.ReduceOp.SUM)
else:
global_counts = self.accum_token_counts

if rank != self.dump_rank:
return None

counts_tensor = self._build_counts_tensor(global_counts) # (num_layers, num_experts)
balancedness = self._compute_balancedness(counts_tensor) # (num_layers,)

# log per-layer balancedness
bal_list = balancedness.tolist()
bottom3 = sorted(range(len(bal_list)), key=lambda i: bal_list[i])[:3]
logger.info(
f'[Expert Statistics] step={step} | avg_balancedness={sum(bal_list)/len(bal_list):.4f} | '
f'most_imbalanced_layers={[(i, f"{bal_list[i]:.4f}") for i in bottom3]}'
)

os.makedirs(self.output_dir, exist_ok=True)
filepath = os.path.join(self.output_dir, f'rank{rank}_step{step}_expert_counts.pt')
torch.save(
{
'counts': counts_tensor,
'balancedness': balancedness,
'total_tokens': counts_tensor.sum(dim=1),
'step': step,
'rank': rank,
},
filepath,
)
logger.info(f'[Expert Statistics] Expert distribution dumped to {filepath}, shape={list(counts_tensor.shape)}')

if envs.expert_dump_visualize:
from lmdeploy.pytorch.tools.utils import visualize_expert_distribution
visualize_expert_distribution(filepath)

return filepath


_global_recorder = None


def get_expert_distribution_recorder():
global _global_recorder

if _global_recorder is None:
if envs.dump_expert_distribution:
_global_recorder = _ExpertsDistributionRecorderImpl()
else:
_global_recorder = _ExpertsDistributionRecorderNoOp()

return _global_recorder
Loading
Loading