Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy
import bisect
from typing import Optional
from tqdm import tqdm
from lightllm.utils.log_utils import init_logger
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup
Expand Down Expand Up @@ -196,7 +197,11 @@ def warmup(self, model):
model: TpPartBaseModel = model

# decode cuda graph init
for batch_size in self.cuda_graph_batch_sizes[::-1]:
progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing CUDA graphs")
for batch_size in progress_bar:
avail_mem, _ = torch.cuda.mem_get_info()
avail_mem_gb = avail_mem / (1024 ** 3)
progress_bar.set_description(f"Capturing CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB")
seq_len = 2
total_token_num = batch_size * seq_len
max_len_in_batch = self.graph_max_len_in_batch
Expand Down Expand Up @@ -251,7 +256,13 @@ def warmup_overlap(self, model):

model: TpPartBaseModel = model

for batch_size in self.cuda_graph_batch_sizes[::-1]:
progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing overlap CUDA graphs")
for batch_size in progress_bar:
avail_mem, _ = torch.cuda.mem_get_info()
avail_mem_gb = avail_mem / (1024 ** 3)
progress_bar.set_description(
f"Capturing overlap CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB"
)
decode_batches = []
for micro_batch_index in [0, 1]:
# dummy decoding, capture the cudagraph
Expand Down
6 changes: 3 additions & 3 deletions lightllm/common/triton_utils/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _try_load_cache(self, static_key):

cache_file = os.path.join(self.cache_dir, KernelConfigs.get_config_file_name(static_key))
if os.path.exists(cache_file):
logger.info(f"Loading cached configs for {self.kernel_name} - {static_key}")
logger.info(f"Loading cached configs for {self.kernel_name} - {dict(static_key)}")
with open(cache_file, "rb") as f:
self.cached_configs[static_key] = orjson.loads(f.read())
return True
Expand Down Expand Up @@ -353,9 +353,9 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size):
option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS | orjson.OPT_NON_STR_KEYS,
)
)
logger.info(f"Saved configs for {self.kernel_name} - {_static_key}")
logger.info(f"Saved configs for {self.kernel_name} - {dict(_static_key)}")

logger.info(f"rank {rank_id} tuning {self.kernel_name} _static_key {static_key} finished")
logger.info(f"rank {rank_id} tuning {self.kernel_name} _static_key {dict(static_key)} finished")

def _mutate_args_clone(self, args, kwargs):
origin_list = []
Expand Down
4 changes: 2 additions & 2 deletions lightllm/server/detokenization/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _add_new_group_req_index(self, recv_obj: GroupReqIndexes):
req.link_prompt_ids_shm_array()
req.link_logprobs_shm_array()

