From ee81afeecbcdf1178e5756f867e443e7b83d0333 Mon Sep 17 00:00:00 2001 From: Luke Garceau Date: Mon, 9 Mar 2026 16:04:20 -0400 Subject: [PATCH] feat: add model selector modal and enhance OAuth prompt Adds a model helper tooltip/modal with two paths for model discovery: - Query with an assisted prompt to get model suggestions based on use case - Browse by common use cases for model inspiration Also fixes security issue with dedented else statements in oauth.py and enhances the model suggestion prompt. Co-Authored-By: Claude Opus 4.6 --- backend/open_webui/constants.py | 1 + backend/open_webui/routers/tasks.py | 91 + backend/open_webui/utils/task.py | 9 + functions/pipes/maintained-gemini-pipe.py | 3331 +++++++++++++++++ src/lib/apis/index.ts | 65 + .../components/chat/ModelHelperModal.svelte | 423 +++ src/lib/components/chat/Navbar.svelte | 21 +- 7 files changed, 3939 insertions(+), 2 deletions(-) create mode 100644 functions/pipes/maintained-gemini-pipe.py create mode 100644 src/lib/components/chat/ModelHelperModal.svelte diff --git a/backend/open_webui/constants.py b/backend/open_webui/constants.py index 4d39d16cdb9..ba82a34290f 100644 --- a/backend/open_webui/constants.py +++ b/backend/open_webui/constants.py @@ -124,3 +124,4 @@ def __str__(self) -> str: AUTOCOMPLETE_GENERATION = "autocomplete_generation" FUNCTION_CALLING = "function_calling" MOA_RESPONSE_GENERATION = "moa_response_generation" + MODEL_RECOMMENDATION = "model_recommendation" diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py index f6404da05cb..ed52fcaa99c 100644 --- a/backend/open_webui/routers/tasks.py +++ b/backend/open_webui/routers/tasks.py @@ -16,6 +16,7 @@ tags_generation_template, emoji_generation_template, moa_response_generation_template, + model_recommendation_template, ) from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.constants import TASKS @@ -762,3 +763,93 @@ async def generate_moa_response( status_code=status.HTTP_400_BAD_REQUEST, content={"detail": str(e)}, ) + + +DEFAULT_MODEL_RECOMMENDATION_PROMPT_TEMPLATE = """Task: Given a user's description and a list of available models, return JSON recommending 1-3 models. + +User wants to: {{TASK_DESCRIPTION}} + +Available models: +{{MODELS_LIST}} + +Rules: +1. Models with type "custom_model_or_pipe" are pipes to external providers (e.g. Google Gemini, Anthropic Claude). Check their name and base_model_id to identify the provider. +2. Match task to model strengths. Coding -> code models. Images -> image generation models. Video -> video models. Writing -> large general models. +3. Prefer models whose name, description, or capabilities explicitly match the task. +4. Use EXACT model IDs from the list. Do not invent IDs. +5. Return 1-3 recommendations, best first. + +Return ONLY this JSON, no markdown, no explanation, no extra text before or after: +{"recommendations":[{"model_id":"exact_id","reason":"one sentence"}]} + +Example valid response: +{"recommendations":[{"model_id":"gpt-4o","reason":"Strong general-purpose model with vision capabilities"}]} + +BEGIN JSON RESPONSE:""" + + +@router.post("/model_recommendation/completions") +async def generate_model_recommendation( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + models = request.app.state.MODELS + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + task_model_id = get_task_model_id( + model_id, + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + log.debug( + f"generating model recommendation using model {task_model_id} for user {user.email}" + ) + + task_description = form_data.get("task_description", "") + available_models = form_data.get("available_models", []) + + models_info_lines = [] + for m in available_models: + parts = [f"ID: {m['id']}"] + for key in ["name", "description", "owned_by", "type", "base_model_id", "capabilities", "system_prompt_hint"]: + if m.get(key): + parts.append(f"{key}: {m[key]}") + models_info_lines.append("- " + ", ".join(parts)) + models_info = "\n".join(models_info_lines) + + template = DEFAULT_MODEL_RECOMMENDATION_PROMPT_TEMPLATE + content = model_recommendation_template( + template, task_description, models_info, user + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "metadata": { + **(request.state.metadata if hasattr(request.state, "metadata") else {}), + "task": str(TASKS.MODEL_RECOMMENDATION), + "task_body": form_data, + }, + } + + try: + payload = await process_pipeline_inlet_filter(request, payload, user, models) + except Exception as e: + raise e + + try: + return await generate_chat_completion(request, form_data=payload, user=user) + except Exception as e: + log.error("Exception occurred", exc_info=True) + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": "An internal error has occurred."}, + ) diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index 0ea525c93e5..ce7c535d237 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -426,6 +426,15 @@ def replacement_function(match): return template +def model_recommendation_template( + template: str, task_description: str, models_info: str, user: Optional[Any] = None +) -> str: + template = template.replace("{{TASK_DESCRIPTION}}", task_description) + template = template.replace("{{MODELS_LIST}}", models_info) + template = prompt_template(template, user) + return template + + def tools_function_calling_generation_template(template: str, tools_specs: str) -> str: template = template.replace("{{TOOLS}}", tools_specs) return template diff --git a/functions/pipes/maintained-gemini-pipe.py b/functions/pipes/maintained-gemini-pipe.py new file mode 100644 index 00000000000..5400b5dd14d --- /dev/null +++ b/functions/pipes/maintained-gemini-pipe.py @@ -0,0 +1,3331 @@ +""" +title: Google Gemini Pipeline +author: owndev, olivier-lacroix +author_url: https://github.com/owndev/ +project_url: https://github.com/owndev/Open-WebUI-Functions +funding_url: https://github.com/sponsors/owndev +version: 1.13.0 +required_open_webui_version: 0.8.0 +license: Apache License 2.0 +description: Highly optimized Google Gemini pipeline with advanced image and video generation capabilities, intelligent compression, and streamlined processing workflows. +features: + - Optimized asynchronous API calls for maximum performance + - Intelligent model caching with configurable TTL + - Streamlined dynamic model specification with automatic prefix handling + - Smart streaming response handling with safety checks + - Advanced multimodal input support (text and images) + - Unified image generation and editing with Gemini 2.5 Flash Image Preview + - Intelligent image optimization with size-aware compression algorithms + - Automated image upload to Open WebUI with robust fallback support + - Optimized text-to-image and image-to-image workflows + - Non-streaming mode for image generation to prevent chunk overflow + - Progressive status updates for optimal user experience + - Consolidated error handling and comprehensive logging + - Seamless Google Generative AI and Vertex AI integration + - Advanced generation parameters (temperature, max tokens, etc.) + - Configurable safety settings with environment variable support + - Military-grade encrypted storage of sensitive API keys + - Intelligent grounding with Google search integration + - Vertex AI Search grounding for RAG + - Native tool calling support with automatic signature management + - URL context grounding for specified web pages + - Unified image processing with consolidated helper methods + - Optimized payload creation for image generation models + - Configurable image processing parameters (size, quality, compression) + - Flexible upload fallback options and optimization controls + - Configurable thinking levels (low/high) for Gemini 3 models + - Configurable thinking budgets (0-32768 tokens) for Gemini 2.5 models + - Configurable image generation aspect ratio (1:1, 16:9, etc.) and resolution (1K, 2K, 4K) + - Model whitelist for filtering available models + - Additional model support for SDK-unsupported models + - Video generation with Google Veo models (Veo 3.1, 3, 2) + - Configurable video generation parameters (aspect ratio, resolution, duration) + - Asynchronous video generation with progressive polling status updates + - Automatic video upload to Open WebUI with embedded playback + - Image-to-video generation support for Veo models + - Negative prompt and person generation controls for video +""" + +import os +import re +import time +import asyncio +import base64 +import hashlib +import logging +import io +import uuid +import aiofiles +from PIL import Image +from google import genai +from google.genai import types +from google.genai.errors import ClientError, ServerError, APIError +from typing import List, Union, Optional, Dict, Any, Tuple, AsyncIterator, Callable +from pydantic_core import core_schema +from pydantic import BaseModel, Field, GetCoreSchemaHandler +from cryptography.fernet import Fernet, InvalidToken +from open_webui.env import SRC_LOG_LEVELS +from fastapi import Request, UploadFile, BackgroundTasks +from open_webui.routers.files import upload_file +from open_webui.models.users import UserModel, Users +from starlette.datastructures import Headers + +ASPECT_RATIO_OPTIONS: List[str] = [ + "default", + "1:1", + "2:3", + "3:2", + "3:4", + "4:3", + "4:5", + "5:4", + "9:16", + "16:9", + "21:9", +] + +RESOLUTION_OPTIONS: List[str] = [ + "default", + "1K", + "2K", + "4K", +] + +VIDEO_ASPECT_RATIO_OPTIONS: List[str] = [ + "default", + "16:9", + "9:16", +] + +VIDEO_RESOLUTION_OPTIONS: List[str] = [ + "default", + "720p", + "1080p", + "4k", +] + +VIDEO_DURATION_OPTIONS: List[str] = [ + "default", + "4", + "5", + "6", + "8", +] + +VIDEO_PERSON_GENERATION_OPTIONS: List[str] = [ + "default", + "allow_all", + "allow_adult", + "dont_allow", +] + + +# Simplified encryption implementation with automatic handling +class EncryptedStr(str): + """A string type that automatically handles encryption/decryption""" + + @classmethod + def _get_encryption_key(cls) -> Optional[bytes]: + """ + Generate encryption key from WEBUI_SECRET_KEY if available + Returns None if no key is configured + """ + secret = os.getenv("WEBUI_SECRET_KEY") + if not secret: + return None + + hashed_key = hashlib.sha256(secret.encode()).digest() + return base64.urlsafe_b64encode(hashed_key) + + @classmethod + def encrypt(cls, value: str) -> str: + """ + Encrypt a string value if a key is available + Returns the original value if no key is available + """ + if not value or value.startswith("encrypted:"): + return value + + key = cls._get_encryption_key() + if not key: # No encryption if no key + return value + + f = Fernet(key) + encrypted = f.encrypt(value.encode()) + return f"encrypted:{encrypted.decode()}" + + @classmethod + def decrypt(cls, value: str) -> str: + """ + Decrypt an encrypted string value if a key is available + Returns the original value if no key is available or decryption fails + """ + if not value or not value.startswith("encrypted:"): + return value + + key = cls._get_encryption_key() + if not key: # No decryption if no key + return value[len("encrypted:") :] # Return without prefix + + try: + encrypted_part = value[len("encrypted:") :] + f = Fernet(key) + decrypted = f.decrypt(encrypted_part.encode()) + return decrypted.decode() + except (InvalidToken, Exception): + return value + + # Pydantic integration + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + return core_schema.union_schema( + [ + core_schema.is_instance_schema(cls), + core_schema.chain_schema( + [ + core_schema.str_schema(), + core_schema.no_info_plain_validator_function( + lambda value: cls(cls.encrypt(value) if value else value) + ), + ] + ), + ], + serialization=core_schema.plain_serializer_function_ser_schema( + lambda instance: str(instance) + ), + ) + + +class Pipe: + """ + Pipeline for interacting with Google Gemini models. + """ + + # User-overridable configuration valves + class UserValves(BaseModel): + IMAGE_GENERATION_ASPECT_RATIO: str = Field( + default=os.getenv("GOOGLE_IMAGE_GENERATION_ASPECT_RATIO", "default"), + description="Default aspect ratio for image generation.", + json_schema_extra={"enum": ASPECT_RATIO_OPTIONS}, + ) + IMAGE_GENERATION_RESOLUTION: str = Field( + default=os.getenv("GOOGLE_IMAGE_GENERATION_RESOLUTION", "default"), + description="Default resolution for image generation.", + json_schema_extra={"enum": RESOLUTION_OPTIONS}, + ) + VIDEO_GENERATION_ASPECT_RATIO: str = Field( + default=os.getenv("GOOGLE_VIDEO_GENERATION_ASPECT_RATIO", "default"), + description="Default aspect ratio for video generation (16:9 landscape or 9:16 portrait).", + json_schema_extra={"enum": VIDEO_ASPECT_RATIO_OPTIONS}, + ) + VIDEO_GENERATION_RESOLUTION: str = Field( + default=os.getenv("GOOGLE_VIDEO_GENERATION_RESOLUTION", "default"), + description="Default resolution for video generation (720p, 1080p, or 4k).", + json_schema_extra={"enum": VIDEO_RESOLUTION_OPTIONS}, + ) + VIDEO_GENERATION_DURATION: str = Field( + default=os.getenv("GOOGLE_VIDEO_GENERATION_DURATION", "default"), + description="Default duration in seconds for video generation (4, 5, 6, or 8 - availability varies by model).", + json_schema_extra={"enum": VIDEO_DURATION_OPTIONS}, + ) + + # Configuration valves for the pipeline + class Valves(BaseModel): + BASE_URL: str = Field( + default=os.getenv( + "GOOGLE_GENAI_BASE_URL", "https://generativelanguage.googleapis.com/" + ), + description="Base URL for the Google Generative AI API.", + ) + GOOGLE_API_KEY: EncryptedStr = Field( + default=os.getenv("GOOGLE_API_KEY", ""), + description="API key for Google Generative AI (used if USE_VERTEX_AI is false).", + json_schema_extra={"input": {"type": "password"}}, + ) + API_VERSION: str = Field( + default=os.getenv("GOOGLE_API_VERSION", "v1alpha"), + description="API version to use for Google Generative AI (e.g., v1alpha, v1beta, v1).", + ) + STREAMING_ENABLED: bool = Field( + default=os.getenv("GOOGLE_STREAMING_ENABLED", "true").lower() == "true", + description="Enable streaming responses (set false to force non-streaming mode).", + ) + INCLUDE_THOUGHTS: bool = Field( + default=os.getenv("GOOGLE_INCLUDE_THOUGHTS", "true").lower() == "true", + description="Enable Gemini thoughts outputs (set false to disable).", + ) + THINKING_BUDGET: int = Field( + default=int(os.getenv("GOOGLE_THINKING_BUDGET", "-1")), + description="Thinking budget for Gemini 2.5 models (0=disabled, -1=dynamic, 1-32768=fixed token limit). " + "Not used for Gemini 3 models which use THINKING_LEVEL instead.", + ) + THINKING_LEVEL: str = Field( + default=os.getenv("GOOGLE_THINKING_LEVEL", ""), + description="Thinking level for Gemini 3 models ('minimal', 'low', 'medium', or 'high'). " + "Ignored for other models. Empty string means use model default.", + ) + USE_VERTEX_AI: bool = Field( + default=os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "false").lower() == "true", + description="Whether to use Google Cloud Vertex AI instead of the Google Generative AI API.", + ) + VERTEX_PROJECT: str | None = Field( + default=os.getenv("GOOGLE_CLOUD_PROJECT"), + description="The Google Cloud project ID to use with Vertex AI.", + ) + VERTEX_LOCATION: str = Field( + default=os.getenv("GOOGLE_CLOUD_LOCATION", "global"), + description="The Google Cloud region to use with Vertex AI.", + ) + VERTEX_AI_RAG_STORE: str | None = Field( + default=os.getenv("GOOGLE_VERTEX_AI_RAG_STORE"), + description="Vertex AI RAG Store path for grounding (e.g., projects/PROJECT/locations/LOCATION/ragCorpora/DATA_STORE_ID). Only used when USE_VERTEX_AI is true.", + ) + USE_PERMISSIVE_SAFETY: bool = Field( + default=os.getenv("GOOGLE_USE_PERMISSIVE_SAFETY", "false").lower() + == "true", + description="Use permissive safety settings for content generation.", + ) + MODEL_CACHE_TTL: int = Field( + default=int(os.getenv("GOOGLE_MODEL_CACHE_TTL", "600")), + description="Time in seconds to cache the model list before refreshing", + ) + RETRY_COUNT: int = Field( + default=int(os.getenv("GOOGLE_RETRY_COUNT", "2")), + description="Number of times to retry API calls on temporary failures", + ) + DEFAULT_SYSTEM_PROMPT: str = Field( + default=os.getenv("GOOGLE_DEFAULT_SYSTEM_PROMPT", ""), + description="Default system prompt applied to all chats. If a user-defined system prompt exists, " + "this is prepended to it. Leave empty to disable.", + ) + ENABLE_FORWARD_USER_INFO_HEADERS: bool = Field( + default=os.getenv( + "GOOGLE_ENABLE_FORWARD_USER_INFO_HEADERS", "false" + ).lower() + == "true", + description="Whether to forward user information headers.", + ) + MODEL_ADDITIONAL: str = Field( + default=os.getenv("GOOGLE_MODEL_ADDITIONAL", ""), + description="A comma-separated list of model IDs to manually add to the list of available models. " + "These are models not returned by the SDK but that you want to make available. " + "Non-Gemini model IDs must be explicitly included in MODEL_WHITELIST to be available.", + ) + MODEL_WHITELIST: str = Field( + default=os.getenv("GOOGLE_MODEL_WHITELIST", ""), + description="A comma-separated list of model IDs to show in the models list. " + "If set, only these models will be available (after MODEL_ADDITIONAL is applied). " + "Leave empty to show all models.", + ) + + # Image Processing Configuration + IMAGE_GENERATION_ASPECT_RATIO: str = Field( + default=os.getenv("GOOGLE_IMAGE_GENERATION_ASPECT_RATIO", "default"), + description="Default aspect ratio for image generation.", + json_schema_extra={"enum": ASPECT_RATIO_OPTIONS}, + ) + IMAGE_GENERATION_RESOLUTION: str = Field( + default=os.getenv("GOOGLE_IMAGE_GENERATION_RESOLUTION", "default"), + description="Default resolution for image generation.", + json_schema_extra={"enum": RESOLUTION_OPTIONS}, + ) + IMAGE_MAX_SIZE_MB: float = Field( + default=float(os.getenv("GOOGLE_IMAGE_MAX_SIZE_MB", "15.0")), + description="Maximum image size in MB before compression is applied", + ) + IMAGE_MAX_DIMENSION: int = Field( + default=int(os.getenv("GOOGLE_IMAGE_MAX_DIMENSION", "2048")), + description="Maximum width or height in pixels before resizing", + ) + IMAGE_COMPRESSION_QUALITY: int = Field( + default=int(os.getenv("GOOGLE_IMAGE_COMPRESSION_QUALITY", "85")), + description="JPEG compression quality (1-100, higher = better quality but larger size)", + ) + IMAGE_ENABLE_OPTIMIZATION: bool = Field( + default=os.getenv("GOOGLE_IMAGE_ENABLE_OPTIMIZATION", "true").lower() + == "true", + description="Enable intelligent image optimization for API compatibility", + ) + IMAGE_PNG_COMPRESSION_THRESHOLD_MB: float = Field( + default=float(os.getenv("GOOGLE_IMAGE_PNG_THRESHOLD_MB", "0.5")), + description="PNG files above this size (MB) will be converted to JPEG for better compression", + ) + IMAGE_HISTORY_MAX_REFERENCES: int = Field( + default=int(os.getenv("GOOGLE_IMAGE_HISTORY_MAX_REFERENCES", "5")), + description="Maximum total number of images (history + current message) to include in a generation call", + ) + IMAGE_ADD_LABELS: bool = Field( + default=os.getenv("GOOGLE_IMAGE_ADD_LABELS", "true").lower() == "true", + description="If true, add small text labels like [Image 1] before each image part so the model can reference them.", + ) + IMAGE_DEDUP_HISTORY: bool = Field( + default=os.getenv("GOOGLE_IMAGE_DEDUP_HISTORY", "true").lower() == "true", + description="If true, deduplicate identical images (by hash) when constructing history context", + ) + IMAGE_HISTORY_FIRST: bool = Field( + default=os.getenv("GOOGLE_IMAGE_HISTORY_FIRST", "true").lower() == "true", + description="If true (default), history images precede current message images; if false, current images first.", + ) + + # Video Generation Configuration (Veo models) + VIDEO_GENERATION_ASPECT_RATIO: str = Field( + default=os.getenv("GOOGLE_VIDEO_GENERATION_ASPECT_RATIO", "default"), + description="Default aspect ratio for video generation (16:9 landscape or 9:16 portrait).", + json_schema_extra={"enum": VIDEO_ASPECT_RATIO_OPTIONS}, + ) + VIDEO_GENERATION_RESOLUTION: str = Field( + default=os.getenv("GOOGLE_VIDEO_GENERATION_RESOLUTION", "default"), + description="Default resolution for video generation (720p, 1080p, or 4k).", + json_schema_extra={"enum": VIDEO_RESOLUTION_OPTIONS}, + ) + VIDEO_GENERATION_DURATION: str = Field( + default=os.getenv("GOOGLE_VIDEO_GENERATION_DURATION", "default"), + description="Default duration in seconds for video generation (4, 5, 6, or 8 - availability varies by model).", + json_schema_extra={"enum": VIDEO_DURATION_OPTIONS}, + ) + VIDEO_GENERATION_NEGATIVE_PROMPT: str = Field( + default=os.getenv("GOOGLE_VIDEO_GENERATION_NEGATIVE_PROMPT", ""), + description="Default negative prompt for video generation (describes what not to include).", + ) + VIDEO_GENERATION_PERSON_GENERATION: str = Field( + default=os.getenv("GOOGLE_VIDEO_GENERATION_PERSON_GENERATION", "default"), + description="Controls generation of people in videos (allow_all, allow_adult, dont_allow).", + json_schema_extra={"enum": VIDEO_PERSON_GENERATION_OPTIONS}, + ) + VIDEO_GENERATION_ENHANCE_PROMPT: bool = Field( + default=os.getenv("GOOGLE_VIDEO_GENERATION_ENHANCE_PROMPT", "true").lower() + == "true", + description="Enable prompt enhancement for video generation.", + ) + VIDEO_POLL_INTERVAL: int = Field( + default=int(os.getenv("GOOGLE_VIDEO_POLL_INTERVAL", "10")), + description="Polling interval in seconds when waiting for video generation to complete.", + ) + VIDEO_POLL_TIMEOUT: int = Field( + default=int(os.getenv("GOOGLE_VIDEO_POLL_TIMEOUT", "600")), + description="Maximum time in seconds to wait for video generation before timing out (0=no limit).", + ) + + # ---------------- Internal Helpers ---------------- # + async def _gather_history_images( + self, + messages: List[Dict[str, Any]], + last_user_msg: Dict[str, Any], + optimization_stats: List[Dict[str, Any]], + ) -> List[Dict[str, Any]]: + history_images: List[Dict[str, Any]] = [] + for msg in messages: + if msg is last_user_msg: + continue + if msg.get("role") not in {"user", "assistant"}: + continue + _p, parts = await self._extract_images_from_message( + msg, stats_list=optimization_stats + ) + if parts: + history_images.extend(parts) + return history_images + + def _deduplicate_images(self, images: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + if not self.valves.IMAGE_DEDUP_HISTORY: + return images + seen: set[str] = set() + result: List[Dict[str, Any]] = [] + for part in images: + try: + data = part["inline_data"]["data"] + # Hash full base64 payload for stronger dedup reliability + h = hashlib.sha256(data.encode()).hexdigest() + if h in seen: + continue + seen.add(h) + except Exception as e: + # Skip images with malformed or missing data, but log for debugging. + self.log.debug(f"Skipping image in deduplication due to error: {e}") + result.append(part) + return result + + def _combine_system_prompts( + self, user_system_prompt: Optional[str] + ) -> Optional[str]: + """Combine default system prompt with user-defined system prompt. + + If DEFAULT_SYSTEM_PROMPT is set and user_system_prompt exists, + the default is prepended to the user's prompt. + If only DEFAULT_SYSTEM_PROMPT is set, it is used as the system prompt. + If only user_system_prompt is set, it is used as-is. + + Args: + user_system_prompt: The user-defined system prompt from messages (may be None) + + Returns: + Combined system prompt or None if neither is set + """ + default_prompt = self.valves.DEFAULT_SYSTEM_PROMPT.strip() + user_prompt = user_system_prompt.strip() if user_system_prompt else "" + + if default_prompt and user_prompt: + combined = f"{default_prompt}\n\n{user_prompt}" + self.log.debug( + f"Combined system prompts: default ({len(default_prompt)} chars) + " + f"user ({len(user_prompt)} chars) = {len(combined)} chars" + ) + return combined + elif default_prompt: + self.log.debug(f"Using default system prompt ({len(default_prompt)} chars)") + return default_prompt + elif user_prompt: + return user_prompt + return None + + def _apply_order_and_limit( + self, + history: List[Dict[str, Any]], + current: List[Dict[str, Any]], + ) -> Tuple[List[Dict[str, Any]], List[bool]]: + """Combine history & current image parts honoring order & global limit. + + Returns: + (combined_parts, reused_flags) where reused_flags[i] == True indicates + the image originated from history, False if from current message. + """ + history_first = self.valves.IMAGE_HISTORY_FIRST + limit = max(1, self.valves.IMAGE_HISTORY_MAX_REFERENCES) + combined: List[Dict[str, Any]] = [] + reused_flags: List[bool] = [] + + def append(parts: List[Dict[str, Any]], reused: bool): + for p in parts: + if len(combined) >= limit: + break + combined.append(p) + reused_flags.append(reused) + + if history_first: + append(history, True) + append(current, False) + else: + append(current, False) + append(history, True) + return combined, reused_flags + + async def _emit_image_stats( + self, + ordered_stats: List[Dict[str, Any]], + reused_flags: List[bool], + total_limit: int, + __event_emitter__: Callable, + ) -> None: + """Emit per-image optimization stats aligned with final combined order. + + ordered_stats: stats list in the exact order images will be sent (same length as combined image list) + reused_flags: parallel list indicating whether image originated from history + """ + if not ordered_stats: + return + for idx, stat in enumerate(ordered_stats, start=1): + reused = reused_flags[idx - 1] if idx - 1 < len(reused_flags) else False + stat_copy = dict(stat) if stat else {} + stat_copy.update({"index": idx, "reused": reused}) + if stat and stat.get("original_size_mb") is not None: + desc = f"Image {idx}: {stat['original_size_mb']:.2f}MB -> {stat['final_size_mb']:.2f}MB" + if stat.get("quality") is not None: + desc += f" (Q{stat['quality']})" + else: + desc = f"Image {idx}: (no metrics)" + reasons = stat.get("reasons") if stat else None + if reasons: + desc += " | " + ", ".join(reasons[:3]) + await __event_emitter__( + { + "type": "status", + "data": { + "action": "image_optimization", + "description": desc, + "index": idx, + "done": False, + "details": stat_copy, + }, + } + ) + await __event_emitter__( + { + "type": "status", + "data": { + "action": "image_optimization", + "description": f"{len(ordered_stats)} image(s) processed (limit {total_limit}).", + "done": True, + }, + } + ) + + async def _build_image_generation_contents( + self, + messages: List[Dict[str, Any]], + __event_emitter__: Callable, + ) -> Tuple[List[Dict[str, Any]], Optional[str]]: + """Construct the contents payload for image-capable models. + + Returns tuple (contents, system_instruction) where system_instruction is extracted from system messages. + """ + # Extract user-defined system instruction first + user_system_instruction = next( + (msg["content"] for msg in messages if msg.get("role") == "system"), + None, + ) + + # Combine with default system prompt if configured + system_instruction = self._combine_system_prompts(user_system_instruction) + + last_user_msg = next( + (m for m in reversed(messages) if m.get("role") == "user"), None + ) + if not last_user_msg: + raise ValueError("No user message found") + + optimization_stats: List[Dict[str, Any]] = [] + history_images = await self._gather_history_images( + messages, last_user_msg, optimization_stats + ) + prompt, current_images = await self._extract_images_from_message( + last_user_msg, stats_list=optimization_stats + ) + + # Deduplicate + history_images = self._deduplicate_images(history_images) + current_images = self._deduplicate_images(current_images) + + combined, reused_flags = self._apply_order_and_limit( + history_images, current_images + ) + + if not prompt and not combined: + raise ValueError("No prompt or images provided") + if not prompt and combined: + prompt = "Analyze and describe the provided images." + + # Build ordered stats aligned with combined list + ordered_stats: List[Dict[str, Any]] = [] + if optimization_stats: + # Build map from final_hash -> stat (first wins) + hash_map: Dict[str, Dict[str, Any]] = {} + for s in optimization_stats: + fh = s.get("final_hash") + if fh and fh not in hash_map: + hash_map[fh] = s + for part in combined: + try: + fh = hashlib.sha256( + part["inline_data"]["data"].encode() + ).hexdigest() + ordered_stats.append(hash_map.get(fh) or {}) + except Exception: + ordered_stats.append({}) + # Emit stats AFTER final ordering so labels match + await self._emit_image_stats( + ordered_stats, + reused_flags, + self.valves.IMAGE_HISTORY_MAX_REFERENCES, + __event_emitter__, + ) + + # Emit mapping + if combined: + mapping = [ + { + "index": i + 1, + "label": ( + f"Image {i + 1}" if self.valves.IMAGE_ADD_LABELS else str(i + 1) + ), + "reused": reused_flags[i], + "origin": "history" if reused_flags[i] else "current", + } + for i in range(len(combined)) + ] + await __event_emitter__( + { + "type": "status", + "data": { + "action": "image_reference_map", + "description": f"{len(combined)} image(s) included (limit {self.valves.IMAGE_HISTORY_MAX_REFERENCES}).", + "images": mapping, + "done": True, + }, + } + ) + + # Build parts + parts: List[Dict[str, Any]] = [] + + # For image generation models, prepend system instruction to the prompt + # since system_instruction parameter may not be supported + final_prompt = prompt + if system_instruction and prompt: + final_prompt = f"{system_instruction}\n\n{prompt}" + self.log.debug( + f"Prepended system instruction to prompt for image generation. " + f"System instruction length: {len(system_instruction)}, " + f"Original prompt length: {len(prompt)}, " + f"Final prompt length: {len(final_prompt)}" + ) + elif system_instruction and not prompt: + final_prompt = system_instruction + self.log.debug( + f"Using system instruction as prompt for image generation " + f"(length: {len(system_instruction)})" + ) + + if final_prompt: + parts.append({"text": final_prompt}) + if self.valves.IMAGE_ADD_LABELS: + for idx, part in enumerate(combined, start=1): + parts.append({"text": f"[Image {idx}]"}) + parts.append(part) + else: + parts.extend(combined) + + self.log.debug( + f"Image-capable payload: history={len(history_images)} current={len(current_images)} used={len(combined)} limit={self.valves.IMAGE_HISTORY_MAX_REFERENCES} history_first={self.valves.IMAGE_HISTORY_FIRST} prompt_len={len(final_prompt)}" + ) + # Return None for system_instruction since we've incorporated it into the prompt + return [{"role": "user", "parts": parts}], None + + def __init__(self): + """Initializes the Pipe instance and configures the genai library.""" + self.valves = self.Valves() + self.name: str = "Google Gemini: " + + # Setup logging + self.log = logging.getLogger("google_ai.pipe") + self.log.setLevel(SRC_LOG_LEVELS.get("OPENAI", logging.INFO)) + + # Model cache + self._model_cache: Optional[List[Dict[str, str]]] = None + self._model_cache_time: float = 0 + + def _get_client(self) -> genai.Client: + """ + Validates API credentials and returns a genai.Client instance. + """ + self._validate_api_key() + + if self.valves.USE_VERTEX_AI: + self.log.debug( + f"Initializing Vertex AI client (Project: {self.valves.VERTEX_PROJECT}, Location: {self.valves.VERTEX_LOCATION})" + ) + return genai.Client( + vertexai=True, + project=self.valves.VERTEX_PROJECT, + location=self.valves.VERTEX_LOCATION, + ) + else: + self.log.debug("Initializing Google Generative AI client with API Key") + headers = {} + if ( + self.valves.ENABLE_FORWARD_USER_INFO_HEADERS + and hasattr(self, "user") + and self.user + ): + + def sanitize_header_value(value: Any, max_length: int = 255) -> str: + if value is None: + return "" + # Convert to string and remove all control characters + sanitized = re.sub(r"[\x00-\x1F\x7F]", "", str(value)) + sanitized = sanitized.strip() + return ( + sanitized[:max_length] + if len(sanitized) > max_length + else sanitized + ) + + user_attrs = { + "X-OpenWebUI-User-Name": sanitize_header_value( + getattr(self.user, "name", None) + ), + "X-OpenWebUI-User-Id": sanitize_header_value( + getattr(self.user, "id", None) + ), + "X-OpenWebUI-User-Email": sanitize_header_value( + getattr(self.user, "email", None) + ), + "X-OpenWebUI-User-Role": sanitize_header_value( + getattr(self.user, "role", None) + ), + } + headers = {k: v for k, v in user_attrs.items() if v not in (None, "")} + options = types.HttpOptions( + api_version=self.valves.API_VERSION, + base_url=self.valves.BASE_URL, + headers=headers, + ) + return genai.Client( + api_key=EncryptedStr.decrypt(self.valves.GOOGLE_API_KEY), + http_options=options, + ) + + def _validate_api_key(self) -> None: + """ + Validates that the necessary Google API credentials are set. + + Raises: + ValueError: If the required credentials are not set. + """ + if self.valves.USE_VERTEX_AI: + if not self.valves.VERTEX_PROJECT: + self.log.error("USE_VERTEX_AI is true, but VERTEX_PROJECT is not set.") + raise ValueError( + "VERTEX_PROJECT is not set. Please provide the Google Cloud project ID." + ) + # For Vertex AI, location has a default, so project is the main thing to check. + # Actual authentication will be handled by ADC or environment. + self.log.debug( + "Using Vertex AI. Ensure ADC or service account is configured." + ) + else: + if not self.valves.GOOGLE_API_KEY: + self.log.error("GOOGLE_API_KEY is not set (and not using Vertex AI).") + raise ValueError( + "GOOGLE_API_KEY is not set. Please provide the API key in the environment variables or valves." + ) + self.log.debug("Using Google Generative AI API with API Key.") + + def strip_prefix(self, model_name: str) -> str: + """ + Extract the model identifier using regex, handling various naming conventions. + e.g., "google_gemini_pipeline.gemini-2.5-flash-preview-04-17" -> "gemini-2.5-flash-preview-04-17" + e.g., "models/gemini-1.5-flash-001" -> "gemini-1.5-flash-001" + e.g., "publishers/google/models/gemini-1.5-pro" -> "gemini-1.5-pro" + """ + # Use regex to remove everything up to and including the last '/' or the first '.' + stripped = re.sub(r"^(?:.*/|[^.]*\.)", "", model_name) + return stripped + + def get_google_models(self, force_refresh: bool = False) -> List[Dict[str, str]]: + """ + Retrieve available Google models suitable for content generation. + Uses caching to reduce API calls. + + Args: + force_refresh: Whether to force refreshing the model cache + + Returns: + List of dictionaries containing model id and name. + """ + # Check cache first + current_time = time.time() + if ( + not force_refresh + and self._model_cache is not None + and (current_time - self._model_cache_time) < self.valves.MODEL_CACHE_TTL + ): + self.log.debug("Using cached model list") + return self._model_cache + + try: + client = self._get_client() + self.log.debug("Fetching models from Google API") + models = list(client.models.list()) + + # Process additional models (models not returned by SDK but that we want to add) + additional = self.valves.MODEL_ADDITIONAL + if additional: + self.log.debug(f"Processing additional models: {additional}") + existing_model_names = {self.strip_prefix(m.name) for m in models} + additional_ids = set(re.findall(r"[^,\s]+", additional)) + + for model_id in additional_ids.difference(existing_model_names): + self.log.debug(f"Adding additional model '{model_id}'.") + models.append(types.Model(name=f"models/{model_id}")) + + available_models = [] + for model in models: + actions = model.supported_actions + model_id_stripped = self.strip_prefix(model.name) + is_content_model = actions is None or "generateContent" in actions + is_video_model = ( + actions is not None and "generateVideos" in actions + ) or model_id_stripped.startswith("veo-") + if is_content_model or is_video_model: + model_id = model_id_stripped + model_name = model.display_name or model_id + + # Check if model supports image generation + supports_image_generation = self._check_image_generation_support( + model_id + ) + if supports_image_generation: + model_name += " 🎨" # Add image generation indicator + + # Check if model supports video generation + supports_video_generation = self._check_video_generation_support( + model_id + ) + if supports_video_generation: + model_name += " 🎬" # Add video generation indicator + + available_models.append( + { + "id": model_id, + "name": model_name, + "image_generation": supports_image_generation, + "video_generation": supports_video_generation, + } + ) + + model_map = {model["id"]: model for model in available_models} + + # Apply MODEL_WHITELIST filter if configured (takes priority) + whitelist = self.valves.MODEL_WHITELIST + if whitelist: + self.log.debug(f"Applying model whitelist: {whitelist}") + whitelisted_ids = set(re.findall(r"[^,\s]+", whitelist)) + # Filter to only include whitelisted models + filtered_models = { + k: v for k, v in model_map.items() if k in whitelisted_ids + } + self.log.debug(f"After whitelist filter: {len(filtered_models)} models") + else: + # If no whitelist, filter to only include models starting with 'gemini-' or 'veo-' + filtered_models = { + k: v + for k, v in model_map.items() + if k.startswith("gemini-") or k.startswith("veo-") + } + self.log.debug(f"After prefix filter: {len(filtered_models)} models") + + # Update cache + self._model_cache = list(filtered_models.values()) + self._model_cache_time = current_time + self.log.debug(f"Found {len(self._model_cache)} Gemini models") + return self._model_cache + + except Exception as e: + self.log.exception(f"Could not fetch models from Google: {str(e)}") + # Return a specific error entry for the UI + return [{"id": "error", "name": f"Could not fetch models: {str(e)}"}] + + def _check_image_generation_support(self, model_id: str) -> bool: + """ + Check if a model supports image generation. + + Args: + model_id: The model ID to check + + Returns: + True if the model supports image generation, False otherwise + """ + # Known image generation models (both Gemini 2.5 and Gemini 3) + image_generation_models = [ + "gemini-2.5-flash-image", + "gemini-2.5-flash-image-preview", + "gemini-3-flash-image", + "gemini-3-flash-image-preview", + "gemini-3-pro-image", + "gemini-3-pro-image-preview", + ] + + # Check for exact matches or pattern matches + for pattern in image_generation_models: + if model_id == pattern or pattern in model_id: + return True + + # Additional pattern checking for future models + if "image" in model_id.lower() and ( + "generation" in model_id.lower() or "preview" in model_id.lower() + ): + return True + + return False + + def _check_image_config_support(self, model_id: str) -> bool: + """ + Check if a model supports ImageConfig (aspect_ratio and image_size parameters). + + ImageConfig is only supported by Gemini 3 image generation models. + Gemini 2.5 image models support image generation but not ImageConfig. + + Args: + model_id: The model ID to check + + Returns: + True if the model supports ImageConfig, False otherwise + """ + # ImageConfig is only supported by Gemini 3 models + model_lower = model_id.lower() + + # Check if it's a Gemini 3 model + if "gemini-3-" not in model_lower: + return False + + # Check if it's an image generation model + return self._check_image_generation_support(model_id) + + def _check_thinking_support(self, model_id: str) -> bool: + """ + Check if a model supports the thinking feature. + + Args: + model_id: The model ID to check + + Returns: + True if the model supports thinking, False otherwise + """ + # Models that do NOT support thinking + non_thinking_models = [ + "gemini-2.5-flash-image-preview", + "gemini-2.5-flash-image", + ] + + # Check for exact matches + for pattern in non_thinking_models: + if model_id == pattern or pattern in model_id: + return False + + # Additional pattern checking - image generation models typically don't support thinking + if "image" in model_id.lower() and ( + "generation" in model_id.lower() or "preview" in model_id.lower() + ): + return False + + # By default, assume models support thinking + return True + + def _check_thinking_level_support(self, model_id: str) -> bool: + """ + Check if a model supports the thinking_level parameter. + + Gemini 3 models support thinking_level and should NOT use thinking_budget. + Other models (like Gemini 2.5) use thinking_budget instead. + + Args: + model_id: The model ID to check + + Returns: + True if the model supports thinking_level, False otherwise + """ + # Gemini 3 models support thinking_level (not thinking_budget) + gemini_3_patterns = [ + "gemini-3-", + ] + + model_lower = model_id.lower() + for pattern in gemini_3_patterns: + if pattern in model_lower: + return True + + return False + + def _validate_thinking_level(self, level: str) -> Optional[str]: + """ + Validate and normalize the thinking level value. + + Args: + level: The thinking level string to validate + + Returns: + Normalized level string ('minimal', 'low', 'medium', 'high') or None if invalid/empty + """ + if not level: + return None + + normalized = level.strip().lower() + valid_levels = ["minimal", "low", "medium", "high"] + + if normalized in valid_levels: + return normalized + + self.log.warning( + f"Invalid thinking level '{level}'. Valid values are: {', '.join(valid_levels)}. " + "Falling back to model default." + ) + return None + + def _validate_thinking_budget(self, budget: int) -> int: + """ + Validate and normalize the thinking budget value. + + Args: + budget: The thinking budget integer to validate + + Returns: + Validated budget: -1 for dynamic, 0 to disable, or 1-32768 for fixed limit + """ + # -1 means dynamic thinking (let the model decide) + if budget == -1: + return -1 + + # 0 means disable thinking + if budget == 0: + return 0 + + # Validate positive range (1-32768) + if budget > 0: + if budget > 32768: + self.log.warning( + f"Thinking budget {budget} exceeds maximum of 32768. Clamping to 32768." + ) + return 32768 + return budget + + # Negative values (except -1) are invalid, treat as -1 (dynamic) + self.log.warning( + f"Invalid thinking budget {budget}. Only -1 (dynamic), 0 (disabled), or 1-32768 are valid. " + "Falling back to dynamic thinking." + ) + return -1 + + def _validate_aspect_ratio(self, aspect_ratio: str) -> Optional[str]: + """ + Validate and normalize the aspect ratio value. + + Args: + aspect_ratio: The aspect ratio string to validate + + Returns: + Validated aspect ratio string, None for "default", or "1:1" as fallback for invalid values + """ + if not aspect_ratio or aspect_ratio == "default": + self.log.debug("Using default aspect ratio (None)") + return None + + normalized = aspect_ratio.strip() + valid_ratios = [r for r in ASPECT_RATIO_OPTIONS if r != "default"] + + if normalized in valid_ratios: + return normalized + + self.log.warning( + f"Invalid aspect ratio '{aspect_ratio}'. Valid values are: {', '.join(valid_ratios)}. " + "Using default '1:1'." + ) + return "1:1" + + def _validate_resolution(self, resolution: str) -> Optional[str]: + """ + Validate and normalize the resolution value. + + Args: + resolution: The resolution string to validate + + Returns: + Validated resolution string, None for "default", or "2K" as fallback for invalid values + """ + if not resolution or resolution.lower() == "default": + self.log.debug("Using default resolution (None)") + return None + + normalized = resolution.strip().upper() + valid_resolutions = [r for r in RESOLUTION_OPTIONS if r.lower() != "default"] + + if normalized in valid_resolutions: + return normalized + + self.log.warning( + f"Invalid resolution '{resolution}'. Valid values are: {', '.join(valid_resolutions)}. " + "Using default '2K'." + ) + return "2K" + + def _check_video_generation_support(self, model_id: str) -> bool: + model_lower = model_id.lower() + return model_lower.startswith("veo-") or ( + "veo" in model_lower and "generate" in model_lower + ) + + def _check_veo_3_1_support(self, model_id: str) -> bool: + """Check if a Veo model is version 3.1 (supports reference images, interpolation, 4k, extension).""" + return "veo-3.1" in model_id.lower() + + def _get_veo_model_capabilities(self, model_id: str) -> Dict[str, Any]: + """Return per-model feature support matrix based on official Google Veo documentation.""" + model_lower = model_id.lower() + is_fast = "fast" in model_lower + + if "veo-3.1" in model_lower: + return { + "version": "3.1", + "is_fast": is_fast, + "supports_enhance_prompt": not is_fast, + "supports_resolution": True, + "valid_resolutions": ["720p", "1080p", "4k"], + "valid_durations": [4, 6, 8], + "max_videos": 1, + "supports_reference_images": True, + "supports_last_frame": True, + "supports_extension": True, + } + if "veo-3" in model_lower: + return { + "version": "3", + "is_fast": is_fast, + "supports_enhance_prompt": not is_fast, + "supports_resolution": True, + "valid_resolutions": ["720p", "1080p"], + "valid_durations": [8], + "max_videos": 1, + "supports_reference_images": False, + "supports_last_frame": True, + "supports_extension": False, + } + if "veo-2" in model_lower: + return { + "version": "2", + "is_fast": False, + "supports_enhance_prompt": False, + "supports_resolution": False, + "valid_resolutions": [], + "valid_durations": [5, 6, 8], + "max_videos": 2, + "supports_reference_images": False, + "supports_last_frame": True, + "supports_extension": False, + } + return { + "version": "unknown", + "is_fast": is_fast, + "supports_enhance_prompt": False, + "supports_resolution": False, + "valid_resolutions": [], + "valid_durations": [8], + "max_videos": 1, + "supports_reference_images": False, + "supports_last_frame": False, + "supports_extension": False, + } + + def _validate_video_aspect_ratio(self, aspect_ratio: str) -> Optional[str]: + if not aspect_ratio or aspect_ratio == "default": + return None + normalized = aspect_ratio.strip() + valid = [r for r in VIDEO_ASPECT_RATIO_OPTIONS if r != "default"] + if normalized in valid: + return normalized + self.log.warning( + f"Invalid video aspect ratio '{aspect_ratio}'. Valid: {', '.join(valid)}. Using default." + ) + return None + + def _validate_video_resolution(self, resolution: str) -> Optional[str]: + if not resolution or resolution.lower() == "default": + return None + normalized = resolution.strip().lower() + valid = [r for r in VIDEO_RESOLUTION_OPTIONS if r.lower() != "default"] + if normalized in valid: + return normalized + self.log.warning( + f"Invalid video resolution '{resolution}'. Valid: {', '.join(valid)}. Using default." + ) + return None + + def _validate_video_duration(self, duration: str) -> Optional[int]: + if not duration or duration.lower() == "default": + return None + valid = {int(d) for d in VIDEO_DURATION_OPTIONS if d != "default"} + try: + val = int(duration) + if val in valid: + return val + except (ValueError, TypeError): + pass + self.log.warning( + f"Invalid video duration '{duration}'. Valid: {', '.join(str(v) for v in sorted(valid))}. Using default." + ) + return None + + def _build_video_generation_config( + self, + body: Dict[str, Any], + __user__: Optional[dict] = None, + model_id: str = "", + ) -> types.GenerateVideosConfig: + """Build GenerateVideosConfig from valves, user overrides, and model capabilities.""" + caps = self._get_veo_model_capabilities(model_id) + + user_ar = self._get_user_valve_value(__user__, "VIDEO_GENERATION_ASPECT_RATIO") + aspect_ratio = self._validate_video_aspect_ratio( + body.get( + "aspect_ratio", user_ar or self.valves.VIDEO_GENERATION_ASPECT_RATIO + ) + ) + + user_res = self._get_user_valve_value(__user__, "VIDEO_GENERATION_RESOLUTION") + resolution = self._validate_video_resolution( + body.get("resolution", user_res or self.valves.VIDEO_GENERATION_RESOLUTION) + ) + + user_dur = self._get_user_valve_value(__user__, "VIDEO_GENERATION_DURATION") + duration_seconds = self._validate_video_duration( + body.get("duration", user_dur or self.valves.VIDEO_GENERATION_DURATION) + ) + + negative_prompt = ( + body.get("negative_prompt", self.valves.VIDEO_GENERATION_NEGATIVE_PROMPT) + or None + ) + + person_generation_raw = body.get( + "person_generation", self.valves.VIDEO_GENERATION_PERSON_GENERATION + ) + person_generation = None + if person_generation_raw and person_generation_raw != "default": + valid_person_values = [ + v for v in VIDEO_PERSON_GENERATION_OPTIONS if v != "default" + ] + if person_generation_raw in valid_person_values: + person_generation = person_generation_raw + else: + self.log.warning( + f"Invalid person_generation '{person_generation_raw}'. " + f"Valid: {', '.join(valid_person_values)}. Ignoring." + ) + + enhance_prompt = body.get( + "enhance_prompt", self.valves.VIDEO_GENERATION_ENHANCE_PROMPT + ) + + number_of_videos_raw = body.get("number_of_videos", 1) + try: + number_of_videos = int(number_of_videos_raw) + except (ValueError, TypeError): + self.log.warning( + f"Invalid number_of_videos '{number_of_videos_raw}', defaulting to 1" + ) + number_of_videos = 1 + + config_params: Dict[str, Any] = { + "number_of_videos": min(max(number_of_videos, 1), caps["max_videos"]), + } + + # enhance_prompt: not supported by Fast models or Veo 2 + if caps["supports_enhance_prompt"] and enhance_prompt: + config_params["enhance_prompt"] = enhance_prompt + + if aspect_ratio: + config_params["aspect_ratio"] = aspect_ratio + + # Resolution: not supported by Veo 2; model-specific valid values + if resolution and caps["supports_resolution"]: + if resolution in caps["valid_resolutions"]: + config_params["resolution"] = resolution + else: + self.log.warning( + f"Resolution '{resolution}' not supported by {model_id}. " + f"Valid: {', '.join(caps['valid_resolutions'])}. Using default." + ) + + # Duration: model-specific valid values + if duration_seconds: + if duration_seconds in caps["valid_durations"]: + config_params["duration_seconds"] = duration_seconds + else: + self.log.warning( + f"Duration {duration_seconds}s not supported by {model_id}. " + f"Valid: {', '.join(str(d) for d in caps['valid_durations'])}. Using default." + ) + + if negative_prompt: + config_params["negative_prompt"] = negative_prompt + if person_generation: + config_params["person_generation"] = person_generation + + self.log.debug(f"Video generation config for {model_id}: {config_params}") + return types.GenerateVideosConfig(**config_params) + + def pipes(self) -> List[Dict[str, str]]: + """ + Returns a list of available Google Gemini models for the UI. + + Returns: + List of dictionaries containing model id and name. + """ + try: + self.name = "Google Gemini: " + return self.get_google_models() + except ValueError as e: + # Handle the case where API key is missing during pipe listing + self.log.error(f"Error during pipes listing (validation): {e}") + return [{"id": "error", "name": str(e)}] + except Exception as e: + # Handle other potential errors during model fetching + self.log.exception( + f"An unexpected error occurred during pipes listing: {str(e)}" + ) + return [{"id": "error", "name": f"An unexpected error occurred: {str(e)}"}] + + def _prepare_model_id(self, model_id: str) -> str: + """ + Prepare and validate the model ID for use with the API. + + Args: + model_id: The original model ID from the user + + Returns: + Properly formatted model ID + + Raises: + ValueError: If the model ID is invalid or unsupported + """ + original_model_id = model_id + model_id = self.strip_prefix(model_id) + + valid_prefixes = ("gemini-", "veo-") + + # If the model ID doesn't match a known prefix, try to find it by name + if not model_id.startswith(valid_prefixes): + models_list = self.get_google_models() + found_model = next( + (m["id"] for m in models_list if m["name"] == original_model_id), None + ) + if found_model and found_model.startswith(valid_prefixes): + model_id = found_model + self.log.debug( + f"Mapped model name '{original_model_id}' to model ID '{model_id}'" + ) + else: + if not model_id.startswith(valid_prefixes): + self.log.error( + f"Invalid or unsupported model ID: '{original_model_id}'" + ) + raise ValueError( + f"Invalid or unsupported Google model ID or name: '{original_model_id}'" + ) + + return model_id + + def _prepare_content( + self, messages: List[Dict[str, Any]] + ) -> Tuple[List[Dict[str, Any]], Optional[str]]: + """ + Prepare messages content for the API and extract system message if present. + + Args: + messages: List of message objects from the request + + Returns: + Tuple of (prepared content list, system message string or None) + """ + # Extract user-defined system message + user_system_message = next( + (msg["content"] for msg in messages if msg.get("role") == "system"), + None, + ) + + # Combine with default system prompt if configured + system_message = self._combine_system_prompts(user_system_message) + + # Prepare contents for the API + contents = [] + for message in messages: + role = message.get("role") + if role == "system": + continue # Skip system messages, handled separately + + content = message.get("content", "") + parts = [] + + # Handle different content types + if isinstance(content, list): # Multimodal content + parts.extend(self._process_multimodal_content(content)) + elif isinstance(content, str): # Plain text content + parts.append({"text": content}) + else: + self.log.warning(f"Unsupported message content type: {type(content)}") + continue # Skip unsupported content + + # Map roles: 'assistant' -> 'model', 'user' -> 'user' + api_role = "model" if role == "assistant" else "user" + if parts: # Only add if there are parts + contents.append({"role": api_role, "parts": parts}) + + return contents, system_message + + def _process_multimodal_content( + self, content_list: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """ + Process multimodal content (text and images). + + Args: + content_list: List of content items + + Returns: + List of processed parts for the Gemini API + """ + parts = [] + + for item in content_list: + if item.get("type") == "text": + parts.append({"text": item.get("text", "")}) + elif item.get("type") == "image_url": + image_url = item.get("image_url", {}).get("url", "") + + if image_url.startswith("data:image"): + # Handle base64 encoded image data with optimization + try: + # Optimize the image before processing + optimized_image = self._optimize_image_for_api(image_url) + header, encoded = optimized_image.split(",", 1) + mime_type = header.split(":")[1].split(";")[0] + + # Basic validation for image types + if mime_type not in [ + "image/jpeg", + "image/png", + "image/webp", + "image/heic", + "image/heif", + ]: + self.log.warning( + f"Unsupported image mime type: {mime_type}" + ) + parts.append( + {"text": f"[Image type {mime_type} not supported]"} + ) + continue + + # Check if the encoded data is too large + if len(encoded) > 15 * 1024 * 1024: # 15MB limit for base64 + self.log.warning( + f"Image data too large: {len(encoded)} characters" + ) + parts.append( + { + "text": "[Image too large for processing - please use a smaller image]" + } + ) + continue + + parts.append( + { + "inline_data": { + "mime_type": mime_type, + "data": encoded, + } + } + ) + except Exception as img_ex: + self.log.exception(f"Could not parse image data URL: {img_ex}") + parts.append({"text": "[Image data could not be processed]"}) + else: + # Gemini API doesn't directly support image URLs + self.log.warning(f"Direct image URLs not supported: {image_url}") + parts.append({"text": f"[Image URL not processed: {image_url}]"}) + + return parts + + # _find_image removed (was single-image oriented and is superseded by multi-image logic) + + async def _extract_images_from_message( + self, + message: Dict[str, Any], + *, + stats_list: Optional[List[Dict[str, Any]]] = None, + ) -> Tuple[str, List[Dict[str, Any]]]: + """Extract prompt text and ALL images from a single user message. + + This replaces the previous single-image _find_image logic for image-capable + models so that multi-image prompts are respected. + + Returns: + (prompt_text, image_parts) + prompt_text: concatenated text content (may be empty) + image_parts: list of {"inline_data": {mime_type, data}} dicts + """ + content = message.get("content", "") + text_segments: List[str] = [] + image_parts: List[Dict[str, Any]] = [] + + # Helper to process a data URL or fetched file and append inline_data + def _add_image(data_url: str): + try: + optimized = self._optimize_image_for_api(data_url, stats_list) + header, b64 = optimized.split(",", 1) + mime = header.split(":", 1)[1].split(";", 1)[0] + image_parts.append({"inline_data": {"mime_type": mime, "data": b64}}) + except Exception as e: # pragma: no cover - defensive + self.log.warning(f"Skipping image (parse failure): {e}") + + # Regex to extract markdown image references + md_pattern = re.compile( + r"!\[[^\]]*\]\((data:image[^)]+|/files/[^)]+|/api/v1/files/[^)]+)\)" + ) + + # Structured multimodal array + if isinstance(content, list): + for item in content: + if item.get("type") == "text": + txt = item.get("text", "") + text_segments.append(txt) + # Also parse any markdown images embedded in the text + for match in md_pattern.finditer(txt): + url = match.group(1) + if url.startswith("data:"): + _add_image(url) + else: + b64 = await self._fetch_file_as_base64(url) + if b64: + _add_image(b64) + elif item.get("type") == "image_url": + url = item.get("image_url", {}).get("url", "") + if url.startswith("data:"): + _add_image(url) + elif "/files/" in url or "/api/v1/files/" in url: + b64 = await self._fetch_file_as_base64(url) + if b64: + _add_image(b64) + # Plain string message (may include markdown images) + elif isinstance(content, str): + text_segments.append(content) + for match in md_pattern.finditer(content): + url = match.group(1) + if url.startswith("data:"): + _add_image(url) + else: + b64 = await self._fetch_file_as_base64(url) + if b64: + _add_image(b64) + else: + self.log.debug( + f"Unsupported content type for image extraction: {type(content)}" + ) + + prompt_text = " ".join(s.strip() for s in text_segments if s.strip()) + return prompt_text, image_parts + + def _optimize_image_for_api( + self, image_data: str, stats_list: Optional[List[Dict[str, Any]]] = None + ) -> str: + """ + Optimize image data for Gemini API using configurable parameters. + + Returns: + Optimized base64 data URL + """ + # Check if optimization is enabled + if not self.valves.IMAGE_ENABLE_OPTIMIZATION: + self.log.debug("Image optimization disabled via configuration") + return image_data + + max_size_mb = self.valves.IMAGE_MAX_SIZE_MB + max_dimension = self.valves.IMAGE_MAX_DIMENSION + base_quality = self.valves.IMAGE_COMPRESSION_QUALITY + png_threshold = self.valves.IMAGE_PNG_COMPRESSION_THRESHOLD_MB + + self.log.debug( + f"Image optimization config: max_size={max_size_mb}MB, max_dim={max_dimension}px, quality={base_quality}, png_threshold={png_threshold}MB" + ) + try: + # Parse the data URL + if image_data.startswith("data:"): + header, encoded = image_data.split(",", 1) + mime_type = header.split(":")[1].split(";")[0] + else: + encoded = image_data + mime_type = "image/png" + + # Decode and analyze the image + image_bytes = base64.b64decode(encoded) + original_size_mb = len(image_bytes) / (1024 * 1024) + base64_size_mb = len(encoded) / (1024 * 1024) + + self.log.debug( + f"Original image: {original_size_mb:.2f} MB (decoded), {base64_size_mb:.2f} MB (base64), type: {mime_type}" + ) + + # Determine optimization strategy + reasons: List[str] = [] + if original_size_mb > max_size_mb: + reasons.append(f"size > {max_size_mb} MB") + if base64_size_mb > max_size_mb * 1.4: + reasons.append("base64 overhead") + if mime_type == "image/png" and original_size_mb > png_threshold: + reasons.append(f"PNG > {png_threshold}MB") + + # Always check dimensions + with Image.open(io.BytesIO(image_bytes)) as img: + width, height = img.size + resized_flag = False + if width > max_dimension or height > max_dimension: + reasons.append(f"dimensions > {max_dimension}px") + + # Early exit: no optimization triggers -> keep original, record stats + if not reasons: + if stats_list is not None: + stats_list.append( + { + "original_size_mb": round(original_size_mb, 4), + "final_size_mb": round(original_size_mb, 4), + "quality": None, + "format": mime_type.split("/")[-1].upper(), + "resized": False, + "reasons": ["no_optimization_needed"], + "final_hash": hashlib.sha256( + encoded.encode() + ).hexdigest(), + } + ) + self.log.debug( + "Skipping optimization: image already within thresholds" + ) + return image_data + + self.log.debug(f"Optimization triggers: {', '.join(reasons)}") + + # Convert to RGB for JPEG compression + if img.mode in ("RGBA", "LA", "P"): + background = Image.new("RGB", img.size, (255, 255, 255)) + if img.mode == "P": + img = img.convert("RGBA") + background.paste( + img, + mask=img.split()[-1] if img.mode in ("RGBA", "LA") else None, + ) + img = background + elif img.mode != "RGB": + img = img.convert("RGB") + + # Resize if needed + if width > max_dimension or height > max_dimension: + ratio = min(max_dimension / width, max_dimension / height) + new_size = (int(width * ratio), int(height * ratio)) + self.log.debug( + f"Resizing from {width}x{height} to {new_size[0]}x{new_size[1]}" + ) + img = img.resize(new_size, Image.Resampling.LANCZOS) + resized_flag = True + + # Determine quality levels based on original size and user configuration + if original_size_mb > 5.0: + quality_levels = [ + base_quality, + base_quality - 10, + base_quality - 20, + base_quality - 30, + base_quality - 40, + max(base_quality - 50, 25), + ] + elif original_size_mb > 2.0: + quality_levels = [ + base_quality, + base_quality - 5, + base_quality - 15, + base_quality - 25, + max(base_quality - 35, 35), + ] + else: + quality_levels = [ + min(base_quality + 5, 95), + base_quality, + base_quality - 10, + max(base_quality - 20, 50), + ] + + # Ensure quality levels are within valid range (1-100) + quality_levels = [max(1, min(100, q)) for q in quality_levels] + + # Try compression levels + for quality in quality_levels: + output_buffer = io.BytesIO() + format_type = ( + "JPEG" + if original_size_mb > png_threshold or "jpeg" in mime_type + else "PNG" + ) + output_mime = f"image/{format_type.lower()}" + + img.save( + output_buffer, + format=format_type, + quality=quality, + optimize=True, + ) + output_bytes = output_buffer.getvalue() + output_size_mb = len(output_bytes) / (1024 * 1024) + + if output_size_mb <= max_size_mb: + optimized_b64 = base64.b64encode(output_bytes).decode("utf-8") + self.log.debug( + f"Optimized: {original_size_mb:.2f} MB → {output_size_mb:.2f} MB (Q{quality})" + ) + if stats_list is not None: + stats_list.append( + { + "original_size_mb": round(original_size_mb, 4), + "final_size_mb": round(output_size_mb, 4), + "quality": quality, + "format": format_type, + "resized": resized_flag, + "reasons": reasons, + "final_hash": hashlib.sha256( + optimized_b64.encode() + ).hexdigest(), + } + ) + return f"data:{output_mime};base64,{optimized_b64}" + + # Fallback: minimum quality + output_buffer = io.BytesIO() + img.save(output_buffer, format="JPEG", quality=15, optimize=True) + output_bytes = output_buffer.getvalue() + output_size_mb = len(output_bytes) / (1024 * 1024) + optimized_b64 = base64.b64encode(output_bytes).decode("utf-8") + + self.log.warning( + f"Aggressive optimization: {output_size_mb:.2f} MB (Q15)" + ) + if stats_list is not None: + stats_list.append( + { + "original_size_mb": round(original_size_mb, 4), + "final_size_mb": round(output_size_mb, 4), + "quality": 15, + "format": "JPEG", + "resized": resized_flag, + "reasons": reasons + ["fallback_min_quality"], + "final_hash": hashlib.sha256( + optimized_b64.encode() + ).hexdigest(), + } + ) + return f"data:image/jpeg;base64,{optimized_b64}" + + except Exception as e: + self.log.error(f"Image optimization failed: {e}") + # Return original or safe fallback + if image_data.startswith("data:"): + if stats_list is not None: + stats_list.append( + { + "original_size_mb": None, + "final_size_mb": None, + "quality": None, + "format": None, + "resized": False, + "reasons": ["optimization_failed"], + "final_hash": ( + hashlib.sha256(encoded.encode()).hexdigest() + if "encoded" in locals() + else None + ), + } + ) + return image_data + return f"data:image/jpeg;base64,{encoded if 'encoded' in locals() else image_data}" + + async def _fetch_file_as_base64(self, file_url: str) -> Optional[str]: + """ + Fetch a file from Open WebUI's file system and convert to base64. + + Args: + file_url: File URL from Open WebUI + + Returns: + Base64 encoded file data or None if file not found + """ + try: + if "/api/v1/files/" in file_url: + fid = file_url.split("/api/v1/files/")[-1].split("/")[0].split("?")[0] + else: + fid = file_url.split("/files/")[-1].split("/")[0].split("?")[0] + + from pathlib import Path + from open_webui.models.files import Files + from open_webui.storage.provider import Storage + + file_obj = Files.get_file_by_id(fid) + if file_obj and file_obj.path: + file_path = Storage.get_file(file_obj.path) + file_path = Path(file_path) + if file_path.is_file(): + async with aiofiles.open(file_path, "rb") as fp: + raw = await fp.read() + enc = base64.b64encode(raw).decode() + mime = file_obj.meta.get("content_type", "image/png") + return f"data:{mime};base64,{enc}" + except Exception as e: + self.log.warning(f"Could not fetch file {file_url}: {e}") + return None + + async def _upload_image_with_status( + self, + image_data: Any, + mime_type: str, + __request__: Request, + __user__: dict, + __event_emitter__: Callable, + ) -> str: + """ + Unified image upload method with status updates and fallback handling. + + Returns: + URL to uploaded image or data URL fallback + """ + try: + await __event_emitter__( + { + "type": "status", + "data": { + "action": "image_upload", + "description": "Uploading generated image to your library...", + "done": False, + }, + } + ) + + self.user = user = Users.get_user_by_id(__user__["id"]) + + # Convert image data to base64 string if needed + if isinstance(image_data, bytes): + image_data_b64 = base64.b64encode(image_data).decode("utf-8") + else: + image_data_b64 = str(image_data) + + image_url = self._upload_image( + __request__=__request__, + user=user, + image_data=image_data_b64, + mime_type=mime_type, + ) + + await __event_emitter__( + { + "type": "status", + "data": { + "action": "image_upload", + "description": "Image uploaded successfully!", + "done": True, + }, + } + ) + + return image_url + + except Exception as e: + self.log.warning(f"File upload failed, falling back to data URL: {e}") + + if isinstance(image_data, bytes): + image_data_b64 = base64.b64encode(image_data).decode("utf-8") + else: + image_data_b64 = str(image_data) + + await __event_emitter__( + { + "type": "status", + "data": { + "action": "image_upload", + "description": "Using inline image (upload failed)", + "done": True, + }, + } + ) + + return f"data:{mime_type};base64,{image_data_b64}" + + def _upload_image( + self, __request__: Request, user: UserModel, image_data: str, mime_type: str + ) -> str: + """ + Upload generated image to Open WebUI's file system. + Expects base64 encoded string input. + + Args: + __request__: FastAPI request object + user: User model object + image_data: Base64 encoded image data string + mime_type: MIME type of the image + + Returns: + URL to the uploaded image or data URL fallback + """ + try: + self.log.debug( + f"Processing image data, type: {type(image_data)}, length: {len(image_data)}" + ) + + # Decode base64 string to bytes + try: + decoded_data = base64.b64decode(image_data) + self.log.debug( + f"Successfully decoded image data: {len(decoded_data)} bytes" + ) + except Exception as decode_error: + self.log.error(f"Failed to decode base64 data: {decode_error}") + # Try to add padding if missing + try: + missing_padding = len(image_data) % 4 + if missing_padding: + image_data += "=" * (4 - missing_padding) + decoded_data = base64.b64decode(image_data) + self.log.debug( + f"Successfully decoded with padding: {len(decoded_data)} bytes" + ) + except Exception as second_decode_error: + self.log.error(f"Still failed to decode: {second_decode_error}") + return f"data:{mime_type};base64,{image_data}" + + bio = io.BytesIO(decoded_data) + bio.seek(0) + + # Determine file extension + extension = "png" + if "jpeg" in mime_type or "jpg" in mime_type: + extension = "jpg" + elif "webp" in mime_type: + extension = "webp" + elif "gif" in mime_type: + extension = "gif" + + # Create filename + filename = f"gemini-generated-{uuid.uuid4().hex}.{extension}" + + # Upload with simple approach like reference + up_obj = upload_file( + request=__request__, + background_tasks=BackgroundTasks(), + file=UploadFile( + file=bio, + filename=filename, + headers=Headers({"content-type": mime_type}), + ), + process=False, # Matching reference - no heavy processing + user=user, + metadata={"mime_type": mime_type, "source": "gemini_image_generation"}, + ) + + self.log.debug( + f"Upload completed. File ID: {up_obj.id}, Decoded size: {len(decoded_data)} bytes" + ) + + # Generate URL using reference method + return __request__.app.url_path_for("get_file_content_by_id", id=up_obj.id) + + except Exception as e: + self.log.exception(f"Image upload failed, using data URL fallback: {e}") + # Fallback to data URL if upload fails + return f"data:{mime_type};base64,{image_data}" + + def _upload_video( + self, + __request__: Request, + user: UserModel, + video_data: bytes, + mime_type: str = "video/mp4", + ) -> Tuple[str, str]: + """Upload generated video to Open WebUI's file system. + + Returns: + Tuple of (content_url, file_id) + """ + bio = io.BytesIO(video_data) + bio.seek(0) + + extension = "mp4" + if "webm" in mime_type: + extension = "webm" + + filename = f"veo-generated-{uuid.uuid4().hex}.{extension}" + + up_obj = upload_file( + request=__request__, + background_tasks=BackgroundTasks(), + file=UploadFile( + file=bio, + filename=filename, + headers=Headers({"content-type": mime_type}), + ), + process=False, + user=user, + metadata={"mime_type": mime_type, "source": "veo_video_generation"}, + ) + + content_url = __request__.app.url_path_for( + "get_file_content_by_id", id=up_obj.id + ) + self.log.debug( + f"Video upload completed. File ID: {up_obj.id}, Size: {len(video_data)} bytes" + ) + return content_url, up_obj.id + + async def _upload_video_with_status( + self, + video_data: bytes, + mime_type: str, + __request__: Request, + __user__: dict, + __event_emitter__: Callable, + ) -> Tuple[str, Optional[str]]: + """Upload video with status updates and data-URL fallback. + + Returns: + Tuple of (content_url_or_data_url, file_id_or_None) + """ + try: + await __event_emitter__( + { + "type": "status", + "data": { + "action": "video_upload", + "description": "Uploading generated video to your library...", + "done": False, + }, + } + ) + + self.user = user = Users.get_user_by_id(__user__["id"]) + video_url, file_id = self._upload_video( + __request__=__request__, + user=user, + video_data=video_data, + mime_type=mime_type, + ) + + await __event_emitter__( + { + "type": "status", + "data": { + "action": "video_upload", + "description": "Video uploaded successfully!", + "done": True, + }, + } + ) + return video_url, file_id + + except Exception as e: + self.log.warning(f"Video upload failed, falling back to data URL: {e}") + video_data_b64 = base64.b64encode(video_data).decode("utf-8") + await __event_emitter__( + { + "type": "status", + "data": { + "action": "video_upload", + "description": "Using inline video (upload failed)", + "done": True, + }, + } + ) + return f"data:{mime_type};base64,{video_data_b64}", None + + def _get_user_valve_value( + self, __user__: Optional[dict], valve_name: str + ) -> Optional[str]: + """Get a user valve value, returning None if not set or set to 'default'""" + if __user__ and "valves" in __user__: + value = getattr(__user__["valves"], valve_name, None) + if value and value != "default": + return value + return None + + def _configure_generation( + self, + body: Dict[str, Any], + system_instruction: Optional[str], + __metadata__: Dict[str, Any], + __tools__: dict[str, Any] | None = None, + __user__: Optional[dict] = None, + enable_image_generation: bool = False, + model_id: str = "", + ) -> types.GenerateContentConfig: + """ + Configure generation parameters and safety settings. + + Args: + body: The request body containing generation parameters + system_instruction: Optional system instruction string + enable_image_generation: Whether to enable image generation + model_id: The model ID being used (for feature support checks) + + Returns: + types.GenerateContentConfig + """ + gen_config_params = { + "temperature": body.get("temperature"), + "top_p": body.get("top_p"), + "top_k": body.get("top_k"), + "max_output_tokens": body.get("max_tokens"), + "stop_sequences": body.get("stop") or None, + "system_instruction": system_instruction, + } + + # Enable image generation if requested + if enable_image_generation: + gen_config_params["response_modalities"] = ["TEXT", "IMAGE"] + + # Configure image generation parameters (aspect ratio and resolution) + # ImageConfig is only supported by Gemini 3 models + if self._check_image_config_support(model_id): + # Body parameters override valve defaults for per-request customization + # Get aspect_ratio: body > user_valves (if not "default") > system valves + user_aspect_ratio = self._get_user_valve_value( + __user__, "IMAGE_GENERATION_ASPECT_RATIO" + ) + aspect_ratio = body.get( + "aspect_ratio", + user_aspect_ratio or self.valves.IMAGE_GENERATION_ASPECT_RATIO, + ) + + # Get resolution: body > user_valves (if not "default") > system valves + user_resolution = self._get_user_valve_value( + __user__, "IMAGE_GENERATION_RESOLUTION" + ) + resolution = body.get( + "resolution", + user_resolution or self.valves.IMAGE_GENERATION_RESOLUTION, + ) + + # Validate and normalize the values + validated_aspect_ratio = self._validate_aspect_ratio(aspect_ratio) + validated_resolution = self._validate_resolution(resolution) + + # Create image config if we have at least one valid value + if validated_aspect_ratio or validated_resolution: + try: + image_config_params = {} + if validated_aspect_ratio: + image_config_params["aspect_ratio"] = validated_aspect_ratio + if validated_resolution: + image_config_params["image_size"] = validated_resolution + gen_config_params["image_config"] = types.ImageConfig( + **image_config_params + ) + self.log.debug( + f"Image generation config: aspect_ratio={validated_aspect_ratio}, resolution={validated_resolution}" + ) + except (AttributeError, TypeError) as e: + # Fall back if SDK does not support ImageConfig + self.log.warning( + f"ImageConfig not supported by SDK version: {e}. Image generation will use default settings." + ) + except Exception as e: + # Log unexpected errors but continue without image config + self.log.warning( + f"Unexpected error configuring ImageConfig: {e}" + ) + else: + self.log.debug( + f"Model {model_id} does not support ImageConfig (aspect_ratio/resolution). " + "ImageConfig is only available for Gemini 3 image models." + ) + + # Configure Gemini thinking/reasoning for models that support it + # This is independent of include_thoughts - thinking config controls HOW the model reasons, + # while include_thoughts controls whether the reasoning is shown in the output + if self._check_thinking_support(model_id): + try: + thinking_config_params: Dict[str, Any] = {} + + # Determine include_thoughts setting + include_thoughts = body.get("include_thoughts", True) + if not self.valves.INCLUDE_THOUGHTS: + include_thoughts = False + self.log.debug( + "Thoughts output disabled via GOOGLE_INCLUDE_THOUGHTS" + ) + thinking_config_params["include_thoughts"] = include_thoughts + + # Check if model supports thinking_level (Gemini 3 models) + if self._check_thinking_level_support(model_id): + # For Gemini 3 models, use thinking_level (not thinking_budget) + # Per-chat reasoning_effort overrides environment-level THINKING_LEVEL + reasoning_effort = body.get("reasoning_effort") + validated_level = None + source = None + + if reasoning_effort: + validated_level = self._validate_thinking_level( + reasoning_effort + ) + if validated_level: + source = "per-chat reasoning_effort" + else: + self.log.debug( + f"Invalid reasoning_effort '{reasoning_effort}', falling back to THINKING_LEVEL" + ) + + # Fall back to environment-level THINKING_LEVEL if no valid reasoning_effort + if not validated_level: + validated_level = self._validate_thinking_level( + self.valves.THINKING_LEVEL + ) + if validated_level: + source = "THINKING_LEVEL" + + if validated_level: + thinking_config_params["thinking_level"] = validated_level + self.log.debug( + f"Using thinking_level='{validated_level}' from {source} for model {model_id}" + ) + else: + self.log.debug( + f"Using default thinking level for model {model_id}" + ) + else: + # For non-Gemini 3 models (e.g., Gemini 2.5), use thinking_budget + # Body-level thinking_budget overrides environment-level THINKING_BUDGET + body_thinking_budget = body.get("thinking_budget") + validated_budget = None + source = None + + if body_thinking_budget is not None: + validated_budget = self._validate_thinking_budget( + body_thinking_budget + ) + if validated_budget is not None: + source = "body thinking_budget" + else: + self.log.debug( + f"Invalid body thinking_budget '{body_thinking_budget}', falling back to THINKING_BUDGET" + ) + + # Fall back to environment-level THINKING_BUDGET + if validated_budget is None: + validated_budget = self._validate_thinking_budget( + self.valves.THINKING_BUDGET + ) + if validated_budget is not None: + source = "THINKING_BUDGET" + + if validated_budget == 0: + # Disable thinking if budget is 0 + thinking_config_params["thinking_budget"] = 0 + self.log.debug( + f"Thinking disabled via thinking_budget=0 from {source} for model {model_id}" + ) + elif validated_budget is not None and validated_budget > 0: + thinking_config_params["thinking_budget"] = validated_budget + self.log.debug( + f"Using thinking_budget={validated_budget} from {source} for model {model_id}" + ) + else: + # -1 or None means dynamic thinking + thinking_config_params["thinking_budget"] = -1 + self.log.debug( + f"Using dynamic thinking (model decides) for model {model_id}" + ) + + gen_config_params["thinking_config"] = types.ThinkingConfig( + **thinking_config_params + ) + except (AttributeError, TypeError) as e: + # Fall back if SDK/model does not support ThinkingConfig + self.log.debug(f"ThinkingConfig not supported: {e}") + except Exception as e: + # Log unexpected errors but continue without thinking config + self.log.warning(f"Unexpected error configuring ThinkingConfig: {e}") + + # Configure safety settings + if self.valves.USE_PERMISSIVE_SAFETY: + safety_settings = [ + types.SafetySetting( + category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_NONE" + ), + types.SafetySetting( + category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_NONE" + ), + types.SafetySetting( + category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="BLOCK_NONE" + ), + types.SafetySetting( + category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_NONE" + ), + ] + gen_config_params |= {"safety_settings": safety_settings} + + # Add various tools to Gemini as required + features = __metadata__.get("features", {}) + params = __metadata__.get("params", {}) + tools = [] + + if features.get("google_search_tool", False): + self.log.debug("Enabling Google search grounding") + tools.append(types.Tool(google_search=types.GoogleSearch())) + self.log.debug("Enabling URL context grounding") + tools.append(types.Tool(url_context=types.UrlContext())) + + if features.get("vertex_ai_search", False) or ( + self.valves.USE_VERTEX_AI + and (self.valves.VERTEX_AI_RAG_STORE or os.getenv("VERTEX_AI_RAG_STORE")) + ): + vertex_rag_store = ( + params.get("vertex_rag_store") + or self.valves.VERTEX_AI_RAG_STORE + or os.getenv("VERTEX_AI_RAG_STORE") + ) + if vertex_rag_store: + self.log.debug( + f"Enabling Vertex AI Search grounding: {vertex_rag_store}" + ) + tools.append( + types.Tool( + retrieval=types.Retrieval( + vertex_ai_search=types.VertexAISearch( + datastore=vertex_rag_store + ) + ) + ) + ) + else: + self.log.warning( + "Vertex AI Search requested but vertex_rag_store not provided in params, valves, or env" + ) + + if __tools__ is not None and params.get("function_calling") == "native": + for name, tool_def in __tools__.items(): + if not name.startswith("_"): + tool = tool_def["callable"] + self.log.debug( + f"Adding tool '{name}' with signature {tool.__signature__}" + ) + tools.append(tool) + + if tools: + gen_config_params["tools"] = tools + + # Filter out None values for generation config + filtered_params = {k: v for k, v in gen_config_params.items() if v is not None} + return types.GenerateContentConfig(**filtered_params) + + @staticmethod + def _format_grounding_chunks_as_sources( + grounding_chunks: list[types.GroundingChunk], + ): + formatted_sources = [] + for chunk in grounding_chunks: + if hasattr(chunk, "retrieved_context") and chunk.retrieved_context: + context = chunk.retrieved_context + formatted_sources.append( + { + "source": { + "name": getattr(context, "title", None) or "Document", + "type": "vertex_ai_search", + "uri": getattr(context, "uri", None), + }, + "document": [getattr(context, "chunk_text", None) or ""], + "metadata": [ + {"source": getattr(context, "title", None) or "Document"} + ], + } + ) + elif hasattr(chunk, "web") and chunk.web: + context = chunk.web + uri = context.uri + title = context.title or "Source" + + formatted_sources.append( + { + "source": { + "name": title, + "type": "web_search_results", + "url": uri, + }, + "document": ["Click the link to view the content."], + "metadata": [{"source": title}], + } + ) + return formatted_sources + + async def _process_grounding_metadata( + self, + grounding_metadata_list: List[types.GroundingMetadata], + text: str, + __event_emitter__: Callable, + ): + """Process and emit grounding metadata events.""" + grounding_chunks = [] + web_search_queries = [] + grounding_supports = [] + + for metadata in grounding_metadata_list: + if metadata.grounding_chunks: + grounding_chunks.extend(metadata.grounding_chunks) + if metadata.web_search_queries: + web_search_queries.extend(metadata.web_search_queries) + if metadata.grounding_supports: + grounding_supports.extend(metadata.grounding_supports) + + # Add sources to the response + if grounding_chunks: + sources = self._format_grounding_chunks_as_sources(grounding_chunks) + await __event_emitter__( + {"type": "chat:completion", "data": {"sources": sources}} + ) + + # Add status specifying google queries used for grounding + if web_search_queries: + await __event_emitter__( + { + "type": "status", + "data": { + "action": "web_search", + "description": "This response was grounded with Google Search", + "urls": [ + f"https://www.google.com/search?q={query}" + for query in web_search_queries + ], + }, + } + ) + + # Add citations in the text body + replaced_text: Optional[str] = None + if grounding_supports: + # Citation indexes are in bytes + ENCODING = "utf-8" + text_bytes = text.encode(ENCODING) + last_byte_index = 0 + cited_chunks = [] + + for support in grounding_supports: + cited_chunks.append( + text_bytes[last_byte_index : support.segment.end_index].decode( + ENCODING + ) + ) + + # Generate and append citations (e.g., "[1][2]") + footnotes = "".join( + [f"[{i + 1}]" for i in support.grounding_chunk_indices] + ) + cited_chunks.append(f" {footnotes}") + + # Update index for the next segment + last_byte_index = support.segment.end_index + + # Append any remaining text after the last citation + if last_byte_index < len(text_bytes): + cited_chunks.append(text_bytes[last_byte_index:].decode(ENCODING)) + + replaced_text = "".join(cited_chunks) + + return replaced_text if replaced_text is not None else text + + async def _handle_streaming_response( + self, + response_iterator: Any, + __event_emitter__: Callable, + __request__: Optional[Request] = None, + __user__: Optional[dict] = None, + ) -> AsyncIterator[Union[str, Dict[str, Any]]]: + """ + Handle streaming response from Gemini API. + + Args: + response_iterator: Iterator from generate_content + __event_emitter__: Event emitter for status updates + + Returns: + Generator yielding text chunks + """ + + async def emit_chat_event(event_type: str, data: Dict[str, Any]) -> None: + if not __event_emitter__: + return + try: + await __event_emitter__({"type": event_type, "data": data}) + except Exception as emit_error: # pragma: no cover - defensive + self.log.warning(f"Failed to emit {event_type} event: {emit_error}") + + await emit_chat_event("chat:start", {"role": "assistant"}) + + grounding_metadata_list = [] + # Accumulate content separately for answer and thoughts + answer_chunks: list[str] = [] + thought_chunks: list[str] = [] + thinking_started_at: Optional[float] = None + stream_usage_metadata = None + + try: + async for chunk in response_iterator: + # Capture usage metadata (final chunk has complete data) + if getattr(chunk, "usage_metadata", None): + stream_usage_metadata = chunk.usage_metadata + + # Check for safety feedback or empty chunks + if not chunk.candidates: + # Check prompt feedback + if chunk.prompt_feedback and chunk.prompt_feedback.block_reason: + block_reason = chunk.prompt_feedback.block_reason.name + message = f"[Blocked due to Prompt Safety: {block_reason}]" + await emit_chat_event( + "chat:finish", + { + "role": "assistant", + "content": message, + "done": True, + "error": True, + }, + ) + yield message + else: + message = "[Blocked by safety settings]" + await emit_chat_event( + "chat:finish", + { + "role": "assistant", + "content": message, + "done": True, + "error": True, + }, + ) + yield message + return # Stop generation + + if chunk.candidates[0].grounding_metadata: + grounding_metadata_list.append( + chunk.candidates[0].grounding_metadata + ) + # Prefer fine-grained parts to split thoughts vs. normal text + parts = [] + try: + parts = chunk.candidates[0].content.parts or [] + except Exception as parts_error: + # Fallback: use aggregated text if parts aren't accessible + self.log.warning(f"Failed to access content parts: {parts_error}") + if hasattr(chunk, "text") and chunk.text: + answer_chunks.append(chunk.text) + await __event_emitter__( + { + "type": "chat:message:delta", + "data": { + "role": "assistant", + "content": chunk.text, + }, + } + ) + continue + + for part in parts: + try: + # Thought parts (internal reasoning) + if getattr(part, "thought", False) and getattr( + part, "text", None + ): + if thinking_started_at is None: + thinking_started_at = time.time() + thought_chunks.append(part.text) + # Emit a live preview of what is currently being thought + preview = part.text.replace("\n", " ").strip() + MAX_PREVIEW = 120 + if len(preview) > MAX_PREVIEW: + preview = preview[:MAX_PREVIEW].rstrip() + "…" + await __event_emitter__( + { + "type": "status", + "data": { + "action": "thinking", + "description": f"Thinking… {preview}", + "done": False, + "hidden": False, + }, + } + ) + + # Regular answer text + elif getattr(part, "text", None): + answer_chunks.append(part.text) + await __event_emitter__( + { + "type": "chat:message:delta", + "data": { + "role": "assistant", + "content": part.text, + }, + } + ) + except Exception as part_error: + # Log part processing errors but continue with the stream + self.log.warning(f"Error processing content part: {part_error}") + continue + + # After processing all chunks, handle grounding data + final_answer_text = "".join(answer_chunks) + if grounding_metadata_list and __event_emitter__: + cited = await self._process_grounding_metadata( + grounding_metadata_list, + final_answer_text, + __event_emitter__, + ) + final_answer_text = cited or final_answer_text + + final_content = final_answer_text + details_block: Optional[str] = None + + if thought_chunks: + duration_s = int( + max(0, time.time() - (thinking_started_at or time.time())) + ) + # Format each line with > for blockquote while preserving formatting + thought_content = "".join(thought_chunks).strip() + quoted_lines = [] + for line in thought_content.split("\n"): + quoted_lines.append(f"> {line}") + quoted_content = "\n".join(quoted_lines) + + details_block = f"""
+Thought ({duration_s}s) + +{quoted_content} + +
""".strip() + final_content = f"{details_block}{final_answer_text}" + + if not final_content: + final_content = "" + + # Ensure downstream consumers (UI, TTS) receive the complete response once streaming ends. + await emit_chat_event( + "replace", {"role": "assistant", "content": final_content} + ) + await emit_chat_event( + "chat:message", + {"role": "assistant", "content": final_content, "done": True}, + ) + + if thought_chunks: + # Clear the thinking status without a summary in the status emitter + await __event_emitter__( + { + "type": "status", + "data": {"action": "thinking", "done": True, "hidden": True}, + } + ) + + # Yield usage data as dict so the middleware can extract and save it to DB + usage = self._build_usage_dict(stream_usage_metadata) + if usage: + yield {"usage": usage} + + await emit_chat_event( + "chat:finish", + {"role": "assistant", "content": final_content, "done": True}, + ) + + # Yield final content to ensure the async iterator completes properly. + # This ensures the response is persisted even if the user navigates away. + yield final_content + + except Exception as e: + self.log.exception(f"Error during streaming: {e}") + # Check if it's a chunk size error and provide specific guidance + error_msg = str(e).lower() + if "chunk too big" in error_msg or "chunk size" in error_msg: + message = "Error: Image too large for processing. Please try with a smaller image (max 15 MB recommended) or reduce image quality." + elif "quota" in error_msg or "rate limit" in error_msg: + message = "Error: API quota exceeded. Please try again later." + else: + message = f"Error during streaming: {e}" + await emit_chat_event( + "chat:finish", + { + "role": "assistant", + "content": message, + "done": True, + "error": True, + }, + ) + yield message + + @staticmethod + def _build_usage_dict(usage_metadata: Any) -> Optional[Dict[str, int]]: + """Extract token usage from Gemini usage_metadata into a standardised dict.""" + if not usage_metadata: + return None + usage: Dict[str, int] = {} + if getattr(usage_metadata, "prompt_token_count", None) is not None: + usage["prompt_tokens"] = usage_metadata.prompt_token_count + if getattr(usage_metadata, "candidates_token_count", None) is not None: + usage["completion_tokens"] = usage_metadata.candidates_token_count + if usage: + usage["total_tokens"] = usage.get("prompt_tokens", 0) + usage.get( + "completion_tokens", 0 + ) + return usage + return None + + def _get_safety_block_message(self, response: Any) -> Optional[str]: + """Check for safety blocks and return appropriate message.""" + # Check prompt feedback + if response.prompt_feedback and response.prompt_feedback.block_reason: + return f"[Blocked due to Prompt Safety: {response.prompt_feedback.block_reason.name}]" + + # Check candidates + if not response.candidates: + return "[Blocked by safety settings or no candidates generated]" + + # Check candidate finish reason + candidate = response.candidates[0] + if candidate.finish_reason == types.FinishReason.SAFETY: + blocking_rating = next( + (r for r in candidate.safety_ratings if r.blocked), None + ) + reason = f" ({blocking_rating.category.name})" if blocking_rating else "" + return f"[Blocked by safety settings{reason}]" + elif candidate.finish_reason == types.FinishReason.PROHIBITED_CONTENT: + return "[Content blocked due to prohibited content policy violation]" + + return None + + async def _generate_video( + self, + body: Dict[str, Any], + model_id: str, + __event_emitter__: Callable, + __request__: Optional[Request] = None, + __user__: Optional[dict] = None, + ) -> Union[str, Dict[str, Any]]: + """Generate video using Google Veo models (long-running operation with polling).""" + + async def emit_status(description: str, done: bool) -> None: + if not __event_emitter__: + return + try: + await __event_emitter__( + { + "type": "status", + "data": { + "action": "video_generation", + "description": description, + "done": done, + }, + } + ) + except Exception as e: + self.log.warning(f"Failed to emit video status event: {e}") + + messages = body.get("messages", []) + last_user_msg = next( + (m for m in reversed(messages) if m.get("role") == "user"), None + ) + if not last_user_msg: + return "Error: No user message found for video generation" + + prompt, images = await self._extract_images_from_message(last_user_msg) + if not prompt: + return "Error: No prompt provided for video generation" + + # Convert first attached image to types.Image for image-to-video + reference_image = None + if images: + first_img = images[0] + try: + img_data = first_img.get("inline_data", {}) + raw_data = img_data.get("data", "") + img_bytes = base64.b64decode(raw_data) + reference_image = types.Image( + image_bytes=img_bytes, + mime_type=img_data.get("mime_type", "image/png"), + ) + self.log.debug("Using attached image for image-to-video generation") + except Exception as e: + self.log.warning(f"Failed to convert image for Veo: {e}") + + config = self._build_video_generation_config(body, __user__, model_id=model_id) + + await emit_status(f"Starting video generation with {model_id}...", False) + + client = self._get_client() + try: + generate_kwargs: Dict[str, Any] = { + "model": model_id, + "prompt": prompt, + "config": config, + } + if reference_image: + generate_kwargs["image"] = reference_image + operation = await client.aio.models.generate_videos(**generate_kwargs) + except Exception as e: + self.log.exception(f"Video generation request failed: {e}") + await emit_status(f"Video generation failed: {e}", True) + return f"Error starting video generation: {e}" + + poll_interval = max(self.valves.VIDEO_POLL_INTERVAL, 5) + poll_timeout = max(self.valves.VIDEO_POLL_TIMEOUT, 0) + elapsed = 0 + while not operation.done: + await asyncio.sleep(poll_interval) + elapsed += poll_interval + if poll_timeout > 0 and elapsed >= poll_timeout: + error_msg = ( + f"Video generation timed out after {elapsed}s " + f"(limit: {poll_timeout}s)" + ) + self.log.error(error_msg) + await emit_status(error_msg, True) + return f"Error: {error_msg}" + try: + operation = await client.aio.operations.get(operation) + except Exception as e: + self.log.warning(f"Polling error (will retry): {e}") + await emit_status(f"Generating video... ({elapsed}s elapsed)", False) + + if operation.error: + error_msg = str(operation.error) + self.log.error(f"Video generation failed: {error_msg}") + await emit_status(f"Video generation failed: {error_msg}", True) + return f"Video generation failed: {error_msg}" + + generated_videos = [] + response = operation.response + if not response or not response.generated_videos: + return "Error: No videos were generated" + + for idx, gen_video in enumerate(response.generated_videos): + video = gen_video.video + if not video: + self.log.warning(f"Video {idx}: no video object in response") + continue + + self.log.debug( + f"Video {idx}: uri={getattr(video, 'uri', None)}, " + f"name={getattr(video, 'name', None)}, " + f"has_bytes={bool(getattr(video, 'video_bytes', None))}" + ) + + video_bytes = None + if getattr(video, "video_bytes", None): + video_bytes = video.video_bytes + + # Download video bytes via SDK (sync version is more reliable) + if not video_bytes: + try: + await asyncio.to_thread(client.files.download, file=video) + video_bytes = getattr(video, "video_bytes", None) + self.log.debug( + f"Video {idx}: SDK download complete, " + f"has_bytes={bool(video_bytes)}" + ) + except Exception as dl_err: + self.log.warning(f"Video {idx} SDK download failed: {dl_err}") + + # Fallback: save to temp file via SDK + if not video_bytes: + tmp_path = None + try: + import tempfile + + with tempfile.NamedTemporaryFile( + suffix=".mp4", delete=False + ) as tmp: + tmp_path = tmp.name + await asyncio.to_thread(video.save, tmp_path) + async with aiofiles.open(tmp_path, "rb") as f: + video_bytes = await f.read() + self.log.debug( + f"Video {idx}: temp-file download complete, " + f"size={len(video_bytes)} bytes" + ) + except Exception as save_err: + self.log.warning(f"Video {idx} temp-file save failed: {save_err}") + finally: + if tmp_path: + try: + os.unlink(tmp_path) + except OSError: + pass + + if not video_bytes: + self.log.warning(f"Video {idx}: could not obtain video bytes") + continue + + mime_type = getattr(video, "mime_type", "video/mp4") or "video/mp4" + + file_id = None + video_url = None + if __request__ and __user__: + video_url, file_id = await self._upload_video_with_status( + video_bytes, mime_type, __request__, __user__, __event_emitter__ + ) + else: + video_data_b64 = base64.b64encode(video_bytes).decode("utf-8") + video_url = f"data:{mime_type};base64,{video_data_b64}" + + # Wrap in
so marked.lexer recognizes it as block-level HTML; + # HTMLToken.svelte then detects the inner