diff --git a/README.md b/README.md index 8d5df0e9..03f8206a 100644 --- a/README.md +++ b/README.md @@ -100,7 +100,7 @@ We provide Infinity models for you to play with, which are on " in content: + import re + content = re.sub(r".*?\s*", "", content, flags=re.DOTALL).strip() + + return content diff --git a/tools/prompt_rewriter.py b/tools/prompt_rewriter.py index fe18a503..9402394f 100644 --- a/tools/prompt_rewriter.py +++ b/tools/prompt_rewriter.py @@ -15,17 +15,17 @@ from PIL import Image import openai -from conf import GPT_AK +from tools.llm_provider import chat_completion, get_llm_client def encode_image(image_path, size=(512, 512)): """ Resize an image and encode it as a Base64 string. - + Args: - image_path (str): Path to the image file. - size (tuple): New size as a tuple, (width, height). - + Returns: - str: Base64 encoded string of the resized image. """ @@ -42,8 +42,8 @@ def encode_image(image_path, size=(512, 512)): SYSTEM = """ -You are part of a team of bots that creates images. You work with an assistant bot that will draw anything you say. -For example, outputting the prompt and parameters like "" will trigger your partner bot to output an image of a forest morning, as described. +You are part of a team of bots that creates images. You work with an assistant bot that will draw anything you say. +For example, outputting the prompt and parameters like "" will trigger your partner bot to output an image of a forest morning, as described. You will be prompted by users looking to create detailed, amazing images. The way to accomplish this is to refine their short prompts and make them extremely detailed and descriptive. - You will only ever output a single image description sentence per user request. - Each image description sentence should be consist of "", where is the image description, is the parameter that control the image generation. @@ -77,48 +77,42 @@ def __init__(self, system, few_shot_history): def rewrite(self, prompt): messages = self.system + self.few_shot_history + [{"role": "user", "content": prompt}] - result, _ = get_gpt_result(model_name='gpt-4o-2024-08-06', messages=messages, retry=5, ak=GPT_AK, return_json=False) + result, _ = get_gpt_result(messages=messages, retry=5, return_json=False) assert result return result -def get_gpt_result(model_name='gpt-4o-2024-05-13', messages=None, retry=5, ak=None, return_json=False): +def get_gpt_result(model_name=None, messages=None, retry=5, ak=None, return_json=False): """ - Retrieves a chat response using the GPT-4 model. - Args: - model_name (str, optional): The name of the GPT model to use. Defaults to 'gpt-4'. [gpt-3.5-turbo, gpt-4] - retry (int, optional): The number of times to retry the chat API if there is an error. Defaults to 5. - Returns: - tuple: A tuple containing the chat response content (str) and the API usage (dict). - Raises: - Exception: If there is an error retrieving the chat response. + Retrieves a chat response using the configured LLM provider. + + The provider is determined by environment variables (see tools/llm_provider.py). + Supports OpenAI, Azure OpenAI, MiniMax, and any OpenAI-compatible API. + + Args: + model_name (str, optional): The model to use. If None, uses provider default. + messages (list): Chat messages. + retry (int): Number of retry attempts on failure. + ak (str, optional): API key override (deprecated, use env vars instead). + return_json (bool): Whether to request JSON response format. + + Returns: + tuple: (response_content, None) on success, (None, -1) on failure. """ - openai_ak = ak - client = openai.AzureOpenAI( - azure_endpoint="https://search-va.byteintl.net/gpt/openapi/online/multimodal/crawl", - api_version="2023-07-01-preview", - api_key=openai_ak - ) for i in range(retry): try: - if return_json: - completion = client.chat.completions.create( - model=model_name, - messages=messages, - response_format={ "type": "json_object" }, - ) - else: - completion = client.chat.completions.create( - model=model_name, - messages=messages, - ) - result = json.loads(completion.model_dump_json())['choices'][0]['message']['content'] - return result,None + result = chat_completion( + messages=messages, + model=model_name, + api_key=ak if ak else None, + return_json=return_json, + ) + return result, None except Exception as e: traceback.print_exc() - if isinstance(e,KeyboardInterrupt): + if isinstance(e, KeyboardInterrupt): exit(0) - sleep_time = 10 + random.randint(2,5)**(i+1) + sleep_time = 10 + random.randint(2, 5) ** (i + 1) time.sleep(sleep_time) return None, -1