From 8efd95f5766316e72206e94f9be8327e2d6e0cd9 Mon Sep 17 00:00:00 2001 From: PR Bot Date: Tue, 24 Mar 2026 15:36:08 +0800 Subject: [PATCH] Add configurable LLM provider for prompt rewriter with MiniMax support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The prompt rewriter previously used a hardcoded Azure OpenAI endpoint (ByteDance internal), making it unusable for external users. This adds a configurable LLM provider system (tools/llm_provider.py) that supports OpenAI, Azure OpenAI, MiniMax, and any OpenAI-compatible API via environment variables. Provider auto-detection, temperature clamping for MiniMax, and think-tag stripping are included. - New: tools/llm_provider.py — configurable LLM provider factory - Updated: tools/prompt_rewriter.py — uses new provider system - Updated: README.md — documents LLM provider configuration - New: 22 unit tests + 3 integration tests --- README.md | 28 ++- tests/conftest.py | 31 +++ tests/test_llm_provider.py | 270 +++++++++++++++++++++++++ tests/test_llm_provider_integration.py | 74 +++++++ tools/llm_provider.py | 169 ++++++++++++++++ tools/prompt_rewriter.py | 66 +++--- 6 files changed, 601 insertions(+), 37 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/test_llm_provider.py create mode 100644 tests/test_llm_provider_integration.py create mode 100644 tools/llm_provider.py diff --git a/README.md b/README.md index 8d5df0e9..03f8206a 100644 --- a/README.md +++ b/README.md @@ -100,7 +100,7 @@ We provide Infinity models for you to play with, which are on " in content: + import re + content = re.sub(r".*?\s*", "", content, flags=re.DOTALL).strip() + + return content diff --git a/tools/prompt_rewriter.py b/tools/prompt_rewriter.py index fe18a503..9402394f 100644 --- a/tools/prompt_rewriter.py +++ b/tools/prompt_rewriter.py @@ -15,17 +15,17 @@ from PIL import Image import openai -from conf import GPT_AK +from tools.llm_provider import chat_completion, get_llm_client def encode_image(image_path, size=(512, 512)): """ Resize an image and encode it as a Base64 string. - + Args: - image_path (str): Path to the image file. - size (tuple): New size as a tuple, (width, height). - + Returns: - str: Base64 encoded string of the resized image. """ @@ -42,8 +42,8 @@ def encode_image(image_path, size=(512, 512)): SYSTEM = """ -You are part of a team of bots that creates images. You work with an assistant bot that will draw anything you say. -For example, outputting the prompt and parameters like "" will trigger your partner bot to output an image of a forest morning, as described. +You are part of a team of bots that creates images. You work with an assistant bot that will draw anything you say. +For example, outputting the prompt and parameters like "" will trigger your partner bot to output an image of a forest morning, as described. You will be prompted by users looking to create detailed, amazing images. The way to accomplish this is to refine their short prompts and make them extremely detailed and descriptive. - You will only ever output a single image description sentence per user request. - Each image description sentence should be consist of "", where is the image description, is the parameter that control the image generation. @@ -77,48 +77,42 @@ def __init__(self, system, few_shot_history): def rewrite(self, prompt): messages = self.system + self.few_shot_history + [{"role": "user", "content": prompt}] - result, _ = get_gpt_result(model_name='gpt-4o-2024-08-06', messages=messages, retry=5, ak=GPT_AK, return_json=False) + result, _ = get_gpt_result(messages=messages, retry=5, return_json=False) assert result return result -def get_gpt_result(model_name='gpt-4o-2024-05-13', messages=None, retry=5, ak=None, return_json=False): +def get_gpt_result(model_name=None, messages=None, retry=5, ak=None, return_json=False): """ - Retrieves a chat response using the GPT-4 model. - Args: - model_name (str, optional): The name of the GPT model to use. Defaults to 'gpt-4'. [gpt-3.5-turbo, gpt-4] - retry (int, optional): The number of times to retry the chat API if there is an error. Defaults to 5. - Returns: - tuple: A tuple containing the chat response content (str) and the API usage (dict). - Raises: - Exception: If there is an error retrieving the chat response. + Retrieves a chat response using the configured LLM provider. + + The provider is determined by environment variables (see tools/llm_provider.py). + Supports OpenAI, Azure OpenAI, MiniMax, and any OpenAI-compatible API. + + Args: + model_name (str, optional): The model to use. If None, uses provider default. + messages (list): Chat messages. + retry (int): Number of retry attempts on failure. + ak (str, optional): API key override (deprecated, use env vars instead). + return_json (bool): Whether to request JSON response format. + + Returns: + tuple: (response_content, None) on success, (None, -1) on failure. """ - openai_ak = ak - client = openai.AzureOpenAI( - azure_endpoint="https://search-va.byteintl.net/gpt/openapi/online/multimodal/crawl", - api_version="2023-07-01-preview", - api_key=openai_ak - ) for i in range(retry): try: - if return_json: - completion = client.chat.completions.create( - model=model_name, - messages=messages, - response_format={ "type": "json_object" }, - ) - else: - completion = client.chat.completions.create( - model=model_name, - messages=messages, - ) - result = json.loads(completion.model_dump_json())['choices'][0]['message']['content'] - return result,None + result = chat_completion( + messages=messages, + model=model_name, + api_key=ak if ak else None, + return_json=return_json, + ) + return result, None except Exception as e: traceback.print_exc() - if isinstance(e,KeyboardInterrupt): + if isinstance(e, KeyboardInterrupt): exit(0) - sleep_time = 10 + random.randint(2,5)**(i+1) + sleep_time = 10 + random.randint(2, 5) ** (i + 1) time.sleep(sleep_time) return None, -1