Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions questions/inference_server/inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
fast_inference,
fast_feature_extract_inference,
)
from questions.vllm_inference import fast_vllm_inference, VLLM_AVAILABLE
from questions.utils import log_time
from sellerinfo import session_secret
from .models import build_model
Expand Down Expand Up @@ -478,10 +479,11 @@ async def feature_extraction(
@app.get("/liveness_check")
async def liveness_check(request: Request):
# global daemon
inference_result = fast_inference(
generate_params=GenerateParams(text="hi my friend", min_probability=0.9, max_length=1, model='any'),
model_cache=MODEL_CACHE,
)
params = GenerateParams(text="hi my friend", min_probability=0.9, max_length=1, model='any')
if VLLM_AVAILABLE:
inference_result = fast_vllm_inference(params, MODEL_CACHE)
else:
inference_result = fast_inference(params, MODEL_CACHE)
return JSONResponse(inference_result)


Expand Down Expand Up @@ -856,7 +858,10 @@ async def generate_route(
# status_code=401, detail="Please subscribe at https://text-generator.io/subscribe first"
# )
# todo validate api key and user
inference_result = fast_inference(generate_params, MODEL_CACHE)
if VLLM_AVAILABLE:
inference_result = fast_vllm_inference(generate_params, MODEL_CACHE)
else:
inference_result = fast_inference(generate_params, MODEL_CACHE)
# todo vuln
if request and background_tasks:
if "X-Rapid-API-Key" not in request.headers and "x-rapid-api-key" not in request.headers:
Expand Down Expand Up @@ -913,7 +918,10 @@ async def generate_route_bulk(
return HTTPException(status_code=400, detail=validation_result)

for generate_params in bulk_params:
inference_result = fast_inference(generate_params)
if VLLM_AVAILABLE:
inference_result = fast_vllm_inference(generate_params)
else:
inference_result = fast_inference(generate_params)
inference_results.append(inference_result)
# todo vuln
if request and background_tasks:
Expand Down Expand Up @@ -989,7 +997,10 @@ async def openai_route_named(
status_code=401,
detail="Please subscribe at https://text-generator.io/subscribe first, also ensure you have a credit card on file"
)
inference_result = fast_inference(generate_params, MODEL_CACHE)
if VLLM_AVAILABLE:
inference_result = fast_vllm_inference(generate_params, MODEL_CACHE)
else:
inference_result = fast_inference(generate_params, MODEL_CACHE)
if not openai_params.echo:
## remove all the inputs from the generated texts
for i in range(len(inference_result)):
Expand Down
88 changes: 88 additions & 0 deletions questions/vllm_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import math
from typing import List

from nltk import sent_tokenize
from loguru import logger

from questions.models import GenerateParams
from questions.fixtures import set_stop_reason, get_stop_reason
from questions.post_process_results import post_process_results
from questions.constants import weights_path_tgz
from questions.utils import log_time
from questions.inference_server.model_cache import ModelCache

try:
from vllm import LLM, SamplingParams
VLLM_AVAILABLE = True
except Exception: # pragma: no cover - vllm may not be installed
LLM = None
SamplingParams = None
VLLM_AVAILABLE = False


def load_vllm_model(model_path: str = weights_path_tgz):
if not VLLM_AVAILABLE:
raise RuntimeError("vllm is not installed")
return LLM(model=model_path)


def _apply_custom_stop(text: str, logprobs: List[float], generate_params: GenerateParams) -> str:
"""Apply custom stopping conditions on generated text."""
cumulative_prob = 1.0
output = ""
for token_text, logprob in zip(text.split(), logprobs):
output += token_text + " "
if generate_params.min_probability and logprob is not None:
cumulative_prob *= math.exp(logprob)
if cumulative_prob < generate_params.min_probability:
set_stop_reason("min_probability")
return output.strip()
if generate_params.max_sentences:
if len(sent_tokenize(output)) > generate_params.max_sentences:
set_stop_reason("max_sentences")
return output.strip()
return output.strip()


def fast_vllm_inference(generate_params: GenerateParams, model_cache: ModelCache = None):
"""Run inference with vLLM and apply custom stopping criteria."""
if not VLLM_AVAILABLE:
raise RuntimeError("vllm is not installed")

llm = None
if model_cache is not None:
llm = model_cache.add_or_get("vllm_model", lambda: load_vllm_model())
else:
llm = load_vllm_model()

sampling_params = SamplingParams(
temperature=generate_params.temperature,
top_p=generate_params.top_p,
top_k=generate_params.top_k,
max_tokens=generate_params.max_length,
stop=generate_params.stop_sequences or None,
repetition_penalty=generate_params.repetition_penalty,
)

with log_time("vllm_generate"):
outputs = llm.generate([generate_params.text], sampling_params)

results = []
for output in outputs:
candidate = output.outputs[0]
generated = _apply_custom_stop(candidate.text, candidate.logprobs, generate_params)
full_text = generate_params.text + generated
results.append({"generated_text": full_text, "stop_reason": get_stop_reason()})

processed = post_process_results(
[r["generated_text"] for r in results],
generate_params,
generate_params.text,
generate_params.text,
)

final = []
for result, processed_text in zip(results, processed):
result["generated_text"] = processed_text
final.append(result)
return final
18 changes: 11 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import httpx
try:
import httpx
except ImportError: # pragma: no cover - optional dependency
httpx = None

_old_init = httpx.Client.__init__
if httpx:
_old_init = httpx.Client.__init__

def fixed_init(self, *args, **kwargs):
# Remove 'app' from kwargs if present
kwargs.pop('app', None)
_old_init(self, *args, **kwargs)
def fixed_init(self, *args, **kwargs):
# Remove 'app' from kwargs if present
kwargs.pop('app', None)
_old_init(self, *args, **kwargs)

httpx.Client.__init__ = fixed_init
httpx.Client.__init__ = fixed_init
12 changes: 12 additions & 0 deletions tests/unit/test_vllm_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import pytest

from questions.vllm_inference import VLLM_AVAILABLE, fast_vllm_inference
from questions.models import GenerateParams
from questions.inference_server.model_cache import ModelCache


@pytest.mark.skipif(VLLM_AVAILABLE, reason="vllm available - skip lightweight test")
def test_vllm_missing():
params = GenerateParams(text="hi")
with pytest.raises(RuntimeError):
fast_vllm_inference(params, ModelCache())