diff --git a/src/crates/core/src/agentic/agents/agentic_mode.rs b/src/crates/core/src/agentic/agents/agentic_mode.rs index dc82dbed..9510e811 100644 --- a/src/crates/core/src/agentic/agents/agentic_mode.rs +++ b/src/crates/core/src/agentic/agents/agentic_mode.rs @@ -55,6 +55,15 @@ impl Agent for AgenticMode { "agentic_mode" } + fn prompt_template_name_for_model(&self, model_name: Option<&str>) -> Option<&str> { + let model_name = model_name?.trim().to_ascii_lowercase(); + if model_name.contains("gpt-5") { + Some("agentic_mode_gpt5") + } else { + None + } + } + fn default_tools(&self) -> Vec { self.default_tools.clone() } @@ -63,3 +72,31 @@ impl Agent for AgenticMode { false } } + +#[cfg(test)] +mod tests { + use super::{Agent, AgenticMode}; + + #[test] + fn selects_gpt5_prompt_template() { + let agent = AgenticMode::new(); + assert_eq!( + agent.prompt_template_name_for_model(Some("gpt-5.1")), + Some("agentic_mode_gpt5") + ); + assert_eq!( + agent.prompt_template_name_for_model(Some("GPT-5-CODEX")), + Some("agentic_mode_gpt5") + ); + } + + #[test] + fn keeps_default_template_for_non_gpt5_models() { + let agent = AgenticMode::new(); + assert_eq!( + agent.prompt_template_name_for_model(Some("claude-sonnet-4")), + None + ); + assert_eq!(agent.prompt_template_name_for_model(None), None); + } +} diff --git a/src/crates/core/src/agentic/agents/mod.rs b/src/crates/core/src/agentic/agents/mod.rs index c75626ae..3cecf41d 100644 --- a/src/crates/core/src/agentic/agents/mod.rs +++ b/src/crates/core/src/agentic/agents/mod.rs @@ -57,6 +57,11 @@ pub trait Agent: Send + Sync + 'static { /// Prompt template name for the agent fn prompt_template_name(&self) -> &str; + /// Prompt template name override for a specific model. + fn prompt_template_name_for_model(&self, _model_name: Option<&str>) -> Option<&str> { + None + } + fn system_reminder_template_name(&self) -> Option<&str> { None // by default, no system reminder } @@ -89,6 +94,30 @@ pub trait Agent: Send + Sync + 'static { } } + /// Get the system prompt for this agent with optional model-aware template selection. + async fn get_system_prompt_for_model( + &self, + workspace_path: Option<&str>, + model_name: Option<&str>, + ) -> BitFunResult { + let Some(workspace_path) = workspace_path else { + return Err(BitFunError::Agent("Workspace path is required".to_string())); + }; + + let Some(template_name) = self.prompt_template_name_for_model(model_name) else { + return self.build_prompt(workspace_path).await; + }; + + let prompt_components = PromptBuilder::new(workspace_path); + let system_prompt_template = get_embedded_prompt(template_name).ok_or_else(|| { + BitFunError::Agent(format!("{} not found in embedded files", template_name)) + })?; + + prompt_components + .build_prompt_from_template(system_prompt_template) + .await + } + /// Get the system reminder for this agent, only used for modes /// system_reminder will be appended to the user_query /// This is not necessary for all modes diff --git a/src/crates/core/src/agentic/agents/prompts/agentic_mode_gpt5.md b/src/crates/core/src/agentic/agents/prompts/agentic_mode_gpt5.md new file mode 100644 index 00000000..831fad2a --- /dev/null +++ b/src/crates/core/src/agentic/agents/prompts/agentic_mode_gpt5.md @@ -0,0 +1,71 @@ +You are BitFun, an ADE (AI IDE) that helps users with software engineering tasks. + +You are pair programming with a USER. Each user message may include extra IDE context, such as open files, cursor position, recent files, edit history, or linter errors. Use what is relevant and ignore what is not. + +Follow the USER's instructions in each message, denoted by the tag. + +Tool results and user messages may include tags. Follow them, but do not mention them to the user. + +IMPORTANT: Assist with defensive security tasks only. Refuse to create, modify, or improve code that may be used maliciously. Do not assist with credential discovery or harvesting, including bulk crawling for SSH keys, browser cookies, or cryptocurrency wallets. Allow security analysis, detection rules, vulnerability explanations, defensive tools, and security documentation. + +IMPORTANT: Never generate or guess URLs for the user unless you are confident they directly help with the programming task. You may use URLs provided by the user or found in local files. + +{LANGUAGE_PREFERENCE} +{VISUAL_MODE} + +# Behavior +- Be concise, direct, and action-oriented. +- Default to doing the work instead of discussing it. +- Read relevant code before editing it. +- Prioritize technical accuracy over agreement. +- Never give time estimates. + +# Editing +- Prefer editing existing files over creating new ones. +- Default to ASCII unless the file already uses non-ASCII and there is a clear reason. +- Add comments only when needed for non-obvious logic. +- Avoid unrelated refactors, speculative abstractions, and unnecessary compatibility shims. +- Do not add features or improvements beyond the request unless required to make the requested change work. +- Do not introduce security issues such as command injection, XSS, SQL injection, path traversal, or unsafe shell handling. + +# Tools +- Use TodoWrite for non-trivial or multi-step tasks, and keep it updated. +- Use AskUserQuestion only when a decision materially changes the result and cannot be inferred safely. +- Prefer Task with Explore or FileFinder for open-ended codebase exploration. +- Prefer Read, Grep, and Glob for targeted lookups. +- Prefer specialized file tools over Bash for reading and editing files. +- Use Bash for builds, tests, git, and scripts. +- Run independent tool calls in parallel when possible. +- Do not use tools to communicate with the user. + +# Questions +- Ask only when you are truly blocked and cannot safely choose a reasonable default. +- If you must ask, do all non-blocked work first, then ask exactly one targeted question with a recommended default. + +# Workspace +- Never revert user changes unless explicitly requested. +- Work with existing changes in touched files instead of discarding them. +- Do not amend commits unless explicitly requested. +- Never use destructive commands like git reset --hard or git checkout -- unless explicitly requested or approved. + +# Responses +- Keep responses short, useful, and technically precise. +- Avoid unnecessary praise, emotional validation, or emojis. +- Summarize meaningful command results instead of pasting raw output. +- Do not tell the user to save or copy files. + +# Code references +- Use clickable markdown links for files and code locations. +- Use bare filenames as link text. +- Use workspace-relative paths for workspace files and absolute paths otherwise. + +Examples: +- [filename.ts](src/filename.ts) +- [filename.ts:42](src/filename.ts#L42) +- [filename.ts:42-51](src/filename.ts#L42-L51) + +{ENV_INFO} +{PROJECT_LAYOUT} +{RULES} +{MEMORIES} +{PROJECT_CONTEXT_FILES:exclude=review} \ No newline at end of file diff --git a/src/crates/core/src/agentic/execution/execution_engine.rs b/src/crates/core/src/agentic/execution/execution_engine.rs index 26d50b11..13d717d8 100644 --- a/src/crates/core/src/agentic/execution/execution_engine.rs +++ b/src/crates/core/src/agentic/execution/execution_engine.rs @@ -391,16 +391,51 @@ impl ExecutionEngine { current_agent.id() ); - // 2. Get System Prompt from current Agent + // 2. Get AI client + // Get model ID from AgentRegistry + let model_id = agent_registry + .get_model_id_for_agent(&agent_type) + .await + .map_err(|e| BitFunError::AIClient(format!("Failed to get model ID: {}", e)))?; + info!( + "Agent using model: agent={}, model_id={}", + current_agent.name(), + model_id + ); + + let ai_client_factory = get_global_ai_client_factory().await.map_err(|e| { + BitFunError::AIClient(format!("Failed to get AI client factory: {}", e)) + })?; + + // Get AI client by model ID + let ai_client = ai_client_factory + .get_client_resolved(&model_id) + .await + .map_err(|e| { + BitFunError::AIClient(format!( + "Failed to get AI client (model_id={}): {}", + model_id, e + )) + })?; + // Get configuration for whether to support preserving historical thinking content + let enable_thinking = ai_client.config.enable_thinking_process; + let support_preserved_thinking = ai_client.config.support_preserved_thinking; + let context_window = ai_client.config.context_window as usize; + + // 3. Get System Prompt from current Agent debug!( - "Building system prompt from agent: {}", - current_agent.name() + "Building system prompt from agent: {}, model={}", + current_agent.name(), + ai_client.config.model ); let system_prompt = { let workspace_path = get_workspace_path(); let workspace_str = workspace_path.as_ref().map(|p| p.display().to_string()); current_agent - .get_system_prompt(workspace_str.as_deref()) + .get_system_prompt_for_model( + workspace_str.as_deref(), + Some(ai_client.config.model.as_str()), + ) .await? }; debug!("System prompt built, length: {} bytes", system_prompt.len()); @@ -436,7 +471,7 @@ impl ExecutionEngine { .collect::>() ); - // 3. Get available tools list (read tool configuration for current mode from global config) + // 4. Get available tools list (read tool configuration for current mode from global config) let allowed_tools = agent_registry.get_agent_tools(&agent_type).await; let enable_tools = context .context @@ -465,37 +500,6 @@ impl ExecutionEngine { let enable_context_compression = session.config.enable_context_compression; let compression_threshold = session.config.compression_threshold; - // 4. Get AI client - // Get model ID from AgentRegistry - let model_id = agent_registry - .get_model_id_for_agent(&agent_type) - .await - .map_err(|e| BitFunError::AIClient(format!("Failed to get model ID: {}", e)))?; - info!( - "Agent using model: agent={}, model_id={}", - current_agent.name(), - model_id - ); - - let ai_client_factory = get_global_ai_client_factory().await.map_err(|e| { - BitFunError::AIClient(format!("Failed to get AI client factory: {}", e)) - })?; - - // Get AI client by model ID - let ai_client = ai_client_factory - .get_client_resolved(&model_id) - .await - .map_err(|e| { - BitFunError::AIClient(format!( - "Failed to get AI client (model_id={}): {}", - model_id, e - )) - })?; - // Get configuration for whether to support preserving historical thinking content - let enable_thinking = ai_client.config.enable_thinking_process; - let support_preserved_thinking = ai_client.config.support_preserved_thinking; - let context_window = ai_client.config.context_window as usize; - // Detect whether the primary model supports multimodal image inputs. // This is used by tools like `view_image` to decide between: // - attaching image content for the primary model to analyze directly, or diff --git a/src/crates/core/src/agentic/execution/round_executor.rs b/src/crates/core/src/agentic/execution/round_executor.rs index b136573a..09cd6dc9 100644 --- a/src/crates/core/src/agentic/execution/round_executor.rs +++ b/src/crates/core/src/agentic/execution/round_executor.rs @@ -308,6 +308,7 @@ impl RoundExecutor { has_more_rounds: false, finish_reason: FinishReason::Complete, usage: stream_result.usage.clone(), + provider_metadata: stream_result.provider_metadata.clone(), }); } @@ -525,6 +526,7 @@ impl RoundExecutor { FinishReason::Complete }, usage: stream_result.usage.clone(), + provider_metadata: stream_result.provider_metadata.clone(), }) } diff --git a/src/crates/core/src/agentic/execution/stream_processor.rs b/src/crates/core/src/agentic/execution/stream_processor.rs index 4acf7c2e..05c589cc 100644 --- a/src/crates/core/src/agentic/execution/stream_processor.rs +++ b/src/crates/core/src/agentic/execution/stream_processor.rs @@ -15,7 +15,7 @@ use crate::util::JsonChecker; use ai_stream_handlers::UnifiedResponse; use futures::StreamExt; use log::{debug, error, trace}; -use serde_json::json; +use serde_json::{json, Value}; use std::collections::HashSet; use std::sync::Arc; use tokio::sync::mpsc; @@ -173,6 +173,8 @@ pub struct StreamResult { pub tool_calls: Vec, /// Token usage statistics (from model response) pub usage: Option, + /// Provider-specific metadata captured from the stream tail. + pub provider_metadata: Option, /// Whether this stream produced any user-visible output (text/thinking/tool events) pub has_effective_output: bool, } @@ -208,6 +210,7 @@ struct StreamContext { full_text: String, tool_calls: Vec, usage: Option, + provider_metadata: Option, // Current tool call state tool_call_buffer: ToolCallBuffer, @@ -239,6 +242,7 @@ impl StreamContext { full_text: String::new(), tool_calls: Vec::new(), usage: None, + provider_metadata: None, tool_call_buffer: ToolCallBuffer::new(), text_chunks_count: 0, thinking_chunks_count: 0, @@ -255,6 +259,7 @@ impl StreamContext { full_text: self.full_text, tool_calls: self.tool_calls, usage: self.usage, + provider_metadata: self.provider_metadata, has_effective_output: self.has_effective_output, } } @@ -282,6 +287,20 @@ impl StreamProcessor { Self { event_queue } } + fn merge_json_value(target: &mut Value, overlay: Value) { + match (target, overlay) { + (Value::Object(target_map), Value::Object(overlay_map)) => { + for (key, value) in overlay_map { + let entry = target_map.entry(key).or_insert(Value::Null); + Self::merge_json_value(entry, value); + } + } + (target_slot, overlay_value) => { + *target_slot = overlay_value; + } + } + } + // ==================== Helper Methods ==================== /// Send thinking end marker (if needed) @@ -433,6 +452,7 @@ impl StreamProcessor { prompt_token_count: response_usage.prompt_token_count, candidates_token_count: response_usage.candidates_token_count, total_token_count: response_usage.total_token_count, + reasoning_token_count: response_usage.reasoning_token_count, cached_content_token_count: response_usage.cached_content_token_count, }); debug!( @@ -453,32 +473,39 @@ impl StreamProcessor { if let Some(tool_id) = tool_call.id { if !tool_id.is_empty() { ctx.has_effective_output = true; - // Clear previous tool_call state - ctx.force_finish_tool_call_buffer(); - - // Normally tool_name should not be empty - let tool_name = tool_call.name.unwrap_or_default(); - debug!("Tool detected: {}", tool_name); - ctx.tool_call_buffer.tool_id = tool_id.clone(); - ctx.tool_call_buffer.tool_name = tool_name.clone(); - ctx.tool_call_buffer.json_checker.reset(); - - // Send early detection event - let _ = self - .event_queue - .enqueue( - AgenticEvent::ToolEvent { - session_id: ctx.session_id.clone(), - turn_id: ctx.dialog_turn_id.clone(), - tool_event: ToolEventData::EarlyDetected { - tool_id: tool_id, - tool_name: tool_name, + // Some providers repeat the tool id on every delta; only treat a new id as a new tool call. + let is_new_tool = ctx.tool_call_buffer.tool_id != tool_id; + if is_new_tool { + // Clear previous tool_call state + ctx.force_finish_tool_call_buffer(); + + // Normally tool_name should not be empty + let tool_name = tool_call.name.unwrap_or_default(); + debug!("Tool detected: {}", tool_name); + ctx.tool_call_buffer.tool_id = tool_id.clone(); + ctx.tool_call_buffer.tool_name = tool_name.clone(); + ctx.tool_call_buffer.json_checker.reset(); + + // Send early detection event + let _ = self + .event_queue + .enqueue( + AgenticEvent::ToolEvent { + session_id: ctx.session_id.clone(), + turn_id: ctx.dialog_turn_id.clone(), + tool_event: ToolEventData::EarlyDetected { + tool_id: tool_id, + tool_name: tool_name, + }, + subagent_parent_info: ctx.event_subagent_parent_info.clone(), }, - subagent_parent_info: ctx.event_subagent_parent_info.clone(), - }, - Some(EventPriority::Normal), - ) - .await; + Some(EventPriority::Normal), + ) + .await; + } else if ctx.tool_call_buffer.tool_name.is_empty() { + // Best-effort: keep name if provider repeats it. + ctx.tool_call_buffer.tool_name = tool_call.name.unwrap_or_default(); + } } } @@ -732,6 +759,13 @@ impl StreamProcessor { self.handle_usage(&mut ctx, response_usage); } + if let Some(provider_metadata) = response.provider_metadata { + match ctx.provider_metadata.as_mut() { + Some(existing) => Self::merge_json_value(existing, provider_metadata), + None => ctx.provider_metadata = Some(provider_metadata), + } + } + // Handle thinking_signature if let Some(signature) = response.thinking_signature { if !signature.is_empty() { diff --git a/src/crates/core/src/agentic/execution/types.rs b/src/crates/core/src/agentic/execution/types.rs index d4fd58ec..026e6612 100644 --- a/src/crates/core/src/agentic/execution/types.rs +++ b/src/crates/core/src/agentic/execution/types.rs @@ -2,6 +2,7 @@ use crate::agentic::core::Message; use crate::agentic::tools::pipeline::SubagentParentInfo; +use serde_json::Value; use std::collections::HashMap; use tokio_util::sync::CancellationToken; @@ -43,6 +44,8 @@ pub struct RoundResult { pub finish_reason: FinishReason, /// Token usage statistics (from model response) pub usage: Option, + /// Provider-specific metadata returned by the model. + pub provider_metadata: Option, } /// Finish reason diff --git a/src/crates/core/src/agentic/image_analysis/types.rs b/src/crates/core/src/agentic/image_analysis/types.rs index 2037495d..f1d520c4 100644 --- a/src/crates/core/src/agentic/image_analysis/types.rs +++ b/src/crates/core/src/agentic/image_analysis/types.rs @@ -117,7 +117,7 @@ impl ImageLimits { /// Get limits based on model provider pub fn for_provider(provider: &str) -> Self { match provider.to_lowercase().as_str() { - "openai" => Self { + "openai" | "response" | "responses" => Self { max_size: 20 * 1024 * 1024, // 20MB max_width: 2048, max_height: 2048, diff --git a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/lib.rs b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/lib.rs index aa8ed3b7..d10cee11 100644 --- a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/lib.rs +++ b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/lib.rs @@ -2,5 +2,7 @@ mod stream_handler; mod types; pub use stream_handler::handle_anthropic_stream; +pub use stream_handler::handle_gemini_stream; pub use stream_handler::handle_openai_stream; +pub use stream_handler::handle_responses_stream; pub use types::unified::{UnifiedResponse, UnifiedTokenUsage, UnifiedToolCall}; diff --git a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/stream_handler/gemini.rs b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/stream_handler/gemini.rs new file mode 100644 index 00000000..395ea7d8 --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/stream_handler/gemini.rs @@ -0,0 +1,248 @@ +use crate::types::gemini::GeminiSSEData; +use crate::types::unified::UnifiedResponse; +use anyhow::{anyhow, Result}; +use eventsource_stream::Eventsource; +use futures::StreamExt; +use log::{error, trace}; +use reqwest::Response; +use serde_json::Value; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::Duration; +use tokio::sync::mpsc; +use tokio::time::timeout; + +static GEMINI_STREAM_ID_SEQ: AtomicU64 = AtomicU64::new(1); + +#[derive(Debug)] +struct GeminiToolCallState { + active_name: Option, + active_id: Option, + stream_id: u64, + next_index: usize, +} + +impl GeminiToolCallState { + fn new() -> Self { + Self { + active_name: None, + active_id: None, + stream_id: GEMINI_STREAM_ID_SEQ.fetch_add(1, Ordering::Relaxed), + next_index: 0, + } + } + + fn on_non_tool_response(&mut self) { + self.active_name = None; + self.active_id = None; + } + + fn assign_id(&mut self, tool_call: &mut crate::types::unified::UnifiedToolCall) { + if let Some(existing_id) = tool_call.id.as_ref().filter(|value| !value.is_empty()) { + self.active_id = Some(existing_id.clone()); + self.active_name = tool_call.name.clone().filter(|value| !value.is_empty()); + return; + } + + let tool_name = tool_call.name.clone().filter(|value| !value.is_empty()); + let is_same_active_call = self.active_id.is_some() && self.active_name == tool_name; + + if is_same_active_call { + tool_call.id = None; + return; + } + + self.next_index += 1; + let generated_id = format!("gemini_call_{}_{}", self.stream_id, self.next_index); + tool_call.id = Some(generated_id.clone()); + self.active_id = Some(generated_id); + self.active_name = tool_name; + } +} + +fn extract_api_error_message(event_json: &Value) -> Option { + let error = event_json.get("error")?; + if let Some(message) = error.get("message").and_then(Value::as_str) { + return Some(message.to_string()); + } + if let Some(message) = error.as_str() { + return Some(message.to_string()); + } + Some("Gemini streaming request failed".to_string()) +} + +pub async fn handle_gemini_stream( + response: Response, + tx_event: mpsc::UnboundedSender>, + tx_raw_sse: Option>, +) { + let mut stream = response.bytes_stream().eventsource(); + let idle_timeout = Duration::from_secs(600); + let mut received_finish_reason = false; + let mut tool_call_state = GeminiToolCallState::new(); + + loop { + let sse_event = timeout(idle_timeout, stream.next()).await; + let sse = match sse_event { + Ok(Some(Ok(sse))) => sse, + Ok(None) => { + if received_finish_reason { + return; + } + let error_msg = "Gemini SSE stream closed before response completed"; + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + Ok(Some(Err(e))) => { + let error_msg = format!("Gemini SSE stream error: {}", e); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + Err(_) => { + let error_msg = format!( + "Gemini SSE stream timeout after {}s", + idle_timeout.as_secs() + ); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + }; + + let raw = sse.data; + trace!("Gemini SSE: {:?}", raw); + + if let Some(ref tx) = tx_raw_sse { + let _ = tx.send(raw.clone()); + } + + if raw == "[DONE]" { + return; + } + + let event_json: Value = match serde_json::from_str(&raw) { + Ok(json) => json, + Err(e) => { + let error_msg = format!("Gemini SSE parsing error: {}, data: {}", e, raw); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + }; + + if let Some(message) = extract_api_error_message(&event_json) { + let error_msg = format!("Gemini SSE API error: {}, data: {}", message, raw); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + + let sse_data: GeminiSSEData = match serde_json::from_value(event_json) { + Ok(data) => data, + Err(e) => { + let error_msg = format!("Gemini SSE data schema error: {}, data: {}", e, raw); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + }; + + let mut unified_responses = sse_data.into_unified_responses(); + for unified_response in &mut unified_responses { + if let Some(tool_call) = unified_response.tool_call.as_mut() { + tool_call_state.assign_id(tool_call); + } else { + tool_call_state.on_non_tool_response(); + } + + if unified_response.finish_reason.is_some() { + received_finish_reason = true; + tool_call_state.on_non_tool_response(); + } + } + + for unified_response in unified_responses { + let _ = tx_event.send(Ok(unified_response)); + } + } +} + +#[cfg(test)] +mod tests { + use super::GeminiToolCallState; + use crate::types::unified::UnifiedToolCall; + + #[test] + fn reuses_active_tool_id_by_omitting_follow_up_ids() { + let mut state = GeminiToolCallState::new(); + + let mut first = UnifiedToolCall { + id: None, + name: Some("get_weather".to_string()), + arguments: Some("{\"city\":".to_string()), + }; + state.assign_id(&mut first); + + let mut second = UnifiedToolCall { + id: None, + name: Some("get_weather".to_string()), + arguments: Some("\"Paris\"}".to_string()), + }; + state.assign_id(&mut second); + + assert!(first + .id + .as_deref() + .is_some_and(|id| id.starts_with("gemini_call_"))); + assert!(second.id.is_none()); + } + + #[test] + fn clears_active_tool_after_non_tool_response() { + let mut state = GeminiToolCallState::new(); + + let mut first = UnifiedToolCall { + id: None, + name: Some("get_weather".to_string()), + arguments: Some("{}".to_string()), + }; + state.assign_id(&mut first); + state.on_non_tool_response(); + + let mut second = UnifiedToolCall { + id: None, + name: Some("get_weather".to_string()), + arguments: Some("{}".to_string()), + }; + state.assign_id(&mut second); + + let first_id = first.id.expect("first id"); + let second_id = second.id.expect("second id"); + assert!(first_id.starts_with("gemini_call_")); + assert!(second_id.starts_with("gemini_call_")); + assert_ne!(first_id, second_id); + } + + #[test] + fn generates_unique_prefixes_across_streams() { + let mut first_state = GeminiToolCallState::new(); + let mut second_state = GeminiToolCallState::new(); + + let mut first = UnifiedToolCall { + id: None, + name: Some("grep".to_string()), + arguments: Some("{}".to_string()), + }; + let mut second = UnifiedToolCall { + id: None, + name: Some("read".to_string()), + arguments: Some("{}".to_string()), + }; + + first_state.assign_id(&mut first); + second_state.assign_id(&mut second); + + assert_ne!(first.id, second.id); + } +} diff --git a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/stream_handler/mod.rs b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/stream_handler/mod.rs index a3f2f220..24e2938a 100644 --- a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/stream_handler/mod.rs +++ b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/stream_handler/mod.rs @@ -1,5 +1,9 @@ mod openai; mod anthropic; +mod responses; +mod gemini; pub use openai::handle_openai_stream; -pub use anthropic::handle_anthropic_stream; \ No newline at end of file +pub use anthropic::handle_anthropic_stream; +pub use responses::handle_responses_stream; +pub use gemini::handle_gemini_stream; diff --git a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/stream_handler/responses.rs b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/stream_handler/responses.rs new file mode 100644 index 00000000..ec2f28ce --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/stream_handler/responses.rs @@ -0,0 +1,548 @@ +use crate::types::responses::{ + parse_responses_output_item, ResponsesCompleted, ResponsesDone, ResponsesStreamEvent, +}; +use crate::types::unified::UnifiedResponse; +use anyhow::{anyhow, Result}; +use eventsource_stream::Eventsource; +use futures::StreamExt; +use log::{error, trace}; +use reqwest::Response; +use serde_json::Value; +use std::collections::HashMap; +use std::time::Duration; +use tokio::sync::mpsc; +use tokio::time::timeout; + +#[derive(Debug, Default, Clone)] +struct InProgressToolCall { + call_id: Option, + name: Option, + args_so_far: String, + saw_any_delta: bool, + sent_header: bool, +} + +impl InProgressToolCall { + fn from_item_value(item: &Value) -> Option { + if item.get("type").and_then(Value::as_str) != Some("function_call") { + return None; + } + Some(Self { + call_id: item + .get("call_id") + .and_then(Value::as_str) + .map(ToString::to_string), + name: item + .get("name") + .and_then(Value::as_str) + .map(ToString::to_string), + args_so_far: String::new(), + saw_any_delta: false, + sent_header: false, + }) + } +} + +fn emit_tool_call_item( + tx_event: &mpsc::UnboundedSender>, + item_value: Value, +) { + if let Some(unified_response) = parse_responses_output_item(item_value) { + if unified_response.tool_call.is_some() { + let _ = tx_event.send(Ok(unified_response)); + } + } +} + +fn cleanup_tool_call_tracking( + output_index: usize, + tool_calls_by_output_index: &mut HashMap, + tool_call_index_by_id: &mut HashMap, +) { + if let Some(tc) = tool_calls_by_output_index.remove(&output_index) { + if let Some(call_id) = tc.call_id { + tool_call_index_by_id.remove(&call_id); + } + } +} + +fn handle_function_call_output_item_done( + tx_event: &mpsc::UnboundedSender>, + event_output_index: Option, + item_value: Value, + tool_calls_by_output_index: &mut HashMap, + tool_call_index_by_id: &mut HashMap, +) { + // Resolve output_index either directly or via call_id mapping. + let output_index = event_output_index.or_else(|| { + item_value + .get("call_id") + .and_then(Value::as_str) + .and_then(|id| tool_call_index_by_id.get(id).copied()) + }); + + let Some(output_index) = output_index else { + emit_tool_call_item(tx_event, item_value); + return; + }; + + let Some(tc) = tool_calls_by_output_index.get_mut(&output_index) else { + // The provider may send `output_item.done` with an output_index even when the + // earlier `output_item.added` event was omitted or missed. Fall back to the full item. + emit_tool_call_item(tx_event, item_value); + return; + }; + + let full_args = item_value + .get("arguments") + .and_then(Value::as_str) + .unwrap_or_default(); + let need_fallback_full = !tc.saw_any_delta; + let need_tail = + tc.saw_any_delta && tc.args_so_far.len() < full_args.len() && full_args.starts_with(&tc.args_so_far); + + if need_fallback_full || need_tail { + let delta = if need_fallback_full { + full_args.to_string() + } else { + full_args[tc.args_so_far.len()..].to_string() + }; + + if !delta.is_empty() { + tc.args_so_far.push_str(&delta); + let (id, name) = if tc.sent_header { + (None, None) + } else { + tc.sent_header = true; + (tc.call_id.clone(), tc.name.clone()) + }; + let _ = tx_event.send(Ok(UnifiedResponse { + tool_call: Some(crate::types::unified::UnifiedToolCall { + id, + name, + arguments: Some(delta), + }), + ..Default::default() + })); + } + } + + cleanup_tool_call_tracking( + output_index, + tool_calls_by_output_index, + tool_call_index_by_id, + ); +} + +fn extract_api_error_message(event_json: &Value) -> Option { + let response = event_json.get("response")?; + let error = response.get("error")?; + + if error.is_null() { + return None; + } + + if let Some(message) = error.get("message").and_then(Value::as_str) { + return Some(message.to_string()); + } + if let Some(message) = error.as_str() { + return Some(message.to_string()); + } + + Some("An error occurred during responses streaming".to_string()) +} + +pub async fn handle_responses_stream( + response: Response, + tx_event: mpsc::UnboundedSender>, + tx_raw_sse: Option>, +) { + let mut stream = response.bytes_stream().eventsource(); + let idle_timeout = Duration::from_secs(600); + // Some providers close the stream after emitting the terminal event and may not send `[DONE]`. + let mut received_finish_reason = false; + let mut received_text_delta = false; + let mut tool_calls_by_output_index: HashMap = HashMap::new(); + let mut tool_call_index_by_id: HashMap = HashMap::new(); + + loop { + let sse_event = timeout(idle_timeout, stream.next()).await; + let sse = match sse_event { + Ok(Some(Ok(sse))) => sse, + Ok(None) => { + if received_finish_reason { + return; + } + let error_msg = "Responses SSE stream closed before response completed"; + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + Ok(Some(Err(e))) => { + let error_msg = format!("Responses SSE stream error: {}", e); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + Err(_) => { + let error_msg = format!( + "Responses SSE stream timeout after {}s", + idle_timeout.as_secs() + ); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + }; + + let raw = sse.data; + trace!("Responses SSE: {:?}", raw); + if let Some(ref tx) = tx_raw_sse { + let _ = tx.send(raw.clone()); + } + if raw == "[DONE]" { + return; + } + + let event_json: Value = match serde_json::from_str(&raw) { + Ok(json) => json, + Err(e) => { + let error_msg = format!("Responses SSE parsing error: {}, data: {}", e, &raw); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + }; + + if let Some(api_error_message) = extract_api_error_message(&event_json) { + let error_msg = format!("Responses SSE API error: {}, data: {}", api_error_message, raw); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + + let event: ResponsesStreamEvent = match serde_json::from_value(event_json) { + Ok(event) => event, + Err(e) => { + let error_msg = format!("Responses SSE schema error: {}, data: {}", e, &raw); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + }; + + match event.kind.as_str() { + "response.output_item.added" => { + // Track tool calls so we can stream arguments via `response.function_call_arguments.delta`. + if let (Some(output_index), Some(item)) = (event.output_index, event.item.as_ref()) { + if let Some(tc) = InProgressToolCall::from_item_value(item) { + if let Some(ref call_id) = tc.call_id { + tool_call_index_by_id.insert(call_id.clone(), output_index); + } + tool_calls_by_output_index.insert(output_index, tc); + } + } + } + "response.output_text.delta" => { + if let Some(delta) = event.delta.filter(|delta| !delta.is_empty()) { + received_text_delta = true; + let _ = tx_event.send(Ok(UnifiedResponse { + text: Some(delta), + ..Default::default() + })); + } + } + "response.reasoning_text.delta" | "response.reasoning_summary_text.delta" => { + if let Some(delta) = event.delta.filter(|delta| !delta.is_empty()) { + let _ = tx_event.send(Ok(UnifiedResponse { + reasoning_content: Some(delta), + ..Default::default() + })); + } + } + "response.function_call_arguments.delta" => { + let Some(delta) = event.delta.filter(|delta| !delta.is_empty()) else { + continue; + }; + let Some(output_index) = event.output_index else { + continue; + }; + let Some(tc) = tool_calls_by_output_index.get_mut(&output_index) else { + continue; + }; + + tc.saw_any_delta = true; + tc.args_so_far.push_str(&delta); + + // Some consumers treat `id` as a "new tool call" marker and reset buffers when it repeats. + // Only send id/name once per tool call; deltas that follow carry arguments only. + let (id, name) = if tc.sent_header { + (None, None) + } else { + tc.sent_header = true; + (tc.call_id.clone(), tc.name.clone()) + }; + + let _ = tx_event.send(Ok(UnifiedResponse { + tool_call: Some(crate::types::unified::UnifiedToolCall { + id, + name, + arguments: Some(delta), + }), + ..Default::default() + })); + } + "response.output_item.done" => { + let Some(item_value) = event.item else { + continue; + }; + + // For tool calls, prefer streaming deltas and only use item.done as a tail-filler / fallback. + if item_value.get("type").and_then(Value::as_str) == Some("function_call") { + handle_function_call_output_item_done( + &tx_event, + event.output_index, + item_value, + &mut tool_calls_by_output_index, + &mut tool_call_index_by_id, + ); + continue; + } + + if let Some(mut unified_response) = parse_responses_output_item(item_value) { + if received_text_delta && unified_response.text.is_some() { + unified_response.text = None; + } + if unified_response.text.is_some() || unified_response.tool_call.is_some() { + let _ = tx_event.send(Ok(unified_response)); + } + } + } + "response.completed" => { + if received_finish_reason { + continue; + } + // Best-effort: use the final response object to fill any missing tool-call argument tail. + if let Some(response_val) = event.response.as_ref() { + if let Some(output) = response_val.get("output").and_then(Value::as_array) { + for (idx, item) in output.iter().enumerate() { + if item.get("type").and_then(Value::as_str) != Some("function_call") { + continue; + } + let Some(tc) = tool_calls_by_output_index.get_mut(&idx) else { + continue; + }; + let full_args = item + .get("arguments") + .and_then(Value::as_str) + .unwrap_or_default(); + if tc.args_so_far.len() < full_args.len() + && full_args.starts_with(&tc.args_so_far) + { + let delta = full_args[tc.args_so_far.len()..].to_string(); + if !delta.is_empty() { + tc.args_so_far.push_str(&delta); + let (id, name) = if tc.sent_header { + (None, None) + } else { + tc.sent_header = true; + (tc.call_id.clone(), tc.name.clone()) + }; + let _ = tx_event.send(Ok(UnifiedResponse { + tool_call: Some(crate::types::unified::UnifiedToolCall { + id, + name, + arguments: Some(delta), + }), + ..Default::default() + })); + } + } + } + } + } + match event.response.map(serde_json::from_value::) { + Some(Ok(response)) => { + received_finish_reason = true; + let _ = tx_event.send(Ok(UnifiedResponse { + usage: response.usage.map(Into::into), + finish_reason: Some("stop".to_string()), + ..Default::default() + })); + continue; + } + Some(Err(e)) => { + let error_msg = format!("Failed to parse response.completed payload: {}", e); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + None => { + received_finish_reason = true; + let _ = tx_event.send(Ok(UnifiedResponse { + finish_reason: Some("stop".to_string()), + ..Default::default() + })); + continue; + } + } + } + "response.done" => { + if received_finish_reason { + continue; + } + match event.response.map(serde_json::from_value::) { + Some(Ok(response)) => { + received_finish_reason = true; + let _ = tx_event.send(Ok(UnifiedResponse { + usage: response.usage.map(Into::into), + finish_reason: Some("stop".to_string()), + ..Default::default() + })); + continue; + } + Some(Err(e)) => { + let error_msg = format!("Failed to parse response.done payload: {}", e); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + None => { + received_finish_reason = true; + let _ = tx_event.send(Ok(UnifiedResponse { + finish_reason: Some("stop".to_string()), + ..Default::default() + })); + continue; + } + } + } + "response.failed" => { + let error_msg = event + .response + .as_ref() + .and_then(|response| response.get("error")) + .and_then(|error| error.get("message")) + .and_then(Value::as_str) + .unwrap_or("Responses API returned response.failed") + .to_string(); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + "response.incomplete" => { + // Prefer returning partial output (rust-genai behavior) instead of hard-failing the round. + // Still mark finish_reason so the caller can decide how to handle it. + if received_finish_reason { + continue; + } + let reason = event + .response + .as_ref() + .and_then(|response| response.get("incomplete_details")) + .and_then(|details| details.get("reason")) + .and_then(Value::as_str) + .map(|s| s.to_string()); + + let finish_reason = reason + .as_deref() + .map(|r| format!("incomplete:{r}")) + .unwrap_or_else(|| "incomplete".to_string()); + + let usage = event + .response + .clone() + .and_then(|v| serde_json::from_value::(v).ok()) + .and_then(|r| r.usage) + .map(Into::into); + + received_finish_reason = true; + let _ = tx_event.send(Ok(UnifiedResponse { + usage, + finish_reason: Some(finish_reason), + ..Default::default() + })); + continue; + } + _ => {} + } + } +} + +#[cfg(test)] +mod tests { + use super::{ + extract_api_error_message, handle_function_call_output_item_done, InProgressToolCall, + }; + use serde_json::json; + use std::collections::HashMap; + use tokio::sync::mpsc; + + #[test] + fn extracts_api_error_message_from_response_error() { + let event = json!({ + "type": "response.failed", + "response": { + "error": { + "message": "provider error" + } + } + }); + + assert_eq!( + extract_api_error_message(&event).as_deref(), + Some("provider error") + ); + } + + #[test] + fn returns_none_when_no_response_error_exists() { + let event = json!({ + "type": "response.created", + "response": { + "id": "resp_1" + } + }); + + assert!(extract_api_error_message(&event).is_none()); + } + + #[test] + fn returns_none_when_response_error_is_null() { + let event = json!({ + "type": "response.created", + "response": { + "id": "resp_1", + "error": null + } + }); + + assert!(extract_api_error_message(&event).is_none()); + } + + #[test] + fn output_item_done_falls_back_when_output_index_is_untracked() { + let (tx_event, mut rx_event) = mpsc::unbounded_channel(); + let mut tool_calls_by_output_index: HashMap = HashMap::new(); + let mut tool_call_index_by_id: HashMap = HashMap::new(); + + handle_function_call_output_item_done( + &tx_event, + Some(3), + json!({ + "type": "function_call", + "call_id": "call_1", + "name": "get_weather", + "arguments": "{\"city\":\"Beijing\"}" + }), + &mut tool_calls_by_output_index, + &mut tool_call_index_by_id, + ); + + let response = rx_event.try_recv().expect("tool call event").expect("ok response"); + let tool_call = response.tool_call.expect("tool call"); + assert_eq!(tool_call.id.as_deref(), Some("call_1")); + assert_eq!(tool_call.name.as_deref(), Some("get_weather")); + assert_eq!(tool_call.arguments.as_deref(), Some("{\"city\":\"Beijing\"}")); + } +} diff --git a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/anthropic.rs b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/anthropic.rs index 049c27bb..4f101ab1 100644 --- a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/anthropic.rs +++ b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/anthropic.rs @@ -64,6 +64,7 @@ impl From for UnifiedTokenUsage { prompt_token_count, candidates_token_count, total_token_count: prompt_token_count + candidates_token_count, + reasoning_token_count: None, cached_content_token_count: match ( value.cache_read_input_tokens, value.cache_creation_input_tokens, @@ -96,6 +97,7 @@ impl From for UnifiedResponse { tool_call: None, usage: value.usage.map(UnifiedTokenUsage::from), finish_reason: value.delta.stop_reason, + provider_metadata: None, } } } diff --git a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/gemini.rs b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/gemini.rs new file mode 100644 index 00000000..3cb810f2 --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/gemini.rs @@ -0,0 +1,697 @@ +use crate::types::unified::{UnifiedResponse, UnifiedTokenUsage, UnifiedToolCall}; +use serde::Deserialize; +use serde_json::{json, Value}; + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GeminiSSEData { + #[serde(default)] + pub candidates: Vec, + #[serde(default)] + pub usage_metadata: Option, + #[serde(default)] + pub prompt_feedback: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GeminiCandidate { + #[serde(default)] + pub content: Option, + #[serde(default)] + pub finish_reason: Option, + #[serde(default)] + pub grounding_metadata: Option, + #[serde(default)] + pub safety_ratings: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GeminiContent { + #[serde(default)] + pub parts: Vec, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GeminiPart { + #[serde(default)] + pub text: Option, + #[serde(default)] + pub thought: Option, + #[serde(default)] + pub thought_signature: Option, + #[serde(default)] + pub function_call: Option, + #[serde(default)] + pub executable_code: Option, + #[serde(default)] + pub code_execution_result: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GeminiFunctionCall { + #[serde(default)] + pub name: Option, + #[serde(default)] + pub args: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GeminiExecutableCode { + #[serde(default)] + pub language: Option, + #[serde(default)] + pub code: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GeminiCodeExecutionResult { + #[serde(default)] + pub outcome: Option, + #[serde(default)] + pub output: Option, +} + +#[derive(Debug, Deserialize, Clone)] +#[serde(rename_all = "camelCase")] +pub struct GeminiUsageMetadata { + #[serde(default)] + pub prompt_token_count: u32, + #[serde(default)] + pub candidates_token_count: u32, + #[serde(default)] + pub total_token_count: u32, + #[serde(default)] + pub thoughts_token_count: Option, + #[serde(default)] + pub cached_content_token_count: Option, +} + +impl From for UnifiedTokenUsage { + fn from(usage: GeminiUsageMetadata) -> Self { + let reasoning_token_count = usage.thoughts_token_count; + let candidates_token_count = usage + .candidates_token_count + .saturating_add(reasoning_token_count.unwrap_or(0)); + Self { + prompt_token_count: usage.prompt_token_count, + candidates_token_count, + total_token_count: usage.total_token_count, + reasoning_token_count, + cached_content_token_count: usage.cached_content_token_count, + } + } +} + +impl GeminiSSEData { + fn render_executable_code(executable_code: &GeminiExecutableCode) -> Option { + let code = executable_code.code.as_deref()?.trim(); + if code.is_empty() { + return None; + } + + let language = executable_code + .language + .as_deref() + .map(|language| language.to_ascii_lowercase()) + .unwrap_or_else(|| "text".to_string()); + + Some(format!( + "Gemini code execution generated code:\n```{}\n{}\n```", + language, code + )) + } + + fn render_code_execution_result(result: &GeminiCodeExecutionResult) -> Option { + let output = result.output.as_deref()?.trim(); + if output.is_empty() { + return None; + } + + let outcome = result.outcome.as_deref().unwrap_or("OUTCOME_UNKNOWN"); + Some(format!( + "Gemini code execution result ({}):\n{}", + outcome, output + )) + } + + fn grounding_summary(metadata: &Value) -> Option { + let mut lines = Vec::new(); + + let queries = metadata + .get("webSearchQueries") + .and_then(Value::as_array) + .map(|queries| { + queries + .iter() + .filter_map(Value::as_str) + .filter(|query| !query.trim().is_empty()) + .collect::>() + }) + .unwrap_or_default(); + + if !queries.is_empty() { + lines.push(format!("Search queries: {}", queries.join(" | "))); + } + + let sources = metadata + .get("groundingChunks") + .and_then(Value::as_array) + .map(|chunks| { + chunks + .iter() + .filter_map(|chunk| { + let web = chunk.get("web")?; + let uri = web.get("uri").and_then(Value::as_str)?.trim(); + if uri.is_empty() { + return None; + } + let title = web + .get("title") + .and_then(Value::as_str) + .map(str::trim) + .filter(|title| !title.is_empty()) + .unwrap_or(uri); + Some((title.to_string(), uri.to_string())) + }) + .collect::>() + }) + .unwrap_or_default(); + + if !sources.is_empty() { + lines.push("Sources:".to_string()); + for (index, (title, uri)) in sources.into_iter().enumerate() { + lines.push(format!("{}. {} - {}", index + 1, title, uri)); + } + } + + let supports = metadata + .get("groundingSupports") + .and_then(Value::as_array) + .map(|supports| { + supports + .iter() + .filter_map(|support| { + let segment_text = support + .get("segment") + .and_then(Value::as_object) + .and_then(|segment| segment.get("text")) + .and_then(Value::as_str) + .map(str::trim) + .filter(|text| !text.is_empty())?; + + let chunk_indices = support + .get("groundingChunkIndices") + .and_then(Value::as_array) + .map(|indices| { + indices + .iter() + .filter_map(Value::as_u64) + .map(|index| (index + 1).to_string()) + .collect::>() + }) + .unwrap_or_default(); + + if chunk_indices.is_empty() { + None + } else { + Some((segment_text.to_string(), chunk_indices.join(", "))) + } + }) + .collect::>() + }) + .unwrap_or_default(); + + if !supports.is_empty() { + lines.push("Citations:".to_string()); + for (segment, indices) in supports.into_iter().take(5) { + lines.push(format!("- \"{}\" -> [{}]", segment, indices)); + } + } + + if lines.is_empty() { + None + } else { + Some(lines.join("\n")) + } + } + + fn safety_summary(prompt_feedback: Option<&Value>, safety_ratings: Option<&Value>) -> Option { + let mut lines = Vec::new(); + + if let Some(prompt_feedback) = prompt_feedback { + if let Some(blocked_reason) = prompt_feedback + .get("blockReason") + .and_then(Value::as_str) + .filter(|reason| !reason.trim().is_empty()) + { + lines.push(format!("Prompt blocked reason: {}", blocked_reason)); + } + + if let Some(block_reason_message) = prompt_feedback + .get("blockReasonMessage") + .and_then(Value::as_str) + .filter(|message| !message.trim().is_empty()) + { + lines.push(format!("Prompt block message: {}", block_reason_message)); + } + } + + let ratings = safety_ratings + .and_then(Value::as_array) + .map(|ratings| { + ratings + .iter() + .filter_map(|rating| { + let category = rating.get("category").and_then(Value::as_str)?; + let probability = rating + .get("probability") + .and_then(Value::as_str) + .unwrap_or("UNKNOWN"); + let blocked = rating + .get("blocked") + .and_then(Value::as_bool) + .unwrap_or(false); + + if blocked || probability != "NEGLIGIBLE" { + Some(format!( + "{} (probability={}, blocked={})", + category, probability, blocked + )) + } else { + None + } + }) + .collect::>() + }) + .unwrap_or_default(); + + if !ratings.is_empty() { + lines.push("Safety ratings:".to_string()); + lines.extend(ratings.into_iter().map(|rating| format!("- {}", rating))); + } + + if lines.is_empty() { + None + } else { + Some(lines.join("\n")) + } + } + + fn provider_metadata_summary(metadata: &Value) -> Option { + let prompt_feedback = metadata.get("promptFeedback"); + let grounding_metadata = metadata.get("groundingMetadata"); + let safety_ratings = metadata.get("safetyRatings"); + + let mut sections = Vec::new(); + if let Some(safety) = Self::safety_summary(prompt_feedback, safety_ratings) { + sections.push(safety); + } + if let Some(grounding) = grounding_metadata.and_then(Self::grounding_summary) { + sections.push(grounding); + } + + if sections.is_empty() { + None + } else { + Some(sections.join("\n\n")) + } + } + + pub fn into_unified_responses(self) -> Vec { + let mut usage = self.usage_metadata.map(Into::into); + let prompt_feedback = self.prompt_feedback; + let Some(candidate) = self.candidates.into_iter().next() else { + return usage + .take() + .map(|usage| { + vec![UnifiedResponse { + usage: Some(usage), + ..Default::default() + }] + }) + .unwrap_or_default(); + }; + + let mut responses = Vec::new(); + let mut finish_reason = candidate.finish_reason; + let grounding_metadata = candidate.grounding_metadata; + let safety_ratings = candidate.safety_ratings; + + if let Some(content) = candidate.content { + for part in content.parts { + let has_function_call = part.function_call.is_some(); + let text = part.text.filter(|text| !text.is_empty()); + let is_thought = part.thought.unwrap_or(false); + let thinking_signature = part.thought_signature.filter(|value| !value.is_empty()); + + if let Some(function_call) = part.function_call { + let arguments = function_call.args.unwrap_or_else(|| json!({})); + responses.push(UnifiedResponse { + text: None, + reasoning_content: None, + thinking_signature, + tool_call: Some(UnifiedToolCall { + id: None, + name: function_call.name, + arguments: serde_json::to_string(&arguments).ok(), + }), + usage: usage.take(), + finish_reason: finish_reason.take(), + provider_metadata: None, + }); + continue; + } + + if let Some(executable_code) = part.executable_code.as_ref() { + if let Some(reasoning_content) = Self::render_executable_code(executable_code) { + responses.push(UnifiedResponse { + text: None, + reasoning_content: Some(reasoning_content), + thinking_signature, + tool_call: None, + usage: usage.take(), + finish_reason: finish_reason.take(), + provider_metadata: None, + }); + continue; + } + } + + if let Some(code_execution_result) = part.code_execution_result.as_ref() { + if let Some(reasoning_content) = + Self::render_code_execution_result(code_execution_result) + { + responses.push(UnifiedResponse { + text: None, + reasoning_content: Some(reasoning_content), + thinking_signature, + tool_call: None, + usage: usage.take(), + finish_reason: finish_reason.take(), + provider_metadata: None, + }); + continue; + } + } + + if let Some(text) = text { + responses.push(UnifiedResponse { + text: if is_thought { None } else { Some(text.clone()) }, + reasoning_content: if is_thought { Some(text) } else { None }, + thinking_signature, + tool_call: None, + usage: usage.take(), + finish_reason: finish_reason.take(), + provider_metadata: None, + }); + continue; + } + + if thinking_signature.is_some() && !has_function_call { + responses.push(UnifiedResponse { + text: None, + reasoning_content: None, + thinking_signature, + tool_call: None, + usage: usage.take(), + finish_reason: finish_reason.take(), + provider_metadata: None, + }); + } + } + } + + let provider_metadata = { + let mut metadata = serde_json::Map::new(); + if let Some(prompt_feedback) = prompt_feedback { + metadata.insert("promptFeedback".to_string(), prompt_feedback); + } + if let Some(grounding_metadata) = grounding_metadata { + metadata.insert("groundingMetadata".to_string(), grounding_metadata); + } + if let Some(safety_ratings) = safety_ratings { + metadata.insert("safetyRatings".to_string(), safety_ratings); + } + + if metadata.is_empty() { + None + } else { + Some(Value::Object(metadata)) + } + }; + + if let Some(provider_metadata) = provider_metadata { + let summary = Self::provider_metadata_summary(&provider_metadata); + responses.push(UnifiedResponse { + text: summary, + reasoning_content: None, + thinking_signature: None, + tool_call: None, + usage: usage.take(), + finish_reason: finish_reason.take(), + provider_metadata: Some(provider_metadata), + }); + } + + if responses.is_empty() { + responses.push(UnifiedResponse { + usage, + finish_reason, + ..Default::default() + }); + } + + responses + } +} + +#[cfg(test)] +mod tests { + use super::GeminiSSEData; + + #[test] + fn converts_text_thought_and_usage() { + let payload = serde_json::json!({ + "candidates": [{ + "content": { + "parts": [ + { "text": "thinking", "thought": true, "thoughtSignature": "sig_1" }, + { "text": "answer" } + ] + }, + "finishReason": "STOP" + }], + "usageMetadata": { + "promptTokenCount": 10, + "candidatesTokenCount": 4, + "thoughtsTokenCount": 2, + "totalTokenCount": 14 + } + }); + + let data: GeminiSSEData = serde_json::from_value(payload).expect("gemini payload"); + let responses = data.into_unified_responses(); + + assert_eq!(responses.len(), 2); + assert_eq!(responses[0].reasoning_content.as_deref(), Some("thinking")); + assert_eq!(responses[0].thinking_signature.as_deref(), Some("sig_1")); + assert_eq!( + responses[0] + .usage + .as_ref() + .and_then(|usage| usage.reasoning_token_count), + Some(2) + ); + assert_eq!( + responses[0] + .usage + .as_ref() + .map(|usage| usage.candidates_token_count), + Some(6) + ); + assert_eq!( + responses[0] + .usage + .as_ref() + .map(|usage| usage.total_token_count), + Some(14) + ); + assert_eq!(responses[1].text.as_deref(), Some("answer")); + } + + #[test] + fn keeps_thought_signature_on_function_call_parts() { + let payload = serde_json::json!({ + "candidates": [{ + "content": { + "parts": [ + { + "thoughtSignature": "sig_tool", + "functionCall": { + "name": "get_weather", + "args": { "city": "Paris" } + } + } + ] + } + }] + }); + + let data: GeminiSSEData = serde_json::from_value(payload).expect("gemini payload"); + let responses = data.into_unified_responses(); + + assert_eq!(responses.len(), 1); + assert_eq!(responses[0].thinking_signature.as_deref(), Some("sig_tool")); + assert_eq!( + responses[0] + .tool_call + .as_ref() + .and_then(|tool_call| tool_call.name.as_deref()), + Some("get_weather") + ); + } + + #[test] + fn keeps_standalone_thought_signature_parts() { + let payload = serde_json::json!({ + "candidates": [{ + "content": { + "parts": [ + { "thoughtSignature": "sig_only" } + ] + } + }] + }); + + let data: GeminiSSEData = serde_json::from_value(payload).expect("gemini payload"); + let responses = data.into_unified_responses(); + + assert_eq!(responses.len(), 1); + assert_eq!(responses[0].thinking_signature.as_deref(), Some("sig_only")); + assert!(responses[0].tool_call.is_none()); + assert!(responses[0].text.is_none()); + assert!(responses[0].reasoning_content.is_none()); + } + + #[test] + fn converts_code_execution_parts_to_reasoning_chunks() { + let payload = serde_json::json!({ + "candidates": [{ + "content": { + "parts": [ + { + "executableCode": { + "language": "PYTHON", + "code": "print(1 + 1)" + } + }, + { + "codeExecutionResult": { + "outcome": "OUTCOME_OK", + "output": "2" + } + } + ] + } + }] + }); + + let data: GeminiSSEData = serde_json::from_value(payload).expect("gemini payload"); + let responses = data.into_unified_responses(); + + assert_eq!(responses.len(), 2); + assert!(responses[0] + .reasoning_content + .as_deref() + .is_some_and(|text| text.contains("print(1 + 1)"))); + assert!(responses[1] + .reasoning_content + .as_deref() + .is_some_and(|text| text.contains("OUTCOME_OK") && text.contains("2"))); + } + + #[test] + fn emits_grounding_summary_and_provider_metadata() { + let payload = serde_json::json!({ + "candidates": [{ + "content": { + "parts": [ + { "text": "answer" } + ] + }, + "groundingMetadata": { + "webSearchQueries": ["latest rust release"], + "groundingChunks": [ + { + "web": { + "uri": "https://www.rust-lang.org", + "title": "Rust" + } + } + ] + } + }] + }); + + let data: GeminiSSEData = serde_json::from_value(payload).expect("gemini payload"); + let responses = data.into_unified_responses(); + + assert_eq!(responses.len(), 2); + assert_eq!(responses[0].text.as_deref(), Some("answer")); + assert!(responses[1] + .text + .as_deref() + .is_some_and(|text| text.contains("Sources:") && text.contains("rust-lang.org"))); + assert!(responses[1] + .provider_metadata + .as_ref() + .and_then(|metadata| metadata.get("groundingMetadata")) + .is_some()); + } + + #[test] + fn emits_prompt_feedback_and_safety_summary() { + let payload = serde_json::json!({ + "candidates": [{ + "content": { "parts": [] }, + "finishReason": "SAFETY", + "safetyRatings": [ + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "MEDIUM", + "blocked": true + } + ] + }], + "promptFeedback": { + "blockReason": "SAFETY", + "blockReasonMessage": "Blocked by safety system" + } + }); + + let data: GeminiSSEData = serde_json::from_value(payload).expect("gemini payload"); + let responses = data.into_unified_responses(); + + assert_eq!(responses.len(), 1); + assert_eq!(responses[0].finish_reason.as_deref(), Some("SAFETY")); + assert!(responses[0] + .text + .as_deref() + .is_some_and(|text| text.contains("Prompt blocked reason: SAFETY"))); + assert!(responses[0] + .text + .as_deref() + .is_some_and(|text| text.contains("HARM_CATEGORY_DANGEROUS_CONTENT"))); + assert!(responses[0] + .provider_metadata + .as_ref() + .and_then(|metadata| metadata.get("promptFeedback")) + .is_some()); + } +} diff --git a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/mod.rs b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/mod.rs index 0463a261..c266edbd 100644 --- a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/mod.rs +++ b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/mod.rs @@ -1,3 +1,5 @@ pub mod unified; pub mod openai; -pub mod anthropic; \ No newline at end of file +pub mod anthropic; +pub mod responses; +pub mod gemini; diff --git a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/openai.rs b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/openai.rs index e584b074..ed2bdedc 100644 --- a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/openai.rs +++ b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/openai.rs @@ -23,6 +23,7 @@ impl From for UnifiedTokenUsage { prompt_token_count: usage.prompt_tokens, candidates_token_count: usage.completion_tokens, total_token_count: usage.total_tokens, + reasoning_token_count: None, cached_content_token_count: usage .prompt_tokens_details .and_then(|prompt_tokens_details| prompt_tokens_details.cached_tokens), @@ -176,6 +177,7 @@ impl OpenAISSEData { tool_call: None, usage: usage.take(), finish_reason: finish_reason.take(), + provider_metadata: None, }); } @@ -193,6 +195,7 @@ impl OpenAISSEData { } else { None }, + provider_metadata: None, }); } } @@ -205,6 +208,7 @@ impl OpenAISSEData { tool_call: None, usage, finish_reason, + provider_metadata: None, }); } diff --git a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/responses.rs b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/responses.rs new file mode 100644 index 00000000..6e8a3e00 --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/responses.rs @@ -0,0 +1,208 @@ +use super::unified::{UnifiedResponse, UnifiedTokenUsage, UnifiedToolCall}; +use serde::Deserialize; +use serde_json::Value; + +#[derive(Debug, Deserialize)] +pub struct ResponsesStreamEvent { + #[serde(rename = "type")] + pub kind: String, + /// Output item index in the `response.output` array. + #[serde(default)] + pub output_index: Option, + /// Content part index within an output item (for content-part events). + #[allow(dead_code)] + #[serde(default)] + pub content_index: Option, + #[serde(default)] + pub response: Option, + #[serde(default)] + pub item: Option, + #[serde(default)] + pub delta: Option, +} + +#[derive(Debug, Deserialize)] +pub struct ResponsesCompleted { + #[allow(dead_code)] + pub id: String, + #[serde(default)] + pub usage: Option, +} + +#[derive(Debug, Deserialize)] +pub struct ResponsesDone { + #[serde(default)] + #[allow(dead_code)] + pub id: Option, + #[serde(default)] + pub usage: Option, +} + +#[derive(Debug, Deserialize)] +pub struct ResponsesUsage { + pub input_tokens: u32, + #[serde(default)] + pub input_tokens_details: Option, + pub output_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Debug, Deserialize)] +pub struct ResponsesInputTokensDetails { + pub cached_tokens: u32, +} + +impl From for UnifiedTokenUsage { + fn from(usage: ResponsesUsage) -> Self { + Self { + prompt_token_count: usage.input_tokens, + candidates_token_count: usage.output_tokens, + total_token_count: usage.total_tokens, + reasoning_token_count: None, + cached_content_token_count: usage + .input_tokens_details + .map(|details| details.cached_tokens), + } + } +} + +pub fn parse_responses_output_item(item_value: Value) -> Option { + let item_type = item_value.get("type")?.as_str()?; + + match item_type { + "function_call" => Some(UnifiedResponse { + text: None, + reasoning_content: None, + thinking_signature: None, + tool_call: Some(UnifiedToolCall { + id: item_value + .get("call_id") + .and_then(Value::as_str) + .map(ToString::to_string), + name: item_value + .get("name") + .and_then(Value::as_str) + .map(ToString::to_string), + arguments: item_value + .get("arguments") + .and_then(Value::as_str) + .map(ToString::to_string), + }), + usage: None, + finish_reason: None, + provider_metadata: None, + }), + "message" => { + let text = item_value + .get("content") + .and_then(Value::as_array) + .map(|content| { + content + .iter() + .filter(|item| { + item.get("type").and_then(Value::as_str) == Some("output_text") + }) + .filter_map(|item| item.get("text").and_then(Value::as_str)) + .collect::() + }) + .filter(|text| !text.is_empty()); + + text.map(|text| UnifiedResponse { + text: Some(text), + reasoning_content: None, + thinking_signature: None, + tool_call: None, + usage: None, + finish_reason: None, + provider_metadata: None, + }) + } + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::{parse_responses_output_item, ResponsesCompleted, ResponsesStreamEvent}; + use serde_json::json; + + #[test] + fn parses_output_text_message_item() { + let response = parse_responses_output_item(json!({ + "type": "message", + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": "hello" + } + ] + })) + .expect("message item"); + + assert_eq!(response.text.as_deref(), Some("hello")); + } + + #[test] + fn parses_function_call_item() { + let response = parse_responses_output_item(json!({ + "type": "function_call", + "call_id": "call_1", + "name": "get_weather", + "arguments": "{\"city\":\"Beijing\"}" + })) + .expect("function call item"); + + let tool_call = response.tool_call.expect("tool call"); + assert_eq!(tool_call.id.as_deref(), Some("call_1")); + assert_eq!(tool_call.name.as_deref(), Some("get_weather")); + } + + #[test] + fn parses_completed_payload_usage() { + let event: ResponsesStreamEvent = serde_json::from_value(json!({ + "type": "response.completed", + "response": { + "id": "resp_1", + "usage": { + "input_tokens": 10, + "input_tokens_details": { "cached_tokens": 2 }, + "output_tokens": 4, + "total_tokens": 14 + } + } + })) + .expect("event"); + + let completed: ResponsesCompleted = serde_json::from_value(event.response.expect("response")) + .expect("completed"); + assert_eq!(completed.id, "resp_1"); + assert_eq!(completed.usage.expect("usage").total_tokens, 14); + } + + #[test] + fn parses_output_item_added_indices() { + let event: ResponsesStreamEvent = serde_json::from_value(json!({ + "type": "response.output_item.added", + "output_index": 3, + "item": { "type": "function_call", "call_id": "call_1", "name": "tool", "arguments": "" } + })) + .expect("event"); + + assert_eq!(event.output_index, Some(3)); + assert!(event.item.is_some()); + } + + #[test] + fn parses_function_call_arguments_delta_indices() { + let event: ResponsesStreamEvent = serde_json::from_value(json!({ + "type": "response.function_call_arguments.delta", + "output_index": 1, + "delta": "{\"a\":" + })) + .expect("event"); + + assert_eq!(event.output_index, Some(1)); + assert_eq!(event.delta.as_deref(), Some("{\"a\":")); + } +} diff --git a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/unified.rs b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/unified.rs index 601fd974..309a3501 100644 --- a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/unified.rs +++ b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/unified.rs @@ -1,4 +1,5 @@ use serde::{Deserialize, Serialize}; +use serde_json::Value; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct UnifiedToolCall { @@ -18,6 +19,8 @@ pub struct UnifiedResponse { pub tool_call: Option, pub usage: Option, pub finish_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub provider_metadata: Option, } impl Default for UnifiedResponse { @@ -29,6 +32,7 @@ impl Default for UnifiedResponse { tool_call: None, usage: None, finish_reason: None, + provider_metadata: None, } } } @@ -40,5 +44,7 @@ pub struct UnifiedTokenUsage { pub candidates_token_count: u32, pub total_token_count: u32, #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_token_count: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub cached_content_token_count: Option, } diff --git a/src/crates/core/src/infrastructure/ai/client.rs b/src/crates/core/src/infrastructure/ai/client.rs index c0c25d84..35d6bcf7 100644 --- a/src/crates/core/src/infrastructure/ai/client.rs +++ b/src/crates/core/src/infrastructure/ai/client.rs @@ -3,11 +3,14 @@ //! Uses a modular architecture to separate provider-specific logic into the providers module use crate::infrastructure::ai::providers::anthropic::AnthropicMessageConverter; +use crate::infrastructure::ai::providers::gemini::GeminiMessageConverter; use crate::infrastructure::ai::providers::openai::OpenAIMessageConverter; use crate::service::config::ProxyConfig; use crate::util::types::*; use crate::util::JsonChecker; -use ai_stream_handlers::{handle_anthropic_stream, handle_openai_stream, UnifiedResponse}; +use ai_stream_handlers::{ + handle_anthropic_stream, handle_gemini_stream, handle_openai_stream, handle_responses_stream, UnifiedResponse, +}; use anyhow::{anyhow, Result}; use futures::StreamExt; use log::{debug, error, info, warn}; @@ -32,7 +35,7 @@ pub struct AIClient { impl AIClient { const TEST_IMAGE_EXPECTED_CODE: &'static str = "BYGR"; const TEST_IMAGE_PNG_BASE64: &'static str = - "iVBORw0KGgoAAAANSUhEUgAAAEAAAABACAIAAAAlC+aJAAAAiklEQVR4nNXZwQkAQQzDQEX995wr4giLpgBj8NMDy6XdOc2XOImTOImTOImTOImTOImTOImTOImTOImTOImTOImTOImTOImTOImTOImTOImTuDm+Bzi+B8gvIHESJ3ESJ3ESJ3ESJ3ESJ3ESJ3ESJ3ESJ3ESJ3ESJ3ESJ3ESJ3ESJ3ESJ3G+LvDXB5LJBXz4d6CTAAAAAElFTkSuQmCC"; + "iVBORw0KGgoAAAANSUhEUgAAAQAAAAEACAIAAADTED8xAAACBklEQVR42u3ZsREAIAwDMYf9dw4txwJupI7Wua+YZEPBfO91h4ZjAgQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABIAAQAAgABAACAAEAAIAAYAAQAAgABAACAAEAAIAAYAAQAAgABAAAAAAAEDRZI3QGf7jDvEPAAIAAYAAQAAgABAACAAEAAIAAYAAQAAgABAACAAEAAIAAYAAQAAgABAACAABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAAAjABAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQALwuLkoG8OSfau4AAAAASUVORK5CYII="; fn image_test_response_matches_expected(response: &str) -> bool { let upper = response.to_ascii_uppercase(); @@ -95,6 +98,38 @@ impl AIClient { color_letter_stream.contains(Self::TEST_IMAGE_EXPECTED_CODE) } + fn is_responses_api_format(api_format: &str) -> bool { + matches!(api_format.to_ascii_lowercase().as_str(), "response" | "responses") + } + + fn build_test_connection_extra_body(&self) -> Option { + let provider = self.config.format.to_ascii_lowercase(); + if !matches!(provider.as_str(), "openai" | "response" | "responses") { + return self.config.custom_request_body.clone(); + } + + let mut extra_body = self + .config + .custom_request_body + .clone() + .unwrap_or_else(|| serde_json::json!({})); + + if let Some(extra_obj) = extra_body.as_object_mut() { + extra_obj + .entry("temperature".to_string()) + .or_insert_with(|| serde_json::json!(0)); + extra_obj + .entry("tool_choice".to_string()) + .or_insert_with(|| serde_json::json!("required")); + } + + Some(extra_body) + } + + fn is_gemini_api_format(api_format: &str) -> bool { + matches!(api_format.to_ascii_lowercase().as_str(), "gemini" | "google") + } + /// Create an AIClient without proxy (backward compatible) pub fn new(config: AIConfig) -> Self { let skip_ssl_verify = config.skip_ssl_verify; @@ -368,6 +403,217 @@ impl AIClient { builder } + /// Apply Gemini-style request headers (merge/replace). + fn apply_gemini_headers( + &self, + mut builder: reqwest::RequestBuilder, + ) -> reqwest::RequestBuilder { + let has_custom_headers = self + .config + .custom_headers + .as_ref() + .map_or(false, |h| !h.is_empty()); + let is_merge_mode = self.is_merge_headers_mode(); + + if has_custom_headers && !is_merge_mode { + return self.apply_custom_headers(builder); + } + + builder = builder + .header("Content-Type", "application/json") + .header("x-goog-api-key", &self.config.api_key); + + if has_custom_headers && is_merge_mode { + builder = self.apply_custom_headers(builder); + } + + builder + } + + fn merge_json_value(target: &mut serde_json::Value, overlay: serde_json::Value) { + match (target, overlay) { + (serde_json::Value::Object(target_map), serde_json::Value::Object(overlay_map)) => { + for (key, value) in overlay_map { + let entry = target_map.entry(key).or_insert(serde_json::Value::Null); + Self::merge_json_value(entry, value); + } + } + (target_slot, overlay_value) => { + *target_slot = overlay_value; + } + } + } + + fn ensure_gemini_generation_config( + request_body: &mut serde_json::Value, + ) -> &mut serde_json::Map { + if !request_body + .get("generationConfig") + .is_some_and(serde_json::Value::is_object) + { + request_body["generationConfig"] = serde_json::json!({}); + } + + request_body["generationConfig"] + .as_object_mut() + .expect("generationConfig must be an object") + } + + fn insert_gemini_generation_field( + request_body: &mut serde_json::Value, + key: &str, + value: serde_json::Value, + ) { + Self::ensure_gemini_generation_config(request_body).insert(key.to_string(), value); + } + + fn normalize_gemini_stop_sequences(value: &serde_json::Value) -> Option { + match value { + serde_json::Value::String(sequence) => Some(serde_json::Value::Array(vec![ + serde_json::Value::String(sequence.clone()), + ])), + serde_json::Value::Array(items) => { + let sequences = items + .iter() + .filter_map(|item| item.as_str().map(|sequence| sequence.to_string())) + .map(serde_json::Value::String) + .collect::>(); + + if sequences.is_empty() { + None + } else { + Some(serde_json::Value::Array(sequences)) + } + } + _ => None, + } + } + + fn apply_gemini_response_format_translation( + request_body: &mut serde_json::Value, + response_format: &serde_json::Value, + ) -> bool { + match response_format { + serde_json::Value::String(kind) if matches!(kind.as_str(), "json" | "json_object") => { + Self::insert_gemini_generation_field( + request_body, + "responseMimeType", + serde_json::Value::String("application/json".to_string()), + ); + true + } + serde_json::Value::Object(map) => { + let Some(kind) = map.get("type").and_then(serde_json::Value::as_str) else { + return false; + }; + + match kind { + "json" | "json_object" => { + Self::insert_gemini_generation_field( + request_body, + "responseMimeType", + serde_json::Value::String("application/json".to_string()), + ); + true + } + "json_schema" => { + Self::insert_gemini_generation_field( + request_body, + "responseMimeType", + serde_json::Value::String("application/json".to_string()), + ); + + if let Some(schema) = map + .get("json_schema") + .and_then(serde_json::Value::as_object) + .and_then(|json_schema| json_schema.get("schema")) + .or_else(|| map.get("schema")) + { + Self::insert_gemini_generation_field( + request_body, + "responseJsonSchema", + GeminiMessageConverter::sanitize_schema(schema.clone()), + ); + } + + true + } + _ => false, + } + } + _ => false, + } + } + + fn translate_gemini_extra_body( + request_body: &mut serde_json::Value, + extra_obj: &mut serde_json::Map, + ) { + if let Some(max_tokens) = extra_obj.remove("max_tokens") { + Self::insert_gemini_generation_field(request_body, "maxOutputTokens", max_tokens); + } + + if let Some(temperature) = extra_obj.remove("temperature") { + Self::insert_gemini_generation_field(request_body, "temperature", temperature); + } + + let top_p = extra_obj.remove("top_p").or_else(|| extra_obj.remove("topP")); + if let Some(top_p) = top_p { + Self::insert_gemini_generation_field(request_body, "topP", top_p); + } + + if let Some(stop_sequences) = extra_obj + .get("stop") + .and_then(Self::normalize_gemini_stop_sequences) + { + extra_obj.remove("stop"); + Self::insert_gemini_generation_field( + request_body, + "stopSequences", + stop_sequences, + ); + } + + if let Some(response_mime_type) = extra_obj + .remove("responseMimeType") + .or_else(|| extra_obj.remove("response_mime_type")) + { + Self::insert_gemini_generation_field( + request_body, + "responseMimeType", + response_mime_type, + ); + } + + if let Some(response_schema) = extra_obj + .remove("responseJsonSchema") + .or_else(|| extra_obj.remove("responseSchema")) + .or_else(|| extra_obj.remove("response_schema")) + { + Self::insert_gemini_generation_field( + request_body, + "responseJsonSchema", + GeminiMessageConverter::sanitize_schema(response_schema), + ); + } + + if let Some(response_format) = extra_obj.get("response_format").cloned() { + if Self::apply_gemini_response_format_translation(request_body, &response_format) { + extra_obj.remove("response_format"); + } + } + } + + fn unified_usage_to_gemini_usage(usage: ai_stream_handlers::UnifiedTokenUsage) -> GeminiUsage { + GeminiUsage { + prompt_token_count: usage.prompt_token_count, + candidates_token_count: usage.candidates_token_count, + total_token_count: usage.total_token_count, + reasoning_token_count: usage.reasoning_token_count, + cached_content_token_count: usage.cached_content_token_count, + } + } + /// Build an OpenAI-format request body fn build_openai_request_body( &self, @@ -435,7 +681,83 @@ impl AIClient { debug!(target: "ai::openai_stream_request", "\ntools: {:?}", tool_names); if !tools.is_empty() { request_body["tools"] = serde_json::Value::Array(tools); - request_body["tool_choice"] = serde_json::Value::String("auto".to_string()); + // Respect `extra_body` overrides (e.g. tool_choice="required") when present. + let has_tool_choice = request_body + .get("tool_choice") + .is_some_and(|v| !v.is_null()); + if !has_tool_choice { + request_body["tool_choice"] = serde_json::Value::String("auto".to_string()); + } + } + } + + request_body + } + + /// Build a Responses API request body. + fn build_responses_request_body( + &self, + instructions: Option, + response_input: Vec, + openai_tools: Option>, + extra_body: Option, + ) -> serde_json::Value { + let mut request_body = serde_json::json!({ + "model": self.config.model, + "input": response_input, + "stream": true + }); + + if let Some(instructions) = instructions.filter(|value| !value.trim().is_empty()) { + request_body["instructions"] = serde_json::Value::String(instructions); + } + + if let Some(max_tokens) = self.config.max_tokens { + request_body["max_output_tokens"] = serde_json::json!(max_tokens); + } + + if let Some(ref effort) = self.config.reasoning_effort { + request_body["reasoning"] = serde_json::json!({ + "effort": effort, + "summary": "auto" + }); + } + + if let Some(extra) = extra_body { + if let Some(extra_obj) = extra.as_object() { + for (key, value) in extra_obj { + request_body[key] = value.clone(); + } + debug!( + target: "ai::responses_stream_request", + "Applied extra_body overrides: {:?}", + extra_obj.keys().collect::>() + ); + } + } + + debug!( + target: "ai::responses_stream_request", + "Responses stream request body (excluding tools):\n{}", + serde_json::to_string_pretty(&request_body) + .unwrap_or_else(|_| "serialization failed".to_string()) + ); + + if let Some(tools) = openai_tools { + let tool_names = tools + .iter() + .map(|tool| Self::extract_openai_tool_name(tool)) + .collect::>(); + debug!(target: "ai::responses_stream_request", "\ntools: {:?}", tool_names); + if !tools.is_empty() { + request_body["tools"] = serde_json::Value::Array(tools); + // Respect `extra_body` overrides (e.g. tool_choice="required") when present. + let has_tool_choice = request_body + .get("tool_choice") + .is_some_and(|v| !v.is_null()); + if !has_tool_choice { + request_body["tool_choice"] = serde_json::Value::String("auto".to_string()); + } } } @@ -508,6 +830,153 @@ impl AIClient { request_body } + /// Build a Gemini-format request body. + fn build_gemini_request_body( + &self, + system_instruction: Option, + contents: Vec, + gemini_tools: Option>, + extra_body: Option, + ) -> serde_json::Value { + let mut request_body = serde_json::json!({ + "contents": contents, + }); + + if let Some(system_instruction) = system_instruction { + request_body["systemInstruction"] = system_instruction; + } + + if let Some(max_tokens) = self.config.max_tokens { + Self::insert_gemini_generation_field( + &mut request_body, + "maxOutputTokens", + serde_json::json!(max_tokens), + ); + } + + if let Some(temperature) = self.config.temperature { + Self::insert_gemini_generation_field( + &mut request_body, + "temperature", + serde_json::json!(temperature), + ); + } + + if let Some(top_p) = self.config.top_p { + Self::insert_gemini_generation_field(&mut request_body, "topP", serde_json::json!(top_p)); + } + + if self.config.enable_thinking_process { + Self::insert_gemini_generation_field(&mut request_body, "thinkingConfig", serde_json::json!({ + "includeThoughts": true, + })); + } + + if let Some(tools) = gemini_tools { + let tool_names = tools + .iter() + .flat_map(|tool| { + if let Some(declarations) = + tool.get("functionDeclarations").and_then(|value| value.as_array()) + { + declarations + .iter() + .filter_map(|declaration| { + declaration + .get("name") + .and_then(|value| value.as_str()) + .map(str::to_string) + }) + .collect::>() + } else { + tool.as_object() + .into_iter() + .flat_map(|map| map.keys().cloned()) + .collect::>() + } + }) + .collect::>(); + debug!(target: "ai::gemini_stream_request", "\ntools: {:?}", tool_names); + + if !tools.is_empty() { + request_body["tools"] = serde_json::Value::Array(tools); + let has_function_declarations = request_body["tools"] + .as_array() + .map(|tools| { + tools.iter() + .any(|tool| tool.get("functionDeclarations").is_some()) + }) + .unwrap_or(false); + + if has_function_declarations { + request_body["toolConfig"] = serde_json::json!({ + "functionCallingConfig": { + "mode": "AUTO" + } + }); + } + } + } + + if let Some(extra) = extra_body { + if let Some(mut extra_obj) = extra.as_object().cloned() { + Self::translate_gemini_extra_body(&mut request_body, &mut extra_obj); + let override_keys = extra_obj.keys().cloned().collect::>(); + + for (key, value) in extra_obj { + if let Some(request_obj) = request_body.as_object_mut() { + let target = request_obj + .entry(key) + .or_insert(serde_json::Value::Null); + Self::merge_json_value(target, value); + } + } + debug!( + target: "ai::gemini_stream_request", + "Applied extra_body overrides: {:?}", + override_keys + ); + } + } + + debug!( + target: "ai::gemini_stream_request", + "Gemini stream request body:\n{}", + serde_json::to_string_pretty(&request_body) + .unwrap_or_else(|_| "serialization failed".to_string()) + ); + + request_body + } + + fn resolve_gemini_request_url(base_url: &str, model_name: &str) -> String { + let trimmed = base_url.trim().trim_end_matches('/'); + if trimmed.is_empty() { + return String::new(); + } + + let mut url = trimmed + .replace(":generateContent", ":streamGenerateContent") + .replace(":streamGenerateContent?alt=sse", ":streamGenerateContent"); + + if !url.contains(":streamGenerateContent") { + if url.contains("/models/") { + url = format!("{}:streamGenerateContent", url); + } else { + let encoded_model = urlencoding::encode(model_name); + url = format!("{}/models/{}:streamGenerateContent", url, encoded_model); + } + } + + if url.contains("alt=sse") { + url + } else if url.contains('?') { + format!("{}&alt=sse", url) + } else { + format!("{}?alt=sse", url) + } + } + fn extract_openai_tool_name(tool: &serde_json::Value) -> String { tool.get("function") .and_then(|f| f.get("name")) @@ -555,6 +1024,14 @@ impl AIClient { self.send_openai_stream(messages, tools, extra_body, max_tries) .await } + format if Self::is_gemini_api_format(format) => { + self.send_gemini_stream(messages, tools, extra_body, max_tries) + .await + } + format if Self::is_responses_api_format(format) => { + self.send_responses_stream(messages, tools, extra_body, max_tries) + .await + } "anthropic" => { self.send_anthropic_stream(messages, tools, extra_body, max_tries) .await @@ -696,6 +1173,251 @@ impl AIClient { Err(anyhow!(error_msg)) } + /// Send a Gemini streaming request with retries. + async fn send_gemini_stream( + &self, + messages: Vec, + tools: Option>, + extra_body: Option, + max_tries: usize, + ) -> Result { + let url = Self::resolve_gemini_request_url(&self.config.request_url, &self.config.model); + debug!( + "Gemini config: model={}, request_url={}, max_tries={}", + self.config.model, url, max_tries + ); + + let (system_instruction, contents) = + GeminiMessageConverter::convert_messages(messages, &self.config.model); + let gemini_tools = GeminiMessageConverter::convert_tools(tools); + let request_body = + self.build_gemini_request_body(system_instruction, contents, gemini_tools, extra_body); + + let mut last_error = None; + let base_wait_time_ms = 500; + + for attempt in 0..max_tries { + let request_start_time = std::time::Instant::now(); + let request_builder = self.apply_gemini_headers(self.client.post(&url)); + let response_result = request_builder.json(&request_body).send().await; + + let response = match response_result { + Ok(resp) => { + let connect_time = request_start_time.elapsed().as_millis(); + let status = resp.status(); + + if status.is_client_error() { + let error_text = resp + .text() + .await + .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); + error!( + "Gemini Streaming API client error {}: {}", + status, error_text + ); + return Err(anyhow!( + "Gemini Streaming API client error {}: {}", + status, + error_text + )); + } + + if status.is_success() { + debug!( + "Gemini stream request connected: {}ms, status: {}, attempt: {}/{}", + connect_time, + status, + attempt + 1, + max_tries + ); + resp + } else { + let error_text = resp + .text() + .await + .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); + let error = + anyhow!("Gemini Streaming API error {}: {}", status, error_text); + warn!( + "Gemini stream request failed: {}ms, attempt {}/{}, error: {}", + connect_time, + attempt + 1, + max_tries, + error + ); + last_error = Some(error); + + if attempt < max_tries - 1 { + let delay_ms = base_wait_time_ms * (1 << attempt.min(3)); + debug!( + "Retrying Gemini after {}ms (attempt {})", + delay_ms, + attempt + 2 + ); + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + } + continue; + } + } + Err(e) => { + let connect_time = request_start_time.elapsed().as_millis(); + let error = anyhow!("Gemini stream request connection failed: {}", e); + warn!( + "Gemini stream request connection failed: {}ms, attempt {}/{}, error: {}", + connect_time, + attempt + 1, + max_tries, + e + ); + last_error = Some(error); + + if attempt < max_tries - 1 { + let delay_ms = base_wait_time_ms * (1 << attempt.min(3)); + debug!( + "Retrying Gemini after {}ms (attempt {})", + delay_ms, + attempt + 2 + ); + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + } + continue; + } + }; + + let (tx, rx) = mpsc::unbounded_channel(); + let (tx_raw, rx_raw) = mpsc::unbounded_channel(); + + tokio::spawn(handle_gemini_stream(response, tx, Some(tx_raw))); + + return Ok(StreamResponse { + stream: Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)), + raw_sse_rx: Some(rx_raw), + }); + } + + let error_msg = format!( + "Gemini stream request failed after {} attempts: {}", + max_tries, + last_error.unwrap_or_else(|| anyhow!("Unknown error")) + ); + error!("{}", error_msg); + Err(anyhow!(error_msg)) + } + + /// Send a Responses API streaming request with retries. + async fn send_responses_stream( + &self, + messages: Vec, + tools: Option>, + extra_body: Option, + max_tries: usize, + ) -> Result { + let url = self.config.request_url.clone(); + debug!( + "Responses config: model={}, request_url={}, max_tries={}", + self.config.model, self.config.request_url, max_tries + ); + + let (instructions, response_input) = + OpenAIMessageConverter::convert_messages_to_responses_input(messages); + let openai_tools = OpenAIMessageConverter::convert_tools(tools); + let request_body = + self.build_responses_request_body(instructions, response_input, openai_tools, extra_body); + + let mut last_error = None; + let base_wait_time_ms = 500; + + for attempt in 0..max_tries { + let request_start_time = std::time::Instant::now(); + let request_builder = self.apply_openai_headers(self.client.post(&url)); + let response_result = request_builder.json(&request_body).send().await; + + let response = match response_result { + Ok(resp) => { + let connect_time = request_start_time.elapsed().as_millis(); + let status = resp.status(); + + if status.is_client_error() { + let error_text = resp + .text() + .await + .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); + error!("Responses API client error {}: {}", status, error_text); + return Err(anyhow!("Responses API client error {}: {}", status, error_text)); + } + + if status.is_success() { + debug!( + "Responses request connected: {}ms, status: {}, attempt: {}/{}", + connect_time, + status, + attempt + 1, + max_tries + ); + resp + } else { + let error_text = resp + .text() + .await + .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); + let error = anyhow!("Responses API error {}: {}", status, error_text); + warn!( + "Responses request failed (attempt {}/{}): {}", + attempt + 1, + max_tries, + error + ); + last_error = Some(error); + + if attempt < max_tries - 1 { + let delay_ms = base_wait_time_ms * (1 << attempt.min(3)); + debug!("Retrying after {}ms (attempt {})", delay_ms, attempt + 2); + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + } + continue; + } + } + Err(e) => { + let connect_time = request_start_time.elapsed().as_millis(); + let error = anyhow!("Responses request connection failed: {}", e); + warn!( + "Responses request connection failed: {}ms, attempt {}/{}, error: {}", + connect_time, + attempt + 1, + max_tries, + e + ); + last_error = Some(error); + + if attempt < max_tries - 1 { + let delay_ms = base_wait_time_ms * (1 << attempt.min(3)); + debug!("Retrying after {}ms (attempt {})", delay_ms, attempt + 2); + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + } + continue; + } + }; + + let (tx, rx) = mpsc::unbounded_channel(); + let (tx_raw, rx_raw) = mpsc::unbounded_channel(); + + tokio::spawn(handle_responses_stream(response, tx, Some(tx_raw))); + + return Ok(StreamResponse { + stream: Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)), + raw_sse_rx: Some(rx_raw), + }); + } + + let error_msg = format!( + "Responses request failed after {} attempts: {}", + max_tries, + last_error.unwrap_or_else(|| anyhow!("Unknown error")) + ); + error!("{}", error_msg); + Err(anyhow!(error_msg)) + } + /// Send an Anthropic streaming request with retries /// /// # Parameters @@ -861,6 +1583,8 @@ impl AIClient { let mut full_text = String::new(); let mut full_reasoning = String::new(); let mut finish_reason = None; + let mut usage = None; + let mut provider_metadata: Option = None; let mut tool_calls: Vec = Vec::new(); let mut cur_tool_call_id = String::new(); @@ -882,13 +1606,36 @@ impl AIClient { finish_reason = Some(finish_reason_); } + if let Some(chunk_usage) = chunk.usage { + usage = Some(Self::unified_usage_to_gemini_usage(chunk_usage)); + } + + if let Some(chunk_provider_metadata) = chunk.provider_metadata { + match provider_metadata.as_mut() { + Some(existing) => { + Self::merge_json_value(existing, chunk_provider_metadata); + } + None => provider_metadata = Some(chunk_provider_metadata), + } + } + if let Some(tool_call) = chunk.tool_call { if let Some(tool_call_id) = tool_call.id { if !tool_call_id.is_empty() { - cur_tool_call_id = tool_call_id; - cur_tool_call_name = tool_call.name.unwrap_or_default(); - json_checker.reset(); - debug!("[send_message] Detected tool call: {}", cur_tool_call_name); + // Some providers repeat the tool id on every delta. Only reset when the id changes. + let is_new_tool = cur_tool_call_id != tool_call_id; + if is_new_tool { + cur_tool_call_id = tool_call_id; + cur_tool_call_name = tool_call.name.unwrap_or_default(); + json_checker.reset(); + debug!( + "[send_message] Detected tool call: {}", + cur_tool_call_name + ); + } else if cur_tool_call_name.is_empty() { + // Best-effort: keep name if provider repeats it. + cur_tool_call_name = tool_call.name.unwrap_or_default(); + } } } @@ -940,8 +1687,9 @@ impl AIClient { text: full_text, reasoning_content, tool_calls: tool_calls_result, - usage: None, + usage, finish_reason, + provider_metadata, }; Ok(response) @@ -950,7 +1698,12 @@ impl AIClient { pub async fn test_connection(&self) -> Result { let start_time = std::time::Instant::now(); - let test_messages = vec![Message::user("What's the weather in Beijing?".to_string())]; + // Force a tool call to avoid false negatives: some models may answer directly when + // `tool_choice=auto`, even if they support tool calls. + let test_messages = vec![Message::user( + "Call the get_weather tool for city=Beijing. Do not answer with plain text." + .to_string(), + )]; let tools = Some(vec![ToolDefinition { name: "get_weather".to_string(), description: "Get the weather of a city".to_string(), @@ -964,7 +1717,16 @@ impl AIClient { }), }]); - match self.send_message(test_messages, tools).await { + let extra_body = self.build_test_connection_extra_body(); + + let result = if extra_body.is_some() { + self.send_message_with_extra_body(test_messages, tools, extra_body) + .await + } else { + self.send_message(test_messages, tools).await + }; + + match result { Ok(response) => { let response_time_ms = start_time.elapsed().as_millis() as u64; if response.tool_calls.is_some() { @@ -979,7 +1741,9 @@ impl AIClient { success: false, response_time_ms, model_response: Some(response.text), - error_details: Some("Model does not support tool calls".to_string()), + error_details: Some( + "Model did not return tool calls (tool_choice=required).".to_string(), + ), }) } } @@ -1080,3 +1844,194 @@ impl AIClient { } } } + +#[cfg(test)] +mod tests { + use super::AIClient; + use crate::infrastructure::ai::providers::gemini::GeminiMessageConverter; + use crate::util::types::{AIConfig, ToolDefinition}; + use serde_json::json; + + fn make_test_client(format: &str, custom_request_body: Option) -> AIClient { + AIClient::new(AIConfig { + name: "test".to_string(), + base_url: "https://example.com/v1".to_string(), + request_url: "https://example.com/v1/chat/completions".to_string(), + api_key: "test-key".to_string(), + model: "test-model".to_string(), + format: format.to_string(), + context_window: 128000, + max_tokens: Some(8192), + temperature: None, + top_p: None, + enable_thinking_process: false, + support_preserved_thinking: false, + custom_headers: None, + custom_headers_mode: None, + skip_ssl_verify: false, + reasoning_effort: None, + custom_request_body, + }) + } + + #[test] + fn build_test_connection_extra_body_merges_custom_body_defaults() { + let client = make_test_client( + "responses", + Some(json!({ + "metadata": { + "source": "test" + } + })), + ); + + let extra_body = client + .build_test_connection_extra_body() + .expect("extra body"); + + assert_eq!(extra_body["metadata"]["source"], "test"); + assert_eq!(extra_body["temperature"], 0); + assert_eq!(extra_body["tool_choice"], "required"); + } + + #[test] + fn build_test_connection_extra_body_preserves_existing_tool_choice() { + let client = make_test_client( + "response", + Some(json!({ + "tool_choice": "auto", + "temperature": 0.3 + })), + ); + + let extra_body = client + .build_test_connection_extra_body() + .expect("extra body"); + + assert_eq!(extra_body["tool_choice"], "auto"); + assert_eq!(extra_body["temperature"], 0.3); + } + + #[test] + fn build_gemini_request_body_translates_response_format_and_merges_generation_config() { + let client = AIClient::new(AIConfig { + name: "gemini".to_string(), + base_url: "https://example.com".to_string(), + request_url: "https://example.com/models/gemini-2.5-pro:streamGenerateContent?alt=sse" + .to_string(), + api_key: "test-key".to_string(), + model: "gemini-2.5-pro".to_string(), + format: "gemini".to_string(), + context_window: 128000, + max_tokens: Some(4096), + temperature: Some(0.2), + top_p: Some(0.8), + enable_thinking_process: true, + support_preserved_thinking: true, + custom_headers: None, + custom_headers_mode: None, + skip_ssl_verify: false, + reasoning_effort: None, + custom_request_body: None, + }); + + let request_body = client.build_gemini_request_body( + None, + vec![json!({ + "role": "user", + "parts": [{ "text": "hello" }] + })], + None, + Some(json!({ + "response_format": { + "type": "json_schema", + "json_schema": { + "schema": { + "type": "object", + "properties": { + "answer": { "type": "string" } + }, + "required": ["answer"], + "additionalProperties": false + } + } + }, + "stop": ["END"], + "generationConfig": { + "candidateCount": 1 + } + })), + ); + + assert_eq!(request_body["generationConfig"]["maxOutputTokens"], 4096); + assert_eq!(request_body["generationConfig"]["temperature"], 0.2); + assert_eq!(request_body["generationConfig"]["topP"], 0.8); + assert_eq!( + request_body["generationConfig"]["thinkingConfig"]["includeThoughts"], + true + ); + assert_eq!( + request_body["generationConfig"]["responseMimeType"], + "application/json" + ); + assert_eq!(request_body["generationConfig"]["candidateCount"], 1); + assert_eq!(request_body["generationConfig"]["stopSequences"], json!(["END"])); + assert_eq!( + request_body["generationConfig"]["responseJsonSchema"]["required"], + json!(["answer"]) + ); + assert!(request_body["generationConfig"]["responseJsonSchema"] + .get("additionalProperties") + .is_none()); + assert!(request_body.get("response_format").is_none()); + assert!(request_body.get("stop").is_none()); + } + + #[test] + fn build_gemini_request_body_omits_function_calling_config_for_native_only_tools() { + let client = AIClient::new(AIConfig { + name: "gemini".to_string(), + base_url: "https://example.com".to_string(), + request_url: "https://example.com/models/gemini-2.5-pro:streamGenerateContent?alt=sse" + .to_string(), + api_key: "test-key".to_string(), + model: "gemini-2.5-pro".to_string(), + format: "gemini".to_string(), + context_window: 128000, + max_tokens: Some(4096), + temperature: None, + top_p: None, + enable_thinking_process: false, + support_preserved_thinking: true, + custom_headers: None, + custom_headers_mode: None, + skip_ssl_verify: false, + reasoning_effort: None, + custom_request_body: None, + }); + + let gemini_tools = GeminiMessageConverter::convert_tools(Some(vec![ToolDefinition { + name: "WebSearch".to_string(), + description: "Search the web".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "query": { "type": "string" } + } + }), + }])); + + let request_body = client.build_gemini_request_body( + None, + vec![json!({ + "role": "user", + "parts": [{ "text": "hello" }] + })], + gemini_tools, + None, + ); + + assert_eq!(request_body["tools"][0]["googleSearch"], json!({})); + assert!(request_body.get("toolConfig").is_none()); + } +} diff --git a/src/crates/core/src/infrastructure/ai/providers/gemini/message_converter.rs b/src/crates/core/src/infrastructure/ai/providers/gemini/message_converter.rs new file mode 100644 index 00000000..70000f97 --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/providers/gemini/message_converter.rs @@ -0,0 +1,902 @@ +//! Gemini message format converter + +use crate::util::types::{Message, ToolDefinition}; +use log::warn; +use serde_json::{json, Map, Value}; + +pub struct GeminiMessageConverter; + +impl GeminiMessageConverter { + pub fn convert_messages(messages: Vec, model_name: &str) -> (Option, Vec) { + let mut system_texts = Vec::new(); + let mut contents = Vec::new(); + let is_gemini_3 = model_name.contains("gemini-3"); + + for msg in messages { + match msg.role.as_str() { + "system" => { + if let Some(content) = msg.content.filter(|content| !content.trim().is_empty()) + { + system_texts.push(content); + } + } + "user" => { + let parts = Self::convert_content_parts(msg.content.as_deref(), false); + Self::push_content(&mut contents, "user", parts); + } + "assistant" => { + let mut parts = Vec::new(); + + let mut pending_thought_signature = msg + .thinking_signature + .filter(|value| !value.trim().is_empty()); + let has_tool_calls = msg + .tool_calls + .as_ref() + .map(|tool_calls| !tool_calls.is_empty()) + .unwrap_or(false); + + if let Some(content) = msg.content.as_deref().filter(|value| !value.trim().is_empty()) { + if !has_tool_calls { + if let Some(signature) = pending_thought_signature.take() { + parts.push(json!({ + "thoughtSignature": signature, + })); + } + } + parts.extend(Self::convert_content_parts(Some(content), true)); + } + + if let Some(tool_calls) = msg.tool_calls { + for (tool_call_index, tool_call) in tool_calls.into_iter().enumerate() { + let mut part = Map::new(); + part.insert( + "functionCall".to_string(), + json!({ + "name": tool_call.name, + "args": tool_call.arguments, + }), + ); + + match pending_thought_signature.take() { + Some(signature) => { + part.insert( + "thoughtSignature".to_string(), + Value::String(signature), + ); + } + None if is_gemini_3 && tool_call_index == 0 => { + part.insert( + "thoughtSignature".to_string(), + Value::String( + "skip_thought_signature_validator".to_string(), + ), + ); + } + None => {} + } + + parts.push(Value::Object(part)); + } + } + + if let Some(signature) = pending_thought_signature { + parts.push(json!({ + "thoughtSignature": signature, + })); + } + + Self::push_content(&mut contents, "model", parts); + } + "tool" => { + let tool_name = msg.name.unwrap_or_default(); + if tool_name.is_empty() { + warn!("Skipping Gemini tool response without tool name"); + continue; + } + + let response = Self::parse_tool_response(msg.content.as_deref()); + let parts = vec![json!({ + "functionResponse": { + "name": tool_name, + "response": response, + } + })]; + + Self::push_content(&mut contents, "user", parts); + } + _ => { + warn!("Unknown Gemini message role: {}", msg.role); + } + } + } + + let system_instruction = if system_texts.is_empty() { + None + } else { + Some(json!({ + "parts": [{ + "text": system_texts.join("\n\n") + }] + })) + }; + + (system_instruction, contents) + } + + pub fn convert_tools(tools: Option>) -> Option> { + tools.and_then(|tool_defs| { + let mut native_tools = Vec::new(); + let mut custom_tools = Vec::new(); + + for tool in tool_defs { + if let Some(native_tool) = Self::convert_native_tool(&tool) { + native_tools.push(native_tool); + } else { + custom_tools.push(tool); + } + } + + // Gemini providers such as AIHubMix reject requests that mix built-in tools + // with custom function declarations. When custom tools are present, keep all + // tools in function-calling mode so BitFun's local tool pipeline still works. + let should_fallback_to_function_calling = + !native_tools.is_empty() && !custom_tools.is_empty(); + + let declarations: Vec = if should_fallback_to_function_calling { + custom_tools + .into_iter() + .chain( + native_tools + .iter() + .cloned() + .filter_map(Self::convert_native_tool_to_custom_definition), + ) + .map(Self::convert_custom_tool) + .collect() + } else { + custom_tools + .into_iter() + .map(Self::convert_custom_tool) + .collect() + }; + + let mut result_tools = if should_fallback_to_function_calling { + Vec::new() + } else { + native_tools + }; + + if !declarations.is_empty() { + result_tools.push(json!({ + "functionDeclarations": declarations, + })); + } + + if result_tools.is_empty() { + None + } else { + Some(result_tools) + } + }) + } + + pub fn sanitize_schema(value: Value) -> Value { + Self::strip_unsupported_schema_fields(value) + } + + fn convert_native_tool(tool: &ToolDefinition) -> Option { + let native_name = Self::native_tool_name(&tool.name)?; + let config = Self::native_tool_config(&tool.parameters); + Some(json!({ + native_name: config, + })) + } + + fn convert_native_tool_to_custom_definition(native_tool: Value) -> Option { + let map = native_tool.as_object()?; + let (name, _config) = map.iter().next()?; + + Some(ToolDefinition { + name: Self::native_tool_fallback_name(name).to_string(), + description: Self::native_tool_fallback_description(name).to_string(), + parameters: Self::native_tool_fallback_schema(name), + }) + } + + fn convert_custom_tool(tool: ToolDefinition) -> Value { + let parameters = Self::sanitize_schema(tool.parameters); + json!({ + "name": tool.name, + "description": tool.description, + "parameters": parameters, + }) + } + + fn native_tool_name(tool_name: &str) -> Option<&'static str> { + match tool_name { + "WebSearch" | "googleSearch" | "GoogleSearch" => Some("googleSearch"), + "WebFetch" | "urlContext" | "UrlContext" | "URLContext" => Some("urlContext"), + "googleSearchRetrieval" | "GoogleSearchRetrieval" => Some("googleSearchRetrieval"), + "codeExecution" | "CodeExecution" => Some("codeExecution"), + _ => None, + } + } + + fn native_tool_fallback_name(native_name: &str) -> &'static str { + match native_name { + "googleSearch" => "WebSearch", + "urlContext" => "WebFetch", + "googleSearchRetrieval" => "googleSearchRetrieval", + "codeExecution" => "codeExecution", + _ => "unknown_native_tool", + } + } + + fn native_tool_fallback_description(native_name: &str) -> &'static str { + match native_name { + "googleSearch" => "Search the web for up-to-date information.", + "urlContext" => "Fetch content from a URL for context.", + "googleSearchRetrieval" => "Retrieve grounded results from Google Search.", + "codeExecution" => "Execute model-generated code and return the result.", + _ => "Gemini native tool fallback.", + } + } + + fn native_tool_fallback_schema(native_name: &str) -> Value { + match native_name { + "googleSearch" | "googleSearchRetrieval" => json!({ + "type": "object", + "properties": { + "query": { + "type": "string", + } + }, + "required": ["query"] + }), + "urlContext" => json!({ + "type": "object", + "properties": { + "url": { + "type": "string", + } + }, + "required": ["url"] + }), + "codeExecution" => json!({ + "type": "object", + "properties": {} + }), + _ => json!({ + "type": "object", + "properties": {} + }), + } + } + + fn native_tool_config(parameters: &Value) -> Value { + if Self::looks_like_schema(parameters) { + json!({}) + } else { + match parameters { + Value::Object(map) if !map.is_empty() => parameters.clone(), + _ => json!({}), + } + } + } + + fn looks_like_schema(parameters: &Value) -> bool { + let Some(map) = parameters.as_object() else { + return false; + }; + + map.contains_key("type") + || map.contains_key("properties") + || map.contains_key("required") + || map.contains_key("$schema") + || map.contains_key("items") + || map.contains_key("allOf") + || map.contains_key("anyOf") + || map.contains_key("oneOf") + || map.contains_key("enum") + || map.contains_key("nullable") + || map.contains_key("format") + } + + fn push_content(contents: &mut Vec, role: &str, parts: Vec) { + if parts.is_empty() { + return; + } + + if let Some(last) = contents.last_mut() { + let last_role = last.get("role").and_then(Value::as_str).unwrap_or_default(); + if last_role == role { + if let Some(existing_parts) = last.get_mut("parts").and_then(Value::as_array_mut) { + existing_parts.extend(parts); + return; + } + } + } + + contents.push(json!({ + "role": role, + "parts": parts, + })); + } + + fn convert_content_parts(content: Option<&str>, is_model_role: bool) -> Vec { + let Some(content) = content else { + return Vec::new(); + }; + + if content.trim().is_empty() { + return Vec::new(); + } + + let parsed = match serde_json::from_str::(content) { + Ok(parsed) if parsed.is_array() => parsed, + _ => return vec![json!({ "text": content })], + }; + + let mut parts = Vec::new(); + + if let Some(items) = parsed.as_array() { + for item in items { + let item_type = item.get("type").and_then(Value::as_str); + match item_type { + Some("text") | Some("input_text") | Some("output_text") => { + if let Some(text) = item.get("text").and_then(Value::as_str) { + if !text.is_empty() { + parts.push(json!({ "text": text })); + } + } + } + Some("image_url") if !is_model_role => { + if let Some(url) = item.get("image_url").and_then(|value| { + value + .get("url") + .and_then(Value::as_str) + .or_else(|| value.as_str()) + }) { + if let Some(part) = Self::convert_image_url_to_part(url) { + parts.push(part); + } + } + } + Some("image") if !is_model_role => { + let source = item.get("source"); + let mime_type = source + .and_then(|value| value.get("media_type")) + .and_then(Value::as_str); + let data = source + .and_then(|value| value.get("data")) + .and_then(Value::as_str); + + if let (Some(mime_type), Some(data)) = (mime_type, data) { + parts.push(json!({ + "inlineData": { + "mimeType": mime_type, + "data": data, + } + })); + } + } + _ => {} + } + } + } + + if parts.is_empty() { + vec![json!({ "text": content })] + } else { + parts + } + } + + fn convert_image_url_to_part(url: &str) -> Option { + let prefix = "data:"; + if !url.starts_with(prefix) { + warn!("Gemini currently supports inline data URLs for image parts; skipping unsupported image URL"); + return None; + } + + let rest = &url[prefix.len()..]; + let (mime_type, data) = rest.split_once(";base64,")?; + if mime_type.is_empty() || data.is_empty() { + return None; + } + + Some(json!({ + "inlineData": { + "mimeType": mime_type, + "data": data, + } + })) + } + + fn parse_tool_response(content: Option<&str>) -> Value { + let Some(content) = content.filter(|value| !value.trim().is_empty()) else { + return json!({ "content": "Tool execution completed" }); + }; + + match serde_json::from_str::(content) { + Ok(Value::Object(map)) => Value::Object(map), + Ok(value) => json!({ "content": value }), + Err(_) => json!({ "content": content }), + } + } + + fn strip_unsupported_schema_fields(value: Value) -> Value { + match value { + Value::Object(mut map) => { + let all_of = map.remove("allOf"); + let any_of = map.remove("anyOf"); + let one_of = map.remove("oneOf"); + let (normalized_type, nullable_from_type) = + Self::normalize_schema_type(map.remove("type")); + + let mut sanitized = Map::new(); + for (key, value) in map { + if key == "properties" { + if let Value::Object(properties) = value { + sanitized.insert( + key, + Value::Object( + properties + .into_iter() + .map(|(name, schema)| { + (name, Self::strip_unsupported_schema_fields(schema)) + }) + .collect(), + ), + ); + } + continue; + } + + if Self::is_supported_schema_key(&key) { + sanitized.insert(key, Self::strip_unsupported_schema_fields(value)); + } + } + + if let Some(all_of) = all_of { + Self::merge_schema_variants(&mut sanitized, all_of, true); + } + + let mut nullable = nullable_from_type; + if let Some(any_of) = any_of { + nullable |= Self::merge_union_variants(&mut sanitized, any_of); + } + if let Some(one_of) = one_of { + nullable |= Self::merge_union_variants(&mut sanitized, one_of); + } + + if let Some(schema_type) = normalized_type { + sanitized.insert("type".to_string(), Value::String(schema_type)); + } + if nullable { + sanitized.insert("nullable".to_string(), Value::Bool(true)); + } + + Value::Object(sanitized) + } + Value::Array(items) => Value::Array( + items + .into_iter() + .map(Self::strip_unsupported_schema_fields) + .collect(), + ), + other => other, + } + } + + fn is_supported_schema_key(key: &str) -> bool { + matches!( + key, + "type" + | "format" + | "description" + | "nullable" + | "enum" + | "items" + | "properties" + | "required" + | "minItems" + | "maxItems" + | "minimum" + | "maximum" + | "minLength" + | "maxLength" + | "pattern" + ) + } + + fn normalize_schema_type(type_value: Option) -> (Option, bool) { + match type_value { + Some(Value::String(value)) if value != "null" => (Some(value), false), + Some(Value::String(_)) => (None, true), + Some(Value::Array(values)) => { + let mut types = values.into_iter().filter_map(|value| value.as_str().map(str::to_string)); + let mut nullable = false; + let mut selected = None; + + for value in types.by_ref() { + if value == "null" { + nullable = true; + } else if selected.is_none() { + selected = Some(value); + } + } + + (selected, nullable) + } + _ => (None, false), + } + } + + fn merge_union_variants(target: &mut Map, variants: Value) -> bool { + let mut nullable = false; + + if let Value::Array(variants) = variants { + for variant in variants { + let sanitized = Self::strip_unsupported_schema_fields(variant); + match sanitized { + Value::Object(map) => { + let is_null_only = map + .get("type") + .and_then(Value::as_str) + .map(|value| value == "null") + .unwrap_or(false) + && map.len() == 1; + + if is_null_only { + nullable = true; + continue; + } + + Self::merge_schema_map(target, map, false); + } + Value::String(value) if value == "null" => nullable = true, + _ => {} + } + } + } + + nullable + } + + fn merge_schema_variants(target: &mut Map, variants: Value, preserve_required: bool) { + if let Value::Array(variants) = variants { + for variant in variants { + if let Value::Object(map) = Self::strip_unsupported_schema_fields(variant) { + Self::merge_schema_map(target, map, preserve_required); + } + } + } + } + + fn merge_schema_map( + target: &mut Map, + source: Map, + preserve_required: bool, + ) { + for (key, value) in source { + match key.as_str() { + "properties" => { + if let Value::Object(source_props) = value { + let target_props = target + .entry(key) + .or_insert_with(|| Value::Object(Map::new())); + if let Value::Object(target_props) = target_props { + for (prop_key, prop_value) in source_props { + target_props.entry(prop_key).or_insert(prop_value); + } + } + } + } + "required" if preserve_required => { + if let Value::Array(source_required) = value { + let target_required = target + .entry(key) + .or_insert_with(|| Value::Array(Vec::new())); + if let Value::Array(target_required) = target_required { + for item in source_required { + if !target_required.contains(&item) { + target_required.push(item); + } + } + } + } + } + "nullable" => { + if value.as_bool().unwrap_or(false) { + target.insert(key, Value::Bool(true)); + } + } + "type" => { + target.entry(key).or_insert(value); + } + _ => { + target.entry(key).or_insert(value); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::GeminiMessageConverter; + use crate::util::types::{Message, ToolCall, ToolDefinition}; + use serde_json::json; + use std::collections::HashMap; + + #[test] + fn converts_messages_to_gemini_format() { + let mut args = HashMap::new(); + args.insert("city".to_string(), json!("Beijing")); + + let messages = vec![ + Message::system("You are helpful".to_string()), + Message::user("Hello".to_string()), + Message { + role: "assistant".to_string(), + content: Some("Working on it".to_string()), + reasoning_content: Some("Let me think".to_string()), + thinking_signature: Some("sig_1".to_string()), + tool_calls: Some(vec![ToolCall { + id: "call_1".to_string(), + name: "get_weather".to_string(), + arguments: args.clone(), + }]), + tool_call_id: None, + name: None, + }, + Message { + role: "tool".to_string(), + content: Some("Sunny".to_string()), + reasoning_content: None, + thinking_signature: None, + tool_calls: None, + tool_call_id: Some("call_1".to_string()), + name: Some("get_weather".to_string()), + }, + ]; + + let (system_instruction, contents) = + GeminiMessageConverter::convert_messages(messages, "gemini-2.5-pro"); + + assert_eq!( + system_instruction.unwrap()["parts"][0]["text"], + json!("You are helpful") + ); + assert_eq!(contents.len(), 3); + assert_eq!(contents[0]["role"], json!("user")); + assert_eq!(contents[1]["role"], json!("model")); + assert_eq!(contents[1]["parts"][0]["text"], json!("Working on it")); + assert_eq!( + contents[1]["parts"][1]["functionCall"]["name"], + json!("get_weather") + ); + assert_eq!(contents[1]["parts"][1]["thoughtSignature"], json!("sig_1")); + assert_eq!( + contents[2]["parts"][0]["functionResponse"]["name"], + json!("get_weather") + ); + } + + #[test] + fn injects_skip_signature_for_first_synthetic_gemini_3_tool_call() { + let mut args = HashMap::new(); + args.insert("city".to_string(), json!("Paris")); + + let messages = vec![Message { + role: "assistant".to_string(), + content: None, + reasoning_content: None, + thinking_signature: None, + tool_calls: Some(vec![ToolCall { + id: "call_1".to_string(), + name: "get_weather".to_string(), + arguments: args, + }]), + tool_call_id: None, + name: None, + }]; + + let (_, contents) = + GeminiMessageConverter::convert_messages(messages, "gemini-3-flash-preview"); + + assert_eq!(contents.len(), 1); + assert_eq!( + contents[0]["parts"][0]["thoughtSignature"], + json!("skip_thought_signature_validator") + ); + } + + #[test] + fn converts_data_url_images_to_inline_data() { + let messages = vec![Message { + role: "user".to_string(), + content: Some( + json!([ + { + "type": "image_url", + "image_url": { + "url": "data:image/png;base64,abc" + } + }, + { + "type": "text", + "text": "Describe this image" + } + ]) + .to_string(), + ), + reasoning_content: None, + thinking_signature: None, + tool_calls: None, + tool_call_id: None, + name: None, + }]; + + let (_, contents) = GeminiMessageConverter::convert_messages(messages, "gemini-2.5-pro"); + + assert_eq!( + contents[0]["parts"][0]["inlineData"]["mimeType"], + json!("image/png") + ); + assert_eq!( + contents[0]["parts"][1]["text"], + json!("Describe this image") + ); + } + + #[test] + fn strips_unsupported_fields_from_tool_schema() { + let tools = Some(vec![ToolDefinition { + name: "get_weather".to_string(), + description: "Get weather".to_string(), + parameters: json!({ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "city": { "type": "string" }, + "timezone": { + "type": ["string", "null"] + }, + "link": { + "anyOf": [ + { + "type": "object", + "properties": { + "url": { "type": "string" } + }, + "required": ["url"] + }, + { "type": "null" } + ] + }, + "items": { + "allOf": [ + { + "type": "object", + "properties": { + "name": { "type": "string" } + }, + "required": ["name"] + }, + { + "type": "object", + "properties": { + "count": { "type": "integer" } + }, + "required": ["count"] + } + ] + } + }, + "required": ["city"], + "additionalProperties": false, + "items": { + "type": "object", + "additionalProperties": false + } + }), + }]); + + let converted = GeminiMessageConverter::convert_tools(tools).expect("converted tools"); + let schema = &converted[0]["functionDeclarations"][0]["parameters"]; + + assert!(schema.get("$schema").is_none()); + assert!(schema.get("additionalProperties").is_none()); + assert!(schema["items"].get("additionalProperties").is_none()); + assert_eq!(schema["properties"]["timezone"]["type"], json!("string")); + assert_eq!(schema["properties"]["timezone"]["nullable"], json!(true)); + assert_eq!(schema["properties"]["link"]["type"], json!("object")); + assert_eq!(schema["properties"]["link"]["nullable"], json!(true)); + assert_eq!(schema["properties"]["items"]["type"], json!("object")); + assert_eq!( + schema["properties"]["items"]["required"], + json!(["name", "count"]) + ); + } + + #[test] + fn maps_web_search_to_native_google_search_tool() { + let tools = Some(vec![ToolDefinition { + name: "WebSearch".to_string(), + description: "Search the web".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "query": { "type": "string" } + }, + "required": ["query"] + }), + }]); + + let converted = GeminiMessageConverter::convert_tools(tools).expect("converted tools"); + assert_eq!(converted.len(), 1); + assert_eq!(converted[0]["googleSearch"], json!({})); + assert!(converted[0].get("functionDeclarations").is_none()); + } + + #[test] + fn falls_back_to_function_declarations_when_native_and_custom_tools_mix() { + let tools = Some(vec![ + ToolDefinition { + name: "WebSearch".to_string(), + description: "Search the web".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "query": { "type": "string" } + } + }), + }, + ToolDefinition { + name: "get_weather".to_string(), + description: "Get weather".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "city": { "type": "string" } + }, + "required": ["city"] + }), + }, + ]); + + let converted = GeminiMessageConverter::convert_tools(tools).expect("converted tools"); + assert_eq!(converted.len(), 1); + assert!(converted[0].get("googleSearch").is_none()); + assert_eq!( + converted[0]["functionDeclarations"][0]["name"], + json!("get_weather") + ); + assert_eq!( + converted[0]["functionDeclarations"][1]["name"], + json!("WebSearch") + ); + } + + #[test] + fn maps_web_fetch_to_native_url_context_tool() { + let tools = Some(vec![ToolDefinition { + name: "WebFetch".to_string(), + description: "Fetch a URL".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "url": { "type": "string" } + }, + "required": ["url"] + }), + }]); + + let converted = GeminiMessageConverter::convert_tools(tools).expect("converted tools"); + assert_eq!(converted.len(), 1); + assert_eq!(converted[0]["urlContext"], json!({})); + } +} diff --git a/src/crates/core/src/infrastructure/ai/providers/gemini/mod.rs b/src/crates/core/src/infrastructure/ai/providers/gemini/mod.rs new file mode 100644 index 00000000..ee6d89d2 --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/providers/gemini/mod.rs @@ -0,0 +1,5 @@ +//! Gemini provider module + +pub mod message_converter; + +pub use message_converter::GeminiMessageConverter; diff --git a/src/crates/core/src/infrastructure/ai/providers/mod.rs b/src/crates/core/src/infrastructure/ai/providers/mod.rs index 61ce45c6..d0e806ae 100644 --- a/src/crates/core/src/infrastructure/ai/providers/mod.rs +++ b/src/crates/core/src/infrastructure/ai/providers/mod.rs @@ -4,6 +4,7 @@ pub mod openai; pub mod anthropic; +pub mod gemini; pub use anthropic::AnthropicMessageConverter; - +pub use gemini::GeminiMessageConverter; diff --git a/src/crates/core/src/infrastructure/ai/providers/openai/message_converter.rs b/src/crates/core/src/infrastructure/ai/providers/openai/message_converter.rs index 7c04e443..0eb1de14 100644 --- a/src/crates/core/src/infrastructure/ai/providers/openai/message_converter.rs +++ b/src/crates/core/src/infrastructure/ai/providers/openai/message_converter.rs @@ -7,12 +7,156 @@ use serde_json::{json, Value}; pub struct OpenAIMessageConverter; impl OpenAIMessageConverter { + pub fn convert_messages_to_responses_input(messages: Vec) -> (Option, Vec) { + let mut instructions = Vec::new(); + let mut input = Vec::new(); + + for msg in messages { + match msg.role.as_str() { + "system" => { + if let Some(content) = msg.content.filter(|content| !content.trim().is_empty()) { + instructions.push(content); + } + } + "tool" => { + if let Some(tool_item) = Self::convert_tool_message_to_responses_item(msg) { + input.push(tool_item); + } + } + "assistant" => { + if let Some(content_items) = Self::convert_message_content_to_responses_items(&msg.role, msg.content.as_deref()) { + input.push(json!({ + "type": "message", + "role": "assistant", + "content": content_items, + })); + } + + if let Some(tool_calls) = msg.tool_calls { + for tool_call in tool_calls { + input.push(json!({ + "type": "function_call", + "call_id": tool_call.id, + "name": tool_call.name, + "arguments": serde_json::to_string(&tool_call.arguments) + .unwrap_or_else(|_| "{}".to_string()), + })); + } + } + } + role => { + if let Some(content_items) = Self::convert_message_content_to_responses_items(role, msg.content.as_deref()) { + input.push(json!({ + "type": "message", + "role": role, + "content": content_items, + })); + } + } + } + } + + let instructions = if instructions.is_empty() { + None + } else { + Some(instructions.join("\n\n")) + }; + + (instructions, input) + } + pub fn convert_messages(messages: Vec) -> Vec { messages.into_iter() .map(Self::convert_single_message) .collect() } + fn convert_tool_message_to_responses_item(msg: Message) -> Option { + let call_id = msg.tool_call_id?; + let output = msg.content.unwrap_or_else(|| "Tool execution completed".to_string()); + + Some(json!({ + "type": "function_call_output", + "call_id": call_id, + "output": output, + })) + } + + fn convert_message_content_to_responses_items(role: &str, content: Option<&str>) -> Option> { + let content = content?; + let text_item_type = Self::responses_text_item_type(role); + + if content.trim().is_empty() { + return Some(vec![json!({ + "type": text_item_type, + "text": " ", + })]); + } + + let parsed = match serde_json::from_str::(content) { + Ok(parsed) if parsed.is_array() => parsed, + _ => { + return Some(vec![json!({ + "type": text_item_type, + "text": content, + })]); + } + }; + + let mut content_items = Vec::new(); + + if let Some(items) = parsed.as_array() { + for item in items { + let item_type = item.get("type").and_then(Value::as_str); + match item_type { + Some("text") | Some("input_text") | Some("output_text") => { + if let Some(text) = item.get("text").and_then(Value::as_str) { + content_items.push(json!({ + "type": text_item_type, + "text": text, + })); + } + } + Some("image_url") if role != "assistant" => { + let image_url = item + .get("image_url") + .and_then(|value| { + value + .get("url") + .and_then(Value::as_str) + .or_else(|| value.as_str()) + }); + + if let Some(image_url) = image_url { + content_items.push(json!({ + "type": "input_image", + "image_url": image_url, + })); + } + } + _ => {} + } + } + } + + if content_items.is_empty() { + Some(vec![json!({ + "type": text_item_type, + "text": content, + })]) + } else { + Some(content_items) + } + } + + fn responses_text_item_type(role: &str) -> &'static str { + if role == "assistant" { + "output_text" + } else { + "input_text" + } + } + fn convert_single_message(msg: Message) -> Value { let mut openai_msg = json!({ "role": msg.role, @@ -125,3 +269,73 @@ impl OpenAIMessageConverter { } } +#[cfg(test)] +mod tests { + use super::OpenAIMessageConverter; + use crate::util::types::{Message, ToolCall}; + use serde_json::json; + use std::collections::HashMap; + + #[test] + fn converts_messages_to_responses_input() { + let mut args = HashMap::new(); + args.insert("city".to_string(), json!("Beijing")); + + let messages = vec![ + Message::system("You are helpful".to_string()), + Message::user("Hello".to_string()), + Message::assistant_with_tools(vec![ToolCall { + id: "call_1".to_string(), + name: "get_weather".to_string(), + arguments: args.clone(), + }]), + Message { + role: "tool".to_string(), + content: Some("Sunny".to_string()), + reasoning_content: None, + thinking_signature: None, + tool_calls: None, + tool_call_id: Some("call_1".to_string()), + name: Some("get_weather".to_string()), + }, + ]; + + let (instructions, input) = OpenAIMessageConverter::convert_messages_to_responses_input(messages); + + assert_eq!(instructions.as_deref(), Some("You are helpful")); + assert_eq!(input.len(), 3); + assert_eq!(input[0]["type"], json!("message")); + assert_eq!(input[1]["type"], json!("function_call")); + assert_eq!(input[2]["type"], json!("function_call_output")); + } + + #[test] + fn converts_openai_style_image_content_to_responses_input() { + let messages = vec![Message { + role: "user".to_string(), + content: Some(json!([ + { + "type": "image_url", + "image_url": { + "url": "data:image/png;base64,abc" + } + }, + { + "type": "text", + "text": "Describe this image" + } + ]).to_string()), + reasoning_content: None, + thinking_signature: None, + tool_calls: None, + tool_call_id: None, + name: None, + }]; + + let (_, input) = OpenAIMessageConverter::convert_messages_to_responses_input(messages); + let content = input[0]["content"].as_array().expect("content array"); + + assert_eq!(content[0]["type"], json!("input_image")); + assert_eq!(content[1]["type"], json!("input_text")); + } +} diff --git a/src/crates/core/src/service/config/types.rs b/src/crates/core/src/service/config/types.rs index 49783e2b..3da58ca6 100644 --- a/src/crates/core/src/service/config/types.rs +++ b/src/crates/core/src/service/config/types.rs @@ -745,6 +745,11 @@ pub struct AIModelConfig { #[serde(default)] pub skip_ssl_verify: bool, + /// Reasoning effort level for OpenAI Responses API (o-series / GPT-5+). + /// Valid values: "low", "medium", "high", "xhigh". None = use API default. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, + /// Custom request body (JSON string, used to override default request body fields). #[serde(default)] pub custom_request_body: Option, @@ -1133,6 +1138,7 @@ impl Default for AIModelConfig { custom_headers: None, custom_headers_mode: None, skip_ssl_verify: false, + reasoning_effort: None, custom_request_body: None, } } diff --git a/src/crates/core/src/util/types/ai.rs b/src/crates/core/src/util/types/ai.rs index 5def0f43..0cab0dbf 100644 --- a/src/crates/core/src/util/types/ai.rs +++ b/src/crates/core/src/util/types/ai.rs @@ -1,4 +1,5 @@ use serde::{Deserialize, Serialize}; +use serde_json::Value; /// Gemini API response #[derive(Debug, Clone, Serialize, Deserialize)] @@ -12,6 +13,8 @@ pub struct GeminiResponse { pub usage: Option, #[serde(skip_serializing_if = "Option::is_none")] pub finish_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub provider_metadata: Option, } /// Gemini usage stats @@ -23,6 +26,9 @@ pub struct GeminiUsage { pub candidates_token_count: u32, #[serde(rename = "totalTokenCount")] pub total_token_count: u32, + #[serde(rename = "reasoningTokenCount")] + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_token_count: Option, #[serde(rename = "cachedContentTokenCount")] #[serde(skip_serializing_if = "Option::is_none")] pub cached_content_token_count: Option, diff --git a/src/crates/core/src/util/types/config.rs b/src/crates/core/src/util/types/config.rs index bc0212fd..9158e6e5 100644 --- a/src/crates/core/src/util/types/config.rs +++ b/src/crates/core/src/util/types/config.rs @@ -13,7 +13,42 @@ fn append_endpoint(base_url: &str, endpoint: &str) -> String { format!("{}/{}", base.trim_end_matches('/'), endpoint) } -fn resolve_request_url(base_url: &str, provider: &str) -> String { +fn resolve_gemini_request_url(base_url: &str, model_name: &str) -> String { + let trimmed = base_url.trim().trim_end_matches('/').to_string(); + if trimmed.is_empty() { + return String::new(); + } + + if let Some(stripped) = trimmed.strip_suffix('#') { + return stripped.trim_end_matches('/').to_string(); + } + + let stream_endpoint = ":streamGenerateContent?alt=sse"; + if trimmed.contains(":generateContent") { + return trimmed.replace(":generateContent", stream_endpoint); + } + if trimmed.contains(":streamGenerateContent") { + if trimmed.contains("alt=sse") { + return trimmed; + } + if trimmed.contains('?') { + return format!("{}&alt=sse", trimmed); + } + return format!("{}?alt=sse", trimmed); + } + if trimmed.contains("/models/") { + return format!("{}{}", trimmed, stream_endpoint); + } + + let model = model_name.trim(); + if model.is_empty() { + return trimmed; + } + + append_endpoint(&trimmed, &format!("models/{}{}", model, stream_endpoint)) +} + +fn resolve_request_url(base_url: &str, provider: &str, model_name: &str) -> String { let trimmed = base_url.trim().trim_end_matches('/').to_string(); if trimmed.is_empty() { return String::new(); @@ -25,7 +60,9 @@ fn resolve_request_url(base_url: &str, provider: &str) -> String { match provider.trim().to_ascii_lowercase().as_str() { "openai" => append_endpoint(&trimmed, "chat/completions"), + "response" | "responses" => append_endpoint(&trimmed, "responses"), "anthropic" => append_endpoint(&trimmed, "v1/messages"), + "gemini" | "google" => resolve_gemini_request_url(&trimmed, model_name), _ => trimmed, } } @@ -43,16 +80,69 @@ pub struct AIConfig { pub format: String, pub context_window: u32, pub max_tokens: Option, + pub temperature: Option, + pub top_p: Option, pub enable_thinking_process: bool, pub support_preserved_thinking: bool, pub custom_headers: Option>, /// "replace" (default) or "merge" (defaults first, then custom) pub custom_headers_mode: Option, pub skip_ssl_verify: bool, + /// Reasoning effort for OpenAI Responses API ("low", "medium", "high", "xhigh") + pub reasoning_effort: Option, /// Custom JSON overriding default request body fields pub custom_request_body: Option, } +#[cfg(test)] +mod tests { + use super::resolve_request_url; + + #[test] + fn resolves_openai_request_url() { + assert_eq!( + resolve_request_url("https://api.openai.com/v1", "openai", ""), + "https://api.openai.com/v1/chat/completions" + ); + } + + #[test] + fn resolves_responses_request_url() { + assert_eq!( + resolve_request_url("https://api.openai.com/v1", "responses", ""), + "https://api.openai.com/v1/responses" + ); + } + + #[test] + fn resolves_response_alias_request_url() { + assert_eq!( + resolve_request_url("https://api.openai.com/v1", "response", ""), + "https://api.openai.com/v1/responses" + ); + } + + #[test] + fn keeps_forced_request_url() { + assert_eq!( + resolve_request_url("https://api.openai.com/v1/responses#", "responses", ""), + "https://api.openai.com/v1/responses" + ); + } + + #[test] + fn resolves_gemini_request_url() { + assert_eq!( + resolve_request_url( + "https://generativelanguage.googleapis.com/v1beta", + "gemini", + "gemini-2.5-pro" + ), + "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro:streamGenerateContent?alt=sse" + ); + } +} + impl TryFrom for AIConfig { type Error = String; fn try_from(other: AIModelConfig) -> Result>::Error> { @@ -73,7 +163,7 @@ impl TryFrom for AIConfig { let request_url = other .request_url .filter(|u| !u.is_empty()) - .unwrap_or_else(|| resolve_request_url(&other.base_url, &other.provider)); + .unwrap_or_else(|| resolve_request_url(&other.base_url, &other.provider, &other.model_name)); Ok(AIConfig { name: other.name.clone(), @@ -84,11 +174,14 @@ impl TryFrom for AIConfig { format: other.provider.clone(), context_window: other.context_window.unwrap_or(128128), max_tokens: other.max_tokens, + temperature: other.temperature, + top_p: other.top_p, enable_thinking_process: other.enable_thinking_process, support_preserved_thinking: other.support_preserved_thinking, custom_headers: other.custom_headers, custom_headers_mode: other.custom_headers_mode, skip_ssl_verify: other.skip_ssl_verify, + reasoning_effort: other.reasoning_effort, custom_request_body, }) } diff --git a/src/web-ui/src/features/onboarding/components/steps/ModelConfigStep.tsx b/src/web-ui/src/features/onboarding/components/steps/ModelConfigStep.tsx index 2e95a4d7..132cbe0f 100644 --- a/src/web-ui/src/features/onboarding/components/steps/ModelConfigStep.tsx +++ b/src/web-ui/src/features/onboarding/components/steps/ModelConfigStep.tsx @@ -19,7 +19,7 @@ interface ModelConfigStepProps { } /** Provider display order */ -const PROVIDER_ORDER = ['zhipu', 'qwen', 'deepseek', 'volcengine', 'minimax', 'moonshot', 'anthropic']; +const PROVIDER_ORDER = ['zhipu', 'qwen', 'deepseek', 'volcengine', 'minimax', 'moonshot', 'gemini', 'anthropic']; type TestStatus = 'idle' | 'testing' | 'success' | 'error'; @@ -33,8 +33,8 @@ export const ModelConfigStep: React.FC = ({ onSkipForNow } const [apiKey, setApiKey] = useState(modelConfig?.apiKey || ''); const [baseUrl, setBaseUrl] = useState(modelConfig?.baseUrl || ''); const [modelName, setModelName] = useState(modelConfig?.modelName || ''); - const [customFormat, setCustomFormat] = useState<'openai' | 'anthropic'>( - (modelConfig?.format as 'openai' | 'anthropic') || 'openai' + const [customFormat, setCustomFormat] = useState<'openai' | 'responses' | 'anthropic' | 'gemini'>( + (modelConfig?.format as 'openai' | 'responses' | 'anthropic' | 'gemini') || 'openai' ); const [testStatus, setTestStatus] = useState('idle'); const [testError, setTestError] = useState(''); @@ -120,7 +120,7 @@ export const ModelConfigStep: React.FC = ({ onSkipForNow } const effectiveModelName = modelName || (template?.models[0] || ''); // Derive format - let format: 'openai' | 'anthropic' = customFormat; + let format: 'openai' | 'responses' | 'anthropic' | 'gemini' = customFormat; if (template) { if (template.baseUrlOptions?.length) { const effectiveUrl = baseUrl || template.baseUrl; @@ -501,10 +501,12 @@ export const ModelConfigStep: React.FC = ({ onSkipForNow } label={t('model.format.label')} options={[ { label: 'OpenAI', value: 'openai' }, - { label: 'Anthropic', value: 'anthropic' } + { label: 'OpenAI Responses', value: 'responses' }, + { label: 'Anthropic', value: 'anthropic' }, + { label: 'Gemini', value: 'gemini' } ]} value={customFormat} - onChange={(val) => setCustomFormat(val as 'openai' | 'anthropic')} + onChange={(val) => setCustomFormat(val as 'openai' | 'responses' | 'anthropic' | 'gemini')} placeholder={t('model.format.placeholder')} /> diff --git a/src/web-ui/src/features/onboarding/store/onboardingStore.ts b/src/web-ui/src/features/onboarding/store/onboardingStore.ts index be1f6796..1451b378 100644 --- a/src/web-ui/src/features/onboarding/store/onboardingStore.ts +++ b/src/web-ui/src/features/onboarding/store/onboardingStore.ts @@ -37,7 +37,7 @@ export interface OnboardingModelConfig { modelName?: string; testPassed?: boolean; // Fields needed for saving the model config on completion - format?: 'openai' | 'anthropic'; + format?: 'openai' | 'responses' | 'anthropic' | 'gemini'; configName?: string; customRequestBody?: string; skipSslVerify?: boolean; diff --git a/src/web-ui/src/flow_chat/components/ModelSelector.scss b/src/web-ui/src/flow_chat/components/ModelSelector.scss index ea21ce15..0607bd41 100644 --- a/src/web-ui/src/flow_chat/components/ModelSelector.scss +++ b/src/web-ui/src/flow_chat/components/ModelSelector.scss @@ -90,6 +90,21 @@ letter-spacing: 0.5px; line-height: 1; } + + &__effort-badge { + display: inline-flex; + align-items: center; + justify-content: center; + padding: 1px 4px; + border-radius: 3px; + background: rgba(100, 200, 160, 0.15); + color: rgba(100, 200, 160, 0.9); + font-size: 7px; + font-weight: 600; + letter-spacing: 0.3px; + line-height: 1; + text-transform: uppercase; + } &__chevron { color: var(--color-text-tertiary); diff --git a/src/web-ui/src/flow_chat/components/ModelSelector.tsx b/src/web-ui/src/flow_chat/components/ModelSelector.tsx index bae057c4..93ce497b 100644 --- a/src/web-ui/src/flow_chat/components/ModelSelector.tsx +++ b/src/web-ui/src/flow_chat/components/ModelSelector.tsx @@ -35,6 +35,7 @@ interface ModelInfo { provider: string; contextWindow?: number; enableThinking?: boolean; + reasoningEffort?: string; } // Helper: identify special model IDs. @@ -137,7 +138,8 @@ export const ModelSelector: React.FC = ({ modelName: model.model_name, provider: model.provider, contextWindow: model.context_window, - enableThinking: model.enable_thinking_process + enableThinking: model.enable_thinking_process, + reasoningEffort: model.reasoning_effort, }; } @@ -150,7 +152,8 @@ export const ModelSelector: React.FC = ({ modelName: model.model_name, provider: model.provider, contextWindow: model.context_window, - enableThinking: model.enable_thinking_process + enableThinking: model.enable_thinking_process, + reasoningEffort: model.reasoning_effort, }; }, [getCurrentModelId, allModels, defaultModels]); @@ -168,7 +171,8 @@ export const ModelSelector: React.FC = ({ modelName: m.model_name, provider: m.provider, contextWindow: m.context_window, - enableThinking: m.enable_thinking_process + enableThinking: m.enable_thinking_process, + reasoningEffort: m.reasoning_effort, })); }, [allModels]); @@ -222,6 +226,11 @@ export const ModelSelector: React.FC = ({ {currentModel?.enableThinking && ( )} + {currentModel?.reasoningEffort && ( + + {currentModel.reasoningEffort} + + )} @@ -291,6 +300,7 @@ export const ModelSelector: React.FC = ({ {model.modelName} {model.contextWindow && ` · ${Math.round(model.contextWindow / 1000)}k`} + {model.reasoningEffort && ` · ${model.reasoningEffort}`} {isSelected && ( diff --git a/src/web-ui/src/infrastructure/api/service-api/AgentAPI.ts b/src/web-ui/src/infrastructure/api/service-api/AgentAPI.ts index aca0b2bc..fcf37805 100644 --- a/src/web-ui/src/infrastructure/api/service-api/AgentAPI.ts +++ b/src/web-ui/src/infrastructure/api/service-api/AgentAPI.ts @@ -185,6 +185,7 @@ export class AgentAPI { } } + async listSessions(): Promise { try { diff --git a/src/web-ui/src/infrastructure/config/components/AIModelConfig.tsx b/src/web-ui/src/infrastructure/config/components/AIModelConfig.tsx index ed067090..63285157 100644 --- a/src/web-ui/src/infrastructure/config/components/AIModelConfig.tsx +++ b/src/web-ui/src/infrastructure/config/components/AIModelConfig.tsx @@ -20,15 +20,21 @@ import './AIModelConfig.scss'; const log = createLogger('AIModelConfig'); +function isResponsesProvider(provider?: string): boolean { + return provider === 'response' || provider === 'responses'; +} + /** * Compute the actual request URL from a base URL and provider format. * Rules: * - Ends with '#' → strip '#', use as-is (force override) * - openai → append '/chat/completions' unless already present + * - responses → append '/responses' unless already present * - anthropic → append '/v1/messages' unless already present + * - gemini → append '/models/{model}:streamGenerateContent?alt=sse' * - other → use base_url as-is */ -function resolveRequestUrl(baseUrl: string, provider: string): string { +function resolveRequestUrl(baseUrl: string, provider: string, modelName = ''): string { const trimmed = baseUrl.trim().replace(/\/+$/, ''); if (trimmed.endsWith('#')) { return trimmed.slice(0, -1).replace(/\/+$/, ''); @@ -36,9 +42,25 @@ function resolveRequestUrl(baseUrl: string, provider: string): string { if (provider === 'openai') { return trimmed.endsWith('chat/completions') ? trimmed : `${trimmed}/chat/completions`; } + if (isResponsesProvider(provider)) { + return trimmed.endsWith('responses') ? trimmed : `${trimmed}/responses`; + } if (provider === 'anthropic') { return trimmed.endsWith('v1/messages') ? trimmed : `${trimmed}/v1/messages`; } + if (provider === 'gemini') { + if (!modelName.trim()) return trimmed; + if (trimmed.includes(':generateContent')) { + return trimmed.replace(':generateContent', ':streamGenerateContent?alt=sse'); + } + if (trimmed.includes(':streamGenerateContent')) { + return trimmed.includes('alt=sse') ? trimmed : `${trimmed}${trimmed.includes('?') ? '&' : '?'}alt=sse`; + } + if (trimmed.includes('/models/')) { + return `${trimmed}:streamGenerateContent?alt=sse`; + } + return `${trimmed}/models/${modelName}:streamGenerateContent?alt=sse`; + } return trimmed; } @@ -69,6 +91,26 @@ const AIModelConfig: React.FC = () => { }); const [isProxySaving, setIsProxySaving] = useState(false); + const requestFormatOptions = useMemo( + () => [ + { label: 'OpenAI (chat/completions)', value: 'openai' }, + { label: 'OpenAI (responses)', value: 'responses' }, + { label: 'Anthropic (messages)', value: 'anthropic' }, + { label: 'Gemini (generateContent)', value: 'gemini' }, + ], + [] + ); + + const reasoningEffortOptions = useMemo( + () => [ + { label: 'Low', value: 'low' }, + { label: 'Medium', value: 'medium' }, + { label: 'High', value: 'high' }, + { label: 'Extra High', value: 'xhigh' }, + ], + [] + ); + useEffect(() => { loadConfig(); @@ -88,7 +130,7 @@ const AIModelConfig: React.FC = () => { }; // Provider options with translations (must be at top level, before any conditional returns) - const providerOrder = ['zhipu', 'qwen', 'deepseek', 'volcengine', 'minimax', 'moonshot', 'anthropic']; + const providerOrder = ['zhipu', 'qwen', 'deepseek', 'volcengine', 'minimax', 'moonshot', 'gemini', 'anthropic']; const providers = useMemo(() => { const sorted = Object.values(PROVIDER_TEMPLATES).sort((a, b) => { const indexA = providerOrder.indexOf(a.id); @@ -141,7 +183,7 @@ const AIModelConfig: React.FC = () => { setEditingConfig({ name: defaultModel ? `${providerName} - ${defaultModel}` : '', base_url: template.baseUrl, - request_url: resolveRequestUrl(template.baseUrl, template.format), + request_url: resolveRequestUrl(template.baseUrl, template.format, defaultModel), api_key: '', model_name: defaultModel, provider: template.format, @@ -216,7 +258,7 @@ const AIModelConfig: React.FC = () => { id: editingConfig.id || `model_${Date.now()}`, name: editingConfig.name, base_url: editingConfig.base_url, - request_url: editingConfig.request_url || resolveRequestUrl(editingConfig.base_url, editingConfig.provider || 'openai'), + request_url: editingConfig.request_url || resolveRequestUrl(editingConfig.base_url, editingConfig.provider || 'openai', editingConfig.model_name || ''), api_key: editingConfig.api_key || '', model_name: editingConfig.model_name || 'search-api', provider: editingConfig.provider || 'openai', @@ -234,6 +276,8 @@ const AIModelConfig: React.FC = () => { enable_thinking_process: editingConfig.enable_thinking_process ?? false, support_preserved_thinking: editingConfig.support_preserved_thinking ?? false, + + reasoning_effort: editingConfig.reasoning_effort, custom_headers: editingConfig.custom_headers, @@ -542,17 +586,17 @@ const AIModelConfig: React.FC = () => { case 'general_chat': defaultCapabilities = ['text_chat', 'function_calling']; updates.base_url = 'https://open.bigmodel.cn/api/paas/v4/chat/completions'; - updates.request_url = resolveRequestUrl(updates.base_url!, prev?.provider || 'openai'); + updates.request_url = resolveRequestUrl(updates.base_url!, prev?.provider || 'openai', prev?.model_name || ''); break; case 'multimodal': defaultCapabilities = ['text_chat', 'image_understanding', 'function_calling']; updates.base_url = 'https://open.bigmodel.cn/api/paas/v4/chat/completions'; - updates.request_url = resolveRequestUrl(updates.base_url!, prev?.provider || 'openai'); + updates.request_url = resolveRequestUrl(updates.base_url!, prev?.provider || 'openai', prev?.model_name || ''); break; case 'image_generation': defaultCapabilities = ['image_generation']; updates.base_url = 'https://open.bigmodel.cn/api/paas/v4/images/generations'; - updates.request_url = resolveRequestUrl(updates.base_url!, prev?.provider || 'openai'); + updates.request_url = resolveRequestUrl(updates.base_url!, prev?.provider || 'openai', prev?.model_name || ''); break; case 'search_enhanced': defaultCapabilities = ['search']; @@ -566,7 +610,7 @@ const AIModelConfig: React.FC = () => { case 'speech_recognition': defaultCapabilities = ['speech_recognition']; updates.base_url = 'https://open.bigmodel.cn/api/paas/v4/chat/completions'; - updates.request_url = resolveRequestUrl(updates.base_url!, prev?.provider || 'openai'); + updates.request_url = resolveRequestUrl(updates.base_url!, prev?.provider || 'openai', prev?.model_name || ''); break; } updates.capabilities = defaultCapabilities; @@ -601,6 +645,7 @@ const AIModelConfig: React.FC = () => { return { ...prev, model_name: newModelName, + request_url: resolveRequestUrl(prev?.base_url || currentTemplate?.baseUrl || '', prev?.provider || currentTemplate?.format || 'openai', newModelName), name: isAutoGenerated && currentTemplate ? `${currentTemplate.name} - ${newModelName}` : prev?.name }; }); @@ -629,7 +674,7 @@ const AIModelConfig: React.FC = () => { setEditingConfig(prev => ({ ...prev, base_url: value as string, - request_url: resolveRequestUrl(value as string, newProvider), + request_url: resolveRequestUrl(value as string, newProvider, editingConfig.model_name || ''), provider: newProvider })); }} @@ -644,7 +689,7 @@ const AIModelConfig: React.FC = () => { onChange={(e) => setEditingConfig(prev => ({ ...prev, base_url: e.target.value, - request_url: resolveRequestUrl(e.target.value, prev?.provider || 'openai') + request_url: resolveRequestUrl(e.target.value, prev?.provider || 'openai', prev?.model_name || '') }))} onFocus={(e) => e.target.select()} placeholder={currentTemplate?.baseUrl} @@ -654,7 +699,7 @@ const AIModelConfig: React.FC = () => {
{t('form.resolvedUrlLabel')} - {resolveRequestUrl(editingConfig.base_url, editingConfig.provider || 'openai')} + {resolveRequestUrl(editingConfig.base_url, editingConfig.provider || 'openai', editingConfig.model_name || '')} {t('form.forceUrlHint')}
@@ -662,16 +707,16 @@ const AIModelConfig: React.FC = () => { )} - + setEditingConfig(prev => ({ ...prev, reasoning_effort: (v as string) || undefined }))} placeholder={t('reasoningEffort.placeholder')} options={reasoningEffortOptions} /> + + )} ) : ( <> @@ -706,7 +756,7 @@ const AIModelConfig: React.FC = () => { onChange={(e) => setEditingConfig(prev => ({ ...prev, base_url: e.target.value, - request_url: resolveRequestUrl(e.target.value, prev?.provider || 'openai') + request_url: resolveRequestUrl(e.target.value, prev?.provider || 'openai', prev?.model_name || '') }))} onFocus={(e) => e.target.select()} placeholder={editingConfig.category === 'search_enhanced' ? 'https://open.bigmodel.cn/api/paas/v4/web_search' : 'https://open.bigmodel.cn/api/paas/v4/chat/completions'} @@ -716,7 +766,7 @@ const AIModelConfig: React.FC = () => {
{t('form.resolvedUrlLabel')} - {resolveRequestUrl(editingConfig.base_url, editingConfig.provider || 'openai')} + {resolveRequestUrl(editingConfig.base_url, editingConfig.provider || 'openai', editingConfig.model_name || '')} {t('form.forceUrlHint')}
@@ -732,10 +782,18 @@ const AIModelConfig: React.FC = () => { {!isFromTemplate && editingConfig.category !== 'search_enhanced' && ( <> - setEditingConfig(prev => ({ ...prev, model_name: e.target.value }))} placeholder={editingConfig.category === 'speech_recognition' ? 'glm-asr' : 'glm-4.7'} inputSize="small" /> + setEditingConfig(prev => ({ ...prev, model_name: e.target.value, request_url: resolveRequestUrl(prev?.base_url || '', prev?.provider || 'openai', e.target.value) }))} placeholder={editingConfig.category === 'speech_recognition' ? 'glm-asr' : 'glm-4.7'} inputSize="small" /> - - { + const provider = value as string; + setEditingConfig(prev => ({ + ...prev, + provider, + request_url: resolveRequestUrl(prev?.base_url || '', provider, prev?.model_name || ''), + reasoning_effort: isResponsesProvider(provider) ? (prev?.reasoning_effort || 'medium') : undefined, + })); + }} placeholder={t('form.providerPlaceholder')} options={requestFormatOptions} /> {editingConfig.category !== 'speech_recognition' && ( <> @@ -752,6 +810,11 @@ const AIModelConfig: React.FC = () => { setEditingConfig(prev => ({ ...prev, enable_thinking_process: e.target.checked }))} size="small" /> + {isResponsesProvider(editingConfig.provider) && ( + +