Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions lightllm/server/core/objs/py_sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
top_k: int = None, # -1 is for all
ignore_eos: bool = False,
image_max_patch_num: int = -1,
max_new_tokens: int = 16384,
max_new_tokens: int = None,
min_new_tokens: int = 1,
stop_sequences: Optional[Union[str, List[str], List[List[int]]]] = None, # 停止句子条件
skip_special_tokens: bool = True, # whether to skip special tokens when decoding
Expand Down Expand Up @@ -141,11 +141,11 @@ def verify(self):
raise ValueError(f"top_p must in (0.0, 1.0], got {self.top_p}")
if self.top_k < -1 or self.top_k == 0:
raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.")
if self.max_new_tokens < 1:
if self.max_new_tokens is not None and self.max_new_tokens < 1:
raise ValueError(f"max_new_tokens must be at least 1, got {self.max_new_tokens}.")
if self.min_new_tokens < 1:
raise ValueError(f"min_new_tokens must be at least 1, got {self.min_new_tokens}.")
if self.min_new_tokens > self.max_new_tokens:
if self.max_new_tokens is not None and self.min_new_tokens > self.max_new_tokens:
raise ValueError(
f"min_new_tokens must <= max_new_tokens, but got min {self.min_new_tokens}, max {self.max_new_tokens}."
)
Expand Down
6 changes: 3 additions & 3 deletions lightllm/server/core/objs/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def init(self, tokenizer, **kwargs):
self.top_k = kwargs.get("top_k", SamplingParams._top_k)
self.ignore_eos = kwargs.get("ignore_eos", False)
self.image_max_patch_num = kwargs.get("image_max_patch_num", -1)
self.max_new_tokens = kwargs.get("max_new_tokens", 16384)
self.max_new_tokens = kwargs.get("max_new_tokens", -1)
self.min_new_tokens = kwargs.get("min_new_tokens", 1)
self.input_penalty = kwargs.get("input_penalty", DEFAULT_INPUT_PENALTY)
self.group_request_id = kwargs.get("group_request_id", -1)
Expand Down Expand Up @@ -439,11 +439,11 @@ def verify(self):
raise ValueError(f"top_p must be in (0.0, 1.0], got {self.top_p}")
if self.top_k < -1 or self.top_k == 0:
raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.")
if self.max_new_tokens < 1:
if self.max_new_tokens != -1 and self.max_new_tokens < 1:
raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.")
if self.min_new_tokens < 1:
raise ValueError(f"min_new_tokens must be at least 1 , got {self.min_new_tokens}.")
if self.min_new_tokens > self.max_new_tokens:
if self.max_new_tokens != -1 and self.min_new_tokens > self.max_new_tokens:
raise ValueError(
f"min_new_tokens must <= max_new_tokens, but got min {self.min_new_tokens}, max {self.max_new_tokens}."
)
Expand Down
18 changes: 18 additions & 0 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,24 @@ async def _check_and_repair_length(self, prompt_ids: List[int], sampling_params:
if not prompt_ids:
raise ValueError("prompt_ids is empty")
prompt_tokens = len(prompt_ids)

if sampling_params.max_new_tokens is None or sampling_params.max_new_tokens == -1:
sampling_params.max_new_tokens = self.max_req_total_len - prompt_tokens
if sampling_params.max_new_tokens < 1:
raise ValueError(
f"the input prompt token len {prompt_tokens} >= max_req_total_len {self.max_req_total_len}, "
f"no space for output tokens"
)
if sampling_params.min_new_tokens > sampling_params.max_new_tokens:
raise ValueError(
f"min_new_tokens ({sampling_params.min_new_tokens}) > auto-calculated max_new_tokens "
f"({sampling_params.max_new_tokens}), consider reducing input length or min_new_tokens"
)
logger.debug(
f"max_new_tokens is unset, auto-calculate to {sampling_params.max_new_tokens} "
f"(max_req_total_len {self.max_req_total_len} - prompt_tokens {prompt_tokens})"
)
Comment on lines +481 to +496
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's a potential issue where min_new_tokens could be greater than the auto-calculated max_new_tokens. The verify() method in SamplingParams correctly skips this check when max_new_tokens is not initially set. However, after max_new_tokens is calculated here, this check is not performed again, which could lead to invalid parameter combinations.

For example:

  • max_req_total_len = 2048
  • prompt_tokens = 2000
  • min_new_tokens = 100 (user-provided)
  • max_new_tokens is not provided.

In this case, max_new_tokens will be auto-calculated to 48. This results in min_new_tokens (100) being greater than max_new_tokens (48).

You should add a check for this condition after max_new_tokens is calculated.

        if sampling_params.max_new_tokens is None or sampling_params.max_new_tokens == -1:
            sampling_params.max_new_tokens = self.max_req_total_len - prompt_tokens
            if sampling_params.max_new_tokens < 1:
                raise ValueError(
                    f"the input prompt token len {prompt_tokens} >= max_req_total_len {self.max_req_total_len}, "
                    f"no space for output tokens"
                )
            if sampling_params.min_new_tokens > sampling_params.max_new_tokens:
                raise ValueError(
                    f"min_new_tokens ({sampling_params.min_new_tokens}) must be <= max_new_tokens, but max_new_tokens "
                    f"was auto-calculated to {sampling_params.max_new_tokens} based on prompt length."
                )
            logger.debug(
                f"max_new_tokens is unset, auto-calculate to {sampling_params.max_new_tokens} "
                f"(max_req_total_len {self.max_req_total_len} - prompt_tokens {prompt_tokens})"
            )


if prompt_tokens + sampling_params.max_new_tokens > self.max_req_total_len:
# use long_truncation_mode to truncate long input len req.
if self.args.long_truncation_mode is None:
Expand Down
Loading