diff --git a/lightllm/server/core/objs/py_sampling_params.py b/lightllm/server/core/objs/py_sampling_params.py index 0fa95d417..3cd8d99aa 100644 --- a/lightllm/server/core/objs/py_sampling_params.py +++ b/lightllm/server/core/objs/py_sampling_params.py @@ -38,7 +38,7 @@ def __init__( top_k: int = None, # -1 is for all ignore_eos: bool = False, image_max_patch_num: int = -1, - max_new_tokens: int = 16384, + max_new_tokens: int = None, min_new_tokens: int = 1, stop_sequences: Optional[Union[str, List[str], List[List[int]]]] = None, # 停止句子条件 skip_special_tokens: bool = True, # whether to skip special tokens when decoding @@ -141,11 +141,11 @@ def verify(self): raise ValueError(f"top_p must in (0.0, 1.0], got {self.top_p}") if self.top_k < -1 or self.top_k == 0: raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.") - if self.max_new_tokens < 1: + if self.max_new_tokens is not None and self.max_new_tokens < 1: raise ValueError(f"max_new_tokens must be at least 1, got {self.max_new_tokens}.") if self.min_new_tokens < 1: raise ValueError(f"min_new_tokens must be at least 1, got {self.min_new_tokens}.") - if self.min_new_tokens > self.max_new_tokens: + if self.max_new_tokens is not None and self.min_new_tokens > self.max_new_tokens: raise ValueError( f"min_new_tokens must <= max_new_tokens, but got min {self.min_new_tokens}, max {self.max_new_tokens}." ) diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 49b21c38f..62df16be6 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -345,7 +345,7 @@ def init(self, tokenizer, **kwargs): self.top_k = kwargs.get("top_k", SamplingParams._top_k) self.ignore_eos = kwargs.get("ignore_eos", False) self.image_max_patch_num = kwargs.get("image_max_patch_num", -1) - self.max_new_tokens = kwargs.get("max_new_tokens", 16384) + self.max_new_tokens = kwargs.get("max_new_tokens", -1) self.min_new_tokens = kwargs.get("min_new_tokens", 1) self.input_penalty = kwargs.get("input_penalty", DEFAULT_INPUT_PENALTY) self.group_request_id = kwargs.get("group_request_id", -1) @@ -439,11 +439,11 @@ def verify(self): raise ValueError(f"top_p must be in (0.0, 1.0], got {self.top_p}") if self.top_k < -1 or self.top_k == 0: raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.") - if self.max_new_tokens < 1: + if self.max_new_tokens != -1 and self.max_new_tokens < 1: raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.") if self.min_new_tokens < 1: raise ValueError(f"min_new_tokens must be at least 1 , got {self.min_new_tokens}.") - if self.min_new_tokens > self.max_new_tokens: + if self.max_new_tokens != -1 and self.min_new_tokens > self.max_new_tokens: raise ValueError( f"min_new_tokens must <= max_new_tokens, but got min {self.min_new_tokens}, max {self.max_new_tokens}." ) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index e28e4c93a..26dc91218 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -477,6 +477,24 @@ async def _check_and_repair_length(self, prompt_ids: List[int], sampling_params: if not prompt_ids: raise ValueError("prompt_ids is empty") prompt_tokens = len(prompt_ids) + + if sampling_params.max_new_tokens is None or sampling_params.max_new_tokens == -1: + sampling_params.max_new_tokens = self.max_req_total_len - prompt_tokens + if sampling_params.max_new_tokens < 1: + raise ValueError( + f"the input prompt token len {prompt_tokens} >= max_req_total_len {self.max_req_total_len}, " + f"no space for output tokens" + ) + if sampling_params.min_new_tokens > sampling_params.max_new_tokens: + raise ValueError( + f"min_new_tokens ({sampling_params.min_new_tokens}) > auto-calculated max_new_tokens " + f"({sampling_params.max_new_tokens}), consider reducing input length or min_new_tokens" + ) + logger.debug( + f"max_new_tokens is unset, auto-calculate to {sampling_params.max_new_tokens} " + f"(max_req_total_len {self.max_req_total_len} - prompt_tokens {prompt_tokens})" + ) + if prompt_tokens + sampling_params.max_new_tokens > self.max_req_total_len: # use long_truncation_mode to truncate long input len req. if self.args.long_truncation_mode is None: