diff --git a/src-tauri/src/commands/preferences.rs b/src-tauri/src/commands/preferences.rs index fdffaee5..2e70eade 100644 --- a/src-tauri/src/commands/preferences.rs +++ b/src-tauri/src/commands/preferences.rs @@ -1,3 +1,6 @@ +use std::collections::HashMap; +use std::sync::{Mutex, OnceLock}; + use serde::{Deserialize, Serialize}; use crate::config_io::{read_json, write_json}; @@ -89,6 +92,19 @@ pub fn get_zeroclaw_usage_stats() -> Result }) } +#[tauri::command] +pub fn get_session_usage_stats(session_id: String) -> Result { + let stats = crate::runtime::zeroclaw::process::get_session_usage(&session_id); + Ok(ZeroclawUsageStatsResponse { + total_calls: stats.total_calls, + usage_calls: stats.usage_calls, + prompt_tokens: stats.prompt_tokens, + completion_tokens: stats.completion_tokens, + total_tokens: stats.total_tokens, + last_updated_ms: stats.last_updated_ms, + }) +} + #[tauri::command] pub fn get_zeroclaw_runtime_target() -> Result { let target = crate::runtime::zeroclaw::process::get_zeroclaw_runtime_target(); @@ -101,6 +117,52 @@ pub fn get_zeroclaw_runtime_target() -> Result &'static Mutex> { + static STORE: OnceLock>> = OnceLock::new(); + STORE.get_or_init(|| Mutex::new(HashMap::new())) +} + +/// Look up a session model override without going through Tauri command dispatch. +pub fn lookup_session_model_override(session_id: &str) -> Option { + session_model_overrides() + .lock() + .ok()? + .get(session_id) + .cloned() +} + +#[tauri::command] +pub fn set_session_model_override(session_id: String, model: String) -> Result<(), String> { + let trimmed = model.trim().to_string(); + if trimmed.is_empty() { + return Err("model must not be empty".into()); + } + if let Ok(mut map) = session_model_overrides().lock() { + map.insert(session_id, trimmed); + } + Ok(()) +} + +#[tauri::command] +pub fn get_session_model_override(session_id: String) -> Result, String> { + let map = session_model_overrides() + .lock() + .map_err(|e| e.to_string())?; + Ok(map.get(&session_id).cloned()) +} + +#[tauri::command] +pub fn clear_session_model_override(session_id: String) -> Result<(), String> { + if let Ok(mut map) = session_model_overrides().lock() { + map.remove(&session_id); + } + Ok(()) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 15aa5b09..226fa5ca 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -8,12 +8,13 @@ use crate::cli_runner::{ }; use crate::commands::{ analyze_sessions, apply_config_patch, backup_before_upgrade, chat_via_openclaw, - check_openclaw_update, clear_all_sessions, connect_docker_instance, connect_local_instance, - connect_ssh_instance, create_agent, delete_agent, delete_backup, delete_cron_job, - delete_local_instance_home, delete_model_profile, delete_registered_instance, - delete_sessions_by_ids, delete_ssh_host, deploy_watchdog, diagnose_primary_via_rescue, - discover_local_instances, ensure_access_profile, extract_model_profiles_from_config, - fix_issues, get_app_preferences, get_cached_model_catalog, get_cron_runs, get_status_extra, + check_openclaw_update, clear_all_sessions, clear_session_model_override, + connect_docker_instance, connect_local_instance, connect_ssh_instance, create_agent, + delete_agent, delete_backup, delete_cron_job, delete_local_instance_home, delete_model_profile, + delete_registered_instance, delete_sessions_by_ids, delete_ssh_host, deploy_watchdog, + diagnose_primary_via_rescue, discover_local_instances, ensure_access_profile, + extract_model_profiles_from_config, fix_issues, get_app_preferences, get_cached_model_catalog, + get_cron_runs, get_session_model_override, get_session_usage_stats, get_status_extra, get_status_light, get_system_status, get_watchdog_status, get_zeroclaw_runtime_target, get_zeroclaw_usage_stats, list_agents_overview, list_backups, list_bindings, list_channels_minimal, list_cron_jobs, list_discord_guild_channels, list_history, @@ -42,11 +43,11 @@ use crate::commands::{ remote_upsert_model_profile, remote_write_raw_config, repair_primary_via_rescue, resolve_api_keys, resolve_provider_auth, restart_gateway, restore_from_backup, rollback, run_doctor_command, run_openclaw_upgrade, set_active_clawpal_data_dir, - set_active_openclaw_home, set_agent_model, set_global_model, set_zeroclaw_model_preference, - setup_agent_identity, sftp_list_dir, sftp_read_file, sftp_remove_file, sftp_write_file, - ssh_connect, ssh_connect_with_passphrase, ssh_disconnect, ssh_exec, ssh_status, start_watchdog, - stop_watchdog, test_model_profile, trigger_cron_job, uninstall_watchdog, upsert_model_profile, - upsert_ssh_host, + set_active_openclaw_home, set_agent_model, set_global_model, set_session_model_override, + set_zeroclaw_model_preference, setup_agent_identity, sftp_list_dir, sftp_read_file, + sftp_remove_file, sftp_write_file, ssh_connect, ssh_connect_with_passphrase, ssh_disconnect, + ssh_exec, ssh_status, start_watchdog, stop_watchdog, test_model_profile, trigger_cron_job, + uninstall_watchdog, upsert_model_profile, upsert_ssh_host, }; use crate::doctor_commands::{ collect_doctor_context, collect_doctor_context_remote, doctor_approve_invoke, doctor_connect, @@ -59,6 +60,7 @@ use crate::install::commands::{ use crate::install::session_store::InstallSessionStore; use crate::install_commands::{install_send_message, install_start_session}; use crate::node_client::NodeClient; +use crate::runtime::zeroclaw::cost::estimate_query_cost; use crate::ssh::SshConnectionPool; pub mod access_discovery; @@ -119,7 +121,12 @@ pub fn run() { get_status_extra, get_app_preferences, get_zeroclaw_usage_stats, + get_session_usage_stats, get_zeroclaw_runtime_target, + set_session_model_override, + get_session_model_override, + clear_session_model_override, + estimate_query_cost, list_recipes, list_model_profiles, get_cached_model_catalog, diff --git a/src-tauri/src/runtime/zeroclaw/cost.rs b/src-tauri/src/runtime/zeroclaw/cost.rs new file mode 100644 index 00000000..4f24d6c5 --- /dev/null +++ b/src-tauri/src/runtime/zeroclaw/cost.rs @@ -0,0 +1,78 @@ +use serde::{Deserialize, Serialize}; + +/// Model pricing: (prompt_price_per_1k_tokens, completion_price_per_1k_tokens) +fn model_pricing(model: &str) -> Option<(f64, f64)> { + let lower = model.trim().to_ascii_lowercase(); + // Match by substring to handle provider prefixes like "openrouter/anthropic/claude-3.7-sonnet" + if lower.contains("claude-3.7-sonnet") || lower.contains("claude-3-7-sonnet") { + Some((0.003, 0.015)) + } else if lower.contains("claude-3.5-haiku") || lower.contains("claude-3-5-haiku") { + Some((0.0008, 0.004)) + } else if lower.contains("gpt-4o-mini") { + Some((0.00015, 0.0006)) + } else if lower.contains("gpt-4o") { + Some((0.0025, 0.01)) + } else if lower.contains("gpt-4.1") { + Some((0.002, 0.008)) + } else if lower.contains("gemini-2.0-flash") { + Some((0.0001, 0.0004)) + } else if lower.contains("kimi-k2") { + Some((0.0006, 0.002)) + } else { + None + } +} + +/// Estimate cost in USD for a given model and token counts. +pub fn estimate_cost(model: &str, prompt_tokens: u64, completion_tokens: u64) -> Option { + let (prompt_price, completion_price) = model_pricing(model)?; + let cost = (prompt_tokens as f64 / 1000.0) * prompt_price + + (completion_tokens as f64 / 1000.0) * completion_price; + Some(cost) +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +pub struct CostEstimate { + pub model: String, + pub prompt_tokens: u64, + pub completion_tokens: u64, + pub estimated_cost_usd: Option, +} + +#[tauri::command] +pub fn estimate_query_cost( + model: String, + prompt_tokens: u64, + completion_tokens: u64, +) -> Result { + let cost = estimate_cost(&model, prompt_tokens, completion_tokens); + Ok(CostEstimate { + model, + prompt_tokens, + completion_tokens, + estimated_cost_usd: cost, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn estimate_cost_for_known_model() { + let cost = estimate_cost("gpt-4o", 1000, 1000).unwrap(); + assert!((cost - 0.0125).abs() < 0.0001); + } + + #[test] + fn estimate_cost_for_unknown_model() { + assert!(estimate_cost("unknown-model", 1000, 1000).is_none()); + } + + #[test] + fn estimate_cost_with_provider_prefix() { + let cost = estimate_cost("openrouter/anthropic/claude-3.7-sonnet", 1000, 500).unwrap(); + assert!(cost > 0.0); + } +} diff --git a/src-tauri/src/runtime/zeroclaw/mod.rs b/src-tauri/src/runtime/zeroclaw/mod.rs index 5216aa98..06ba3c97 100644 --- a/src-tauri/src/runtime/zeroclaw/mod.rs +++ b/src-tauri/src/runtime/zeroclaw/mod.rs @@ -1,4 +1,5 @@ pub mod adapter; +pub mod cost; pub mod install_adapter; pub mod process; pub mod sanitize; diff --git a/src-tauri/src/runtime/zeroclaw/process.rs b/src-tauri/src/runtime/zeroclaw/process.rs index 9d903808..00030f3e 100644 --- a/src-tauri/src/runtime/zeroclaw/process.rs +++ b/src-tauri/src/runtime/zeroclaw/process.rs @@ -266,6 +266,43 @@ pub fn get_zeroclaw_usage_stats() -> ZeroclawUsageStats { usage_store().lock().map(|stats| *stats).unwrap_or_default() } +// --------------------------------------------------------------------------- +// Per-session usage tracking +// --------------------------------------------------------------------------- + +fn session_usage_store() -> &'static Mutex> { + static STORE: OnceLock>> = + OnceLock::new(); + STORE.get_or_init(|| Mutex::new(std::collections::HashMap::new())) +} + +pub fn record_session_usage(session_id: &str, prompt_tokens: u64, completion_tokens: u64) { + if session_id.is_empty() { + return; + } + if let Ok(mut map) = session_usage_store().lock() { + let stats = map + .entry(session_id.to_string()) + .or_insert_with(ZeroclawUsageStats::default); + stats.total_calls = stats.total_calls.saturating_add(1); + stats.usage_calls = stats.usage_calls.saturating_add(1); + stats.prompt_tokens = stats.prompt_tokens.saturating_add(prompt_tokens); + stats.completion_tokens = stats.completion_tokens.saturating_add(completion_tokens); + stats.total_tokens = stats + .total_tokens + .saturating_add(prompt_tokens.saturating_add(completion_tokens)); + stats.last_updated_ms = now_ms(); + } +} + +pub fn get_session_usage(session_id: &str) -> ZeroclawUsageStats { + session_usage_store() + .lock() + .ok() + .and_then(|map| map.get(session_id).copied()) + .unwrap_or_default() +} + fn sanitize_instance_namespace(raw: &str) -> String { let trimmed = raw.trim(); if trimmed.is_empty() { @@ -804,7 +841,9 @@ pub fn run_zeroclaw_message( "-m".to_string(), message, ]; - let preferred_model = crate::commands::load_zeroclaw_model_preference(); + // Per-session model override takes priority over global preference. + let preferred_model = crate::commands::preferences::lookup_session_model_override(instance_id) + .or_else(|| crate::commands::load_zeroclaw_model_preference()); let provider_order = provider_order_for_runtime(&env_pairs, preferred_model.as_deref()); if provider_order.is_empty() { return Err( @@ -821,7 +860,13 @@ pub fn run_zeroclaw_message( let stdout = sanitize_output(&String::from_utf8_lossy(&output.stdout)); let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); record_zeroclaw_usage(&stdout, &stderr); - if parse_usage_from_text(&stdout).is_none() && parse_usage_from_text(&stderr).is_none() { + // Also record per-session usage. + let session_usage = + parse_usage_from_text(&stdout).or_else(|| parse_usage_from_text(&stderr)); + if let Some((prompt, completion, _total)) = session_usage { + record_session_usage(instance_id, prompt, completion); + } + if session_usage.is_none() { if let Ok(mut stats) = usage_store().lock() { if let Some((prompt, completion, total)) = read_usage_from_builtin_traces(&cmd, &cfg, &env_pairs) @@ -831,6 +876,8 @@ pub fn run_zeroclaw_message( stats.completion_tokens = stats.completion_tokens.saturating_add(completion); stats.total_tokens = stats.total_tokens.saturating_add(total); stats.last_updated_ms = now_ms(); + // Record per-session usage from traces as well. + record_session_usage(instance_id, prompt, completion); } } } diff --git a/src/components/ModelSwitcher.tsx b/src/components/ModelSwitcher.tsx new file mode 100644 index 00000000..8c212846 --- /dev/null +++ b/src/components/ModelSwitcher.tsx @@ -0,0 +1,107 @@ +import { useEffect, useState } from "react"; +import { invoke } from "@tauri-apps/api/core"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { + Popover, + PopoverContent, + PopoverTrigger, +} from "@/components/ui/popover"; + +const AVAILABLE_MODELS = [ + "gpt-4o", + "gpt-4o-mini", + "gpt-4.1", + "claude-3.7-sonnet", + "claude-3.5-haiku", + "gemini-2.0-flash", + "kimi-k2.5", +]; + +interface ModelSwitcherProps { + sessionId: string; + defaultModel?: string; + /** Notifies parent when the effective model changes (override set/cleared). */ + onModelChange?: (model: string | undefined) => void; +} + +export function ModelSwitcher({ sessionId, defaultModel, onModelChange }: ModelSwitcherProps) { + const [override, setOverride] = useState(null); + const [open, setOpen] = useState(false); + + useEffect(() => { + if (!sessionId) return; + invoke("get_session_model_override", { sessionId }) + .then(setOverride) + .catch(() => {}); + }, [sessionId]); + + const currentModel = override ?? defaultModel ?? "auto"; + + const handleSelect = async (model: string) => { + try { + await invoke("set_session_model_override", { sessionId, model }); + setOverride(model); + onModelChange?.(model); + } catch { + // silently ignore + } + setOpen(false); + }; + + const handleClear = async () => { + try { + await invoke("clear_session_model_override", { sessionId }); + setOverride(null); + onModelChange?.(undefined); + } catch { + // silently ignore + } + setOpen(false); + }; + + return ( +
+ + + + + +
+ Switch model for this session +
+
+ {AVAILABLE_MODELS.map((model) => ( + + ))} +
+ {override && ( +
+ +
+ )} +
+
+ {override && ( + + Session override + + )} +
+ ); +} diff --git a/src/components/TokenBadge.tsx b/src/components/TokenBadge.tsx new file mode 100644 index 00000000..4869efed --- /dev/null +++ b/src/components/TokenBadge.tsx @@ -0,0 +1,77 @@ +import { useEffect, useState } from "react"; +import { invoke } from "@tauri-apps/api/core"; +import { Badge } from "@/components/ui/badge"; + +interface SessionUsageStats { + totalCalls: number; + usageCalls: number; + promptTokens: number; + completionTokens: number; + totalTokens: number; + lastUpdatedMs: number; +} + +interface CostEstimate { + model: string; + promptTokens: number; + completionTokens: number; + estimatedCostUsd: number | null; +} + +interface TokenBadgeProps { + sessionId: string; + model?: string; +} + +export function TokenBadge({ sessionId, model }: TokenBadgeProps) { + const [stats, setStats] = useState(null); + const [cost, setCost] = useState(null); + + useEffect(() => { + if (!sessionId) return; + + const fetchStats = async () => { + try { + const usage = await invoke("get_session_usage_stats", { + sessionId, + }); + setStats(usage); + + if (model && (usage.promptTokens > 0 || usage.completionTokens > 0)) { + const estimate = await invoke("estimate_query_cost", { + model, + promptTokens: usage.promptTokens, + completionTokens: usage.completionTokens, + }); + setCost(estimate.estimatedCostUsd); + } + } catch { + // silently ignore + } + }; + + fetchStats(); + const interval = setInterval(fetchStats, 5000); + return () => clearInterval(interval); + }, [sessionId, model]); + + if (!stats || stats.totalTokens === 0) return null; + + const formatTokens = (n: number) => { + if (n >= 1000000) return `${(n / 1000000).toFixed(1)}M`; + if (n >= 1000) return `${(n / 1000).toFixed(1)}k`; + return String(n); + }; + + const formatCost = (c: number) => { + if (c < 0.01) return `$${c.toFixed(4)}`; + return `$${c.toFixed(2)}`; + }; + + return ( + + 🪙 {formatTokens(stats.totalTokens)} + {cost !== null && ({formatCost(cost)})} + + ); +} diff --git a/src/pages/Doctor.tsx b/src/pages/Doctor.tsx index 52703671..1dabf5f9 100644 --- a/src/pages/Doctor.tsx +++ b/src/pages/Doctor.tsx @@ -1,4 +1,5 @@ import { useCallback, useEffect, useRef, useState } from "react"; +import { invoke } from "@tauri-apps/api/core"; import { useTranslation } from "react-i18next"; import { useApi, hasGuidanceEmitted } from "@/lib/use-api"; import { useInstance } from "@/lib/instance-context"; @@ -35,6 +36,8 @@ import { } from "@/components/ui/alert-dialog"; import { Skeleton } from "@/components/ui/skeleton"; import { DoctorChat } from "@/components/DoctorChat"; +import { TokenBadge } from "@/components/TokenBadge"; +import { ModelSwitcher } from "@/components/ModelSwitcher"; import { SessionAnalysisPanel } from "@/components/SessionAnalysisPanel"; import type { BackupInfo } from "@/lib/types"; import { formatTime, formatBytes } from "@/lib/utils"; @@ -133,6 +136,8 @@ export function Doctor({ const ua = useApi(); const { instanceId, isDocker, isRemote, isConnected } = useInstance(); const doctor = useDoctorAgent(); + const [runtimeModel, setRuntimeModel] = useState(undefined); + const [sessionModelOverride, setSessionModelOverride] = useState(undefined); const [remoteConnState, setRemoteConnState] = useState<"checking" | "connected" | "disconnected">("checking"); const [diagnosing, setDiagnosing] = useState(false); @@ -199,6 +204,30 @@ export function Doctor({ // eslint-disable-next-line react-hooks/exhaustive-deps }, [doctor.setTarget, instanceId, isRemote]); + // Fetch runtime target model for TokenBadge / ModelSwitcher. + useEffect(() => { + invoke<{ model?: string }>("get_zeroclaw_runtime_target") + .then((target) => { + if (target?.model) setRuntimeModel(target.model); + }) + .catch(() => {}); + }, []); + + // Use instanceId as the stable session key for model override / usage tracking. + // This matches the backend which looks up overrides by instance_id. + const doctorSessionId = instanceId || "local"; + + // Track session model override so TokenBadge uses the effective model for cost. + useEffect(() => { + if (!doctorSessionId) return; + invoke("get_session_model_override", { sessionId: doctorSessionId }) + .then((m) => setSessionModelOverride(m ?? undefined)) + .catch(() => {}); + }, [doctorSessionId]); + + // Effective model: session override takes priority over global runtime model. + const effectiveModel = sessionModelOverride ?? runtimeModel; + const handleStartDiagnosis = async (extraContext?: string) => { setStartError(null); setDiagnosing(true); @@ -1030,7 +1059,7 @@ export function Doctor({ )}
-
+
{t("doctor.engineZeroclaw")} @@ -1038,6 +1067,8 @@ export function Doctor({ {doctor.bridgeConnected ? t("doctor.bridgeConnected") : t("doctor.bridgeDisconnected")} + +