logger.info(
logger.debug(
f"detokenization recv req id {req.request_id} " f"cost time {time.time() - recv_obj.time_mark} s"
)

Expand Down Expand Up @@ -160,7 +160,7 @@ def remove_finished_reqs(self):

for decode_req in finished_reqs:
decode_req.req.can_released_mark = True
logger.info(f"detoken release req id {decode_req.req.request_id}")
logger.debug(f"detoken release req id {decode_req.req.request_id}")
self.shm_req_manager.put_back_req_obj(decode_req.req)
self.req_id_to_out.pop(decode_req.request_id, None)
return
Expand Down
6 changes: 3 additions & 3 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ async def _wait_to_token_package(
(out_token_counter - sum(sub_req_id_to_mtp_accepted_token_num.values())), 1
)
format_start_time = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
logger.info(
logger.debug(
f"X-Request-Id:{x_request_id} "
f"X-Session-Id:{x_session_id} start_time:{format_start_time} "
f"lightllm_req_id:{group_request_id} first_token_cost:{first_token_cost_ms}ms "
Expand Down Expand Up @@ -732,8 +732,8 @@ async def recycle_resource_loop(self):
if req_status is None:
continue

logger.info(
f"left req id {req_status.group_req_objs.group_req_id}"
logger.debug(
f"left req id {req_status.group_req_objs.group_req_id} "
f"can release {req_status.group_req_objs.shm_req_objs[0].can_released_mark} "
f"refcount {req_status.group_req_objs.shm_req_objs[0].ref_count}"
)
Expand Down
8 changes: 2 additions & 6 deletions lightllm/server/router/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Dict, List, Optional, Tuple, Union
from lightllm.server.core.objs import ShmReqManager, Req
from lightllm.utils.log_utils import init_logger
from .stats import RouterStatics

logger = init_logger(__name__)

Expand Down Expand Up @@ -50,14 +49,11 @@ def get_all_dp_req_num(self) -> List[int]:
all_dp_req_num[req.sample_params.suggested_dp_index] += 1
return all_dp_req_num

def filter_out_finished_req(self, shm_req_manager: ShmReqManager, router_statics: RouterStatics):
def filter_out_finished_req(self, shm_req_manager: ShmReqManager):
unfinished_req_ids = []
for req in self.reqs:
if req.shm_infer_released:
logger.info(f"router release req id {req.request_id}")
if not req.is_aborted:
router_statics.update(req.candetoken_out_len)

logger.debug(f"router release req id {req.request_id}")
shm_req_manager.put_back_req_obj(req)
req = None
else:
Expand Down
65 changes: 35 additions & 30 deletions lightllm/server/router/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .batch import Batch, Req
from .model_infer.model_rpc import start_model_process, ModelRpcClient
from .req_queue import build_req_queue
from .stats import SystemStatusReporter
from lightllm.server.core.objs.io_objs import (
GroupReqIndexes,
AbortedReqCmd,
Expand All @@ -25,7 +26,7 @@
from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient
from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient
from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer
from lightllm.utils.log_utils import init_logger, log_time_ready
from lightllm.utils.log_utils import init_logger
from lightllm.server.router.token_load import TokenLoad
from lightllm.server.metrics.manager import MetricClient
from lightllm.common.basemodel.infer_lock import g_router_lock
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(self, args: StartArgs):
self.read_only_statics_mem_manager = ReadOnlyStaticsMemoryManager()
# 初始化 radix_cache_client 用于读取 prompt cache 的管理信息
self.radix_cache_client = None
self.status_reporter = None

# 共享变量,用于存储router端调度分析得到的机器负载信息
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node)
Expand Down Expand Up @@ -194,6 +196,11 @@ async def wait_to_model_ready(self):
)
self.req_queue = build_req_queue(self.args, self, self.dp_size_in_node)
logger.info(f"use req queue {self.req_queue.__class__.__name__}")
self.status_reporter = SystemStatusReporter(
args=self.args,
max_total_token_num=self.max_total_token_num,
dp_size_in_node=self.dp_size_in_node,
)

if self.args.run_mode == "prefill":
# 启动 prefill kv move 管理进程
Expand Down Expand Up @@ -239,27 +246,10 @@ async def loop_for_fwd(
await self._step()
counter_count += 1
if self.running_batch is not None:
# Count output tokens (each running req produces ~1 token per decode step)
self.status_reporter.count_output_tokens(len(self.running_batch.reqs))
if counter_count % 100 == 0:
for dp_index in range(self.dp_size_in_node):
token_ratio1 = self.get_used_tokens(dp_index) / self.max_total_token_num
token_ratio2 = (
self.max_total_token_num
- self.read_only_statics_mem_manager.get_unrefed_token_num(dp_index)
) / self.max_total_token_num
d_i = dp_index
frozen_token_num = self.shared_token_load.get_frozened_token_count(d_i)
estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(d_i)
paused_req_num = self._get_paused_req_num_in_dp_index(dp_index=d_i)
logger.debug(
f"dp_i {d_i} current batch size: {len(self.running_batch.reqs)} \n"
f"dp_i {d_i} paused req num: {paused_req_num} \n"
f"dp_i {d_i} frozen token num: {frozen_token_num} \n"
f"dp_i {d_i} estimated_peak_token_count: {estimated_peak_token_count} \n"
f"dp_i {d_i} token used ratio: {token_ratio1} not contain prompt cache tree unrefed token\n"
f"dp_i {d_i} token used ratio: {token_ratio2} contain prompt cache tree unrefed token"
)
logger.debug(self.router_statics.log_str())
self.metric_client.gauge_set("lightllm_batch_pause_size", paused_req_num)
self.metric_client.gauge_set("lightllm_batch_pause_size", self._get_paused_req_num())
# pd decode mode need to update token_load more frequently
self.req_queue.update_token_load(self.running_batch, force_update=self.is_pd_decode_mode)
self.metric_client.gauge_set("lightllm_batch_current_size", len(self.running_batch.reqs))
Expand All @@ -278,13 +268,15 @@ async def loop_for_fwd(
self.metric_client.gauge_set("lightllm_batch_pause_size", 0.0)
self.metric_client.gauge_set("lightllm_queue_size", 0.0)
self.metric_client.gauge_set("lightllm_batch_current_max_tokens", 0.0)
# 60s print once
if log_time_ready("frozen_info", 60):
for dp_i in range(self.dp_size_in_node):
frozen_token_num = self.shared_token_load.get_frozened_token_count(dp_i)
estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(dp_i)
logger.debug(f"dp_i {dp_i} frozen token num: {frozen_token_num} \n")
logger.debug(f"dp_i {dp_i} estimated_peak_token_count: {estimated_peak_token_count} \n")

self.status_reporter.maybe_print(
running_batch=self.running_batch,
req_queue=self.req_queue,
read_only_statics_mem_manager=self.read_only_statics_mem_manager,
paused_req_num=self._get_paused_req_num(),
radix_cache_client=self.radix_cache_client,
disable_dynamic_prompt_cache=self.args.disable_dynamic_prompt_cache,
)

await asyncio.sleep(self._get_schedule_time_interval())

Expand Down Expand Up @@ -314,6 +306,7 @@ async def _step(self):

async def _add_batch(self, batch: Batch):
# 添加新请求
self.status_reporter.count_prompt_tokens(batch.input_tokens())
reqs = [r.to_router_rpc_obj() for r in batch.reqs]
while not self.shm_reqs_io_buffer.is_empty():
await asyncio.sleep(0.02)
Expand Down Expand Up @@ -350,7 +343,19 @@ def _add_new_batch_to_running_batch(self, new_batch: Batch):

def _filter_reqs_from_running_batch(self):
if self.running_batch is not None:
self.running_batch.filter_out_finished_req(self.shm_req_manager, self.router_statics)
# Capture finished req stats before filtering
for req in self.running_batch.reqs:
if req.shm_infer_released:
self.status_reporter.on_request_completed(
input_len=req.input_len,
output_len=req.shm_cur_output_len,
cache_len=req.prompt_cache_len,
mtp_accepted=req.mtp_accepted_token_num,
)
# Update EMA for output length prediction (used by scheduler)
if not req.is_aborted:
self.router_statics.update(req.candetoken_out_len)
self.running_batch.filter_out_finished_req(self.shm_req_manager)
if self.running_batch.is_clear():
self.running_batch = None
return
Expand Down Expand Up @@ -422,7 +427,7 @@ def _add_req(self, group_req_indexes: GroupReqIndexes):
req._router_stop_str_matched = False
req_group.append(req)

logger.info(f"router recive req id {req.request_id} cost time {time.time() - req.start_time} s")
logger.debug(f"router receive req id {req.request_id} cost time {time.time() - req.start_time} s")
self.req_queue.extend(req_group)
self.send_to_detokenization.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
return
Expand Down
111 changes: 109 additions & 2 deletions lightllm/server/router/stats.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,114 @@
from lightllm.utils.log_utils import init_logger
import time
import logging
from lightllm.server.core.objs import StartArgs
from lightllm.utils.log_utils import init_system_status_logger

logger = init_logger(__name__)
logger = logging.getLogger(__name__)


class SystemStatusReporter:
def __init__(self, args, max_total_token_num, dp_size_in_node):
self.enabled = not args.disable_log_stats
self.interval = max(5, args.log_stats_interval)
if args.log_stats_interval < 5:
logger.warning(f"log_stats_interval={args.log_stats_interval}s is below minimum, using 5s")
self.max_total_token_num = max_total_token_num
self.dp_size_in_node = dp_size_in_node
self.status_logger = init_system_status_logger("router")

# Accumulation counters (reset each interval)
self.last_print_time = time.time()
self.prompt_tokens = 0
self.output_tokens = 0

# Global counters (never reset, for lifetime stats)
self.global_input_total = 0
self.global_cache_total = 0
self.global_mtp_output_total = 0
self.global_mtp_accepted_total = 0

def count_prompt_tokens(self, num_tokens: int):
if self.enabled:
self.prompt_tokens += num_tokens

def count_output_tokens(self, num_tokens: int):
if self.enabled:
self.output_tokens += num_tokens

def on_request_completed(self, input_len: int, output_len: int, cache_len: int, mtp_accepted: int):
if self.enabled:
self.global_input_total += input_len
self.global_cache_total += cache_len
self.global_mtp_output_total += output_len
self.global_mtp_accepted_total += mtp_accepted

def maybe_print(
self,
running_batch,
req_queue,
read_only_statics_mem_manager,
paused_req_num=0,
radix_cache_client=None,
disable_dynamic_prompt_cache=False,
):
Comment on lines +45 to +53
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The maybe_print method has a large number of parameters. Many of these, such as req_queue, read_only_statics_mem_manager, radix_cache_client, and disable_dynamic_prompt_cache, are available when SystemStatusReporter is initialized and seem to be constant throughout its lifetime.

To improve code clarity and maintainability, consider moving these stable dependencies to the __init__ method. This simplifies the maybe_print signature and makes the dependencies of SystemStatusReporter more explicit.

For example, you could modify __init__ to accept these objects and store them as instance attributes. Then maybe_print would only need the parameters that change on each call, like running_batch and paused_req_num.

if not self.enabled:
return
now = time.time()
elapsed = now - self.last_print_time
if elapsed < self.interval:
return

total_tps = (self.prompt_tokens + self.output_tokens) / elapsed
input_tps = self.prompt_tokens / elapsed
output_tps = self.output_tokens / elapsed

running = len(running_batch.reqs) if running_batch else 0
queued = req_queue.get_wait_req_num()

# Memory utilization (average across dp)
# kv_used: physical KV memory usage (includes prefix cache tree occupancy)
# kv_used_no_cache: effective usage excluding unrefed prefix cache tokens
kv_used_list = []
kv_used_no_cache_list = []
for dp_i in range(self.dp_size_in_node):
unrefed = read_only_statics_mem_manager.get_unrefed_token_num(dp_i)
used = self.max_total_token_num - unrefed
kv_used_list.append(used / self.max_total_token_num)
if not disable_dynamic_prompt_cache and radix_cache_client is not None:
cache_unrefed = radix_cache_client.get_unrefed_tokens_num(dp_i)
kv_used_no_cache_list.append((used - cache_unrefed) / self.max_total_token_num)
else:
kv_used_no_cache_list.append(used / self.max_total_token_num)
avg_kv_used = sum(kv_used_list) / len(kv_used_list)
avg_kv_used_no_cache = sum(kv_used_no_cache_list) / len(kv_used_no_cache_list)

# Global prefix cache hit rate
cache_hit_rate = (
(self.global_cache_total / self.global_input_total * 100) if self.global_input_total > 0 else 0.0
)

kv_pct = avg_kv_used * 100
kv_pct_no_cache = avg_kv_used_no_cache * 100

# Avg MTP accepted length (only shown when MTP is active)
mtp_suffix = ""
if self.global_mtp_accepted_total > 0:
decode_steps = self.global_mtp_output_total - self.global_mtp_accepted_total
avg_mtp_len = self.global_mtp_output_total / max(decode_steps, 1)
mtp_suffix = f" | MTP {avg_mtp_len:.2f}"

self.status_logger.info(
f"Throughput {total_tps:>7.1f} tok/s (in {input_tps:.1f}, out {output_tps:.1f}) | "
f"Reqs {running} run, {queued} wait, {paused_req_num} pause | "
f"KV Cache {kv_pct:.1f}% (active {kv_pct_no_cache:.1f}%) | "
f"Prefix Hit {cache_hit_rate:.1f}%"
f"{mtp_suffix}"
)

# Reset windowed counters
self.prompt_tokens = 0
self.output_tokens = 0
self.last_print_time = now


class RouterStatics:
Expand Down
Loading
Loading