diff --git a/questions/inference_server/inference_server.py b/questions/inference_server/inference_server.py index 0f49f2d..f87ed03 100644 --- a/questions/inference_server/inference_server.py +++ b/questions/inference_server/inference_server.py @@ -35,6 +35,7 @@ fast_inference, fast_feature_extract_inference, ) +from questions.vllm_inference import fast_vllm_inference, VLLM_AVAILABLE from questions.utils import log_time from sellerinfo import session_secret from .models import build_model @@ -478,10 +479,11 @@ async def feature_extraction( @app.get("/liveness_check") async def liveness_check(request: Request): # global daemon - inference_result = fast_inference( - generate_params=GenerateParams(text="hi my friend", min_probability=0.9, max_length=1, model='any'), - model_cache=MODEL_CACHE, - ) + params = GenerateParams(text="hi my friend", min_probability=0.9, max_length=1, model='any') + if VLLM_AVAILABLE: + inference_result = fast_vllm_inference(params, MODEL_CACHE) + else: + inference_result = fast_inference(params, MODEL_CACHE) return JSONResponse(inference_result) @@ -856,7 +858,10 @@ async def generate_route( # status_code=401, detail="Please subscribe at https://text-generator.io/subscribe first" # ) # todo validate api key and user - inference_result = fast_inference(generate_params, MODEL_CACHE) + if VLLM_AVAILABLE: + inference_result = fast_vllm_inference(generate_params, MODEL_CACHE) + else: + inference_result = fast_inference(generate_params, MODEL_CACHE) # todo vuln if request and background_tasks: if "X-Rapid-API-Key" not in request.headers and "x-rapid-api-key" not in request.headers: @@ -913,7 +918,10 @@ async def generate_route_bulk( return HTTPException(status_code=400, detail=validation_result) for generate_params in bulk_params: - inference_result = fast_inference(generate_params) + if VLLM_AVAILABLE: + inference_result = fast_vllm_inference(generate_params) + else: + inference_result = fast_inference(generate_params) inference_results.append(inference_result) # todo vuln if request and background_tasks: @@ -989,7 +997,10 @@ async def openai_route_named( status_code=401, detail="Please subscribe at https://text-generator.io/subscribe first, also ensure you have a credit card on file" ) - inference_result = fast_inference(generate_params, MODEL_CACHE) + if VLLM_AVAILABLE: + inference_result = fast_vllm_inference(generate_params, MODEL_CACHE) + else: + inference_result = fast_inference(generate_params, MODEL_CACHE) if not openai_params.echo: ## remove all the inputs from the generated texts for i in range(len(inference_result)): diff --git a/questions/vllm_inference.py b/questions/vllm_inference.py new file mode 100644 index 0000000..8c70d79 --- /dev/null +++ b/questions/vllm_inference.py @@ -0,0 +1,88 @@ +import math +from typing import List + +from nltk import sent_tokenize +from loguru import logger + +from questions.models import GenerateParams +from questions.fixtures import set_stop_reason, get_stop_reason +from questions.post_process_results import post_process_results +from questions.constants import weights_path_tgz +from questions.utils import log_time +from questions.inference_server.model_cache import ModelCache + +try: + from vllm import LLM, SamplingParams + VLLM_AVAILABLE = True +except Exception: # pragma: no cover - vllm may not be installed + LLM = None + SamplingParams = None + VLLM_AVAILABLE = False + + +def load_vllm_model(model_path: str = weights_path_tgz): + if not VLLM_AVAILABLE: + raise RuntimeError("vllm is not installed") + return LLM(model=model_path) + + +def _apply_custom_stop(text: str, logprobs: List[float], generate_params: GenerateParams) -> str: + """Apply custom stopping conditions on generated text.""" + cumulative_prob = 1.0 + output = "" + for token_text, logprob in zip(text.split(), logprobs): + output += token_text + " " + if generate_params.min_probability and logprob is not None: + cumulative_prob *= math.exp(logprob) + if cumulative_prob < generate_params.min_probability: + set_stop_reason("min_probability") + return output.strip() + if generate_params.max_sentences: + if len(sent_tokenize(output)) > generate_params.max_sentences: + set_stop_reason("max_sentences") + return output.strip() + return output.strip() + + +def fast_vllm_inference(generate_params: GenerateParams, model_cache: ModelCache = None): + """Run inference with vLLM and apply custom stopping criteria.""" + if not VLLM_AVAILABLE: + raise RuntimeError("vllm is not installed") + + llm = None + if model_cache is not None: + llm = model_cache.add_or_get("vllm_model", lambda: load_vllm_model()) + else: + llm = load_vllm_model() + + sampling_params = SamplingParams( + temperature=generate_params.temperature, + top_p=generate_params.top_p, + top_k=generate_params.top_k, + max_tokens=generate_params.max_length, + stop=generate_params.stop_sequences or None, + repetition_penalty=generate_params.repetition_penalty, + ) + + with log_time("vllm_generate"): + outputs = llm.generate([generate_params.text], sampling_params) + + results = [] + for output in outputs: + candidate = output.outputs[0] + generated = _apply_custom_stop(candidate.text, candidate.logprobs, generate_params) + full_text = generate_params.text + generated + results.append({"generated_text": full_text, "stop_reason": get_stop_reason()}) + + processed = post_process_results( + [r["generated_text"] for r in results], + generate_params, + generate_params.text, + generate_params.text, + ) + + final = [] + for result, processed_text in zip(results, processed): + result["generated_text"] = processed_text + final.append(result) + return final diff --git a/tests/conftest.py b/tests/conftest.py index 87963c8..c1b0bc6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,14 @@ -import httpx +try: + import httpx +except ImportError: # pragma: no cover - optional dependency + httpx = None -_old_init = httpx.Client.__init__ +if httpx: + _old_init = httpx.Client.__init__ -def fixed_init(self, *args, **kwargs): - # Remove 'app' from kwargs if present - kwargs.pop('app', None) - _old_init(self, *args, **kwargs) + def fixed_init(self, *args, **kwargs): + # Remove 'app' from kwargs if present + kwargs.pop('app', None) + _old_init(self, *args, **kwargs) -httpx.Client.__init__ = fixed_init \ No newline at end of file + httpx.Client.__init__ = fixed_init diff --git a/tests/unit/test_vllm_inference.py b/tests/unit/test_vllm_inference.py new file mode 100644 index 0000000..9419105 --- /dev/null +++ b/tests/unit/test_vllm_inference.py @@ -0,0 +1,12 @@ +import pytest + +from questions.vllm_inference import VLLM_AVAILABLE, fast_vllm_inference +from questions.models import GenerateParams +from questions.inference_server.model_cache import ModelCache + + +@pytest.mark.skipif(VLLM_AVAILABLE, reason="vllm available - skip lightweight test") +def test_vllm_missing(): + params = GenerateParams(text="hi") + with pytest.raises(RuntimeError): + fast_vllm_inference(params, ModelCache())