From 75675a040fdbcfeb616c7fcd4c72071fe5daf57b Mon Sep 17 00:00:00 2001 From: dev01lay2 Date: Mon, 2 Mar 2026 07:55:55 +0000 Subject: [PATCH 1/5] feat: in-session model switching and cost awareness (#32) - Add per-session model override commands (set/get/clear) in preferences.rs - Add per-session token tracking (record_session_usage, get_session_usage) in process.rs - Add cost estimation module (cost.rs) with hardcoded pricing for common models - Add TokenBadge component showing session token usage and estimated cost - Add ModelSwitcher component for switching models per session - Wire TokenBadge and ModelSwitcher into Doctor page zeroclaw view - Register all new Tauri commands in lib.rs Closes #32 --- src-tauri/src/commands/preferences.rs | 53 +++++++++++ src-tauri/src/lib.rs | 29 +++--- src-tauri/src/runtime/zeroclaw/cost.rs | 78 ++++++++++++++++ src-tauri/src/runtime/zeroclaw/mod.rs | 1 + src-tauri/src/runtime/zeroclaw/process.rs | 37 ++++++++ src/components/ModelSwitcher.tsx | 103 ++++++++++++++++++++++ src/components/TokenBadge.tsx | 77 ++++++++++++++++ src/pages/Doctor.tsx | 6 +- 8 files changed, 372 insertions(+), 12 deletions(-) create mode 100644 src-tauri/src/runtime/zeroclaw/cost.rs create mode 100644 src/components/ModelSwitcher.tsx create mode 100644 src/components/TokenBadge.tsx diff --git a/src-tauri/src/commands/preferences.rs b/src-tauri/src/commands/preferences.rs index fdffaee5..555f5d83 100644 --- a/src-tauri/src/commands/preferences.rs +++ b/src-tauri/src/commands/preferences.rs @@ -1,3 +1,6 @@ +use std::collections::HashMap; +use std::sync::{Mutex, OnceLock}; + use serde::{Deserialize, Serialize}; use crate::config_io::{read_json, write_json}; @@ -89,6 +92,19 @@ pub fn get_zeroclaw_usage_stats() -> Result }) } +#[tauri::command] +pub fn get_session_usage_stats(session_id: String) -> Result { + let stats = crate::runtime::zeroclaw::process::get_session_usage(&session_id); + Ok(ZeroclawUsageStatsResponse { + total_calls: stats.total_calls, + usage_calls: stats.usage_calls, + prompt_tokens: stats.prompt_tokens, + completion_tokens: stats.completion_tokens, + total_tokens: stats.total_tokens, + last_updated_ms: stats.last_updated_ms, + }) +} + #[tauri::command] pub fn get_zeroclaw_runtime_target() -> Result { let target = crate::runtime::zeroclaw::process::get_zeroclaw_runtime_target(); @@ -101,6 +117,43 @@ pub fn get_zeroclaw_runtime_target() -> Result &'static Mutex> { + static STORE: OnceLock>> = OnceLock::new(); + STORE.get_or_init(|| Mutex::new(HashMap::new())) +} + +#[tauri::command] +pub fn set_session_model_override(session_id: String, model: String) -> Result<(), String> { + let trimmed = model.trim().to_string(); + if trimmed.is_empty() { + return Err("model must not be empty".into()); + } + if let Ok(mut map) = session_model_overrides().lock() { + map.insert(session_id, trimmed); + } + Ok(()) +} + +#[tauri::command] +pub fn get_session_model_override(session_id: String) -> Result, String> { + let map = session_model_overrides() + .lock() + .map_err(|e| e.to_string())?; + Ok(map.get(&session_id).cloned()) +} + +#[tauri::command] +pub fn clear_session_model_override(session_id: String) -> Result<(), String> { + if let Ok(mut map) = session_model_overrides().lock() { + map.remove(&session_id); + } + Ok(()) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 15aa5b09..226fa5ca 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -8,12 +8,13 @@ use crate::cli_runner::{ }; use crate::commands::{ analyze_sessions, apply_config_patch, backup_before_upgrade, chat_via_openclaw, - check_openclaw_update, clear_all_sessions, connect_docker_instance, connect_local_instance, - connect_ssh_instance, create_agent, delete_agent, delete_backup, delete_cron_job, - delete_local_instance_home, delete_model_profile, delete_registered_instance, - delete_sessions_by_ids, delete_ssh_host, deploy_watchdog, diagnose_primary_via_rescue, - discover_local_instances, ensure_access_profile, extract_model_profiles_from_config, - fix_issues, get_app_preferences, get_cached_model_catalog, get_cron_runs, get_status_extra, + check_openclaw_update, clear_all_sessions, clear_session_model_override, + connect_docker_instance, connect_local_instance, connect_ssh_instance, create_agent, + delete_agent, delete_backup, delete_cron_job, delete_local_instance_home, delete_model_profile, + delete_registered_instance, delete_sessions_by_ids, delete_ssh_host, deploy_watchdog, + diagnose_primary_via_rescue, discover_local_instances, ensure_access_profile, + extract_model_profiles_from_config, fix_issues, get_app_preferences, get_cached_model_catalog, + get_cron_runs, get_session_model_override, get_session_usage_stats, get_status_extra, get_status_light, get_system_status, get_watchdog_status, get_zeroclaw_runtime_target, get_zeroclaw_usage_stats, list_agents_overview, list_backups, list_bindings, list_channels_minimal, list_cron_jobs, list_discord_guild_channels, list_history, @@ -42,11 +43,11 @@ use crate::commands::{ remote_upsert_model_profile, remote_write_raw_config, repair_primary_via_rescue, resolve_api_keys, resolve_provider_auth, restart_gateway, restore_from_backup, rollback, run_doctor_command, run_openclaw_upgrade, set_active_clawpal_data_dir, - set_active_openclaw_home, set_agent_model, set_global_model, set_zeroclaw_model_preference, - setup_agent_identity, sftp_list_dir, sftp_read_file, sftp_remove_file, sftp_write_file, - ssh_connect, ssh_connect_with_passphrase, ssh_disconnect, ssh_exec, ssh_status, start_watchdog, - stop_watchdog, test_model_profile, trigger_cron_job, uninstall_watchdog, upsert_model_profile, - upsert_ssh_host, + set_active_openclaw_home, set_agent_model, set_global_model, set_session_model_override, + set_zeroclaw_model_preference, setup_agent_identity, sftp_list_dir, sftp_read_file, + sftp_remove_file, sftp_write_file, ssh_connect, ssh_connect_with_passphrase, ssh_disconnect, + ssh_exec, ssh_status, start_watchdog, stop_watchdog, test_model_profile, trigger_cron_job, + uninstall_watchdog, upsert_model_profile, upsert_ssh_host, }; use crate::doctor_commands::{ collect_doctor_context, collect_doctor_context_remote, doctor_approve_invoke, doctor_connect, @@ -59,6 +60,7 @@ use crate::install::commands::{ use crate::install::session_store::InstallSessionStore; use crate::install_commands::{install_send_message, install_start_session}; use crate::node_client::NodeClient; +use crate::runtime::zeroclaw::cost::estimate_query_cost; use crate::ssh::SshConnectionPool; pub mod access_discovery; @@ -119,7 +121,12 @@ pub fn run() { get_status_extra, get_app_preferences, get_zeroclaw_usage_stats, + get_session_usage_stats, get_zeroclaw_runtime_target, + set_session_model_override, + get_session_model_override, + clear_session_model_override, + estimate_query_cost, list_recipes, list_model_profiles, get_cached_model_catalog, diff --git a/src-tauri/src/runtime/zeroclaw/cost.rs b/src-tauri/src/runtime/zeroclaw/cost.rs new file mode 100644 index 00000000..4f24d6c5 --- /dev/null +++ b/src-tauri/src/runtime/zeroclaw/cost.rs @@ -0,0 +1,78 @@ +use serde::{Deserialize, Serialize}; + +/// Model pricing: (prompt_price_per_1k_tokens, completion_price_per_1k_tokens) +fn model_pricing(model: &str) -> Option<(f64, f64)> { + let lower = model.trim().to_ascii_lowercase(); + // Match by substring to handle provider prefixes like "openrouter/anthropic/claude-3.7-sonnet" + if lower.contains("claude-3.7-sonnet") || lower.contains("claude-3-7-sonnet") { + Some((0.003, 0.015)) + } else if lower.contains("claude-3.5-haiku") || lower.contains("claude-3-5-haiku") { + Some((0.0008, 0.004)) + } else if lower.contains("gpt-4o-mini") { + Some((0.00015, 0.0006)) + } else if lower.contains("gpt-4o") { + Some((0.0025, 0.01)) + } else if lower.contains("gpt-4.1") { + Some((0.002, 0.008)) + } else if lower.contains("gemini-2.0-flash") { + Some((0.0001, 0.0004)) + } else if lower.contains("kimi-k2") { + Some((0.0006, 0.002)) + } else { + None + } +} + +/// Estimate cost in USD for a given model and token counts. +pub fn estimate_cost(model: &str, prompt_tokens: u64, completion_tokens: u64) -> Option { + let (prompt_price, completion_price) = model_pricing(model)?; + let cost = (prompt_tokens as f64 / 1000.0) * prompt_price + + (completion_tokens as f64 / 1000.0) * completion_price; + Some(cost) +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +pub struct CostEstimate { + pub model: String, + pub prompt_tokens: u64, + pub completion_tokens: u64, + pub estimated_cost_usd: Option, +} + +#[tauri::command] +pub fn estimate_query_cost( + model: String, + prompt_tokens: u64, + completion_tokens: u64, +) -> Result { + let cost = estimate_cost(&model, prompt_tokens, completion_tokens); + Ok(CostEstimate { + model, + prompt_tokens, + completion_tokens, + estimated_cost_usd: cost, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn estimate_cost_for_known_model() { + let cost = estimate_cost("gpt-4o", 1000, 1000).unwrap(); + assert!((cost - 0.0125).abs() < 0.0001); + } + + #[test] + fn estimate_cost_for_unknown_model() { + assert!(estimate_cost("unknown-model", 1000, 1000).is_none()); + } + + #[test] + fn estimate_cost_with_provider_prefix() { + let cost = estimate_cost("openrouter/anthropic/claude-3.7-sonnet", 1000, 500).unwrap(); + assert!(cost > 0.0); + } +} diff --git a/src-tauri/src/runtime/zeroclaw/mod.rs b/src-tauri/src/runtime/zeroclaw/mod.rs index 5216aa98..06ba3c97 100644 --- a/src-tauri/src/runtime/zeroclaw/mod.rs +++ b/src-tauri/src/runtime/zeroclaw/mod.rs @@ -1,4 +1,5 @@ pub mod adapter; +pub mod cost; pub mod install_adapter; pub mod process; pub mod sanitize; diff --git a/src-tauri/src/runtime/zeroclaw/process.rs b/src-tauri/src/runtime/zeroclaw/process.rs index 9d903808..8881d39c 100644 --- a/src-tauri/src/runtime/zeroclaw/process.rs +++ b/src-tauri/src/runtime/zeroclaw/process.rs @@ -266,6 +266,43 @@ pub fn get_zeroclaw_usage_stats() -> ZeroclawUsageStats { usage_store().lock().map(|stats| *stats).unwrap_or_default() } +// --------------------------------------------------------------------------- +// Per-session usage tracking +// --------------------------------------------------------------------------- + +fn session_usage_store() -> &'static Mutex> { + static STORE: OnceLock>> = + OnceLock::new(); + STORE.get_or_init(|| Mutex::new(std::collections::HashMap::new())) +} + +pub fn record_session_usage(session_id: &str, prompt_tokens: u64, completion_tokens: u64) { + if session_id.is_empty() { + return; + } + if let Ok(mut map) = session_usage_store().lock() { + let stats = map + .entry(session_id.to_string()) + .or_insert_with(ZeroclawUsageStats::default); + stats.total_calls = stats.total_calls.saturating_add(1); + stats.usage_calls = stats.usage_calls.saturating_add(1); + stats.prompt_tokens = stats.prompt_tokens.saturating_add(prompt_tokens); + stats.completion_tokens = stats.completion_tokens.saturating_add(completion_tokens); + stats.total_tokens = stats + .total_tokens + .saturating_add(prompt_tokens.saturating_add(completion_tokens)); + stats.last_updated_ms = now_ms(); + } +} + +pub fn get_session_usage(session_id: &str) -> ZeroclawUsageStats { + session_usage_store() + .lock() + .ok() + .and_then(|map| map.get(session_id).copied()) + .unwrap_or_default() +} + fn sanitize_instance_namespace(raw: &str) -> String { let trimmed = raw.trim(); if trimmed.is_empty() { diff --git a/src/components/ModelSwitcher.tsx b/src/components/ModelSwitcher.tsx new file mode 100644 index 00000000..53fe222a --- /dev/null +++ b/src/components/ModelSwitcher.tsx @@ -0,0 +1,103 @@ +import { useEffect, useState } from "react"; +import { invoke } from "@tauri-apps/api/core"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { + Popover, + PopoverContent, + PopoverTrigger, +} from "@/components/ui/popover"; + +const AVAILABLE_MODELS = [ + "gpt-4o", + "gpt-4o-mini", + "gpt-4.1", + "claude-3.7-sonnet", + "claude-3.5-haiku", + "gemini-2.0-flash", + "kimi-k2.5", +]; + +interface ModelSwitcherProps { + sessionId: string; + defaultModel?: string; +} + +export function ModelSwitcher({ sessionId, defaultModel }: ModelSwitcherProps) { + const [override, setOverride] = useState(null); + const [open, setOpen] = useState(false); + + useEffect(() => { + if (!sessionId) return; + invoke("get_session_model_override", { sessionId }) + .then(setOverride) + .catch(() => {}); + }, [sessionId]); + + const currentModel = override ?? defaultModel ?? "auto"; + + const handleSelect = async (model: string) => { + try { + await invoke("set_session_model_override", { sessionId, model }); + setOverride(model); + } catch { + // silently ignore + } + setOpen(false); + }; + + const handleClear = async () => { + try { + await invoke("clear_session_model_override", { sessionId }); + setOverride(null); + } catch { + // silently ignore + } + setOpen(false); + }; + + return ( +
+ + + + + +
+ Switch model for this session +
+
+ {AVAILABLE_MODELS.map((model) => ( + + ))} +
+ {override && ( +
+ +
+ )} +
+
+ {override && ( + + Session override + + )} +
+ ); +} diff --git a/src/components/TokenBadge.tsx b/src/components/TokenBadge.tsx new file mode 100644 index 00000000..4869efed --- /dev/null +++ b/src/components/TokenBadge.tsx @@ -0,0 +1,77 @@ +import { useEffect, useState } from "react"; +import { invoke } from "@tauri-apps/api/core"; +import { Badge } from "@/components/ui/badge"; + +interface SessionUsageStats { + totalCalls: number; + usageCalls: number; + promptTokens: number; + completionTokens: number; + totalTokens: number; + lastUpdatedMs: number; +} + +interface CostEstimate { + model: string; + promptTokens: number; + completionTokens: number; + estimatedCostUsd: number | null; +} + +interface TokenBadgeProps { + sessionId: string; + model?: string; +} + +export function TokenBadge({ sessionId, model }: TokenBadgeProps) { + const [stats, setStats] = useState(null); + const [cost, setCost] = useState(null); + + useEffect(() => { + if (!sessionId) return; + + const fetchStats = async () => { + try { + const usage = await invoke("get_session_usage_stats", { + sessionId, + }); + setStats(usage); + + if (model && (usage.promptTokens > 0 || usage.completionTokens > 0)) { + const estimate = await invoke("estimate_query_cost", { + model, + promptTokens: usage.promptTokens, + completionTokens: usage.completionTokens, + }); + setCost(estimate.estimatedCostUsd); + } + } catch { + // silently ignore + } + }; + + fetchStats(); + const interval = setInterval(fetchStats, 5000); + return () => clearInterval(interval); + }, [sessionId, model]); + + if (!stats || stats.totalTokens === 0) return null; + + const formatTokens = (n: number) => { + if (n >= 1000000) return `${(n / 1000000).toFixed(1)}M`; + if (n >= 1000) return `${(n / 1000).toFixed(1)}k`; + return String(n); + }; + + const formatCost = (c: number) => { + if (c < 0.01) return `$${c.toFixed(4)}`; + return `$${c.toFixed(2)}`; + }; + + return ( + + 🪙 {formatTokens(stats.totalTokens)} + {cost !== null && ({formatCost(cost)})} + + ); +} diff --git a/src/pages/Doctor.tsx b/src/pages/Doctor.tsx index 52703671..cc362f5b 100644 --- a/src/pages/Doctor.tsx +++ b/src/pages/Doctor.tsx @@ -35,6 +35,8 @@ import { } from "@/components/ui/alert-dialog"; import { Skeleton } from "@/components/ui/skeleton"; import { DoctorChat } from "@/components/DoctorChat"; +import { TokenBadge } from "@/components/TokenBadge"; +import { ModelSwitcher } from "@/components/ModelSwitcher"; import { SessionAnalysisPanel } from "@/components/SessionAnalysisPanel"; import type { BackupInfo } from "@/lib/types"; import { formatTime, formatBytes } from "@/lib/utils"; @@ -1030,7 +1032,7 @@ export function Doctor({ )}
-
+
{t("doctor.engineZeroclaw")} @@ -1038,6 +1040,8 @@ export function Doctor({ {doctor.bridgeConnected ? t("doctor.bridgeConnected") : t("doctor.bridgeDisconnected")} + +